Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
284 changes: 210 additions & 74 deletions src/elevenlabs/conversational_ai/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
from enum import Enum
import json
import logging
import queue
import threading
from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Tuple, Union
import urllib.parse

import websockets
from websockets.exceptions import ConnectionClosedOK
from websockets.exceptions import ConnectionClosedOK, ConnectionClosedError, ConnectionClosed
from websockets.sync.client import Connection, connect

from ..base_client import BaseElevenLabs
Expand Down Expand Up @@ -645,6 +646,9 @@ def __init__(
self._thread = None
self._ws: Optional[Connection] = None
self._should_stop = threading.Event()
self._audio_queue: Optional[queue.Queue] = None
self._audio_sender_thread: Optional[threading.Thread] = None
self._connection_closed = threading.Event()

def start_session(self):
"""Starts the conversation session.
Expand All @@ -657,10 +661,21 @@ def start_session(self):

def end_session(self):
"""Ends the conversation session and cleans up resources."""
self._should_stop.set()
self._connection_closed.set()

self.audio_interface.stop()

if self._audio_sender_thread and self._audio_sender_thread.is_alive():
if self._audio_queue:
try:
self._audio_queue.put(None, timeout=0.1)
except queue.Full:
pass
self._audio_sender_thread.join(timeout=1.0)

self.client_tools.stop()
self._ws = None
self._should_stop.set()

if self.callback_end_session:
self.callback_end_session()
Expand Down Expand Up @@ -737,43 +752,95 @@ def send_contextual_update(self, text: str):
raise

def _run(self, ws_url: str):
with connect(ws_url, max_size=16 * 1024 * 1024) as ws:
self._ws = ws
if self.on_prem_config:
ws.send(self._create_on_prem_initiation_message())
ws.send(self._create_initiation_message())
self._ws = ws

def input_callback(audio):
self._audio_queue = queue.Queue(maxsize=50)
self._connection_closed.clear()

def audio_sender():
"""Dedicated thread for sending audio chunks from the queue."""
while not self._should_stop.is_set() and not self._connection_closed.is_set():
try:
ws.send(
json.dumps(
{
"user_audio_chunk": base64.b64encode(audio).decode(),
}
audio = self._audio_queue.get(timeout=0.1)
if audio is None:
break

if self._ws is None or self._connection_closed.is_set():
logger.debug("Connection closed, dropping audio chunk")
continue

try:
self._ws.send(
json.dumps(
{
"user_audio_chunk": base64.b64encode(audio).decode(),
}
)
)
)
except ConnectionClosedOK:
self.end_session()
except Exception as e:
logger.error(f"Error sending user audio chunk: {e}")
self.end_session()

self.audio_interface.start(input_callback)
while not self._should_stop.is_set():
try:
message = json.loads(ws.recv(timeout=0.5))
if self._should_stop.is_set():
return
self._handle_message(message, ws)
except ConnectionClosedOK as e:
self.end_session()
except TimeoutError:
pass
except (ConnectionClosedOK, ConnectionClosedError, ConnectionClosed) as e:
logger.warning(f"WebSocket connection closed while sending audio: {e}")
self._connection_closed.set()
break
except Exception as e:
logger.warning(f"Error sending audio chunk (will retry): {e}")
try:
self._audio_queue.put_nowait(audio)
except queue.Full:
logger.warning("Audio queue full, dropping chunk")
finally:
self._audio_queue.task_done()
except queue.Empty:
continue
except Exception as e:
logger.error(f"Error receiving message: {e}")
self.end_session()

logger.error(f"Unexpected error in audio sender thread: {e}")
break

# Start audio sender thread
self._audio_sender_thread = threading.Thread(target=audio_sender, daemon=True, name="AudioSender")
self._audio_sender_thread.start()

def input_callback(audio):
"""Callback from audio interface - queues audio for sending."""
if self._should_stop.is_set() or self._connection_closed.is_set():
return
try:
# Non-blocking put - drop if queue is full to prevent blocking audio capture
self._audio_queue.put_nowait(audio)
except queue.Full:
logger.warning("Audio queue full, dropping audio chunk to prevent blocking")

try:
with connect(ws_url, max_size=16 * 1024 * 1024) as ws:
self._ws = ws
if self.on_prem_config:
ws.send(self._create_on_prem_initiation_message())
ws.send(self._create_initiation_message())

self.audio_interface.start(input_callback)

while not self._should_stop.is_set():
try:
message = json.loads(ws.recv(timeout=0.5))
if self._should_stop.is_set():
return
self._handle_message(message, ws)
except ConnectionClosedOK:
logger.info("WebSocket connection closed normally")
self._connection_closed.set()
break
except (ConnectionClosedError, ConnectionClosed) as e:
logger.warning(f"WebSocket connection closed with error: {e}")
self._connection_closed.set()
break
except TimeoutError:
pass
except Exception as e:
logger.error(f"Error receiving message: {e}")
# Don't immediately end session - let it try to recover
# Only end if it's a connection error
if isinstance(e, (ConnectionClosedOK, ConnectionClosedError, ConnectionClosed)):
self._connection_closed.set()
break
finally:
self._connection_closed.set()
self._ws = None

def _handle_message(self, message, ws):
Expand Down Expand Up @@ -911,6 +978,9 @@ def __init__(
self._task = None
self._ws = None
self._should_stop = asyncio.Event()
self._audio_queue: Optional[asyncio.Queue] = None
self._audio_sender_task: Optional[asyncio.Task] = None
self._connection_closed = asyncio.Event()

async def start_session(self):
"""Starts the conversation session.
Expand All @@ -922,10 +992,22 @@ async def start_session(self):

async def end_session(self):
"""Ends the conversation session and cleans up resources."""
self._should_stop.set()
self._connection_closed.set()

# Stop audio interface
await self.audio_interface.stop()

# Cancel audio sender task
if self._audio_sender_task and not self._audio_sender_task.done():
self._audio_sender_task.cancel()
try:
await self._audio_sender_task
except asyncio.CancelledError:
pass

self.client_tools.stop()
self._ws = None
self._should_stop.set()

if self.callback_end_session:
await self.callback_end_session()
Expand Down Expand Up @@ -1002,48 +1084,102 @@ async def send_contextual_update(self, text: str):
raise

async def _run(self, ws_url: str):
async with websockets.connect(ws_url, max_size=16 * 1024 * 1024) as ws:
self._ws = ws
if self.on_prem_config:
await ws.send(self._create_on_prem_initiation_message())
await ws.send(self._create_initiation_message())

async def input_callback(audio):
# Create audio queue with max size to prevent unbounded growth
# Max size of 50 chunks (roughly 12.5 seconds at 250ms per chunk)
self._audio_queue = asyncio.Queue(maxsize=50)
self._connection_closed.clear()

async def audio_sender():
"""Dedicated task for sending audio chunks from the queue."""
while not self._should_stop.is_set() and not self._connection_closed.is_set():
try:
await ws.send(
json.dumps(
{
"user_audio_chunk": base64.b64encode(audio).decode(),
}
# Get audio chunk from queue with timeout
audio = await asyncio.wait_for(self._audio_queue.get(), timeout=0.1)
if audio is None: # Sentinel value to stop
break

# Check if connection is still valid
if self._ws is None or self._connection_closed.is_set():
logger.debug("Connection closed, dropping audio chunk")
continue

try:
await self._ws.send(
json.dumps(
{
"user_audio_chunk": base64.b64encode(audio).decode(),
}
)
)
)
except ConnectionClosedOK:
await self.end_session()
except (ConnectionClosedOK, ConnectionClosedError, ConnectionClosed) as e:
logger.warning(f"WebSocket connection closed while sending audio: {e}")
self._connection_closed.set()
break
except Exception as e:
# Log but don't end session for transient errors
logger.warning(f"Error sending audio chunk (will retry): {e}")
# Put the chunk back at the front if queue isn't full
try:
self._audio_queue.put_nowait(audio)
except asyncio.QueueFull:
logger.warning("Audio queue full, dropping chunk")
except asyncio.TimeoutError:
continue
except Exception as e:
logger.error(f"Error sending user audio chunk: {e}")
await self.end_session()
logger.error(f"Unexpected error in audio sender task: {e}")
break

# Start audio sender task
self._audio_sender_task = asyncio.create_task(audio_sender())

async def input_callback(audio):
"""Callback from audio interface - queues audio for sending."""
if self._should_stop.is_set() or self._connection_closed.is_set():
return
try:
# Non-blocking put - drop if queue is full to prevent blocking audio capture
self._audio_queue.put_nowait(audio)
except asyncio.QueueFull:
logger.warning("Audio queue full, dropping audio chunk to prevent blocking")

await self.audio_interface.start(input_callback)
try:
async with websockets.connect(ws_url, max_size=16 * 1024 * 1024) as ws:
self._ws = ws
if self.on_prem_config:
await ws.send(self._create_on_prem_initiation_message())
await ws.send(self._create_initiation_message())

try:
while not self._should_stop.is_set():
try:
message_str = await asyncio.wait_for(ws.recv(), timeout=0.5)
if self._should_stop.is_set():
return
message = json.loads(message_str)
await self._handle_message(message, ws)
except asyncio.TimeoutError:
pass
except ConnectionClosedOK:
await self.end_session()
break
except Exception as e:
logger.error(f"Error receiving message: {e}")
await self.end_session()
break
finally:
self._ws = None
await self.audio_interface.start(input_callback)

try:
while not self._should_stop.is_set():
try:
message_str = await asyncio.wait_for(ws.recv(), timeout=0.5)
if self._should_stop.is_set():
return
message = json.loads(message_str)
await self._handle_message(message, ws)
except asyncio.TimeoutError:
pass
except ConnectionClosedOK:
logger.info("WebSocket connection closed normally")
self._connection_closed.set()
break
except (ConnectionClosedError, ConnectionClosed) as e:
logger.warning(f"WebSocket connection closed with error: {e}")
self._connection_closed.set()
break
except Exception as e:
logger.error(f"Error receiving message: {e}")
# Don't immediately end session - let it try to recover
# Only end if it's a connection error
if isinstance(e, (ConnectionClosedOK, ConnectionClosedError, ConnectionClosed)):
self._connection_closed.set()
break
finally:
self._ws = None
finally:
self._connection_closed.set()

async def _handle_message(self, message, ws):
class AsyncMessageHandler:
Expand Down