Skip to content

Commit b5511f7

Browse files
fix(bedrock): return identical Bedrock object
1 parent 2bcd306 commit b5511f7

File tree

1 file changed

+23
-33
lines changed

1 file changed

+23
-33
lines changed

src/openlayer/lib/integrations/bedrock_tracer.py

Lines changed: 23 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
"""Module with methods used to trace AWS Bedrock LLMs."""
22

3+
import io
34
import json
45
import logging
56
import time
67
from functools import wraps
7-
from typing import Any, Dict, Iterator, Optional, Union, TYPE_CHECKING
8+
from typing import TYPE_CHECKING, Any, Dict, Iterator, Optional, Union
9+
10+
from botocore.response import StreamingBody
11+
812

913
try:
1014
import boto3
@@ -89,20 +93,7 @@ def handle_non_streaming_invoke(
8993
inference_id: Optional[str] = None,
9094
**kwargs,
9195
) -> Dict[str, Any]:
92-
"""Handles the invoke_model method for non-streaming requests.
93-
94-
Parameters
95-
----------
96-
invoke_func : callable
97-
The invoke_model method to handle.
98-
inference_id : Optional[str], optional
99-
A user-generated inference id, by default None
100-
101-
Returns
102-
-------
103-
Dict[str, Any]
104-
The model invocation response.
105-
"""
96+
"""Handles the invoke_model method for non-streaming requests."""
10697
start_time = time.time()
10798
response = invoke_func(*args, **kwargs)
10899
end_time = time.time()
@@ -115,21 +106,27 @@ def handle_non_streaming_invoke(
115106
body_str = body_str.decode("utf-8")
116107
body_data = json.loads(body_str) if isinstance(body_str, str) else body_str
117108

118-
# Parse the response body
119-
response_body = response["body"].read()
120-
if isinstance(response_body, bytes):
121-
response_body = response_body.decode("utf-8")
122-
response_data = json.loads(response_body)
109+
# Read the response body ONCE and preserve it
110+
original_body = response["body"]
111+
response_body_bytes = original_body.read()
112+
113+
# Parse the response data for tracing
114+
if isinstance(response_body_bytes, bytes):
115+
response_body_str = response_body_bytes.decode("utf-8")
116+
else:
117+
response_body_str = response_body_bytes
118+
response_data = json.loads(response_body_str)
123119

124-
# Extract input and output data
120+
# Create a NEW StreamingBody with the same data and type
121+
# This preserves the exact botocore.response.StreamingBody type
122+
new_stream = io.BytesIO(response_body_bytes)
123+
response["body"] = StreamingBody(new_stream, len(response_body_bytes))
124+
125+
# Extract data for tracing
125126
inputs = extract_inputs_from_body(body_data)
126127
output_data = extract_output_data(response_data)
127-
128-
# Extract tokens and model info
129128
tokens_info = extract_tokens_info(response_data)
130129
model_id = kwargs.get("modelId", "unknown")
131-
132-
# Extract metadata including stop information
133130
metadata = extract_metadata(response_data)
134131

135132
trace_args = create_trace_args(
@@ -149,19 +146,12 @@ def handle_non_streaming_invoke(
149146

150147
add_to_trace(**trace_args)
151148

152-
# pylint: disable=broad-except
153149
except Exception as e:
154150
logger.error(
155151
"Failed to trace the Bedrock model invocation with Openlayer. %s", e
156152
)
157153

158-
# Reset response body for return (since we read it)
159-
response_bytes = json.dumps(response_data).encode("utf-8")
160-
response["body"] = type(
161-
"MockBody",
162-
(),
163-
{"read": lambda size=-1: response_bytes[:size] if size > 0 else response_bytes},
164-
)()
154+
# Return the response with the properly restored body
165155
return response
166156

167157

0 commit comments

Comments
 (0)