Skip to content

Commit a16cd4f

Browse files
authored
Refa: add result to callback for agent tool use. (#9137)
### What problem does this PR solve? ### Type of change - [x] Refactoring
1 parent c5823a3 commit a16cd4f

26 files changed

+10870
-892
lines changed

agent/canvas.py

Lines changed: 50 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -252,15 +252,6 @@ def _node_finished(cpn_obj):
252252
"created_at": cpn_obj.output("_created_time"),
253253
})
254254

255-
def _append_path(cpn_id):
256-
if self.path[-1] == cpn_id:
257-
return
258-
self.path.append(cpn_id)
259-
260-
def _extend_path(cpn_ids):
261-
for cpn_id in cpn_ids:
262-
_append_path(cpn_id)
263-
264255
self.error = ""
265256
idx = len(self.path) - 1
266257
partials = []
@@ -279,10 +270,11 @@ def _extend_path(cpn_ids):
279270
# post processing of components invocation
280271
for i in range(idx, to):
281272
cpn = self.get_component(self.path[i])
282-
if cpn["obj"].component_name.lower() == "message":
283-
if isinstance(cpn["obj"].output("content"), partial):
273+
cpn_obj = self.get_component_obj(self.path[i])
274+
if cpn_obj.component_name.lower() == "message":
275+
if isinstance(cpn_obj.output("content"), partial):
284276
_m = ""
285-
for m in cpn["obj"].output("content")():
277+
for m in cpn_obj.output("content")():
286278
if not m:
287279
continue
288280
if m == "<think>":
@@ -292,48 +284,65 @@ def _extend_path(cpn_ids):
292284
else:
293285
yield decorate("message", {"content": m})
294286
_m += m
295-
cpn["obj"].set_output("content", _m)
287+
cpn_obj.set_output("content", _m)
296288
else:
297-
yield decorate("message", {"content": cpn["obj"].output("content")})
289+
yield decorate("message", {"content": cpn_obj.output("content")})
298290
yield decorate("message_end", {"reference": self.get_reference()})
299291

300292
while partials:
301-
_cpn = self.get_component(partials[0])
302-
if isinstance(_cpn["obj"].output("content"), partial):
293+
_cpn_obj = self.get_component_obj(partials[0])
294+
if isinstance(_cpn_obj.output("content"), partial):
303295
break
304-
yield _node_finished(_cpn["obj"])
296+
yield _node_finished(_cpn_obj)
305297
partials.pop(0)
306298

307-
if cpn["obj"].error():
308-
ex = cpn["obj"].exception_handler()
309-
if ex and ex["comment"]:
310-
yield decorate("message", {"content": ex["comment"]})
311-
yield decorate("message_end", {})
299+
other_branch = False
300+
if cpn_obj.error():
301+
ex = cpn_obj.exception_handler()
312302
if ex and ex["goto"]:
313-
self.path.append(ex["goto"])
314-
elif not ex or not ex["default_value"]:
315-
self.error = cpn["obj"].error()
303+
self.path.extend(ex["goto"])
304+
other_branch = True
305+
elif ex and ex["default_value"]:
306+
yield decorate("message", {"content": ex["default_value"]})
307+
yield decorate("message_end", {})
308+
else:
309+
self.error = cpn_obj.error()
316310

317-
if cpn["obj"].component_name.lower() != "iteration":
318-
if isinstance(cpn["obj"].output("content"), partial):
311+
if cpn_obj.component_name.lower() != "iteration":
312+
if isinstance(cpn_obj.output("content"), partial):
319313
if self.error:
320-
cpn["obj"].set_output("content", None)
321-
yield _node_finished(cpn["obj"])
314+
cpn_obj.set_output("content", None)
315+
yield _node_finished(cpn_obj)
322316
else:
323317
partials.append(self.path[i])
324318
else:
325-
yield _node_finished(cpn["obj"])
326-
327-
if cpn["obj"].component_name.lower() == "iterationitem" and cpn["obj"].end():
328-
iter = cpn["obj"].get_parent()
319+
yield _node_finished(cpn_obj)
320+
321+
def _append_path(cpn_id):
322+
nonlocal other_branch
323+
if other_branch:
324+
return
325+
if self.path[-1] == cpn_id:
326+
return
327+
self.path.append(cpn_id)
328+
329+
def _extend_path(cpn_ids):
330+
nonlocal other_branch
331+
if other_branch:
332+
return
333+
for cpn_id in cpn_ids:
334+
_append_path(cpn_id)
335+
336+
if cpn_obj.component_name.lower() == "iterationitem" and cpn_obj.end():
337+
iter = cpn_obj.get_parent()
329338
yield _node_finished(iter)
330339
_extend_path(self.get_component(cpn["parent_id"])["downstream"])
331-
elif cpn["obj"].component_name.lower() in ["categorize", "switch"]:
332-
_extend_path(cpn["obj"].output("_next"))
333-
elif cpn["obj"].component_name.lower() == "iteration":
334-
_append_path(cpn["obj"].get_start())
335-
elif not cpn["downstream"] and cpn["obj"].get_parent():
336-
_append_path(cpn["obj"].get_parent().get_start())
340+
elif cpn_obj.component_name.lower() in ["categorize", "switch"]:
341+
_extend_path(cpn_obj.output("_next"))
342+
elif cpn_obj.component_name.lower() == "iteration":
343+
_append_path(cpn_obj.get_start())
344+
elif not cpn["downstream"] and cpn_obj.get_parent():
345+
_append_path(cpn_obj.get_parent().get_start())
337346
else:
338347
_extend_path(cpn["downstream"])
339348

