From f9c68f78b7d4fe5ecc705a16776e0c22827dcba3 Mon Sep 17 00:00:00 2001 From: Gustavo Cid Date: Thu, 17 Jul 2025 12:44:28 -0300 Subject: [PATCH] fix(bedrock): return identical Bedrock object --- .../lib/integrations/bedrock_tracer.py | 56 ++++++++----------- 1 file changed, 23 insertions(+), 33 deletions(-) diff --git a/src/openlayer/lib/integrations/bedrock_tracer.py b/src/openlayer/lib/integrations/bedrock_tracer.py index e1dca78c..336d7cda 100644 --- a/src/openlayer/lib/integrations/bedrock_tracer.py +++ b/src/openlayer/lib/integrations/bedrock_tracer.py @@ -1,10 +1,14 @@ """Module with methods used to trace AWS Bedrock LLMs.""" +import io import json import logging import time from functools import wraps -from typing import Any, Dict, Iterator, Optional, Union, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, Iterator, Optional, Union + +from botocore.response import StreamingBody + try: import boto3 @@ -89,20 +93,7 @@ def handle_non_streaming_invoke( inference_id: Optional[str] = None, **kwargs, ) -> Dict[str, Any]: - """Handles the invoke_model method for non-streaming requests. - - Parameters - ---------- - invoke_func : callable - The invoke_model method to handle. - inference_id : Optional[str], optional - A user-generated inference id, by default None - - Returns - ------- - Dict[str, Any] - The model invocation response. - """ + """Handles the invoke_model method for non-streaming requests.""" start_time = time.time() response = invoke_func(*args, **kwargs) end_time = time.time() @@ -115,21 +106,27 @@ def handle_non_streaming_invoke( body_str = body_str.decode("utf-8") body_data = json.loads(body_str) if isinstance(body_str, str) else body_str - # Parse the response body - response_body = response["body"].read() - if isinstance(response_body, bytes): - response_body = response_body.decode("utf-8") - response_data = json.loads(response_body) + # Read the response body ONCE and preserve it + original_body = response["body"] + response_body_bytes = original_body.read() + + # Parse the response data for tracing + if isinstance(response_body_bytes, bytes): + response_body_str = response_body_bytes.decode("utf-8") + else: + response_body_str = response_body_bytes + response_data = json.loads(response_body_str) - # Extract input and output data + # Create a NEW StreamingBody with the same data and type + # This preserves the exact botocore.response.StreamingBody type + new_stream = io.BytesIO(response_body_bytes) + response["body"] = StreamingBody(new_stream, len(response_body_bytes)) + + # Extract data for tracing inputs = extract_inputs_from_body(body_data) output_data = extract_output_data(response_data) - - # Extract tokens and model info tokens_info = extract_tokens_info(response_data) model_id = kwargs.get("modelId", "unknown") - - # Extract metadata including stop information metadata = extract_metadata(response_data) trace_args = create_trace_args( @@ -149,19 +146,12 @@ def handle_non_streaming_invoke( add_to_trace(**trace_args) - # pylint: disable=broad-except except Exception as e: logger.error( "Failed to trace the Bedrock model invocation with Openlayer. %s", e ) - # Reset response body for return (since we read it) - response_bytes = json.dumps(response_data).encode("utf-8") - response["body"] = type( - "MockBody", - (), - {"read": lambda size=-1: response_bytes[:size] if size > 0 else response_bytes}, - )() + # Return the response with the properly restored body return response