Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions livekit-agents/livekit/agents/inference/stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,8 +535,8 @@ def __init__(
conn_options: APIConnectOptions,
) -> None:
super().__init__(stt=stt, conn_options=conn_options, sample_rate=opts.sample_rate)
self._stt: STT = stt
self._opts = opts
self._session = stt._ensure_session()
self._request_id = str(utils.shortuuid("stt_request_"))

self._speaking = False
Expand Down Expand Up @@ -588,6 +588,7 @@ async def _send_session_update(self, msg: dict[str, Any]) -> None:
async def _run(self) -> None:
"""Main loop for streaming transcription."""
closing_ws = False
http_session = self._stt._ensure_session()

@utils.log_exceptions(logger=logger)
async def send_task(ws: aiohttp.ClientWebSocketResponse) -> None:
Expand Down Expand Up @@ -632,7 +633,7 @@ async def recv_task(ws: aiohttp.ClientWebSocketResponse) -> None:
aiohttp.WSMsgType.CLOSE,
aiohttp.WSMsgType.CLOSING,
):
if closing_ws or self._session.closed:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we have the same pattern in other files as well.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added this back with a local http_session var

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we fix all "or self._session.closed:" usage?

if closing_ws or http_session.closed:
return
raise APIStatusError(
message="LiveKit Inference STT connection closed unexpectedly"
Expand Down Expand Up @@ -669,7 +670,7 @@ async def recv_task(ws: aiohttp.ClientWebSocketResponse) -> None:

ws: aiohttp.ClientWebSocketResponse | None = None
try:
ws = await self._connect_ws()
ws = await self._connect_ws(http_session)
self._ws = ws
tasks = [
asyncio.create_task(send_task(ws)),
Expand All @@ -684,7 +685,9 @@ async def recv_task(ws: aiohttp.ClientWebSocketResponse) -> None:
if ws is not None:
await ws.close()

async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse:
async def _connect_ws(
self, http_session: aiohttp.ClientSession
) -> aiohttp.ClientWebSocketResponse:
"""Connect to the LiveKit Inference STT WebSocket."""
params: dict[str, Any] = {
"settings": {
Expand Down Expand Up @@ -722,7 +725,7 @@ async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse:
}
try:
ws = await asyncio.wait_for(
self._session.ws_connect(
http_session.ws_connect(
f"{base_url}/stt?model={self._opts.model}", headers=headers
),
self._conn_options.timeout,
Expand Down
5 changes: 5 additions & 0 deletions livekit-agents/livekit/agents/utils/http_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,8 @@ async def _close_http_ctx() -> None:
logger.debug("http_session(): closing the httpclient ctx")
await val().close()
_ContextVar.set(None)


def _is_http_session_ctx_set() -> bool:
"""Return True if an http session factory is bound to the current context."""
return _ContextVar.get(None) is not None
22 changes: 22 additions & 0 deletions livekit-agents/livekit/agents/voice/agent_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,7 @@ def __init__(
self._closing_task: asyncio.Task[None] | None = None
self._closing: bool = False
self._job_context_cb_registered: bool = False
self._owned_http_session_ctx: bool = False

self._global_run_state: RunResult | None = None
# TODO(theomonnom): need a better way to expose early assistant metrics
Expand Down Expand Up @@ -624,6 +625,18 @@ async def start(
# configure observability first
record_is_given = is_given(record)
job_ctx = get_job_context(required=False)

# Outside a job context (tests, scripts, ad-hoc usage) there's no
# http session bound to the event loop. Create one scoped to this
# session so STT/TTS can use http_context.http_session()
if (
job_ctx is None
and not self._owned_http_session_ctx
and not utils.http_context._is_http_session_ctx_set()
):
utils.http_context._new_session_ctx()
self._owned_http_session_ctx = True

if not is_given(record):
# defer to server-side setting for recording
record = job_ctx.job.enable_recording if job_ctx else False
Expand Down Expand Up @@ -906,6 +919,11 @@ async def _aclose_impl(

async with self._lock:
if not self._started:
# start() may have set up the http session ctx before failing —
# clean it up so we don't leak the factory on a failed start.
if self._owned_http_session_ctx:
await utils.http_context._close_http_ctx()
self._owned_http_session_ctx = False
return

self._closing = True
Expand Down Expand Up @@ -1008,6 +1026,10 @@ async def _aclose_impl(
await self._room_io.aclose()
self._room_io = None

if self._owned_http_session_ctx:
await utils.http_context._close_http_ctx()
self._owned_http_session_ctx = False

logger.debug("session closed", extra={"reason": reason.value, "error": error})

async def aclose(self) -> None:
Expand Down
3 changes: 2 additions & 1 deletion makefile
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ unit-tests:
tests/test_tool_search.py \
tests/test_tool_proxy.py \
tests/test_endpointing.py \
tests/test_session_host.py
tests/test_session_host.py \
tests/test_http_session_lifecycle.py

# ============================================
# Development Workflows
Expand Down
272 changes: 272 additions & 0 deletions tests/test_http_session_lifecycle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
"""
Tests for AgentSession-owned http_context lifecycle.

When running outside a job context (tests, scripts, ad-hoc usage) there is no
process-level http_session bound to the event loop. AgentSession sets one up in
start() and tears it down in aclose() so that STT/TTS/etc. can call
``utils.http_context.http_session()`` without a job process running.
"""

from __future__ import annotations

import asyncio
from pathlib import Path
from unittest.mock import MagicMock, patch

import aiohttp
import pytest

from livekit.agents import (
NOT_GIVEN,
Agent,
AgentSession,
NotGivenOr,
stt as stt_module,
tts as tts_module,
)
from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions
from livekit.agents.utils import http_context

from .fake_io import FakeAudioInput, FakeAudioOutput, FakeTextOutput
from .fake_llm import FakeLLM
from .fake_vad import FakeVAD

_AGENT_SESSION_MOD = "livekit.agents.voice.agent_session"


class _CapturingSTT(stt_module.STT):
"""STT that records the http session it sees during stream() — no network."""

def __init__(self) -> None:
super().__init__(
capabilities=stt_module.STTCapabilities(streaming=True, interim_results=False),
)
self.captured_session: aiohttp.ClientSession | None = None

async def _recognize_impl(self, *args, **kwargs): # pragma: no cover - unused
raise NotImplementedError

def stream(
self,
*,
language: NotGivenOr[str] = NOT_GIVEN,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
) -> _NoopSTTStream:
# The point of the test: this call must succeed inside an active
# AgentSession, regardless of whether a job context is set.
self.captured_session = http_context.http_session()
return _NoopSTTStream(stt=self, conn_options=conn_options)


class _NoopSTTStream(stt_module.RecognizeStream):
async def _run(self) -> None:
async for _ in self._input_ch:
pass


class _CapturingTTS(tts_module.TTS):
"""TTS that records the http session it sees during synthesize() — no network."""

def __init__(self) -> None:
super().__init__(
capabilities=tts_module.TTSCapabilities(streaming=False),
sample_rate=24000,
num_channels=1,
)
self.captured_session: aiohttp.ClientSession | None = None

def synthesize(
self,
text: str,
*,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
) -> _NoopChunkedStream:
self.captured_session = http_context.http_session()
return _NoopChunkedStream(tts=self, input_text=text, conn_options=conn_options)


class _NoopChunkedStream(tts_module.ChunkedStream):
async def _run(self, output_emitter: tts_module.AudioEmitter) -> None:
output_emitter.initialize(
request_id="noop",
sample_rate=24000,
num_channels=1,
mime_type="audio/pcm",
)
output_emitter.flush()


class _NoopAgent(Agent):
def __init__(self) -> None:
super().__init__(instructions="noop")


def _make_session(
stt: _CapturingSTT | None = None, tts: _CapturingTTS | None = None
) -> AgentSession:
session = AgentSession[None](
vad=FakeVAD(fake_user_speeches=[], min_silence_duration=0.5, min_speech_duration=0.05),
stt=stt or _CapturingSTT(),
llm=FakeLLM(fake_responses=[]),
tts=tts or _CapturingTTS(),
# disable AEC warmup so we don't leak the timer
aec_warmup_duration=None,
)
session.input.audio = FakeAudioInput()
session.output.audio = FakeAudioOutput()
session.output.transcription = FakeTextOutput()
return session


async def test_http_session_available_during_agent_session() -> None:
"""Inside a started AgentSession, http_context.http_session() returns a working session.

After aclose, the context is reset and http_session() raises again.
"""
# Sanity: nothing set in this task before start
with pytest.raises(RuntimeError):
http_context.http_session()

capturing_stt = _CapturingSTT()
session = _make_session(stt=capturing_stt)

await session.start(_NoopAgent())

# The set in start() propagates to this task's context (start awaited here).
sess = http_context.http_session()
assert isinstance(sess, aiohttp.ClientSession)
assert not sess.closed

# The STT.stream() called during activity start sees the same session.
assert capturing_stt.captured_session is sess

await session.aclose()

# After aclose the underlying session is closed and the contextvar is reset.
assert sess.closed
with pytest.raises(RuntimeError):
http_context.http_session()


async def test_concurrent_sessions_in_separate_tasks_are_isolated() -> None:
"""Two AgentSessions started inside their own asyncio.Task each get their own
http session. Closing one does not affect the other.
"""
barrier = asyncio.Event()

async def session_worker() -> tuple[aiohttp.ClientSession, aiohttp.ClientSession]:
capturing_stt = _CapturingSTT()
session = _make_session(stt=capturing_stt)

await session.start(_NoopAgent())
seen = http_context.http_session()
# wait so both tasks are alive simultaneously — proves isolation
await barrier.wait()
# session is still live and accessible from this task's context
still_seen = http_context.http_session()
await session.aclose()
return seen, still_seen

task_a = asyncio.create_task(session_worker())
task_b = asyncio.create_task(session_worker())

# let both reach the barrier
await asyncio.sleep(0.05)
barrier.set()

(a_first, a_second), (b_first, b_second) = await asyncio.gather(task_a, task_b)

# each task sees a stable session before close
assert a_first is a_second
assert b_first is b_second

# tasks see different sessions — not a single global one
assert a_first is not b_first

# both got closed independently
assert a_first.closed
assert b_first.closed


def _mock_job_ctx() -> MagicMock:
"""Build the minimal JobContext mock that AgentSession.start() reads from."""
mock = MagicMock()
mock.job.enable_recording = False
mock.job.id = "test-job-id"
mock.job.agent_name = "test-agent"
mock.room.name = "test-room"
mock._primary_agent_session = None
mock.session_directory = Path("/tmp/test-session")
return mock


async def test_session_does_not_own_http_ctx_inside_job_context(
job_process: None, # fixture sets up http_context for the test
) -> None:
"""When AgentSession runs inside a real job context, it must not overwrite or
close the process-level http_context on aclose.
"""
outer_session = http_context.http_session()
assert not outer_session.closed

session = _make_session()

with patch(f"{_AGENT_SESSION_MOD}.get_job_context", return_value=_mock_job_ctx()):
await session.start(_NoopAgent())

# AgentSession reuses the existing context — same ClientSession surfaces.
assert http_context.http_session() is outer_session

await session.aclose()

# The job-context session is still alive — only the job_process fixture closes it.
assert not outer_session.closed
assert http_context.http_session() is outer_session


async def test_nested_sessions_in_same_task_share_http_ctx() -> None:
"""A second AgentSession started inside a still-running outer session (same
task) must reuse the outer's http session and not close it on aclose.
"""
outer = _make_session()
await outer.start(_NoopAgent())
outer_session = http_context.http_session()
assert outer._owned_http_session_ctx is True

inner = _make_session()
await inner.start(_NoopAgent())

# inner sees the contextvar already set → does not take ownership
assert inner._owned_http_session_ctx is False
assert http_context.http_session() is outer_session

await inner.aclose()

# outer's session is unaffected by inner's close
assert not outer_session.closed
assert http_context.http_session() is outer_session

await outer.aclose()
assert outer_session.closed


async def test_start_failure_cleans_up_http_ctx() -> None:
"""If start() fails after setting up the http session, aclose() must still
clean it up. Otherwise __aexit__ on the async-with would leak the factory.
"""
session = _make_session()

with patch.object(AgentSession, "_update_activity", side_effect=RuntimeError("boom")):
with pytest.raises(BaseException): # noqa: B017,PT011 - want any failure
await session.start(_NoopAgent())

# the session never reached _started=True, but the http_session ctx was set
assert session._started is False
assert session._owned_http_session_ctx is True

await session.aclose()

# aclose must clean up even when start failed
assert session._owned_http_session_ctx is False
with pytest.raises(RuntimeError):
http_context.http_session()
Loading