@@ -342,13 +351,13 @@ def _extend_path(cpn_ids):
342351
break
343352
idx = to
344353

345-
if any([self.get_component(c)["obj"].component_name.lower() == "userfillup" for c in self.path[idx:]]):
354+
if any([self.get_component_obj(c).component_name.lower() == "userfillup" for c in self.path[idx:]]):
346355
path = [c for c in self.path[idx:] if self.get_component(c)["obj"].component_name.lower() == "userfillup"]
347356
path.extend([c for c in self.path[idx:] if self.get_component(c)["obj"].component_name.lower() != "userfillup"])
348357
another_inputs = {}
349358
tips = ""
350359
for c in path:
351-
o = self.get_component(c)["obj"]
360+
o = self.get_component_obj(c)
352361
if o.component_name.lower() == "userfillup":
353362
another_inputs.update(o.get_input_elements())
354363
if o.get_param("enable_tips"):

agent/component/agent_with_tools.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,8 @@ def _invoke(self, **kwargs):
157157
prompt, msg = self._prepare_prompt_variables()
158158

159159
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
160-
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not self._param.output_structure:
160+
ex = self.exception_handler()
161+
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not self._param.output_structure and not (ex and ex["goto"]):
161162
self.set_output("content", partial(self.stream_output_with_tools, prompt, msg))
162163
return
163164

@@ -169,7 +170,10 @@ def _invoke(self, **kwargs):
169170

170171
if ans.find("**ERROR**") >= 0:
171172
logging.error(f"Agent._chat got error. response: {ans}")
172-
self.set_output("_ERROR", ans)
173+
if self.get_exception_default_value():
174+
self.set_output("content", self.get_exception_default_value())
175+
else:
176+
self.set_output("_ERROR", ans)
173177
return
174178

175179
self.set_output("content", ans)
@@ -182,6 +186,12 @@ def stream_output_with_tools(self, prompt, msg):
182186
answer_without_toolcall = ""
183187
use_tools = []
184188
for delta_ans,_ in self._react_with_tools_streamly(msg, use_tools):
189+
if delta_ans.find("**ERROR**") >= 0:
190+
if self.get_exception_default_value():
191+
self.set_output("content", self.get_exception_default_value())
192+
yield self.get_exception_default_value()
193+
else:
194+
self.set_output("_ERROR", delta_ans)
185195
answer_without_toolcall += delta_ans
186196
yield delta_ans
187197

@@ -204,8 +214,8 @@ def _react_with_tools_streamly(self, history: list[dict], use_tools):
204214
hist = deepcopy(history)
205215
last_calling = ""
206216
if len(hist) > 3:
207-
self.callback("Multi-turn conversation optimization", {}, " running ...")
208217
user_request = full_question(messages=history, chat_mdl=self.chat_mdl)
218+
self.callback("Multi-turn conversation optimization", {}, user_request)
209219
else:
210220
user_request = history[-1]["content"]
211221

@@ -241,9 +251,6 @@ def complete():
241251
cited = True
242252
yield "", token_count
243253

244-
if not cited and need2cite:
245-
self.callback("gen_citations", {}, " running ...")
246-
247254
_hist = hist
248255
if len(hist) > 12:
249256
_hist = [hist[0], hist[1], *hist[-10:]]
@@ -255,17 +262,21 @@ def complete():
255262
if not need2cite or cited:
256263
return
257264

