Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions livekit-agents/livekit/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
InterruptionOptions,
PreemptiveGenerationOptions,
TurnHandlingOptions,
UserStateSource,
)
from .worker import (
AgentServer,
Expand Down Expand Up @@ -234,6 +235,7 @@ def __getattr__(name: str) -> typing.Any:
"AMDCategory",
"AMDResult",
"TurnHandlingOptions",
"UserStateSource",
"EndpointingOptions",
"InterruptionOptions",
"PreemptiveGenerationOptions",
Expand Down
77 changes: 70 additions & 7 deletions livekit-agents/livekit/agents/voice/agent_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
update_instructions,
)
from .speech_handle import DEFAULT_INPUT_DETAILS, InputDetails, SpeechHandle
from .turn import EndpointingOptions, TurnDetectionMode
from .turn import EndpointingOptions, TurnDetectionMode, UserStateSource

if TYPE_CHECKING:
from ..llm import mcp
Expand Down Expand Up @@ -200,6 +200,23 @@ def __init__(self, agent: Agent, sess: AgentSession) -> None:
)
self._turn_detection = self._validate_turn_detection(turn_detection)

# validate + resolve user_state_source locally (don't mutate session)
self._user_state_source: UserStateSource = self._session.user_state_source
if self._user_state_source == "stt" and not self.stt:
logger.warning(
"user_state_source is set to 'stt', but no STT model is provided. "
"STT events will never fire, so user_state will be stuck. "
"Falling back to 'auto'."
)
self._user_state_source = "auto"
elif self._user_state_source == "vad" and not self.vad:
logger.warning(
"user_state_source is set to 'vad', but no VAD model is provided. "
"VAD events will never fire, so user_state will be stuck. "
"Falling back to 'auto'."
)
self._user_state_source = "auto"

