2828 List ,
2929 Optional ,
3030 Set ,
31+ Tuple ,
3132 Type ,
3233 Union ,
3334)
7879if TYPE_CHECKING :
7980 from camel .terminators import ResponseTerminator
8081
81-
8282logger = logging .getLogger (__name__ )
8383
8484# AgentOps decorator setting
@@ -110,10 +110,17 @@ class ChatAgent(BaseAgent):
110110
111111 Args:
112112 system_message (Union[BaseMessage, str], optional): The system message
113- for the chat agent.
114- model (BaseModelBackend, optional): The model backend to use for
115- generating responses. (default: :obj:`ModelPlatformType.DEFAULT`
116- with `ModelType.DEFAULT`)
113+ for the chat agent. (default: :obj:`None`)
114+ model (Union[BaseModelBackend, Tuple[str, str], str, ModelType,
115+ Tuple[ModelPlatformType, ModelType], List[BaseModelBackend],
116+ List[str], List[ModelType], List[Tuple[str, str]],
117+ List[Tuple[ModelPlatformType, ModelType]]], optional):
118+ The model backend(s) to use. Can be a single instance,
119+ a specification (string, enum, tuple), or a list of instances
120+ or specifications to be managed by `ModelManager`. If a list of
121+ specifications (not `BaseModelBackend` instances) is provided,
122+ they will be instantiated using `ModelFactory`. (default:
123+ :obj:`ModelPlatformType.DEFAULT` with `ModelType.DEFAULT`)
117124 memory (AgentMemory, optional): The agent memory for managing chat
118125 messages. If `None`, a :obj:`ChatHistoryMemory` will be used.
119126 (default: :obj:`None`)
@@ -150,7 +157,18 @@ def __init__(
150157 self ,
151158 system_message : Optional [Union [BaseMessage , str ]] = None ,
152159 model : Optional [
153- Union [BaseModelBackend , List [BaseModelBackend ]]
160+ Union [
161+ BaseModelBackend ,
162+ Tuple [str , str ],
163+ str ,
164+ ModelType ,
165+ Tuple [ModelPlatformType , ModelType ],
166+ List [BaseModelBackend ],
167+ List [str ],
168+ List [ModelType ],
169+ List [Tuple [str , str ]],
170+ List [Tuple [ModelPlatformType , ModelType ]],
171+ ]
154172 ] = None ,
155173 memory : Optional [AgentMemory ] = None ,
156174 message_window_size : Optional [int ] = None ,
@@ -165,19 +183,14 @@ def __init__(
165183 single_iteration : bool = False ,
166184 agent_id : Optional [str ] = None ,
167185 ) -> None :
168- # Set up model backend
186+ # Resolve model backends and set up model manager
187+ resolved_models = self ._resolve_models (model )
169188 self .model_backend = ModelManager (
170- (
171- model
172- if model is not None
173- else ModelFactory .create (
174- model_platform = ModelPlatformType .DEFAULT ,
175- model_type = ModelType .DEFAULT ,
176- )
177- ),
189+ resolved_models ,
178190 scheduling_strategy = scheduling_strategy ,
179191 )
180192 self .model_type = self .model_backend .model_type
193+
181194 # Assign unique ID
182195 self .agent_id = agent_id if agent_id else str (uuid .uuid4 ())
183196
@@ -247,6 +260,137 @@ def reset(self):
247260 for terminator in self .response_terminators :
248261 terminator .reset ()
249262
263+ def _resolve_models (
264+ self ,
265+ model : Optional [
266+ Union [
267+ BaseModelBackend ,
268+ Tuple [str , str ],
269+ str ,
270+ ModelType ,
271+ Tuple [ModelPlatformType , ModelType ],
272+ List [BaseModelBackend ],
273+ List [str ],
274+ List [ModelType ],
275+ List [Tuple [str , str ]],
276+ List [Tuple [ModelPlatformType , ModelType ]],
277+ ]
278+ ],
279+ ) -> Union [BaseModelBackend , List [BaseModelBackend ]]:
280+ r"""Resolves model specifications into model backend instances.
281+
282+ This method handles various input formats for model specifications and
283+ returns the appropriate model backend(s).
284+
285+ Args:
286+ model: Model specification in various formats including single
287+ model, list of models, or model type specifications.
288+
289+ Returns:
290+ Union[BaseModelBackend, List[BaseModelBackend]]: Resolved model
291+ backend(s).
292+
293+ Raises:
294+ TypeError: If the model specification format is not supported.
295+ """
296+ if model is None :
297+ # Default single model if none provided
298+ return ModelFactory .create (
299+ model_platform = ModelPlatformType .DEFAULT ,
300+ model_type = ModelType .DEFAULT ,
301+ )
302+ elif isinstance (model , BaseModelBackend ):
303+ # Already a single pre-instantiated model
304+ return model
305+ elif isinstance (model , list ):
306+ return self ._resolve_model_list (model )
307+ elif isinstance (model , (ModelType , str )):
308+ # Single string or ModelType -> use default platform
309+ model_platform = ModelPlatformType .DEFAULT
310+ model_type = model
311+ logger .warning (
312+ f"Model type '{ model_type } ' provided without a platform. "
313+ f"Using platform '{ model_platform } '. Note: platform "
314+ "is not automatically inferred based on model type."
315+ )
316+ return ModelFactory .create (
317+ model_platform = model_platform ,
318+ model_type = model_type ,
319+ )
320+ elif isinstance (model , tuple ) and len (model ) == 2 :
321+ # Single tuple (platform, type)
322+ model_platform , model_type = model # type: ignore[assignment]
323+ return ModelFactory .create (
324+ model_platform = model_platform ,
325+ model_type = model_type ,
326+ )
327+ else :
328+ raise TypeError (
329+ f"Unsupported type for model parameter: { type (model )} "
330+ )
331+
332+ def _resolve_model_list (
333+ self , model_list : list
334+ ) -> Union [BaseModelBackend , List [BaseModelBackend ]]:
335+ r"""Resolves a list of model specifications into model backend
336+ instances.
337+
338+ Args:
339+ model_list (list): List of model specifications in various formats.
340+
341+ Returns:
342+ Union[BaseModelBackend, List[BaseModelBackend]]: Resolved model
343+ backend(s).
344+
345+ Raises:
346+ TypeError: If the list elements format is not supported.
347+ """
348+ if not model_list : # Handle empty list
349+ logger .warning (
350+ "Empty list provided for model, using default model."
351+ )
352+ return ModelFactory .create (
353+ model_platform = ModelPlatformType .DEFAULT ,
354+ model_type = ModelType .DEFAULT ,
355+ )
356+ elif isinstance (model_list [0 ], BaseModelBackend ):
357+ # List of pre-instantiated models
358+ return model_list # type: ignore[return-value]
359+ elif isinstance (model_list [0 ], (str , ModelType )):
360+ # List of strings or ModelTypes -> use default platform
361+ model_platform = ModelPlatformType .DEFAULT
362+ logger .warning (
363+ f"List of model types { model_list } provided without "
364+ f"platforms. Using platform '{ model_platform } ' for all. "
365+ "Note: platform is not automatically inferred based on "
366+ "model type."
367+ )
368+ resolved_models_list = []
369+ for model_type_item in model_list :
370+ resolved_models_list .append (
371+ ModelFactory .create (
372+ model_platform = model_platform ,
373+ model_type = model_type_item , # type: ignore[arg-type]
374+ )
375+ )
376+ return resolved_models_list
377+ elif isinstance (model_list [0 ], tuple ) and len (model_list [0 ]) == 2 :
378+ # List of tuples (platform, type)
379+ resolved_models_list = []
380+ for model_spec in model_list :
381+ platform , type_ = model_spec [0 ], model_spec [1 ] # type: ignore[index]
382+ resolved_models_list .append (
383+ ModelFactory .create (
384+ model_platform = platform , model_type = type_
385+ )
386+ )
387+ return resolved_models_list
388+ else :
389+ raise TypeError (
390+ "Unsupported type for list elements in model: "
391+ f"{ type (model_list [0 ])} "
392+ )
393+
250394 @property
251395 def system_message (self ) -> Optional [BaseMessage ]:
252396 r"""Returns the system message for the agent."""
0 commit comments