265+
txt = ""
258266
for delta_ans in self._gen_citations(entire_txt):
259267
yield delta_ans, 0
268+
txt += delta_ans
269+
270+
self.callback("gen_citations", {}, txt)
260271

261272
def append_user_content(hist, content):
262273
if hist[-1]["role"] == "user":
263274
hist[-1]["content"] += content
264275
else:
265276
hist.append({"role": "user", "content": content})
266277

267-
self.callback("analyze_task", {}, " running ...")
268278
task_desc = analyze_task(self.chat_mdl, user_request, tool_metas)
279+
self.callback("analyze_task", {}, task_desc)
269280
for _ in range(self._param.max_rounds + 1):
270281
response, tk = next_step(self.chat_mdl, hist, tool_metas, task_desc)
271282
# self.callback("next_step", {}, str(response)[:256]+"...")

agent/component/base.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ def __init__(self):
4444
self.delay_after_error = 2.0
4545
self.exception_method = None
4646
self.exception_default_value = None
47-
self.exception_comment = None
4847
self.exception_goto = None
4948
self.debug_inputs = {}
5049

@@ -97,6 +96,14 @@ def __str__(self):
9796
def as_dict(self):
9897
def _recursive_convert_obj_to_dict(obj):
9998
ret_dict = {}
99+
if isinstance(obj, dict):
100+
for k,v in obj.items():
101+
if isinstance(v, dict) or (v and type(v).__name__ not in dir(builtins)):
102+
ret_dict[k] = _recursive_convert_obj_to_dict(v)
103+
else:
104+
ret_dict[k] = v
105+
return ret_dict
106+
100107
for attr_name in list(obj.__dict__):
101108
if attr_name in [_FEEDED_DEPRECATED_PARAMS, _DEPRECATED_PARAMS, _USER_FEEDED_PARAMS, _IS_RAW_CONF]:
102109
continue
@@ -105,7 +112,7 @@ def _recursive_convert_obj_to_dict(obj):
105112
if isinstance(attr, pd.DataFrame):
106113
ret_dict[attr_name] = attr.to_dict()
107114
continue
108-
if attr and type(attr).__name__ not in dir(builtins):
115+
if isinstance(attr, dict) or (attr and type(attr).__name__ not in dir(builtins)):
109116
ret_dict[attr_name] = _recursive_convert_obj_to_dict(attr)
110117
else:
111118
ret_dict[attr_name] = attr
@@ -415,7 +422,10 @@ def invoke(self, **kwargs) -> dict[str, Any]:
415422
try:
416423
self._invoke(**kwargs)
417424
except Exception as e:
418-
self._param.outputs["_ERROR"] = {"value": str(e)}
425+
if self.get_exception_default_value():
426+
self.set_exception_default_value()
427+
else:
428+
self.set_output("_ERROR", str(e))
419429
logging.exception(e)
420430
self._param.debug_inputs = {}
421431
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
@@ -427,7 +437,7 @@ def _invoke(self, **kwargs):
427437

428438
def output(self, var_nm: str=None) -> Union[dict[str, Any], Any]:
429439
if var_nm:
430-
return self._param.outputs.get(var_nm, {}).get("value")
440+
return self._param.outputs.get(var_nm, {}).get("value", "")
431441
return {k: o.get("value") for k,o in self._param.outputs.items()}
432442

433443
def set_output(self, key: str, value: Any):
@@ -520,7 +530,7 @@ def get_upstream(self) -> List[str]:
520530
def string_format(content: str, kv: dict[str, str]) -> str:
521531
for n, v in kv.items():
522532
content = re.sub(
523-
r"\{%s\}" % re.escape(n), re.escape(v), content
533+
r"\{%s\}" % re.escape(n), v, content
524534
)
525535
return content
526536

@@ -529,13 +539,17 @@ def exception_handler(self):
529539
return
530540
return {
531541
"goto": self._param.exception_goto,
532-
"comment": self._param.exception_comment,
533542
"default_value": self._param.exception_default_value
534543
}
535544

536545
def get_exception_default_value(self):
546+
if self._param.exception_method != "comment":
547+
return ""
537548
return self._param.exception_default_value
538549

550+
def set_exception_default_value(self):
551+
self.set_output("result", self.get_exception_default_value())
552+
539553
@abstractmethod
540554
def thoughts(self) -> str:
541555
...

agent/component/begin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,4 @@ def _invoke(self, **kwargs):
4646
self.set_input_value(k, v)
4747

4848
def thoughts(self) -> str:
49-
return "☕ Here we go..."
49+
return ""

0 commit comments

Comments
 (0)