diff --git a/.fernignore b/.fernignore index 910bfd6b..35b23289 100644 --- a/.fernignore +++ b/.fernignore @@ -11,6 +11,7 @@ src/elevenlabs/music_custom.py src/elevenlabs/speech_to_text_custom.py src/elevenlabs/url_utils.py src/elevenlabs/realtime/ +src/elevenlabs/speech_engine/ # Ignore CI files .github/ diff --git a/README.md b/README.md index dfdd534c..5d820344 100644 --- a/README.md +++ b/README.md @@ -266,6 +266,141 @@ client_tools.register("calculate_sum", calculate_sum, is_async=False) client_tools.register("fetch_data", fetch_data, is_async=True) ``` +## Speech Engine + +Speech Engine lets you build server-side voice agents that receive real-time transcripts from the ElevenLabs API and stream LLM responses back for text-to-speech synthesis. Your server acts as a WebSocket endpoint — ElevenLabs connects to it, sends user transcripts, and your code decides how to respond. + +Speech Engine is async-only and available on `AsyncElevenLabs`. + +### Quick Start + +```python +import asyncio +from openai import AsyncOpenAI +from elevenlabs import AsyncElevenLabs + +openai_client = AsyncOpenAI() +elevenlabs = AsyncElevenLabs() + +async def main(): + engine = await elevenlabs.speech_engine.get("seng_123") + + async def on_transcript(transcript, session): + stream = await openai_client.responses.create( + model="gpt-4o", + input=[ + {"role": "assistant" if m.role == "agent" else m.role, "content": m.content} + for m in transcript + ], + stream=True, + ) + await session.send_response(stream) + + async def on_init(conversation_id, session): + print(f"Session started: {conversation_id}") + + async def on_close(session): + print(f"Session ended: {session.conversation_id}") + + async def on_error(err, session): + print(f"Error: {err}") + + await engine.serve( + port=3001, + debug=True, + on_init=on_init, + on_transcript=on_transcript, + on_close=on_close, + on_error=on_error, + ) + +asyncio.run(main()) +``` + +### How It Works + +When `engine.serve()` starts, it opens a WebSocket server on the specified port. For each incoming connection from the ElevenLabs API: + +1. An `init` message arrives with a `conversation_id` +2. As the user speaks, `user_transcript` messages arrive with the full conversation history +3. Your `on_transcript` handler generates a response (using any LLM) and calls `session.send_response()` +4. If the user interrupts (speaks again mid-response), the previous handler is automatically cancelled + +### Sending Responses + +`send_response()` accepts strings or async iterators. LLM stream formats from OpenAI, Anthropic, and Google Gemini are auto-detected: + +```python +# Plain string +await session.send_response("Hello world") + +# OpenAI stream (auto-parsed) +stream = await openai_client.responses.create(model="gpt-4o", ..., stream=True) +await session.send_response(stream) + +# Anthropic stream (auto-parsed) +stream = anthropic_client.messages.stream(model="claude-sonnet-4-20250514", ...) +await session.send_response(stream) + +# Any async iterator of strings +async def my_generator(): + yield "Hello " + yield "world" +await session.send_response(my_generator()) +``` + +### Interruption Handling + +When a new transcript arrives while a previous response is still streaming, the previous handler's `asyncio.Task` is cancelled automatically. Any `await` in your handler (including LLM SDK calls) will raise `asyncio.CancelledError`, which cleanly aborts the in-flight request. No manual signal handling needed. + +### Custom Server Integration (FastAPI, Starlette) + +For integrating with an existing web server, use `create_session()` instead of `serve()`: + +```python +from fastapi import FastAPI, WebSocket + +app = FastAPI() +engine = ... # SpeechEngineResource from await client.speech_engine.get(...) + +@app.websocket("/api/speech-engine/ws") +async def speech_engine_ws(ws: WebSocket): + await ws.accept() + session = engine.create_session(ws, debug=True) + session.on("user_transcript", handle_transcript) + await session.run() +``` + +When using `session.on()` directly, handlers receive just the event data (no `session` argument, since you already have the reference): + +| Event | Handler signature | +|---|---| +| `"init"` | `async (conversation_id: str) -> None` | +| `"user_transcript"` | `async (transcript: list[ConversationMessage]) -> None` | +| `"close"` | `async () -> None` | +| `"disconnected"` | `async () -> None` | +| `"error"` | `async (error: Exception) -> None` | + +### Standalone Server + +For full control over the server lifecycle, use `SpeechEngineServer` directly: + +```python +from elevenlabs.speech_engine import SpeechEngineServer + +server = SpeechEngineServer( + port=3001, + debug=True, + on_transcript=handle_transcript, +) + +# In one task: +await server.serve() + +# In another task (e.g. signal handler): +await server.stop() +``` + ## Languages Supported Explore [all models & languages](https://elevenlabs.io/docs/models). diff --git a/pyproject.toml b/pyproject.toml index 2e7ec3e4..f81527db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,7 @@ pydantic = ">= 1.9.2" pydantic-core = ">=2.18.2" requests = ">=2.20" typing_extensions = ">= 4.0.0" -websockets = ">=11.0" +websockets = ">=13.0" [tool.poetry.group.dev.dependencies] mypy = "==1.13.0" diff --git a/src/elevenlabs/client.py b/src/elevenlabs/client.py index 6fd75a53..e907e986 100644 --- a/src/elevenlabs/client.py +++ b/src/elevenlabs/client.py @@ -7,6 +7,7 @@ from .environment import ElevenLabsEnvironment from .music_custom import AsyncMusicClient, MusicClient from .realtime_tts import RealtimeTextToSpeechClient +from .speech_engine_custom import AsyncSpeechEngineClient, SpeechEngineClient from .speech_to_text_custom import AsyncSpeechToTextClient, SpeechToTextClient from .webhooks_custom import AsyncWebhooksClient, WebhooksClient @@ -62,6 +63,11 @@ def __init__( self._webhooks = WebhooksClient(client_wrapper=self._client_wrapper) self._music = MusicClient(client_wrapper=self._client_wrapper) self._speech_to_text = SpeechToTextClient(client_wrapper=self._client_wrapper) + self._speech_engine = SpeechEngineClient(client_wrapper=self._client_wrapper) + + @property + def speech_engine(self) -> SpeechEngineClient: + return typing.cast(SpeechEngineClient, self._speech_engine) class AsyncElevenLabs(AsyncBaseElevenLabs): @@ -107,3 +113,8 @@ def __init__( self._webhooks = AsyncWebhooksClient(client_wrapper=self._client_wrapper) self._music = AsyncMusicClient(client_wrapper=self._client_wrapper) self._speech_to_text = AsyncSpeechToTextClient(client_wrapper=self._client_wrapper) + self._speech_engine = AsyncSpeechEngineClient(client_wrapper=self._client_wrapper) + + @property + def speech_engine(self) -> AsyncSpeechEngineClient: + return typing.cast(AsyncSpeechEngineClient, self._speech_engine) diff --git a/src/elevenlabs/speech_engine/__init__.py b/src/elevenlabs/speech_engine/__init__.py index 5cde0202..7b613d7d 100644 --- a/src/elevenlabs/speech_engine/__init__.py +++ b/src/elevenlabs/speech_engine/__init__.py @@ -2,3 +2,31 @@ # isort: skip_file +"""ElevenLabs Speech Engine SDK module.""" + +from .resource import SpeechEngineResource, verify_speech_engine_jwt +from .server import SpeechEngineServer +from .session import SpeechEngineSession +from .types import ( + CLOSE, + DISCONNECTED, + ERROR, + INIT, + USER_TRANSCRIPT, + ConversationMessage, + WebSocketLike, +) + +__all__ = [ + "ConversationMessage", + "SpeechEngineResource", + "SpeechEngineServer", + "SpeechEngineSession", + "WebSocketLike", + "verify_speech_engine_jwt", + "CLOSE", + "DISCONNECTED", + "ERROR", + "INIT", + "USER_TRANSCRIPT", +] diff --git a/src/elevenlabs/speech_engine/resource.py b/src/elevenlabs/speech_engine/resource.py new file mode 100644 index 00000000..7d25bb16 --- /dev/null +++ b/src/elevenlabs/speech_engine/resource.py @@ -0,0 +1,181 @@ +"""SpeechEngineResource — client-facing handle for a speech engine instance.""" + +import base64 +import hashlib +import hmac +import json +import logging +import time +import typing + +from .server import SpeechEngineServer +from .session import SpeechEngineSession +from .types import WebSocketLike + +logger = logging.getLogger("elevenlabs.speech_engine") + +_ISSUER = "https://api.elevenlabs.io/convai/speech-engine" +_SUBJECT = "convai_speech_engine_upstream" +_LEEWAY_SECONDS = 60 + + +def _base64url_decode(data: str) -> bytes: + padded = data.replace("-", "+").replace("_", "/") + remainder = len(padded) % 4 + if remainder: + padded += "=" * (4 - remainder) + return base64.b64decode(padded) + + +def verify_speech_engine_jwt(value: str, api_key: str) -> typing.Dict[str, typing.Any]: + """Verify an HS256 JWT from the ElevenLabs Speech Engine API. + + The HMAC secret is the SHA-256 hash of the API key. Returns the + decoded payload on success, raises :class:`ValueError` on failure. + """ + token = value.strip() + if token.lower().startswith("bearer "): + token = token[7:].strip() + + parts = token.split(".") + if len(parts) != 3: + raise ValueError("Invalid JWT: expected 3 parts") + + header_b64, payload_b64, signature_b64 = parts + + try: + payload = json.loads(_base64url_decode(payload_b64)) + except Exception: + raise ValueError("Invalid JWT: failed to decode payload") + + trimmed_key = api_key.strip() + secret = hashlib.sha256(trimmed_key.encode("utf-8")).digest() + + expected_sig = hmac.new( + secret, f"{header_b64}.{payload_b64}".encode(), hashlib.sha256 + ).digest() + actual_sig = _base64url_decode(signature_b64) + + if not hmac.compare_digest(expected_sig, actual_sig): + raise ValueError( + "Invalid JWT: signature mismatch — make sure the API key used " + "by your Speech Engine server matches the one used to create " + "the engine." + ) + + if payload.get("iss") != _ISSUER: + raise ValueError( + f'Invalid JWT: expected issuer "{_ISSUER}", got "{payload.get("iss")}"' + ) + if payload.get("sub") != _SUBJECT: + raise ValueError( + f'Invalid JWT: expected subject "{_SUBJECT}", got "{payload.get("sub")}"' + ) + + now = int(time.time()) + + exp = payload.get("exp") + if not isinstance(exp, (int, float)): + raise ValueError("Invalid JWT: missing exp claim") + iat = payload.get("iat") + if not isinstance(iat, (int, float)): + raise ValueError("Invalid JWT: missing iat claim") + if exp + _LEEWAY_SECONDS < now: + raise ValueError("Invalid JWT: token has expired") + if iat - _LEEWAY_SECONDS > now: + raise ValueError("Invalid JWT: iat is in the future") + + return payload + + +class SpeechEngineResource: + """Represents a speech engine instance. + + Returned by ``await client.speech_engine.get("seng_123")``. + + Use :meth:`serve` to start a standalone WebSocket server, or + :meth:`create_session` to wrap an existing WebSocket for custom + server integration (FastAPI, Starlette, etc.). + + Example:: + + engine = await elevenlabs.speech_engine.get("seng_123") + + async def on_transcript(transcript, session): + stream = await openai.responses.create(...) + await session.send_response(stream) + + await engine.serve( + port=3001, + debug=True, + on_transcript=on_transcript, + ) + """ + + def __init__( + self, + engine_id: str, + client_wrapper: typing.Any = None, + ) -> None: + self.engine_id = engine_id + self._client_wrapper = client_wrapper + + def _get_api_key(self) -> typing.Optional[str]: + if self._client_wrapper is None: + return None + headers = self._client_wrapper.get_headers() + return headers.get("xi-api-key") + + def verify_request( + self, headers: typing.Dict[str, typing.Any] + ) -> bool: + """Verify that an incoming request is from the ElevenLabs API. + + Checks the ``X-Elevenlabs-Speech-Engine-Authorization`` header + for a valid JWT signed with the SHA-256 hash of the API key. + + Only needed when managing the WebSocket upgrade yourself. + When using :meth:`serve`, verification is handled automatically. + """ + api_key = self._get_api_key() + if not api_key: + return False + raw = headers.get("x-elevenlabs-speech-engine-authorization") + if isinstance(raw, list): + raw = raw[0] if raw else None + if not raw: + return False + try: + verify_speech_engine_jwt(raw, api_key) + return True + except ValueError: + return False + + async def serve( + self, + *, + port: int = 3001, + path: typing.Optional[str] = None, + debug: bool = False, + **handlers: typing.Any, + ) -> None: + """Start a standalone WebSocket server. Blocks until stopped.""" + api_key = self._get_api_key() + server = SpeechEngineServer( + port=port, path=path, debug=debug, api_key=api_key, **handlers + ) + await server.serve() + + def create_session( + self, + ws: WebSocketLike, + *, + debug: bool = False, + ) -> SpeechEngineSession: + """Wrap *ws* in a :class:`SpeechEngineSession`. + + Use this for custom server integration (e.g. FastAPI, Starlette). + Wire handlers via :meth:`~SpeechEngineSession.on` then ``await + session.run()``. + """ + return SpeechEngineSession(ws, debug=debug) diff --git a/src/elevenlabs/speech_engine/server.py b/src/elevenlabs/speech_engine/server.py new file mode 100644 index 00000000..adf9f252 --- /dev/null +++ b/src/elevenlabs/speech_engine/server.py @@ -0,0 +1,137 @@ +"""SpeechEngineServer — standalone WebSocket server for Speech Engine.""" + +import asyncio +import http +import os +import typing + +from .session import SpeechEngineSession, _make_log, _wire_handlers +from .types import WebSocketLike + + +class SpeechEngineServer: + """Standalone WebSocket server that produces :class:`SpeechEngineSession` + instances for each incoming connection from the ElevenLabs Speech Engine + API. + + Every incoming connection is verified against the ElevenLabs API using + the configured API key before being accepted. + + Example:: + + server = SpeechEngineServer( + port=3001, + api_key="sk_...", + debug=True, + on_transcript=handle_transcript, + ) + await server.serve() + """ + + def __init__( + self, + *, + port: int = 3001, + path: typing.Optional[str] = None, + api_key: typing.Optional[str] = None, + debug: bool = False, + **handlers: typing.Any, + ) -> None: + self._port = port + self._path = path + self._api_key = api_key + self._debug = debug + self._handlers = handlers + self._stop_event = None # type: typing.Optional[asyncio.Event] + self._server = None # type: typing.Any + self._log = _make_log(debug) + + def handle_connection(self, ws: WebSocketLike) -> SpeechEngineSession: + """Wrap *ws* in a :class:`SpeechEngineSession` with the server's + handlers wired up. + + Use this when you manage your own WebSocket server and want to wrap + individual connections. The returned session's :meth:`run` must + still be awaited by the caller. + """ + self._log("creating new session") + session = SpeechEngineSession(ws, debug=self._debug) + _wire_handlers(session, self._handlers) + return session + + async def serve(self) -> None: + """Start the WebSocket server. Blocks until :meth:`stop` is called.""" + from .resource import verify_speech_engine_jwt # noqa: E402 + + import websockets # noqa: E402 — keep import lazy + + api_key = self._api_key or os.environ.get("ELEVENLABS_API_KEY") + if not api_key: + raise RuntimeError( + "SpeechEngineServer requires an API key to verify incoming " + "connections. Pass api_key= or set the ELEVENLABS_API_KEY " + "environment variable." + ) + + self._stop_event = asyncio.Event() + + def _process_request( + connection: typing.Any, request: typing.Any + ) -> typing.Any: + if self._path is not None and request.path != self._path: + self._log( + "rejected connection — path mismatch: " + "expected %s, got %s", + self._path, + request.path, + ) + return connection.respond( + http.HTTPStatus.NOT_FOUND, "not found\n" + ) + + header_value = request.headers.get( + "x-elevenlabs-speech-engine-authorization" + ) + if not header_value: + self._log( + "rejected connection — missing " + "X-Elevenlabs-Speech-Engine-Authorization header" + ) + return connection.respond( + http.HTTPStatus.UNAUTHORIZED, + "missing authorization header\n", + ) + + try: + verify_speech_engine_jwt(header_value, api_key) + except ValueError as e: + self._log("rejected connection — %s", e) + return connection.respond( + http.HTTPStatus.UNAUTHORIZED, + "authorization failed\n", + ) + + return None + + async def _handler(websocket: typing.Any) -> None: + self._log("verified connection, accepting WebSocket") + session = self.handle_connection(websocket) + await session.run() + + self._server = await websockets.serve( # type: ignore[attr-defined] + _handler, + "", + self._port, + process_request=_process_request, + ) + self._log("speech engine server listening on port %d", self._port) + try: + await self._stop_event.wait() + finally: + self._server.close() + await self._server.wait_closed() + + async def stop(self) -> None: + """Signal the server to shut down gracefully.""" + if self._stop_event is not None: + self._stop_event.set() diff --git a/src/elevenlabs/speech_engine/session.py b/src/elevenlabs/speech_engine/session.py new file mode 100644 index 00000000..3279a966 --- /dev/null +++ b/src/elevenlabs/speech_engine/session.py @@ -0,0 +1,533 @@ +"""SpeechEngineSession — WebSocket session for Speech Engine conversations.""" + +import asyncio +import json +import logging +import typing + +from .types import ConversationMessage, WebSocketLike, wrap_websocket + +logger = logging.getLogger("elevenlabs.speech_engine") + + +def _make_log( + debug: bool, +) -> typing.Callable[..., None]: + """Return a per-instance log function, mirroring the JS SDK pattern.""" + if debug: + def _log(msg: str, *args: typing.Any) -> None: + print("[SpeechEngine]", msg % args if args else msg) + return _log + return lambda *_args, **_kw: None + +Callback = typing.Callable[..., typing.Any] + + +# --------------------------------------------------------------------------- +# LLM stream text extraction +# --------------------------------------------------------------------------- + + +def _get(obj: typing.Any, key: str) -> typing.Any: + """Attribute-or-key access helper for LLM stream chunks.""" + if isinstance(obj, dict): + return obj.get(key) + return getattr(obj, key, None) + + +def _extract_text(chunk: typing.Any) -> typing.Optional[str]: + """Extract text content from an LLM stream chunk. + + Handles plain strings and common LLM streaming formats: + + - OpenAI Responses API (``response.output_text.delta``) + - OpenAI Chat Completions API (``choices[0].delta.content``) + - Anthropic Messages API (``content_block_delta`` with ``text_delta``) + - Google Gemini API (``candidates[0].content.parts[0].text``) + """ + if isinstance(chunk, str): + return chunk + if chunk is None or isinstance(chunk, (int, float, bool)): + return None + + # OpenAI Responses API + if _get(chunk, "type") == "response.output_text.delta": + delta = _get(chunk, "delta") + if isinstance(delta, str): + return delta + + # OpenAI Chat Completions API + choices = _get(chunk, "choices") + if isinstance(choices, (list, tuple)) and len(choices) > 0: + delta = _get(choices[0], "delta") + if delta is not None: + content = _get(delta, "content") + if isinstance(content, str): + return content + + # Anthropic Messages API + if _get(chunk, "type") == "content_block_delta": + delta = _get(chunk, "delta") + if delta is not None: + if _get(delta, "type") == "text_delta": + text = _get(delta, "text") + if isinstance(text, str): + return text + + # Google Gemini API + candidates = _get(chunk, "candidates") + if isinstance(candidates, (list, tuple)) and len(candidates) > 0: + content = _get(candidates[0], "content") + if content is not None: + parts = _get(content, "parts") + if isinstance(parts, (list, tuple)) and len(parts) > 0: + text = _get(parts[0], "text") + if isinstance(text, str): + return text + + return None + + +# --------------------------------------------------------------------------- +# Handler wiring (kwargs -> event emitter) +# --------------------------------------------------------------------------- + + +def _wire_handlers( + session: "SpeechEngineSession", + handlers: typing.Dict[str, typing.Any], +) -> None: + """Wire keyword-argument handlers onto *session* events.""" + on_init = handlers.get("on_init") + on_transcript = handlers.get("on_transcript") + on_close = handlers.get("on_close") + on_disconnect = handlers.get("on_disconnect") + on_error = handlers.get("on_error") + + if on_init: + async def _init_handler(conversation_id: str) -> None: + result = on_init(conversation_id, session) + if asyncio.iscoroutine(result): + await result + session.on("init", _init_handler) + + if on_transcript: + async def _transcript_handler( + transcript: typing.List[ConversationMessage], + ) -> None: + result = on_transcript(transcript, session) + if asyncio.iscoroutine(result): + await result + session.on("user_transcript", _transcript_handler) + + if on_close: + async def _close_handler() -> None: + result = on_close(session) + if asyncio.iscoroutine(result): + await result + session.on("close", _close_handler) + + if on_disconnect: + async def _disconnect_handler() -> None: + result = on_disconnect(session) + if asyncio.iscoroutine(result): + await result + session.on("disconnected", _disconnect_handler) + + if on_error: + async def _error_handler(err: Exception) -> None: + result = on_error(err, session) + if asyncio.iscoroutine(result): + await result + session.on("error", _error_handler) + + +# --------------------------------------------------------------------------- +# SpeechEngineSession +# --------------------------------------------------------------------------- + + +class SpeechEngineSession: + """Wraps a WebSocket connection from the ElevenLabs Speech Engine API. + + Each connection represents one conversation. The session emits events + for transcripts and lifecycle changes, and provides methods to send LLM + responses back. When a new transcript arrives the previous transcript's + handler task is cancelled automatically, interrupting any in-flight LLM + call. + + Example:: + + session = SpeechEngineSession(ws, debug=True) + + async def handle(transcript): + stream = await openai.responses.create(...) + await session.send_response(stream) + + session.on("user_transcript", handle) + await session.run() + """ + + def __init__( + self, + ws: typing.Any, + *, + debug: bool = False, + ) -> None: + self._ws = wrap_websocket(ws) + self._conversation_id = None # type: typing.Optional[str] + self._current_task = None # type: typing.Optional[asyncio.Task] # type: ignore[type-arg] + self._current_event_id = None # type: typing.Optional[int] + self._in_transcript_handler = False + self._closed = False + self._event_handlers = {} # type: typing.Dict[str, typing.List[Callback]] + self._once_handlers = {} # type: typing.Dict[str, typing.List[Callback]] + self._log = _make_log(debug) + + # ------------------------------------------------------------------ + # Event emitter interface + # ------------------------------------------------------------------ + + def on(self, event: str, handler: Callback) -> "SpeechEngineSession": + """Register *handler* for *event*.""" + self._event_handlers.setdefault(event, []).append(handler) + return self + + def off(self, event: str, handler: Callback) -> "SpeechEngineSession": + """Remove *handler* from *event*.""" + handlers = self._event_handlers.get(event, []) + try: + handlers.remove(handler) + except ValueError: + pass + return self + + def once(self, event: str, handler: Callback) -> "SpeechEngineSession": + """Register *handler* for *event*, removed after the first call.""" + self._once_handlers.setdefault(event, []).append(handler) + return self + + # ------------------------------------------------------------------ + # Properties + # ------------------------------------------------------------------ + + @property + def conversation_id(self) -> typing.Optional[str]: + """The conversation ID assigned by the Speech Engine API. + + Available after the ``init`` event fires. + """ + return self._conversation_id + + @property + def is_open(self) -> bool: + """Whether the session is still open.""" + return not self._closed + + # ------------------------------------------------------------------ + # Main message loop + # ------------------------------------------------------------------ + + async def run(self) -> None: + """Run the receive loop until the WebSocket closes. + + This is the main entry point after constructing a session. It + processes incoming messages and dispatches events to registered + handlers. + """ + try: + while not self._closed: + try: + raw = await self._ws.recv() + except asyncio.CancelledError: + raise + except Exception: + self._log("WebSocket connection lost") + break + + try: + if isinstance(raw, bytes): + raw = raw.decode("utf-8") + msg = json.loads(raw) + except (ValueError, TypeError, UnicodeDecodeError) as e: + await self._emit("error", e) + continue + + if not isinstance(msg, dict): + await self._emit( + "error", + ValueError(f"expected JSON object, got {type(msg).__name__}"), + ) + continue + + await self._handle_message(msg) + except asyncio.CancelledError: + raise + finally: + if not self._closed: + self._closed = True + await self._cancel_current_and_wait() + await self._emit("disconnected") + + # ------------------------------------------------------------------ + # Sending responses + # ------------------------------------------------------------------ + + async def send_response( + self, + response: typing.Any, + ) -> None: + """Send an LLM response back for TTS synthesis. + + Accepts: + + * A plain **string** — sent as a single agent response. + * An **async iterator** yielding strings or LLM stream event objects + (OpenAI, Anthropic, Gemini formats are auto-detected). + + This method is a coroutine so the caller can ``await`` it to know + when the full response has been sent. + """ + if self._closed: + raise RuntimeError("Cannot send response: session is closed") + + if not self._in_transcript_handler: + logger.warning( + "send_response() called outside of an on_transcript handler. " + "Responses can only be sent in reply to a user transcript. " + "To have the agent speak first, set a first message in your " + "Speech Engine conversation config on the client." + ) + return + + event_id = self._current_event_id + + if isinstance(response, str): + self._log( + 'sending string response: "%s", event_id=%s', + response, + event_id, + ) + await self._send_agent_response(response, False, event_id) + await self._send_agent_response("", True, event_id) + else: + self._log( + "starting streamed response, event_id=%s", + event_id, + ) + await self._stream_response(response, event_id) + + def close(self) -> None: + """Close the session and the underlying WebSocket connection.""" + if self._closed: + return + self._closed = True + self._cancel_current() + try: + asyncio.ensure_future(self._ws.close()) + except RuntimeError: + # No running event loop — best-effort. + pass + + # ------------------------------------------------------------------ + # Internals + # ------------------------------------------------------------------ + + async def _handle_message(self, msg: typing.Dict[str, typing.Any]) -> None: + msg_type = msg.get("type") + + if msg_type == "init": + self._conversation_id = msg.get("conversation_id") + self._log( + "session initialized, conversation_id=%s", + self._conversation_id, + ) + await self._emit("init", self._conversation_id) + + elif msg_type == "user_transcript": + incoming_event_id = msg.get("event_id") + + if ( + incoming_event_id is not None + and incoming_event_id == self._current_event_id + and self._current_task is not None + and not self._current_task.done() + ): + self._log( + "skipping duplicate transcript, event_id=%s", + incoming_event_id, + ) + return + + was_active = ( + self._current_task is not None + and not self._current_task.done() + ) + await self._cancel_current_and_wait() + if was_active: + self._log( + "interrupted: cancelling previous response " + "(event_id=%s) for new transcript (event_id=%s)", + self._current_event_id, + incoming_event_id, + ) + + self._current_event_id = incoming_event_id + transcript_data = msg.get("user_transcript") or [] + + try: + transcript = [ + ConversationMessage(role=m["role"], content=m["content"]) + for m in transcript_data + ] + except (KeyError, TypeError) as e: + await self._emit("error", e) + return + + self._log( + "received transcript, event_id=%s, messages=%d", + self._current_event_id, + len(transcript), + ) + + handlers = list( + self._event_handlers.get("user_transcript", []) + ) + once_handlers = self._once_handlers.pop("user_transcript", []) + all_handlers = handlers + once_handlers + + if all_handlers: + self._current_task = asyncio.create_task( + self._run_transcript_handlers(all_handlers, transcript) + ) + # Yield so the handler task can start before the next + # message is read. This mirrors the JS behaviour where + # emitter.emit() invokes listeners synchronously. + await asyncio.sleep(0) + + elif msg_type == "ping": + await self._send({"type": "pong"}) + + elif msg_type == "close": + self._closed = True + self._in_transcript_handler = False + await self._cancel_current_and_wait() + await self._emit("close") + + elif msg_type == "error": + await self._emit("error", Exception(msg.get("message", ""))) + + # Unknown types are silently ignored for forward compatibility. + + async def _run_transcript_handlers( + self, + handlers: typing.List[Callback], + transcript: typing.List[ConversationMessage], + ) -> None: + self._in_transcript_handler = True + try: + for handler in handlers: + result = handler(transcript) + if asyncio.iscoroutine(result): + await result + except asyncio.CancelledError: + raise + except Exception as e: + await self._emit("error", e) + finally: + self._in_transcript_handler = False + + async def _stream_response( + self, stream: typing.Any, event_id: typing.Optional[int] = None + ) -> None: + chunks = 0 + try: + async for chunk in stream: + if self._closed: + self._log( + "stream stopped: session closed after %d chunks, " + "event_id=%s", + chunks, + event_id, + ) + return + text = _extract_text(chunk) + if text: + chunks += 1 + await self._send_agent_response(text, False, event_id) + if not self._closed: + self._log( + "stream complete: %d chunks sent, event_id=%s", + chunks, + event_id, + ) + await self._send_agent_response("", True, event_id) + except asyncio.CancelledError: + raise + except Exception as e: + await self._emit("error", e) + if not self._closed: + await self._send_agent_response("", True, event_id) + + async def _send_agent_response( + self, + content: str, + is_final: bool, + event_id: typing.Optional[int] = None, + ) -> None: + if event_id is None: + event_id = self._current_event_id + await self._send( + { + "type": "agent_response", + "content": content, + "event_id": event_id, + "is_final": is_final, + } + ) + + async def _send(self, msg: typing.Dict[str, typing.Any]) -> None: + if self._closed: + return + try: + await self._ws.send(json.dumps(msg)) + except asyncio.CancelledError: + raise + except Exception: + # Send failed — the recv loop will detect the closed connection. + pass + + async def _emit(self, event: str, *args: typing.Any) -> None: + handlers = list(self._event_handlers.get(event, [])) + once_handlers = self._once_handlers.pop(event, []) + all_handlers = handlers + once_handlers + + for handler in all_handlers: + try: + result = handler(*args) + if asyncio.iscoroutine(result): + await result + except asyncio.CancelledError: + raise + except Exception as e: + if event != "error": + await self._emit("error", e) + else: + logger.exception("unhandled error in error handler: %s", e) + + async def _cancel_current_and_wait(self) -> None: + """Cancel the current handler task and wait for cleanup.""" + task = self._current_task + self._current_task = None + if task is not None and not task.done(): + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): + pass + + def _cancel_current(self) -> None: + """Cancel the current handler task (fire-and-forget).""" + if self._current_task is not None and not self._current_task.done(): + self._current_task.cancel() + self._current_task = None diff --git a/src/elevenlabs/speech_engine/types.py b/src/elevenlabs/speech_engine/types.py new file mode 100644 index 00000000..0b9f676c --- /dev/null +++ b/src/elevenlabs/speech_engine/types.py @@ -0,0 +1,105 @@ +"""Types for the Speech Engine module.""" + +import typing + +import pydantic + +# --------------------------------------------------------------------------- +# Event name constants +# --------------------------------------------------------------------------- + +INIT = "init" +USER_TRANSCRIPT = "user_transcript" +CLOSE = "close" +ERROR = "error" +DISCONNECTED = "disconnected" + +# --------------------------------------------------------------------------- +# Wire protocol — incoming (ElevenLabs API -> developer server) +# --------------------------------------------------------------------------- +# +# InitMessage: {"type": "init", "conversation_id": "..."} +# UserTranscriptMessage: {"type": "user_transcript", "user_transcript": [...], "event_id": N} +# PingMessage: {"type": "ping"} +# CloseMessage: {"type": "close"} +# ErrorMessage: {"type": "error", "message": "..."} +# +# --------------------------------------------------------------------------- +# Wire protocol — outgoing (developer server -> ElevenLabs API) +# --------------------------------------------------------------------------- +# +# AgentResponseMessage: {"type": "agent_response", "content": "...", "event_id": N, "is_final": bool} +# PongMessage: {"type": "pong"} +# + +# --------------------------------------------------------------------------- +# ConversationMessage +# --------------------------------------------------------------------------- + + +class ConversationMessage(pydantic.BaseModel): + """A single message in a speech engine conversation. + + Attributes: + role: Either ``"user"`` or ``"agent"``. + content: The text content of the message. + """ + + role: str + content: str + + +# --------------------------------------------------------------------------- +# WebSocket abstraction +# --------------------------------------------------------------------------- + + +class WebSocketLike(typing.Protocol): + """Protocol for WebSocket-like objects. + + Compatible with ``websockets.WebSocketServerProtocol`` and + FastAPI/Starlette ``WebSocket`` out of the box (auto-detected). + """ + + async def recv(self) -> typing.Union[str, bytes]: + ... # pragma: no cover + + async def send(self, data: str) -> None: + ... # pragma: no cover + + async def close(self) -> None: + ... # pragma: no cover + + +class _ASGIWebSocketAdapter: + """Adapts a FastAPI/Starlette WebSocket to the :class:`WebSocketLike` + interface expected by :class:`~.session.SpeechEngineSession`.""" + + def __init__(self, ws: typing.Any) -> None: + self._ws = ws + + async def recv(self) -> typing.Union[str, bytes]: + return await self._ws.receive_text() + + async def send(self, data: str) -> None: + await self._ws.send_text(data) + + async def close(self) -> None: + await self._ws.close() + + +def wrap_websocket(ws: typing.Any) -> WebSocketLike: + """Return a :class:`WebSocketLike` wrapper for *ws*. + + If *ws* already has ``recv``/``send`` (e.g. the ``websockets`` library), + it is returned as-is. If it has ``receive_text``/``send_text`` (e.g. + FastAPI/Starlette), it is wrapped with :class:`_ASGIWebSocketAdapter`. + """ + if hasattr(ws, "recv"): + return ws + if hasattr(ws, "receive_text"): + return _ASGIWebSocketAdapter(ws) + raise TypeError( + f"Cannot wrap {type(ws).__name__}: expected a websockets-style " + f"object (recv/send) or an ASGI-style object (receive_text/send_text)" + ) diff --git a/src/elevenlabs/speech_engine_custom.py b/src/elevenlabs/speech_engine_custom.py new file mode 100644 index 00000000..18115838 --- /dev/null +++ b/src/elevenlabs/speech_engine_custom.py @@ -0,0 +1,231 @@ +import typing + +from .core import RequestOptions +from .speech_engine.client import AsyncSpeechEngineClient as AutogeneratedAsyncSpeechEngineClient +from .speech_engine.client import SpeechEngineClient as AutogeneratedSpeechEngineClient +from .speech_engine.resource import SpeechEngineResource +from .types.agent_call_limits import AgentCallLimits +from .types.asr_conversational_config import AsrConversationalConfig +from .types.base_turn_config import BaseTurnConfig +from .types.conversation_config_input import ConversationConfigInput +from .types.privacy_config_input import PrivacyConfigInput +from .types.speech_engine_config import SpeechEngineConfig +from .types.tts_conversational_config_input import TtsConversationalConfigInput + +OMIT = typing.cast(typing.Any, ...) + + +class SpeechEngineClient(AutogeneratedSpeechEngineClient): + """Extends the generated SpeechEngineClient with WebSocket server integration.""" + + def create( # type: ignore[override] + self, + *, + speech_engine: SpeechEngineConfig, + name: typing.Optional[str] = OMIT, + asr: typing.Optional[AsrConversationalConfig] = OMIT, + tts: typing.Optional[TtsConversationalConfigInput] = OMIT, + turn: typing.Optional[BaseTurnConfig] = OMIT, + conversation: typing.Optional[ConversationConfigInput] = OMIT, + privacy: typing.Optional[PrivacyConfigInput] = OMIT, + call_limits: typing.Optional[AgentCallLimits] = OMIT, + language: typing.Optional[str] = OMIT, + tags: typing.Optional[typing.Sequence[str]] = OMIT, + request_options: typing.Optional[RequestOptions] = None, + ) -> SpeechEngineResource: + """Create a Speech Engine resource. + + Makes an API call to create the engine, then returns a + :class:`SpeechEngineResource` with WebSocket integration methods + (:meth:`~SpeechEngineResource.serve`, + :meth:`~SpeechEngineResource.create_session`, + :meth:`~SpeechEngineResource.verify_request`). + """ + response = super().create( + speech_engine=speech_engine, + name=name, + asr=asr, + tts=tts, + turn=turn, + conversation=conversation, + privacy=privacy, + call_limits=call_limits, + language=language, + tags=tags, + request_options=request_options, + ) + return SpeechEngineResource( + engine_id=response.speech_engine_id, + client_wrapper=self._raw_client._client_wrapper, + ) + + def get( # type: ignore[override] + self, + speech_engine_id: str, + *, + request_options: typing.Optional[RequestOptions] = None, + ) -> SpeechEngineResource: + """Retrieve a Speech Engine resource. + + Makes an API call to validate the engine exists, then returns a + :class:`SpeechEngineResource` with WebSocket integration methods + (:meth:`~SpeechEngineResource.serve`, + :meth:`~SpeechEngineResource.create_session`, + :meth:`~SpeechEngineResource.verify_request`). + """ + super().get(speech_engine_id, request_options=request_options) + return SpeechEngineResource( + engine_id=speech_engine_id, + client_wrapper=self._raw_client._client_wrapper, + ) + + def update( # type: ignore[override] + self, + speech_engine_id: str, + *, + name: typing.Optional[str] = OMIT, + speech_engine: typing.Optional[SpeechEngineConfig] = OMIT, + asr: typing.Optional[AsrConversationalConfig] = OMIT, + tts: typing.Optional[TtsConversationalConfigInput] = OMIT, + turn: typing.Optional[BaseTurnConfig] = OMIT, + conversation: typing.Optional[ConversationConfigInput] = OMIT, + privacy: typing.Optional[PrivacyConfigInput] = OMIT, + call_limits: typing.Optional[AgentCallLimits] = OMIT, + language: typing.Optional[str] = OMIT, + tags: typing.Optional[typing.Sequence[str]] = OMIT, + request_options: typing.Optional[RequestOptions] = None, + ) -> SpeechEngineResource: + """Update a Speech Engine resource. + + Makes an API call to update the engine, then returns a + :class:`SpeechEngineResource` with WebSocket integration methods + (:meth:`~SpeechEngineResource.serve`, + :meth:`~SpeechEngineResource.create_session`, + :meth:`~SpeechEngineResource.verify_request`). + """ + super().update( + speech_engine_id, + name=name, + speech_engine=speech_engine, + asr=asr, + tts=tts, + turn=turn, + conversation=conversation, + privacy=privacy, + call_limits=call_limits, + language=language, + tags=tags, + request_options=request_options, + ) + return SpeechEngineResource( + engine_id=speech_engine_id, + client_wrapper=self._raw_client._client_wrapper, + ) + + +class AsyncSpeechEngineClient(AutogeneratedAsyncSpeechEngineClient): + """Extends the generated AsyncSpeechEngineClient with WebSocket server integration.""" + + async def create( # type: ignore[override] + self, + *, + speech_engine: SpeechEngineConfig, + name: typing.Optional[str] = OMIT, + asr: typing.Optional[AsrConversationalConfig] = OMIT, + tts: typing.Optional[TtsConversationalConfigInput] = OMIT, + turn: typing.Optional[BaseTurnConfig] = OMIT, + conversation: typing.Optional[ConversationConfigInput] = OMIT, + privacy: typing.Optional[PrivacyConfigInput] = OMIT, + call_limits: typing.Optional[AgentCallLimits] = OMIT, + language: typing.Optional[str] = OMIT, + tags: typing.Optional[typing.Sequence[str]] = OMIT, + request_options: typing.Optional[RequestOptions] = None, + ) -> SpeechEngineResource: + """Create a Speech Engine resource. + + Makes an API call to create the engine, then returns a + :class:`SpeechEngineResource` with WebSocket integration methods + (:meth:`~SpeechEngineResource.serve`, + :meth:`~SpeechEngineResource.create_session`, + :meth:`~SpeechEngineResource.verify_request`). + """ + response = await super().create( + speech_engine=speech_engine, + name=name, + asr=asr, + tts=tts, + turn=turn, + conversation=conversation, + privacy=privacy, + call_limits=call_limits, + language=language, + tags=tags, + request_options=request_options, + ) + return SpeechEngineResource( + engine_id=response.speech_engine_id, + client_wrapper=self._raw_client._client_wrapper, + ) + + async def get( # type: ignore[override] + self, + speech_engine_id: str, + *, + request_options: typing.Optional[RequestOptions] = None, + ) -> SpeechEngineResource: + """Retrieve a Speech Engine resource. + + Makes an API call to validate the engine exists, then returns a + :class:`SpeechEngineResource` with WebSocket integration methods + (:meth:`~SpeechEngineResource.serve`, + :meth:`~SpeechEngineResource.create_session`, + :meth:`~SpeechEngineResource.verify_request`). + """ + await super().get(speech_engine_id, request_options=request_options) + return SpeechEngineResource( + engine_id=speech_engine_id, + client_wrapper=self._raw_client._client_wrapper, + ) + + async def update( # type: ignore[override] + self, + speech_engine_id: str, + *, + name: typing.Optional[str] = OMIT, + speech_engine: typing.Optional[SpeechEngineConfig] = OMIT, + asr: typing.Optional[AsrConversationalConfig] = OMIT, + tts: typing.Optional[TtsConversationalConfigInput] = OMIT, + turn: typing.Optional[BaseTurnConfig] = OMIT, + conversation: typing.Optional[ConversationConfigInput] = OMIT, + privacy: typing.Optional[PrivacyConfigInput] = OMIT, + call_limits: typing.Optional[AgentCallLimits] = OMIT, + language: typing.Optional[str] = OMIT, + tags: typing.Optional[typing.Sequence[str]] = OMIT, + request_options: typing.Optional[RequestOptions] = None, + ) -> SpeechEngineResource: + """Update a Speech Engine resource. + + Makes an API call to update the engine, then returns a + :class:`SpeechEngineResource` with WebSocket integration methods + (:meth:`~SpeechEngineResource.serve`, + :meth:`~SpeechEngineResource.create_session`, + :meth:`~SpeechEngineResource.verify_request`). + """ + await super().update( + speech_engine_id, + name=name, + speech_engine=speech_engine, + asr=asr, + tts=tts, + turn=turn, + conversation=conversation, + privacy=privacy, + call_limits=call_limits, + language=language, + tags=tags, + request_options=request_options, + ) + return SpeechEngineResource( + engine_id=speech_engine_id, + client_wrapper=self._raw_client._client_wrapper, + ) diff --git a/tests/test_speech_engine_auth.py b/tests/test_speech_engine_auth.py new file mode 100644 index 00000000..856a6fff --- /dev/null +++ b/tests/test_speech_engine_auth.py @@ -0,0 +1,196 @@ +"""Tests for Speech Engine JWT verification.""" + +import base64 +import hashlib +import hmac +import json +import time +import typing + +import pytest + +from elevenlabs.speech_engine.resource import ( + SpeechEngineResource, + verify_speech_engine_jwt, +) + +TEST_API_KEY = "test-key" +JWT_ISSUER = "https://api.elevenlabs.io/convai/speech-engine" +JWT_SUBJECT = "convai_speech_engine_upstream" + + +# --------------------------------------------------------------------------- +# JWT test helpers +# --------------------------------------------------------------------------- + + +def _base64url_encode(data: bytes) -> str: + return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii") + + +def _create_test_jwt( + payload: typing.Dict[str, typing.Any], + api_key: str = TEST_API_KEY, +) -> str: + header = {"alg": "HS256", "typ": "JWT"} + header_b64 = _base64url_encode(json.dumps(header).encode()) + payload_b64 = _base64url_encode(json.dumps(payload).encode()) + secret = hashlib.sha256(api_key.encode("utf-8")).digest() + signature = hmac.new( + secret, f"{header_b64}.{payload_b64}".encode(), hashlib.sha256 + ).digest() + return f"{header_b64}.{payload_b64}.{_base64url_encode(signature)}" + + +def _valid_payload(**overrides: typing.Any) -> typing.Dict[str, typing.Any]: + now = int(time.time()) + return { + "iss": JWT_ISSUER, + "sub": JWT_SUBJECT, + "iat": now, + "exp": now + 300, + **overrides, + } + + +# --------------------------------------------------------------------------- +# verify_speech_engine_jwt +# --------------------------------------------------------------------------- + + +class TestVerifySpeechEngineJwt: + def test_valid_token(self) -> None: + token = _create_test_jwt(_valid_payload()) + payload = verify_speech_engine_jwt(token, TEST_API_KEY) + assert payload["iss"] == JWT_ISSUER + assert payload["sub"] == JWT_SUBJECT + + def test_accepts_bearer_prefix(self) -> None: + token = _create_test_jwt(_valid_payload()) + payload = verify_speech_engine_jwt(f"Bearer {token}", TEST_API_KEY) + assert payload["iss"] == JWT_ISSUER + + def test_rejects_wrong_key(self) -> None: + token = _create_test_jwt(_valid_payload(), api_key="other-key") + with pytest.raises(ValueError, match="signature mismatch"): + verify_speech_engine_jwt(token, TEST_API_KEY) + + def test_rejects_wrong_issuer(self) -> None: + token = _create_test_jwt(_valid_payload(iss="https://evil.com")) + with pytest.raises(ValueError, match="expected issuer"): + verify_speech_engine_jwt(token, TEST_API_KEY) + + def test_rejects_wrong_subject(self) -> None: + token = _create_test_jwt(_valid_payload(sub="wrong_subject")) + with pytest.raises(ValueError, match="expected subject"): + verify_speech_engine_jwt(token, TEST_API_KEY) + + def test_rejects_expired_token_beyond_leeway(self) -> None: + now = int(time.time()) + token = _create_test_jwt( + _valid_payload(exp=now - 120, iat=now - 420) + ) + with pytest.raises(ValueError, match="expired"): + verify_speech_engine_jwt(token, TEST_API_KEY) + + def test_accepts_expired_within_leeway(self) -> None: + now = int(time.time()) + token = _create_test_jwt( + _valid_payload(exp=now - 30, iat=now - 330) + ) + payload = verify_speech_engine_jwt(token, TEST_API_KEY) + assert payload["iss"] == JWT_ISSUER + + def test_rejects_future_iat_beyond_leeway(self) -> None: + now = int(time.time()) + token = _create_test_jwt( + _valid_payload(iat=now + 120, exp=now + 420) + ) + with pytest.raises(ValueError, match="iat is in the future"): + verify_speech_engine_jwt(token, TEST_API_KEY) + + def test_accepts_future_iat_within_leeway(self) -> None: + now = int(time.time()) + token = _create_test_jwt( + _valid_payload(iat=now + 30, exp=now + 330) + ) + payload = verify_speech_engine_jwt(token, TEST_API_KEY) + assert payload["iss"] == JWT_ISSUER + + def test_rejects_malformed_token(self) -> None: + with pytest.raises(ValueError, match="expected 3 parts"): + verify_speech_engine_jwt("not.a.valid.jwt.token", TEST_API_KEY) + + def test_rejects_missing_exp(self) -> None: + payload = _valid_payload() + del payload["exp"] + token = _create_test_jwt(payload) + with pytest.raises(ValueError, match="missing exp"): + verify_speech_engine_jwt(token, TEST_API_KEY) + + def test_rejects_missing_iat(self) -> None: + payload = _valid_payload() + del payload["iat"] + token = _create_test_jwt(payload) + with pytest.raises(ValueError, match="missing iat"): + verify_speech_engine_jwt(token, TEST_API_KEY) + + +# --------------------------------------------------------------------------- +# SpeechEngineResource.verify_request +# --------------------------------------------------------------------------- + + +class _FakeClientWrapper: + def __init__(self, api_key: str) -> None: + self._api_key = api_key + + def get_headers(self) -> dict: + return {"xi-api-key": self._api_key} + + +class TestVerifyRequest: + def test_valid_header(self) -> None: + resource = SpeechEngineResource( + "seng_test", client_wrapper=_FakeClientWrapper(TEST_API_KEY) + ) + token = _create_test_jwt(_valid_payload()) + assert resource.verify_request( + {"x-elevenlabs-speech-engine-authorization": token} + ) + + def test_missing_header(self) -> None: + resource = SpeechEngineResource( + "seng_test", client_wrapper=_FakeClientWrapper(TEST_API_KEY) + ) + assert not resource.verify_request({}) + + def test_no_api_key(self) -> None: + resource = SpeechEngineResource("seng_test") + token = _create_test_jwt(_valid_payload()) + assert not resource.verify_request( + {"x-elevenlabs-speech-engine-authorization": token} + ) + + def test_invalid_token(self) -> None: + resource = SpeechEngineResource( + "seng_test", client_wrapper=_FakeClientWrapper(TEST_API_KEY) + ) + assert not resource.verify_request( + {"x-elevenlabs-speech-engine-authorization": "bad-token"} + ) + + +# --------------------------------------------------------------------------- +# SpeechEngineServer — api_key requirement +# --------------------------------------------------------------------------- + + +class TestServerApiKeyRequirement: + @pytest.mark.asyncio + async def test_raises_without_api_key(self) -> None: + from elevenlabs.speech_engine import SpeechEngineServer + + server = SpeechEngineServer(port=0) + with pytest.raises(RuntimeError, match="API key"): + await server.serve() diff --git a/tests/test_speech_engine_custom.py b/tests/test_speech_engine_custom.py new file mode 100644 index 00000000..71247ac9 --- /dev/null +++ b/tests/test_speech_engine_custom.py @@ -0,0 +1,95 @@ +"""Tests for SpeechEngineClient and AsyncSpeechEngineClient wrapper methods.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from elevenlabs.speech_engine.client import AsyncSpeechEngineClient as AutogeneratedAsyncSpeechEngineClient +from elevenlabs.speech_engine.client import SpeechEngineClient as AutogeneratedSpeechEngineClient +from elevenlabs.speech_engine.resource import SpeechEngineResource +from elevenlabs.speech_engine_custom import AsyncSpeechEngineClient, SpeechEngineClient +from elevenlabs.types.create_speech_engine_response import CreateSpeechEngineResponse +from elevenlabs.types.speech_engine_config import SpeechEngineConfig + + +def _make_sync_client() -> SpeechEngineClient: + return SpeechEngineClient(client_wrapper=MagicMock()) + + +def _make_async_client() -> AsyncSpeechEngineClient: + return AsyncSpeechEngineClient(client_wrapper=MagicMock()) + + +# --------------------------------------------------------------------------- +# Sync +# --------------------------------------------------------------------------- + + +def test_create_returns_resource() -> None: + client = _make_sync_client() + mock_response = CreateSpeechEngineResponse(speech_engine_id="seng_abc") + + with patch.object(AutogeneratedSpeechEngineClient, "create", return_value=mock_response): + result = client.create(speech_engine=SpeechEngineConfig(ws_url="wss://test")) + + assert isinstance(result, SpeechEngineResource) + assert result.engine_id == "seng_abc" + + +def test_get_returns_resource() -> None: + client = _make_sync_client() + + with patch.object(AutogeneratedSpeechEngineClient, "get", return_value=None): + result = client.get("seng_abc") + + assert isinstance(result, SpeechEngineResource) + assert result.engine_id == "seng_abc" + + +def test_update_returns_resource() -> None: + client = _make_sync_client() + + with patch.object(AutogeneratedSpeechEngineClient, "update", return_value=None): + result = client.update("seng_abc", name="Renamed") + + assert isinstance(result, SpeechEngineResource) + assert result.engine_id == "seng_abc" + + +# --------------------------------------------------------------------------- +# Async +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_async_create_returns_resource() -> None: + client = _make_async_client() + mock_response = CreateSpeechEngineResponse(speech_engine_id="seng_abc") + + with patch.object(AutogeneratedAsyncSpeechEngineClient, "create", new_callable=AsyncMock, return_value=mock_response): + result = await client.create(speech_engine=SpeechEngineConfig(ws_url="wss://test")) + + assert isinstance(result, SpeechEngineResource) + assert result.engine_id == "seng_abc" + + +@pytest.mark.asyncio +async def test_async_get_returns_resource() -> None: + client = _make_async_client() + + with patch.object(AutogeneratedAsyncSpeechEngineClient, "get", new_callable=AsyncMock, return_value=None): + result = await client.get("seng_abc") + + assert isinstance(result, SpeechEngineResource) + assert result.engine_id == "seng_abc" + + +@pytest.mark.asyncio +async def test_async_update_returns_resource() -> None: + client = _make_async_client() + + with patch.object(AutogeneratedAsyncSpeechEngineClient, "update", new_callable=AsyncMock, return_value=None): + result = await client.update("seng_abc", name="Renamed") + + assert isinstance(result, SpeechEngineResource) + assert result.engine_id == "seng_abc" diff --git a/tests/test_speech_engine_resource.py b/tests/test_speech_engine_resource.py new file mode 100644 index 00000000..a3ee6aa2 --- /dev/null +++ b/tests/test_speech_engine_resource.py @@ -0,0 +1,107 @@ +"""Tests for SpeechEngineResource — mirrors SpeechEngineResource.test.ts.""" + +import asyncio +import json +import typing + +import pytest + +from elevenlabs.speech_engine import SpeechEngineResource, SpeechEngineSession + + +# --------------------------------------------------------------------------- +# MockWebSocket (same as session tests) +# --------------------------------------------------------------------------- + +_CLOSE_SENTINEL = object() + + +class MockWebSocket: + def __init__(self) -> None: + self._inbox = asyncio.Queue() # type: asyncio.Queue[typing.Any] + self.sent = [] # type: typing.List[str] + self.closed = False + + async def recv(self) -> str: + msg = await self._inbox.get() + if msg is _CLOSE_SENTINEL: + raise ConnectionError("connection closed") + return msg + + async def send(self, data: str) -> None: + self.sent.append(data) + + async def close(self) -> None: + self.closed = True + + def receive_message(self, msg: typing.Dict[str, typing.Any]) -> None: + self._inbox.put_nowait(json.dumps(msg)) + + def simulate_disconnect(self) -> None: + self._inbox.put_nowait(_CLOSE_SENTINEL) + + +# --------------------------------------------------------------------------- +# create_session +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_create_session_returns_speech_engine_session() -> None: + resource = SpeechEngineResource("seng_test") + ws = MockWebSocket() + session = resource.create_session(ws) + assert isinstance(session, SpeechEngineSession) + + +@pytest.mark.asyncio +async def test_create_session_protocol_works() -> None: + """Full send/receive cycle through a session created by the resource.""" + resource = SpeechEngineResource("seng_test") + ws = MockWebSocket() + session = resource.create_session(ws, debug=False) + + async def handler(transcript: typing.Any) -> None: + last = transcript[-1] + await session.send_response("echo: {}".format(last.content)) + + session.on("user_transcript", handler) + + ws.receive_message( + { + "type": "user_transcript", + "user_transcript": [{"role": "user", "content": "hello"}], + "event_id": 1, + } + ) + ws.simulate_disconnect() + await session.run() + + sent = [json.loads(s) for s in ws.sent] + assert sent[0] == { + "type": "agent_response", + "content": "echo: hello", + "event_id": 1, + "is_final": False, + } + + +# --------------------------------------------------------------------------- +# client accessor stub +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_async_client_speech_engine_get() -> None: + """The AsyncElevenLabs.speech_engine.get() returns a SpeechEngineResource.""" + from unittest.mock import AsyncMock, patch + + from elevenlabs import AsyncElevenLabs + from elevenlabs.speech_engine.client import AsyncSpeechEngineClient as AutogeneratedAsyncSpeechEngineClient + + client = AsyncElevenLabs(api_key="test-key") + with patch.object(AutogeneratedAsyncSpeechEngineClient, "get", new_callable=AsyncMock, return_value=None): + resource = await client.speech_engine.get("seng_123") + + assert isinstance(resource, SpeechEngineResource) + assert resource.engine_id == "seng_123" diff --git a/tests/test_speech_engine_server.py b/tests/test_speech_engine_server.py new file mode 100644 index 00000000..fb0dbff4 --- /dev/null +++ b/tests/test_speech_engine_server.py @@ -0,0 +1,138 @@ +"""Tests for SpeechEngineServer — mirrors SpeechEngineServer.test.ts.""" + +import asyncio +import json +import typing + +import pytest +import websockets + +from elevenlabs.speech_engine import SpeechEngineServer + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +async def _connect_and_send( + port: int, + messages: typing.List[typing.Dict[str, typing.Any]], +) -> typing.List[typing.Dict[str, typing.Any]]: + """Open a client WS, send *messages*, and collect the first response.""" + uri = "ws://127.0.0.1:{}".format(port) + received = [] # type: typing.List[typing.Dict[str, typing.Any]] + async with websockets.connect(uri) as ws: # type: ignore[attr-defined] + for msg in messages: + await ws.send(json.dumps(msg)) + # Give the server a moment to process and respond. + try: + while True: + raw = await asyncio.wait_for(ws.recv(), timeout=0.2) + received.append(json.loads(raw)) + except (asyncio.TimeoutError, websockets.exceptions.ConnectionClosed): + pass + return received + + +# --------------------------------------------------------------------------- +# handleConnection +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_handle_connection_wraps_ws_and_calls_on_init() -> None: + init_ids = [] # type: typing.List[str] + + async def on_init(conversation_id: str, session: typing.Any) -> None: + init_ids.append(conversation_id) + + server = SpeechEngineServer(port=0, on_init=on_init) + + # Use port 0 to get an ephemeral port; start via websockets directly. + started = asyncio.Event() + actual_port = 0 + + async def _handler(ws: typing.Any, *_args: typing.Any) -> None: + session = server.handle_connection(ws) + await session.run() + + ws_server = await websockets.serve(_handler, "127.0.0.1", 0) # type: ignore[attr-defined] + for sock in ws_server.sockets: + actual_port = sock.getsockname()[1] + break + + try: + uri = "ws://127.0.0.1:{}".format(actual_port) + async with websockets.connect(uri) as client: # type: ignore[attr-defined] + await client.send( + json.dumps({"type": "init", "conversation_id": "conv_1"}) + ) + await asyncio.sleep(0.1) + finally: + ws_server.close() + await ws_server.wait_closed() + + assert init_ids == ["conv_1"] + + +# --------------------------------------------------------------------------- +# Session responses are received by the client +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_session_responses_received_by_client() -> None: + async def on_transcript( + transcript: typing.Any, session: typing.Any + ) -> None: + last = transcript[-1] + await session.send_response("echo: {}".format(last.content)) + + server = SpeechEngineServer(port=0, on_transcript=on_transcript) + + async def _handler(ws: typing.Any, *_args: typing.Any) -> None: + session = server.handle_connection(ws) + await session.run() + + ws_server = await websockets.serve(_handler, "127.0.0.1", 0) # type: ignore[attr-defined] + actual_port = 0 + for sock in ws_server.sockets: + actual_port = sock.getsockname()[1] + break + + try: + uri = "ws://127.0.0.1:{}".format(actual_port) + async with websockets.connect(uri) as client: # type: ignore[attr-defined] + await client.send( + json.dumps( + { + "type": "user_transcript", + "user_transcript": [{"role": "user", "content": "ping"}], + "event_id": 1, + } + ) + ) + raw = await asyncio.wait_for(client.recv(), timeout=1.0) + response = json.loads(raw) + finally: + ws_server.close() + await ws_server.wait_closed() + + assert response == { + "type": "agent_response", + "content": "echo: ping", + "event_id": 1, + "is_final": False, + } + + +# --------------------------------------------------------------------------- +# Lifecycle +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_stop_resolves_when_no_server_running() -> None: + server = SpeechEngineServer() + await server.stop() # Should not raise. diff --git a/tests/test_speech_engine_session.py b/tests/test_speech_engine_session.py new file mode 100644 index 00000000..17ff3223 --- /dev/null +++ b/tests/test_speech_engine_session.py @@ -0,0 +1,876 @@ +"""Tests for SpeechEngineSession — mirrors SpeechEngineSession.test.ts.""" + +import asyncio +import json +import typing + +import pytest + +from elevenlabs.speech_engine import ( + CLOSE, + DISCONNECTED, + ERROR, + INIT, + USER_TRANSCRIPT, + ConversationMessage, + SpeechEngineSession, +) + + +# --------------------------------------------------------------------------- +# MockWebSocket +# --------------------------------------------------------------------------- + + +class MockWebSocket: + """In-memory WebSocket stand-in backed by an asyncio.Queue.""" + + def __init__(self) -> None: + self._inbox = asyncio.Queue() # type: asyncio.Queue[typing.Any] + self.sent = [] # type: typing.List[str] + self.closed = False + + async def recv(self) -> str: + msg = await self._inbox.get() + if msg is _CLOSE_SENTINEL: + raise ConnectionError("connection closed") + return msg + + async def send(self, data: str) -> None: + self.sent.append(data) + + async def close(self) -> None: + self.closed = True + + # -- test helpers -- + + def receive_message(self, msg: typing.Dict[str, typing.Any]) -> None: + """Inject an incoming message from the "ElevenLabs API".""" + self._inbox.put_nowait(json.dumps(msg)) + + def receive_raw(self, data: str) -> None: + """Inject raw (possibly invalid) data.""" + self._inbox.put_nowait(data) + + def simulate_disconnect(self) -> None: + """Simulate a WebSocket disconnection.""" + self._inbox.put_nowait(_CLOSE_SENTINEL) + + +_CLOSE_SENTINEL = object() + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +TRANSCRIPT = [ + {"role": "agent", "content": "how can I help you today?"}, + {"role": "user", "content": "I need a pizza"}, +] + +TRANSCRIPT_2 = [ + {"role": "agent", "content": "how can I help you today?"}, + {"role": "user", "content": "I need a pizza"}, + {"role": "agent", "content": "what size?"}, + {"role": "user", "content": "large"}, +] + + +@pytest.fixture +def ws() -> MockWebSocket: + return MockWebSocket() + + +@pytest.fixture +def session(ws: MockWebSocket) -> SpeechEngineSession: + return SpeechEngineSession(ws) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _parsed_sent(ws: MockWebSocket) -> typing.List[typing.Dict[str, typing.Any]]: + return [json.loads(s) for s in ws.sent] + + +async def _run_until_idle(session: SpeechEngineSession, ws: MockWebSocket) -> None: + """Drive the session run loop, then disconnect so it returns.""" + ws.simulate_disconnect() + await session.run() + + +# --------------------------------------------------------------------------- +# init event +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_emits_init_with_conversation_id( + ws: MockWebSocket, session: SpeechEngineSession +) -> None: + captured = [] # type: typing.List[str] + + async def handler(conversation_id: str) -> None: + captured.append(conversation_id) + + session.on(INIT, handler) + ws.receive_message({"type": "init", "conversation_id": "conv_42"}) + await _run_until_idle(session, ws) + + assert captured == ["conv_42"] + assert session.conversation_id == "conv_42" + + +# --------------------------------------------------------------------------- +# user_transcript events +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_emits_user_transcript_with_conversation_history( + ws: MockWebSocket, session: SpeechEngineSession +) -> None: + captured = [] # type: typing.List[typing.List[ConversationMessage]] + + async def handler(transcript: typing.List[ConversationMessage]) -> None: + captured.append(transcript) + + session.on(USER_TRANSCRIPT, handler) + ws.receive_message( + {"type": "user_transcript", "user_transcript": TRANSCRIPT, "event_id": 1} + ) + await _run_until_idle(session, ws) + + assert len(captured) == 1 + assert len(captured[0]) == 2 + assert captured[0][0].role == "agent" + assert captured[0][0].content == "how can I help you today?" + assert captured[0][1].role == "user" + assert captured[0][1].content == "I need a pizza" + + +@pytest.mark.asyncio +async def test_cancels_previous_handler_on_new_transcript( + ws: MockWebSocket, session: SpeechEngineSession +) -> None: + """When a new transcript arrives the previous handler task is cancelled.""" + cancelled = asyncio.Event() + + async def slow_handler(transcript: typing.List[ConversationMessage]) -> None: + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + cancelled.set() + raise + + session.on(USER_TRANSCRIPT, slow_handler) + ws.receive_message( + {"type": "user_transcript", "user_transcript": TRANSCRIPT, "event_id": 1} + ) + ws.receive_message( + {"type": "user_transcript", "user_transcript": TRANSCRIPT_2, "event_id": 2} + ) + await _run_until_idle(session, ws) + + assert cancelled.is_set() + + +# --------------------------------------------------------------------------- +# ping / pong +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_auto_responds_to_ping( + ws: MockWebSocket, session: SpeechEngineSession +) -> None: + ws.receive_message({"type": "ping"}) + await _run_until_idle(session, ws) + + assert len(ws.sent) == 1 + assert json.loads(ws.sent[0]) == {"type": "pong"} + + +# --------------------------------------------------------------------------- +# close event (protocol-level) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_emits_close_and_cancels_handler_on_close_message( + ws: MockWebSocket, session: SpeechEngineSession +) -> None: + close_called = asyncio.Event() + cancelled = asyncio.Event() + + async def close_handler() -> None: + close_called.set() + + async def slow_handler(transcript: typing.List[ConversationMessage]) -> None: + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + cancelled.set() + raise + + session.on(CLOSE, close_handler) + session.on(USER_TRANSCRIPT, slow_handler) + + ws.receive_message( + {"type": "user_transcript", "user_transcript": TRANSCRIPT, "event_id": 1} + ) + ws.receive_message({"type": "close"}) + await _run_until_idle(session, ws) + + assert close_called.is_set() + assert cancelled.is_set() + + +# --------------------------------------------------------------------------- +# error event (protocol-level) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_emits_error_on_protocol_error_message( + ws: MockWebSocket, session: SpeechEngineSession +) -> None: + errors = [] # type: typing.List[Exception] + + async def handler(err: Exception) -> None: + errors.append(err) + + session.on(ERROR, handler) + ws.receive_message({"type": "error", "message": "something went wrong"}) + await _run_until_idle(session, ws) + + assert len(errors) == 1 + assert isinstance(errors[0], Exception) + assert str(errors[0]) == "something went wrong" + + +# --------------------------------------------------------------------------- +# WebSocket-level events +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_emits_disconnected_when_websocket_closes( + ws: MockWebSocket, session: SpeechEngineSession +) -> None: + disconnected = asyncio.Event() + + async def handler() -> None: + disconnected.set() + + session.on(DISCONNECTED, handler) + ws.simulate_disconnect() + await session.run() + + assert disconnected.is_set() + assert not session.is_open + + +@pytest.mark.asyncio +async def test_cancels_handler_when_websocket_closes( + ws: MockWebSocket, session: SpeechEngineSession +) -> None: + cancelled = asyncio.Event() + + async def slow_handler(transcript: typing.List[ConversationMessage]) -> None: + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + cancelled.set() + raise + + session.on(USER_TRANSCRIPT, slow_handler) + ws.receive_message( + {"type": "user_transcript", "user_transcript": TRANSCRIPT, "event_id": 1} + ) + # Give the handler task a moment to start before disconnecting. + await asyncio.sleep(0) + ws.simulate_disconnect() + await session.run() + + # The disconnect triggers _cancel_current in the finally block. + assert cancelled.is_set() + + +@pytest.mark.asyncio +async def test_emits_error_on_malformed_json( + ws: MockWebSocket, session: SpeechEngineSession +) -> None: + errors = [] # type: typing.List[Exception] + + async def handler(err: Exception) -> None: + errors.append(err) + + session.on(ERROR, handler) + ws.receive_raw("not json") + await _run_until_idle(session, ws) + + assert len(errors) == 1 + + +@pytest.mark.asyncio +async def test_ignores_unknown_message_types( + ws: MockWebSocket, session: SpeechEngineSession +) -> None: + errors = [] # type: typing.List[Exception] + + async def handler(err: Exception) -> None: + errors.append(err) + + session.on(ERROR, handler) + ws.receive_message({"type": "unknown_future_event", "data": 123}) + await _run_until_idle(session, ws) + + assert len(errors) == 0 + + +# --------------------------------------------------------------------------- +# sendResponse (string) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_sends_string_response_with_event_id_and_is_final( + ws: MockWebSocket, session: SpeechEngineSession +) -> None: + ws.receive_message( + {"type": "user_transcript", "user_transcript": TRANSCRIPT, "event_id": 5} + ) + + async def handler(transcript: typing.List[ConversationMessage]) -> None: + await session.send_response("The answer is 42") + + session.on(USER_TRANSCRIPT, handler) + await _run_until_idle(session, ws) + + sent = _parsed_sent(ws) + assert len(sent) == 2 + assert sent[0] == { + "type": "agent_response", + "content": "The answer is 42", + "event_id": 5, + "is_final": False, + } + assert sent[1] == { + "type": "agent_response", + "content": "", + "event_id": 5, + "is_final": True, + } + + +@pytest.mark.asyncio +async def test_raises_when_sending_after_close( + ws: MockWebSocket, session: SpeechEngineSession +) -> None: + session.close() + with pytest.raises(RuntimeError, match="session is closed"): + await session.send_response("too late") + + +# --------------------------------------------------------------------------- +# sendResponse (streaming) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_streams_chunks_with_is_final_false_then_terminator( + ws: MockWebSocket, session: SpeechEngineSession +) -> None: + ws.receive_message( + {"type": "user_transcript", "user_transcript": TRANSCRIPT, "event_id": 3} + ) + + async def tokens(): # type: ignore[return] + yield "Hello" + yield " world" + + async def handler(transcript: typing.List[ConversationMessage]) -> None: + await session.send_response(tokens()) + + session.on(USER_TRANSCRIPT, handler) + await _run_until_idle(session, ws) + + sent = _parsed_sent(ws) + assert len(sent) == 3 + assert sent[0] == { + "type": "agent_response", + "content": "Hello", + "event_id": 3, + "is_final": False, + } + assert sent[1] == { + "type": "agent_response", + "content": " world", + "event_id": 3, + "is_final": False, + } + assert sent[2] == { + "type": "agent_response", + "content": "", + "event_id": 3, + "is_final": True, + } + + +@pytest.mark.asyncio +async def test_stops_streaming_when_session_closed_mid_stream( + ws: MockWebSocket, session: SpeechEngineSession +) -> None: + ws.receive_message( + {"type": "user_transcript", "user_transcript": TRANSCRIPT, "event_id": 7} + ) + + async def slow_tokens(): # type: ignore[return] + yield "first" + yield "second" + await asyncio.sleep(0.1) + yield "third" + + async def handler(transcript: typing.List[ConversationMessage]) -> None: + await session.send_response(slow_tokens()) + + session.on(USER_TRANSCRIPT, handler) + + # Start the run loop in the background. + task = asyncio.create_task(session.run()) + # Let the handler task progress through the first two yields. + await asyncio.sleep(0.05) + session.close() + await asyncio.sleep(0.15) + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + sent = _parsed_sent(ws) + chunks = [m for m in sent if m.get("type") == "agent_response"] + # "first" and "second" are sent; session closes before "third". + assert len(chunks) == 2 + assert chunks[0] == { + "type": "agent_response", + "content": "first", + "event_id": 7, + "is_final": False, + } + assert chunks[1] == { + "type": "agent_response", + "content": "second", + "event_id": 7, + "is_final": False, + } + + +# --------------------------------------------------------------------------- +# sendResponse (LLM stream formats) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_extracts_text_from_openai_responses_api( + ws: MockWebSocket, session: SpeechEngineSession +) -> None: + ws.receive_message( + {"type": "user_transcript", "user_transcript": TRANSCRIPT, "event_id": 10} + ) + + async def openai_responses_stream(): # type: ignore[return] + yield {"type": "response.created", "response": {}} + yield {"type": "response.output_text.delta", "delta": "Hello"} + yield {"type": "response.output_text.delta", "delta": " world"} + yield {"type": "response.output_text.done", "text": "Hello world"} + yield {"type": "response.completed", "response": {}} + + async def handler(transcript: typing.List[ConversationMessage]) -> None: + await session.send_response(openai_responses_stream()) + + session.on(USER_TRANSCRIPT, handler) + await _run_until_idle(session, ws) + + sent = _parsed_sent(ws) + assert len(sent) == 3 + assert sent[0] == { + "type": "agent_response", + "content": "Hello", + "event_id": 10, + "is_final": False, + } + assert sent[1] == { + "type": "agent_response", + "content": " world", + "event_id": 10, + "is_final": False, + } + assert sent[2] == { + "type": "agent_response", + "content": "", + "event_id": 10, + "is_final": True, + } + + +@pytest.mark.asyncio +async def test_extracts_text_from_openai_chat_completions_api( + ws: MockWebSocket, session: SpeechEngineSession +) -> None: + ws.receive_message( + {"type": "user_transcript", "user_transcript": TRANSCRIPT, "event_id": 11} + ) + + async def openai_chat_stream(): # type: ignore[return] + yield {"choices": [{"delta": {"role": "assistant"}}]} + yield {"choices": [{"delta": {"content": "Hi"}}]} + yield {"choices": [{"delta": {"content": " there"}}]} + yield {"choices": [{"delta": {}}]} + + async def handler(transcript: typing.List[ConversationMessage]) -> None: + await session.send_response(openai_chat_stream()) + + session.on(USER_TRANSCRIPT, handler) + await _run_until_idle(session, ws) + + sent = _parsed_sent(ws) + assert len(sent) == 3 + assert sent[0] == { + "type": "agent_response", + "content": "Hi", + "event_id": 11, + "is_final": False, + } + assert sent[1] == { + "type": "agent_response", + "content": " there", + "event_id": 11, + "is_final": False, + } + assert sent[2] == { + "type": "agent_response", + "content": "", + "event_id": 11, + "is_final": True, + } + + +@pytest.mark.asyncio +async def test_extracts_text_from_anthropic_messages_api( + ws: MockWebSocket, session: SpeechEngineSession +) -> None: + ws.receive_message( + {"type": "user_transcript", "user_transcript": TRANSCRIPT, "event_id": 12} + ) + + async def anthropic_stream(): # type: ignore[return] + yield {"type": "message_start", "message": {}} + yield {"type": "content_block_start", "content_block": {"type": "text", "text": ""}} + yield {"type": "content_block_delta", "delta": {"type": "text_delta", "text": "Good"}} + yield {"type": "content_block_delta", "delta": {"type": "text_delta", "text": " morning"}} + yield {"type": "content_block_stop"} + yield {"type": "message_stop"} + + async def handler(transcript: typing.List[ConversationMessage]) -> None: + await session.send_response(anthropic_stream()) + + session.on(USER_TRANSCRIPT, handler) + await _run_until_idle(session, ws) + + sent = _parsed_sent(ws) + assert len(sent) == 3 + assert sent[0] == { + "type": "agent_response", + "content": "Good", + "event_id": 12, + "is_final": False, + } + assert sent[1] == { + "type": "agent_response", + "content": " morning", + "event_id": 12, + "is_final": False, + } + assert sent[2] == { + "type": "agent_response", + "content": "", + "event_id": 12, + "is_final": True, + } + + +@pytest.mark.asyncio +async def test_extracts_text_from_gemini_api( + ws: MockWebSocket, session: SpeechEngineSession +) -> None: + ws.receive_message( + {"type": "user_transcript", "user_transcript": TRANSCRIPT, "event_id": 13} + ) + + async def gemini_stream(): # type: ignore[return] + yield {"candidates": [{"content": {"parts": [{"text": "Hey"}], "role": "model"}}]} + yield {"candidates": [{"content": {"parts": [{"text": " buddy"}], "role": "model"}}]} + + async def handler(transcript: typing.List[ConversationMessage]) -> None: + await session.send_response(gemini_stream()) + + session.on(USER_TRANSCRIPT, handler) + await _run_until_idle(session, ws) + + sent = _parsed_sent(ws) + assert len(sent) == 3 + assert sent[0] == { + "type": "agent_response", + "content": "Hey", + "event_id": 13, + "is_final": False, + } + assert sent[1] == { + "type": "agent_response", + "content": " buddy", + "event_id": 13, + "is_final": False, + } + assert sent[2] == { + "type": "agent_response", + "content": "", + "event_id": 13, + "is_final": True, + } + + +@pytest.mark.asyncio +async def test_skips_unrecognized_stream_events( + ws: MockWebSocket, session: SpeechEngineSession +) -> None: + ws.receive_message( + {"type": "user_transcript", "user_transcript": TRANSCRIPT, "event_id": 14} + ) + + async def mixed_stream(): # type: ignore[return] + yield {"type": "unknown_event", "data": 123} + yield {"type": "response.output_text.delta", "delta": "text"} + yield 42 + yield None + + async def handler(transcript: typing.List[ConversationMessage]) -> None: + await session.send_response(mixed_stream()) + + session.on(USER_TRANSCRIPT, handler) + await _run_until_idle(session, ws) + + sent = _parsed_sent(ws) + assert len(sent) == 2 + assert sent[0] == { + "type": "agent_response", + "content": "text", + "event_id": 14, + "is_final": False, + } + assert sent[1] == { + "type": "agent_response", + "content": "", + "event_id": 14, + "is_final": True, + } + + +# --------------------------------------------------------------------------- +# event_id tracking across interrupts +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_stamps_correct_event_id_after_interrupt( + ws: MockWebSocket, session: SpeechEngineSession +) -> None: + ws.receive_message( + {"type": "user_transcript", "user_transcript": TRANSCRIPT, "event_id": 1} + ) + ws.receive_message( + {"type": "user_transcript", "user_transcript": TRANSCRIPT_2, "event_id": 2} + ) + + async def handler(transcript: typing.List[ConversationMessage]) -> None: + # Non-blocking so both handlers complete within the run loop. + await session.send_response("response") + + session.on(USER_TRANSCRIPT, handler) + await _run_until_idle(session, ws) + + sent = _parsed_sent(ws) + # Each sendResponse emits content + terminator = 2 messages. + # The first handler's task may or may not complete before cancellation; + # the second handler's response should use event_id 2. + # Filter to the final handler's messages (event_id=2): + id2 = [m for m in sent if m.get("event_id") == 2] + assert len(id2) == 2 + assert id2[0]["content"] == "response" + assert id2[0]["is_final"] is False + assert id2[1]["content"] == "" + assert id2[1]["is_final"] is True + + +# --------------------------------------------------------------------------- +# close() +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_close_is_idempotent( + ws: MockWebSocket, session: SpeechEngineSession +) -> None: + session.close() + session.close() + assert not session.is_open + + +# --------------------------------------------------------------------------- +# Event constants +# --------------------------------------------------------------------------- + + +def test_event_constants() -> None: + assert INIT == "init" + assert USER_TRANSCRIPT == "user_transcript" + assert CLOSE == "close" + assert ERROR == "error" + assert DISCONNECTED == "disconnected" + + +# --------------------------------------------------------------------------- +# on / off / once +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_off_removes_handler( + ws: MockWebSocket, session: SpeechEngineSession +) -> None: + calls = [] # type: typing.List[str] + + async def handler(conversation_id: str) -> None: + calls.append(conversation_id) + + session.on(INIT, handler) + session.off(INIT, handler) + + ws.receive_message({"type": "init", "conversation_id": "conv_1"}) + await _run_until_idle(session, ws) + + assert calls == [] + + +@pytest.mark.asyncio +async def test_once_fires_only_once( + ws: MockWebSocket, session: SpeechEngineSession +) -> None: + calls = [] # type: typing.List[str] + + async def handler(conversation_id: str) -> None: + calls.append(conversation_id) + + session.once(INIT, handler) + + ws.receive_message({"type": "init", "conversation_id": "conv_1"}) + ws.receive_message({"type": "init", "conversation_id": "conv_2"}) + await _run_until_idle(session, ws) + + assert calls == ["conv_1"] + + +# --------------------------------------------------------------------------- +# Transcript deduplication +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_skips_duplicate_event_id( + ws: MockWebSocket, session: SpeechEngineSession +) -> None: + """Duplicate transcripts with the same event_id are ignored.""" + call_count = 0 + + async def handler(transcript: typing.List[ConversationMessage]) -> None: + nonlocal call_count + call_count += 1 + await asyncio.sleep(0.5) + + session.on(USER_TRANSCRIPT, handler) + ws.receive_message( + {"type": "user_transcript", "user_transcript": TRANSCRIPT, "event_id": 30} + ) + ws.receive_message( + {"type": "user_transcript", "user_transcript": TRANSCRIPT, "event_id": 30} + ) + ws.receive_message( + {"type": "user_transcript", "user_transcript": TRANSCRIPT, "event_id": 30} + ) + await _run_until_idle(session, ws) + + assert call_count == 1 + + +@pytest.mark.asyncio +async def test_does_not_skip_different_event_id( + ws: MockWebSocket, session: SpeechEngineSession +) -> None: + """Different event_ids are processed normally.""" + call_count = 0 + + async def handler(transcript: typing.List[ConversationMessage]) -> None: + nonlocal call_count + call_count += 1 + + session.on(USER_TRANSCRIPT, handler) + ws.receive_message( + {"type": "user_transcript", "user_transcript": TRANSCRIPT, "event_id": 1} + ) + ws.receive_message( + {"type": "user_transcript", "user_transcript": TRANSCRIPT_2, "event_id": 2} + ) + await _run_until_idle(session, ws) + + assert call_count == 2 + + +@pytest.mark.asyncio +async def test_does_not_deduplicate_when_event_id_is_none( + ws: MockWebSocket, session: SpeechEngineSession +) -> None: + """Transcripts without event_id should never be deduplicated.""" + call_count = 0 + + async def handler(transcript: typing.List[ConversationMessage]) -> None: + nonlocal call_count + call_count += 1 + + session.on(USER_TRANSCRIPT, handler) + ws.receive_message( + {"type": "user_transcript", "user_transcript": TRANSCRIPT} + ) + ws.receive_message( + {"type": "user_transcript", "user_transcript": TRANSCRIPT_2} + ) + await _run_until_idle(session, ws) + + assert call_count == 2 + + +# --------------------------------------------------------------------------- +# send_response outside on_transcript +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_send_response_warns_outside_transcript( + ws: MockWebSocket, session: SpeechEngineSession +) -> None: + """send_response before any transcript should warn and not send.""" + await session.send_response("hello") + assert len(ws.sent) == 0