diff --git a/src/openlayer/lib/__init__.py b/src/openlayer/lib/__init__.py index 6bf3ec9a..15bec994 100644 --- a/src/openlayer/lib/__init__.py +++ b/src/openlayer/lib/__init__.py @@ -7,12 +7,15 @@ "trace_openai_assistant_thread_run", "trace_mistral", "trace_groq", + "trace_async_openai", + "trace_async", ] # ---------------------------------- Tracing --------------------------------- # from .tracing import tracer trace = tracer.trace +trace_async = tracer.trace_async def trace_anthropic(client): diff --git a/src/openlayer/lib/integrations/async_openai_tracer.py b/src/openlayer/lib/integrations/async_openai_tracer.py index 4e65f45a..8576d575 100644 --- a/src/openlayer/lib/integrations/async_openai_tracer.py +++ b/src/openlayer/lib/integrations/async_openai_tracer.py @@ -4,7 +4,7 @@ import logging import time from functools import wraps -from typing import Any, Dict, Iterator, Optional, Union +from typing import Any, AsyncIterator, Optional, Union import openai @@ -56,7 +56,7 @@ async def traced_create_func(*args, **kwargs): stream = kwargs.get("stream", False) if stream: - return await handle_async_streaming_create( + return handle_async_streaming_create( *args, **kwargs, create_func=create_func, @@ -81,7 +81,7 @@ async def handle_async_streaming_create( is_azure_openai: bool = False, inference_id: Optional[str] = None, **kwargs, -) -> Iterator[Any]: +) -> AsyncIterator[Any]: """Handles the create method when streaming is enabled. Parameters @@ -95,25 +95,12 @@ async def handle_async_streaming_create( Returns ------- - Iterator[Any] + AsyncIterator[Any] A generator that yields the chunks of the completion. """ chunks = await create_func(*args, **kwargs) - return await stream_async_chunks( - chunks=chunks, - kwargs=kwargs, - inference_id=inference_id, - is_azure_openai=is_azure_openai, - ) - -async def stream_async_chunks( - chunks: Iterator[Any], - kwargs: Dict[str, any], - is_azure_openai: bool = False, - inference_id: Optional[str] = None, -): - """Streams the chunks of the completion and traces the completion.""" + # Create and return a new async generator that processes chunks collected_output_data = [] collected_function_call = { "name": "", @@ -143,9 +130,9 @@ async def stream_async_chunks( if delta.function_call.name: collected_function_call["name"] += delta.function_call.name if delta.function_call.arguments: - collected_function_call["arguments"] += ( - delta.function_call.arguments - ) + collected_function_call[ + "arguments" + ] += delta.function_call.arguments elif delta.tool_calls: if delta.tool_calls[0].function.name: collected_function_call["name"] += delta.tool_calls[0].function.name @@ -155,6 +142,7 @@ async def stream_async_chunks( ].function.arguments yield chunk + end_time = time.time() latency = (end_time - start_time) * 1000 # pylint: disable=broad-except diff --git a/src/openlayer/lib/integrations/openai_tracer.py b/src/openlayer/lib/integrations/openai_tracer.py index e3faab0d..3d8773c5 100644 --- a/src/openlayer/lib/integrations/openai_tracer.py +++ b/src/openlayer/lib/integrations/openai_tracer.py @@ -137,9 +137,9 @@ def stream_chunks( if delta.function_call.name: collected_function_call["name"] += delta.function_call.name if delta.function_call.arguments: - collected_function_call["arguments"] += ( - delta.function_call.arguments - ) + collected_function_call[ + "arguments" + ] += delta.function_call.arguments elif delta.tool_calls: if delta.tool_calls[0].function.name: collected_function_call["name"] += delta.tool_calls[0].function.name @@ -257,9 +257,10 @@ def add_to_trace(is_azure_openai: bool = False, **kwargs) -> None: tracer.add_chat_completion_step_to_trace( **kwargs, name="Azure OpenAI Chat Completion", provider="Azure" ) - tracer.add_chat_completion_step_to_trace( - **kwargs, name="OpenAI Chat Completion", provider="OpenAI" - ) + else: + tracer.add_chat_completion_step_to_trace( + **kwargs, name="OpenAI Chat Completion", provider="OpenAI" + ) def handle_non_streaming_create(