@@ -393,11 +393,21 @@ def _format_passthrough_output(self, result: Any, context: Dict[str, Any]) -> An
393
393
return passthrough_output
394
394
395
395
def _format_chat_prompt_output (
396
- self , result : Any , tool_calls : Optional [list ] = None
396
+ self ,
397
+ result : Any ,
398
+ tool_calls : Optional [list ] = None ,
399
+ metadata : Optional [dict ] = None ,
397
400
) -> AIMessage :
398
401
"""Format output for ChatPromptValue input."""
399
402
content = self ._extract_content_from_result (result )
400
- if tool_calls :
403
+
404
+ if metadata and isinstance (metadata , dict ):
405
+ metadata_copy = metadata .copy ()
406
+ metadata_copy .pop ("content" , None )
407
+ if tool_calls :
408
+ metadata_copy ["tool_calls" ] = tool_calls
409
+ return AIMessage (content = content , ** metadata_copy )
410
+ elif tool_calls :
401
411
return AIMessage (content = content , tool_calls = tool_calls )
402
412
return AIMessage (content = content )
403
413
@@ -406,11 +416,21 @@ def _format_string_prompt_output(self, result: Any) -> str:
406
416
return self ._extract_content_from_result (result )
407
417
408
418
def _format_message_output (
409
- self , result : Any , tool_calls : Optional [list ] = None
419
+ self ,
420
+ result : Any ,
421
+ tool_calls : Optional [list ] = None ,
422
+ metadata : Optional [dict ] = None ,
410
423
) -> AIMessage :
411
424
"""Format output for BaseMessage input types."""
412
425
content = self ._extract_content_from_result (result )
413
- if tool_calls :
426
+
427
+ if metadata and isinstance (metadata , dict ):
428
+ metadata_copy = metadata .copy ()
429
+ metadata_copy .pop ("content" , None )
430
+ if tool_calls :
431
+ metadata_copy ["tool_calls" ] = tool_calls
432
+ return AIMessage (content = content , ** metadata_copy )
433
+ elif tool_calls :
414
434
return AIMessage (content = content , tool_calls = tool_calls )
415
435
return AIMessage (content = content )
416
436
@@ -434,25 +454,50 @@ def _format_dict_output_for_dict_message_list(
434
454
}
435
455
436
456
def _format_dict_output_for_base_message_list (
437
- self , result : Any , output_key : str , tool_calls : Optional [list ] = None
457
+ self ,
458
+ result : Any ,
459
+ output_key : str ,
460
+ tool_calls : Optional [list ] = None ,
461
+ metadata : Optional [dict ] = None ,
438
462
) -> Dict [str , Any ]:
439
463
"""Format dict output when user input was a list of BaseMessage objects."""
440
464
content = self ._extract_content_from_result (result )
441
- if tool_calls :
465
+
466
+ if metadata and isinstance (metadata , dict ):
467
+ metadata_copy = metadata .copy ()
468
+ metadata_copy .pop ("content" , None )
469
+ if tool_calls :
470
+ metadata_copy ["tool_calls" ] = tool_calls
471
+ return {output_key : AIMessage (content = content , ** metadata_copy )}
472
+ elif tool_calls :
442
473
return {output_key : AIMessage (content = content , tool_calls = tool_calls )}
443
474
return {output_key : AIMessage (content = content )}
444
475
445
476
def _format_dict_output_for_base_message (
446
- self , result : Any , output_key : str , tool_calls : Optional [list ] = None
477
+ self ,
478
+ result : Any ,
479
+ output_key : str ,
480
+ tool_calls : Optional [list ] = None ,
481
+ metadata : Optional [dict ] = None ,
447
482
) -> Dict [str , Any ]:
448
483
"""Format dict output when user input was a BaseMessage."""
449
484
content = self ._extract_content_from_result (result )
450
- if tool_calls :
485
+
486
+ if metadata :
487
+ metadata_copy = metadata .copy ()
488
+ if tool_calls :
489
+ metadata_copy ["tool_calls" ] = tool_calls
490
+ return {output_key : AIMessage (content = content , ** metadata_copy )}
491
+ elif tool_calls :
451
492
return {output_key : AIMessage (content = content , tool_calls = tool_calls )}
452
493
return {output_key : AIMessage (content = content )}
453
494
454
495
def _format_dict_output (
455
- self , input_dict : dict , result : Any , tool_calls : Optional [list ] = None
496
+ self ,
497
+ input_dict : dict ,
498
+ result : Any ,
499
+ tool_calls : Optional [list ] = None ,
500
+ metadata : Optional [dict ] = None ,
456
501
) -> Dict [str , Any ]:
457
502
"""Format output for dictionary input."""
458
503
output_key = self .passthrough_bot_output_key
@@ -471,13 +516,13 @@ def _format_dict_output(
471
516
)
472
517
elif all (isinstance (msg , BaseMessage ) for msg in user_input ):
473
518
return self ._format_dict_output_for_base_message_list (
474
- result , output_key , tool_calls
519
+ result , output_key , tool_calls , metadata
475
520
)
476
521
else :
477
522
return {output_key : result }
478
523
elif isinstance (user_input , BaseMessage ):
479
524
return self ._format_dict_output_for_base_message (
480
- result , output_key , tool_calls
525
+ result , output_key , tool_calls , metadata
481
526
)
482
527
483
528
# Generic fallback for dictionaries
@@ -490,6 +535,7 @@ def _format_output(
490
535
result : Any ,
491
536
context : Dict [str , Any ],
492
537
tool_calls : Optional [list ] = None ,
538
+ metadata : Optional [dict ] = None ,
493
539
) -> Any :
494
540
"""Format the output based on the input type and rails result.
495
541
@@ -512,17 +558,17 @@ def _format_output(
512
558
return self ._format_passthrough_output (result , context )
513
559
514
560
if isinstance (input , ChatPromptValue ):
515
- return self ._format_chat_prompt_output (result , tool_calls )
561
+ return self ._format_chat_prompt_output (result , tool_calls , metadata )
516
562
elif isinstance (input , StringPromptValue ):
517
563
return self ._format_string_prompt_output (result )
518
564
elif isinstance (input , (HumanMessage , AIMessage , BaseMessage )):
519
- return self ._format_message_output (result , tool_calls )
565
+ return self ._format_message_output (result , tool_calls , metadata )
520
566
elif isinstance (input , list ) and all (
521
567
isinstance (msg , BaseMessage ) for msg in input
522
568
):
523
- return self ._format_message_output (result , tool_calls )
569
+ return self ._format_message_output (result , tool_calls , metadata )
524
570
elif isinstance (input , dict ):
525
- return self ._format_dict_output (input , result , tool_calls )
571
+ return self ._format_dict_output (input , result , tool_calls , metadata )
526
572
elif isinstance (input , str ):
527
573
return self ._format_string_prompt_output (result )
528
574
else :
@@ -669,7 +715,9 @@ def _full_rails_invoke(
669
715
result = result [0 ]
670
716
671
717
# Format and return the output based in input type
672
- return self ._format_output (input , result , context , res .tool_calls )
718
+ return self ._format_output (
719
+ input , result , context , res .tool_calls , res .llm_metadata
720
+ )
673
721
674
722
async def ainvoke (
675
723
self ,
@@ -731,7 +779,9 @@ async def _full_rails_ainvoke(
731
779
result = res .response
732
780
733
781
# Format and return the output based on input type
734
- return self ._format_output (input , result , context , res .tool_calls )
782
+ return self ._format_output (
783
+ input , result , context , res .tool_calls , res .llm_metadata
784
+ )
735
785
736
786
def stream (
737
787
self ,
0 commit comments