diff --git a/src/agentlab/agents/dynamic_prompting.py b/src/agentlab/agents/dynamic_prompting.py index 92ad25b9..b22c2815 100644 --- a/src/agentlab/agents/dynamic_prompting.py +++ b/src/agentlab/agents/dynamic_prompting.py @@ -3,7 +3,7 @@ import platform import time from copy import copy, deepcopy -from dataclasses import asdict, dataclass +from dataclasses import asdict, dataclass, fields from textwrap import dedent from typing import Literal from warnings import warn @@ -34,14 +34,21 @@ def asdict(self): return asdict(self) @classmethod - def from_dict(self, flags_dict): + def from_dict(cls, flags_dict): """Helper for JSON serializable requirement.""" - if isinstance(flags_dict, ObsFlags): + if isinstance(flags_dict, cls): return flags_dict if not isinstance(flags_dict, dict): - raise ValueError(f"Unregcognized type for flags_dict of type {type(flags_dict)}.") - return ObsFlags(**flags_dict) + raise ValueError(f"Unrecognized type for flags_dict of type {type(flags_dict)}.") + + # Get the names of the fields of the dataclass + class_fields = {f.name for f in fields(cls)} + + # Filter the dictionary to only include keys that are fields of the class + filtered_dict = {k: v for k, v in flags_dict.items() if k in class_fields} + + return cls(**filtered_dict) @dataclass