Skip to content

Commit 0e0ea4f

Browse files
committed
feat(runnable-rails): implement AIMessage metadata parity in RunnableRails
Ensure AIMessage responses from RunnableRails contain the same metadata fields (response_metadata, usage_metadata, additional_kwargs, id) as direct LLM calls, enabling consistent LangChain integration behavior.
1 parent 0c2a65e commit 0e0ea4f

File tree

7 files changed

+533
-17
lines changed

7 files changed

+533
-17
lines changed

nemoguardrails/actions/llm/utils.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from nemoguardrails.colang.v2_x.runtime.flows import InternalEvent, InternalEvents
2727
from nemoguardrails.context import (
2828
llm_call_info_var,
29+
llm_response_metadata_var,
2930
reasoning_trace_var,
3031
tool_calls_var,
3132
)
@@ -85,6 +86,7 @@ async def llm_call(
8586
response = await _invoke_with_message_list(llm, prompt, all_callbacks, stop)
8687

8788
_store_tool_calls(response)
89+
_store_response_metadata(response)
8890
return _extract_content(response)
8991

9092

@@ -173,6 +175,20 @@ def _store_tool_calls(response) -> None:
173175
tool_calls_var.set(tool_calls)
174176

175177

178+
def _store_response_metadata(response) -> None:
179+
"""Store response metadata excluding content for metadata preservation."""
180+
if hasattr(response, "model_fields"):
181+
metadata = {}
182+
for field_name in response.model_fields:
183+
if (
184+
field_name != "content"
185+
): # Exclude content since it may be modified by rails
186+
metadata[field_name] = getattr(response, field_name)
187+
llm_response_metadata_var.set(metadata)
188+
else:
189+
llm_response_metadata_var.set(None)
190+
191+
176192
def _extract_content(response) -> str:
177193
"""Extract text content from response."""
178194
if hasattr(response, "content"):
@@ -655,3 +671,15 @@ def get_and_clear_tool_calls_contextvar() -> Optional[list]:
655671
tool_calls_var.set(None)
656672
return tool_calls
657673
return None
674+
675+
676+
def get_and_clear_response_metadata_contextvar() -> Optional[dict]:
677+
"""Get the current response metadata and clear it from the context.
678+
679+
Returns:
680+
Optional[dict]: The response metadata if it exists, None otherwise.
681+
"""
682+
if metadata := llm_response_metadata_var.get():
683+
llm_response_metadata_var.set(None)
684+
return metadata
685+
return None

nemoguardrails/context.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,8 @@
4242
tool_calls_var: contextvars.ContextVar[Optional[list]] = contextvars.ContextVar(
4343
"tool_calls", default=None
4444
)
45+
46+
# The response metadata from the current LLM response.
47+
llm_response_metadata_var: contextvars.ContextVar[
48+
Optional[dict]
49+
] = contextvars.ContextVar("llm_response_metadata", default=None)

nemoguardrails/integrations/langchain/runnable_rails.py

Lines changed: 67 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -393,11 +393,21 @@ def _format_passthrough_output(self, result: Any, context: Dict[str, Any]) -> An
393393
return passthrough_output
394394

395395
def _format_chat_prompt_output(
396-
self, result: Any, tool_calls: Optional[list] = None
396+
self,
397+
result: Any,
398+
tool_calls: Optional[list] = None,
399+
metadata: Optional[dict] = None,
397400
) -> AIMessage:
398401
"""Format output for ChatPromptValue input."""
399402
content = self._extract_content_from_result(result)
400-
if tool_calls:
403+
404+
if metadata and isinstance(metadata, dict):
405+
metadata_copy = metadata.copy()
406+
metadata_copy.pop("content", None)
407+
if tool_calls:
408+
metadata_copy["tool_calls"] = tool_calls
409+
return AIMessage(content=content, **metadata_copy)
410+
elif tool_calls:
401411
return AIMessage(content=content, tool_calls=tool_calls)
402412
return AIMessage(content=content)
403413

@@ -406,11 +416,21 @@ def _format_string_prompt_output(self, result: Any) -> str:
406416
return self._extract_content_from_result(result)
407417

408418
def _format_message_output(
409-
self, result: Any, tool_calls: Optional[list] = None
419+
self,
420+
result: Any,
421+
tool_calls: Optional[list] = None,
422+
metadata: Optional[dict] = None,
410423
) -> AIMessage:
411424
"""Format output for BaseMessage input types."""
412425
content = self._extract_content_from_result(result)
413-
if tool_calls:
426+
427+
if metadata and isinstance(metadata, dict):
428+
metadata_copy = metadata.copy()
429+
metadata_copy.pop("content", None)
430+
if tool_calls:
431+
metadata_copy["tool_calls"] = tool_calls
432+
return AIMessage(content=content, **metadata_copy)
433+
elif tool_calls:
414434
return AIMessage(content=content, tool_calls=tool_calls)
415435
return AIMessage(content=content)
416436

@@ -434,25 +454,50 @@ def _format_dict_output_for_dict_message_list(
434454
}
435455

436456
def _format_dict_output_for_base_message_list(
437-
self, result: Any, output_key: str, tool_calls: Optional[list] = None
457+
self,
458+
result: Any,
459+
output_key: str,
460+
tool_calls: Optional[list] = None,
461+
metadata: Optional[dict] = None,
438462
) -> Dict[str, Any]:
439463
"""Format dict output when user input was a list of BaseMessage objects."""
440464
content = self._extract_content_from_result(result)
441-
if tool_calls:
465+
466+
if metadata and isinstance(metadata, dict):
467+
metadata_copy = metadata.copy()
468+
metadata_copy.pop("content", None)
469+
if tool_calls:
470+
metadata_copy["tool_calls"] = tool_calls
471+
return {output_key: AIMessage(content=content, **metadata_copy)}
472+
elif tool_calls:
442473
return {output_key: AIMessage(content=content, tool_calls=tool_calls)}
443474
return {output_key: AIMessage(content=content)}
444475

445476
def _format_dict_output_for_base_message(
446-
self, result: Any, output_key: str, tool_calls: Optional[list] = None
477+
self,
478+
result: Any,
479+
output_key: str,
480+
tool_calls: Optional[list] = None,
481+
metadata: Optional[dict] = None,
447482
) -> Dict[str, Any]:
448483
"""Format dict output when user input was a BaseMessage."""
449484
content = self._extract_content_from_result(result)
450-
if tool_calls:
485+
486+
if metadata:
487+
metadata_copy = metadata.copy()
488+
if tool_calls:
489+
metadata_copy["tool_calls"] = tool_calls
490+
return {output_key: AIMessage(content=content, **metadata_copy)}
491+
elif tool_calls:
451492
return {output_key: AIMessage(content=content, tool_calls=tool_calls)}
452493
return {output_key: AIMessage(content=content)}
453494

454495
def _format_dict_output(
455-
self, input_dict: dict, result: Any, tool_calls: Optional[list] = None
496+
self,
497+
input_dict: dict,
498+
result: Any,
499+
tool_calls: Optional[list] = None,
500+
metadata: Optional[dict] = None,
456501
) -> Dict[str, Any]:
457502
"""Format output for dictionary input."""
458503
output_key = self.passthrough_bot_output_key
@@ -471,13 +516,13 @@ def _format_dict_output(
471516
)
472517
elif all(isinstance(msg, BaseMessage) for msg in user_input):
473518
return self._format_dict_output_for_base_message_list(
474-
result, output_key, tool_calls
519+
result, output_key, tool_calls, metadata
475520
)
476521
else:
477522
return {output_key: result}
478523
elif isinstance(user_input, BaseMessage):
479524
return self._format_dict_output_for_base_message(
480-
result, output_key, tool_calls
525+
result, output_key, tool_calls, metadata
481526
)
482527

483528
# Generic fallback for dictionaries
@@ -490,6 +535,7 @@ def _format_output(
490535
result: Any,
491536
context: Dict[str, Any],
492537
tool_calls: Optional[list] = None,
538+
metadata: Optional[dict] = None,
493539
) -> Any:
494540
"""Format the output based on the input type and rails result.
495541
@@ -512,17 +558,17 @@ def _format_output(
512558
return self._format_passthrough_output(result, context)
513559

514560
if isinstance(input, ChatPromptValue):
515-
return self._format_chat_prompt_output(result, tool_calls)
561+
return self._format_chat_prompt_output(result, tool_calls, metadata)
516562
elif isinstance(input, StringPromptValue):
517563
return self._format_string_prompt_output(result)
518564
elif isinstance(input, (HumanMessage, AIMessage, BaseMessage)):
519-
return self._format_message_output(result, tool_calls)
565+
return self._format_message_output(result, tool_calls, metadata)
520566
elif isinstance(input, list) and all(
521567
isinstance(msg, BaseMessage) for msg in input
522568
):
523-
return self._format_message_output(result, tool_calls)
569+
return self._format_message_output(result, tool_calls, metadata)
524570
elif isinstance(input, dict):
525-
return self._format_dict_output(input, result, tool_calls)
571+
return self._format_dict_output(input, result, tool_calls, metadata)
526572
elif isinstance(input, str):
527573
return self._format_string_prompt_output(result)
528574
else:
@@ -669,7 +715,9 @@ def _full_rails_invoke(
669715
result = result[0]
670716

671717
# Format and return the output based in input type
672-
return self._format_output(input, result, context, res.tool_calls)
718+
return self._format_output(
719+
input, result, context, res.tool_calls, res.llm_metadata
720+
)
673721

674722
async def ainvoke(
675723
self,
@@ -731,7 +779,9 @@ async def _full_rails_ainvoke(
731779
result = res.response
732780

733781
# Format and return the output based on input type
734-
return self._format_output(input, result, context, res.tool_calls)
782+
return self._format_output(
783+
input, result, context, res.tool_calls, res.llm_metadata
784+
)
735785

736786
def stream(
737787
self,

nemoguardrails/rails/llm/llmrails.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from nemoguardrails.actions.llm.generation import LLMGenerationActions
3434
from nemoguardrails.actions.llm.utils import (
3535
get_and_clear_reasoning_trace_contextvar,
36+
get_and_clear_response_metadata_contextvar,
3637
get_and_clear_tool_calls_contextvar,
3738
get_colang_history,
3839
)
@@ -1086,6 +1087,7 @@ async def generate_async(
10861087
options.log.internal_events = True
10871088

10881089
tool_calls = get_and_clear_tool_calls_contextvar()
1090+
llm_metadata = get_and_clear_response_metadata_contextvar()
10891091

10901092
# If we have generation options, we prepare a GenerationResponse instance.
10911093
if options:
@@ -1106,6 +1108,9 @@ async def generate_async(
11061108
if tool_calls:
11071109
res.tool_calls = tool_calls
11081110

1111+
if llm_metadata:
1112+
res.llm_metadata = llm_metadata
1113+
11091114
if self.config.colang_version == "1.0":
11101115
# If output variables are specified, we extract their values
11111116
if options.output_vars:

nemoguardrails/rails/llm/options.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,10 @@ class GenerationResponse(BaseModel):
412412
default=None,
413413
description="Tool calls extracted from the LLM response, if any.",
414414
)
415+
llm_metadata: Optional[dict] = Field(
416+
default=None,
417+
description="Metadata from the LLM response (additional_kwargs, response_metadata, usage_metadata, etc.)",
418+
)
415419

416420

417421
if __name__ == "__main__":

0 commit comments

Comments
 (0)