self._interruption_detector: inference.AdaptiveInterruptionDetector | None = (
self._resolve_interruption_detection()
)
Expand Down Expand Up @@ -444,6 +461,7 @@ def update_options(
tool_choice: NotGivenOr[llm.ToolChoice | None] = NOT_GIVEN,
endpointing_opts: NotGivenOr[EndpointingOptions] = NOT_GIVEN,
turn_detection: NotGivenOr[TurnDetectionMode | None] = NOT_GIVEN,
user_state_source: NotGivenOr[UserStateSource] = NOT_GIVEN,
# deprecated
min_endpointing_delay: NotGivenOr[float] = NOT_GIVEN,
max_endpointing_delay: NotGivenOr[float] = NOT_GIVEN,
Expand All @@ -469,6 +487,21 @@ def update_options(
if self._rt_session is not None:
self._rt_session.update_options(tool_choice=self._tool_choice)

if utils.is_given(user_state_source):
if user_state_source == "stt" and not self.stt:
logger.warning(
"user_state_source is set to 'stt', but no STT model is provided. "
"Falling back to 'auto'."
)
user_state_source = "auto"
elif user_state_source == "vad" and not self.vad:
logger.warning(
"user_state_source is set to 'vad', but no VAD model is provided. "
"Falling back to 'auto'."
)
user_state_source = "auto"
self._user_state_source = user_state_source

if utils.is_given(turn_detection):
turn_detection = self._validate_turn_detection(turn_detection)

Expand All @@ -490,6 +523,9 @@ def update_options(
if is_given(endpointing_opts)
else NOT_GIVEN,
turn_detection=turn_detection,
user_state_source=self._user_state_source
if utils.is_given(user_state_source)
else NOT_GIVEN,
)

def _create_speech_task(
Expand Down Expand Up @@ -784,6 +820,7 @@ async def _setup_toolset(toolset: llm.Toolset) -> None:
interruption_detection=self._interruption_detector,
endpointing=create_endpointing(self.endpointing_opts),
turn_detection=self._turn_detection,
user_state_source=self._user_state_source,
stt_model=self.stt.model if self.stt else None,
stt_provider=self.stt.provider if self.stt else None,
)
Expand Down Expand Up @@ -1650,12 +1687,27 @@ def _interrupt_by_audio_activity(

# region recognition hooks

def _should_update_user_state(self, from_vad: bool) -> bool:
"""Whether this event source should drive ``user_state`` transitions."""
if self._user_state_source == "vad":
return from_vad
if self._user_state_source == "stt":
return not from_vad
# "auto": VAD drives unless turn_detection is "stt"
if from_vad:
td = self._turn_detection
return not (isinstance(td, str) and td == "stt")
else:
td = self._turn_detection
return isinstance(td, str) and td == "stt"

def on_start_of_speech(
self,
ev: vad.VADEvent | None,
speech_start_time: float,
) -> None:
self._session._update_user_state("speaking", last_speaking_time=speech_start_time)
if self._should_update_user_state(from_vad=ev is not None):
self._session._update_user_state("speaking", last_speaking_time=speech_start_time)
if self._audio_recognition:
self._audio_recognition.on_start_of_speech(
started_at=speech_start_time,
Expand Down Expand Up @@ -1686,7 +1738,7 @@ def on_start_of_speech(
self._update_paused_speech(current_speech, timeout=0)
audio_output.pause()

def on_end_of_speech(self, ev: vad.VADEvent | None) -> None:
def on_end_of_speech(self, ev: vad.VADEvent | None, *, force: bool = False) -> None:
speech_end_time = time.time()
if ev:
speech_end_time = speech_end_time - ev.silence_duration - ev.inference_duration
Expand All @@ -1702,15 +1754,26 @@ def on_end_of_speech(self, ev: vad.VADEvent | None) -> None:
else NOT_GIVEN,
)

self._session._update_user_state(
"listening",
last_speaking_time=speech_end_time,
)
if force or self._should_update_user_state(from_vad=ev is not None):
self._session._update_user_state(
"listening",
last_speaking_time=speech_end_time,
)
self._user_silence_event.set()
Comment thread
MdSadiqMd marked this conversation as resolved.

if self._paused_speech:
self._start_false_interruption_timer(self._paused_speech.timeout)

def on_stt_speech_started(self) -> None:
"""Notify that STT detected speech start (without triggering endpointing)."""
self._stt_eos_received = False
self._user_silence_event.clear()

def on_stt_speech_ended(self) -> None:
"""Notify that STT detected speech end (without triggering endpointing)."""
self._stt_eos_received = True
self._user_silence_event.set()

def on_vad_inference_done(self, ev: vad.VADEvent) -> None:
if self._turn_detection in ("manual", "realtime_llm"):
# ignore vad inference done event if turn_detection is manual or realtime_llm
Expand Down
18 changes: 17 additions & 1 deletion livekit-agents/livekit/agents/voice/agent_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
PreemptiveGenerationOptions,
TurnDetectionMode,
TurnHandlingOptions,
UserStateSource,
_migrate_turn_handling,
_resolve_endpointing,
_resolve_interruption,
Expand Down Expand Up @@ -354,6 +355,7 @@ def __init__(
interruption = _resolve_interruption(turn_handling.get("interruption"))
preemptive_gen = _resolve_preemptive_generation(turn_handling.get("preemptive_generation"))
raw_turn_detection = turn_handling.get("turn_detection", None)
raw_user_state_source: UserStateSource = turn_handling.get("user_state_source", "auto")

# This is the "global" chat_context, it holds the entire conversation history
self._chat_ctx = ChatContext.empty()
Expand All @@ -362,6 +364,7 @@ def __init__(
endpointing=endpointing,
interruption=interruption,
turn_detection=raw_turn_detection,
user_state_source=raw_user_state_source,
preemptive_generation=preemptive_gen,
),
max_tool_steps=max_tool_steps,
Expand Down Expand Up @@ -396,6 +399,7 @@ def __init__(
self._llm = llm or None
self._tts = tts or None
self._turn_detection = raw_turn_detection
self._user_state_source: UserStateSource = raw_user_state_source
self._interruption_detection = interruption.get("mode", NOT_GIVEN)
self._mcp_servers = mcp_servers or None
self._tools = tools if is_given(tools) else []
Expand Down Expand Up @@ -496,6 +500,10 @@ def userdata(self, value: Userdata_T) -> None:
def turn_detection(self) -> TurnDetectionMode | None:
return self._turn_detection

@property
def user_state_source(self) -> UserStateSource:
return self._user_state_source

@property
def mcp_servers(self) -> list[mcp.MCPServer] | None:
return self._mcp_servers
Expand Down Expand Up @@ -1018,6 +1026,7 @@ def update_options(
*,
endpointing_opts: NotGivenOr[EndpointingOptions] = NOT_GIVEN,
turn_detection: NotGivenOr[TurnDetectionMode | None] = NOT_GIVEN,
user_state_source: NotGivenOr[UserStateSource] = NOT_GIVEN,
# deprecated
min_endpointing_delay: NotGivenOr[float] = NOT_GIVEN,
max_endpointing_delay: NotGivenOr[float] = NOT_GIVEN,
Expand All @@ -1029,6 +1038,8 @@ def update_options(
endpointing_opts (NotGivenOr[EndpointingOptions], optional): Endpointing options.
turn_detection (NotGivenOr[TurnDetectionMode | None], optional): Strategy for deciding
when the user has finished speaking. ``None`` reverts to automatic selection.
user_state_source (NotGivenOr[UserStateSource], optional): Which signal drives the
user_state machine. ``"vad"`` for VAD, ``"stt"`` for STT, ``"auto"`` for default.
min_endpointing_delay: Deprecated, use ``endpointing_opts`` instead.
max_endpointing_delay: Deprecated, use ``endpointing_opts`` instead.
"""
Expand Down Expand Up @@ -1064,12 +1075,17 @@ def update_options(
if is_given(turn_detection):
self._turn_detection = turn_detection

if is_given(user_state_source):
self._user_state_source = user_state_source
self._opts.turn_handling["user_state_source"] = user_state_source

if self._activity is not None:
self._activity.update_options(
endpointing_opts=(
self._opts.endpointing if is_given(endpointing_opts) else NOT_GIVEN
),
turn_detection=turn_detection,
user_state_source=user_state_source,
)

async def _start_ivr_detection(self, transcript: str | None = None) -> None:
Expand Down Expand Up @@ -1569,7 +1585,7 @@ def _on_audio_enabled_changed(self, enabled: bool) -> None:
"""End user speaking state when audio is disabled by default."""
if not enabled and self._user_state == "speaking":
if self._activity is not None:
self._activity.on_end_of_speech(None)
self._activity.on_end_of_speech(None, force=True)
else:
self._update_user_state("listening")

Expand Down
Loading