Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 23 additions & 33 deletions src/openlayer/lib/integrations/bedrock_tracer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
"""Module with methods used to trace AWS Bedrock LLMs."""

import io
import json
import logging
import time
from functools import wraps
from typing import Any, Dict, Iterator, Optional, Union, TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Dict, Iterator, Optional, Union

from botocore.response import StreamingBody


try:
import boto3
Expand Down Expand Up @@ -89,20 +93,7 @@ def handle_non_streaming_invoke(
inference_id: Optional[str] = None,
**kwargs,
) -> Dict[str, Any]:
"""Handles the invoke_model method for non-streaming requests.

Parameters
----------
invoke_func : callable
The invoke_model method to handle.
inference_id : Optional[str], optional
A user-generated inference id, by default None

Returns
-------
Dict[str, Any]
The model invocation response.
"""
"""Handles the invoke_model method for non-streaming requests."""
start_time = time.time()
response = invoke_func(*args, **kwargs)
end_time = time.time()
Expand All @@ -115,21 +106,27 @@ def handle_non_streaming_invoke(
body_str = body_str.decode("utf-8")
body_data = json.loads(body_str) if isinstance(body_str, str) else body_str

# Parse the response body
response_body = response["body"].read()
if isinstance(response_body, bytes):
response_body = response_body.decode("utf-8")
response_data = json.loads(response_body)
# Read the response body ONCE and preserve it
original_body = response["body"]
response_body_bytes = original_body.read()

# Parse the response data for tracing
if isinstance(response_body_bytes, bytes):
response_body_str = response_body_bytes.decode("utf-8")
else:
response_body_str = response_body_bytes
response_data = json.loads(response_body_str)

# Extract input and output data
# Create a NEW StreamingBody with the same data and type
# This preserves the exact botocore.response.StreamingBody type
new_stream = io.BytesIO(response_body_bytes)
response["body"] = StreamingBody(new_stream, len(response_body_bytes))

# Extract data for tracing
inputs = extract_inputs_from_body(body_data)
output_data = extract_output_data(response_data)

# Extract tokens and model info
tokens_info = extract_tokens_info(response_data)
model_id = kwargs.get("modelId", "unknown")

# Extract metadata including stop information
metadata = extract_metadata(response_data)

trace_args = create_trace_args(
Expand All @@ -149,19 +146,12 @@ def handle_non_streaming_invoke(

add_to_trace(**trace_args)

# pylint: disable=broad-except
except Exception as e:
logger.error(
"Failed to trace the Bedrock model invocation with Openlayer. %s", e
)

# Reset response body for return (since we read it)
response_bytes = json.dumps(response_data).encode("utf-8")
response["body"] = type(
"MockBody",
(),
{"read": lambda size=-1: response_bytes[:size] if size > 0 else response_bytes},
)()
# Return the response with the properly restored body
return response


Expand Down