Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions livekit-agents/livekit/agents/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,18 @@ class InterruptionMetrics(_BaseMetrics):
metadata: Metadata | None = None


class AvatarMetrics(_BaseMetrics):
type: Literal["avatar_metrics"] = "avatar_metrics"
timestamp: float
playback_latency: float = 0
"""Delay between forwarding the first audio frame to the avatar and the playback started."""
session_started_time: float | None = None
"""Time when the avatar session was started."""
avatar_joined_time: float | None = None
"""Time when the avatar participant joined and started video track."""
metadata: Metadata | None = None


AgentMetrics = (
STTMetrics
| LLMMetrics
Expand All @@ -189,4 +201,5 @@ class InterruptionMetrics(_BaseMetrics):
| EOUMetrics
| RealtimeModelMetrics
| InterruptionMetrics
| AvatarMetrics
)
10 changes: 10 additions & 0 deletions livekit-agents/livekit/agents/metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ..log import logger as default_logger
from .base import (
AgentMetrics,
AvatarMetrics,
EOUMetrics,
InterruptionMetrics,
LLMMetrics,
Expand Down Expand Up @@ -104,3 +105,12 @@ def log_metrics(metrics: AgentMetrics, *, logger: logging.Logger | None = None)
"num_requests": metrics.num_requests,
},
)
elif isinstance(metrics, AvatarMetrics):
extra: dict[str, str | float] = {}
if metrics.session_started_time and metrics.avatar_joined_time:
extra["avatar_join_latency"] = round(
metrics.avatar_joined_time - metrics.session_started_time, 3
)
if metrics.playback_latency:
extra["playback_latency"] = round(metrics.playback_latency, 3)
logger.info("Avatar metrics", extra=metadata | extra)
86 changes: 79 additions & 7 deletions livekit-agents/livekit/agents/utils/participant.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
from typing import Literal, overload

from livekit import rtc

Expand Down Expand Up @@ -67,23 +68,46 @@ def on_connection_state_changed(state: int) -> None:
room.off("connection_state_changed", on_connection_state_changed)


@overload
async def wait_for_participant(
room: rtc.Room,
*,
identity: str | None = None,
kind: list[rtc.ParticipantKind.ValueType] | rtc.ParticipantKind.ValueType | None = None,
) -> rtc.RemoteParticipant:
include_local: Literal[False] = False,
) -> rtc.RemoteParticipant: ...


@overload
async def wait_for_participant(
room: rtc.Room,
*,
identity: str | None = None,
kind: list[rtc.ParticipantKind.ValueType] | rtc.ParticipantKind.ValueType | None = None,
include_local: Literal[True],
) -> rtc.Participant: ...


async def wait_for_participant(
room: rtc.Room,
*,
identity: str | None = None,
kind: list[rtc.ParticipantKind.ValueType] | rtc.ParticipantKind.ValueType | None = None,
include_local: bool = False,
) -> rtc.Participant:
"""
Returns a participant that matches the given identity. If identity is None, the first
participant that joins the room will be returned.
If the participant has already joined, the function will return immediately.

When `include_local` is True, the local participant is also considered.
"""
if not room.isconnected():
raise RuntimeError("room is not connected")

fut = asyncio.Future[rtc.RemoteParticipant]()
fut = asyncio.Future[rtc.Participant]()

def kind_match(p: rtc.RemoteParticipant) -> bool:
def kind_match(p: rtc.Participant) -> bool:
if kind is None:
return True

Expand All @@ -105,6 +129,11 @@ def _on_connection_state_changed(state: int) -> None:
room.on("connection_state_changed", _on_connection_state_changed)

try:
if include_local:
local = room.local_participant
if (identity is None or local.identity == identity) and kind_match(local):
return local

for p in room.remote_participants.values():
if p.state == rtc.ParticipantState.PARTICIPANT_STATE_ACTIVE:
_on_participant_active(p)
Expand All @@ -117,20 +146,43 @@ def _on_connection_state_changed(state: int) -> None:
room.off("connection_state_changed", _on_connection_state_changed)


@overload
async def wait_for_track_publication(
room: rtc.Room,
*,
identity: str | None = None,
kind: list[rtc.TrackKind.ValueType] | rtc.TrackKind.ValueType | None = None,
) -> rtc.RemoteTrackPublication:
"""Returns a remote track matching the given identity and kind.
include_local: Literal[False] = False,
) -> rtc.RemoteTrackPublication: ...


@overload
async def wait_for_track_publication(
room: rtc.Room,
*,
identity: str | None = None,
kind: list[rtc.TrackKind.ValueType] | rtc.TrackKind.ValueType | None = None,
include_local: Literal[True],
) -> rtc.TrackPublication: ...


async def wait_for_track_publication(
room: rtc.Room,
*,
identity: str | None = None,
kind: list[rtc.TrackKind.ValueType] | rtc.TrackKind.ValueType | None = None,
include_local: bool = False,
) -> rtc.TrackPublication:
"""Returns a track publication matching the given identity and kind.
If identity is None, the first track matching the kind will be returned.
If the track has already been published, the function will return immediately.

When `include_local` is True, tracks published by the local participant are also considered.
"""
if not room.isconnected():
raise RuntimeError("room is not connected")

fut = asyncio.Future[rtc.RemoteTrackPublication]()
fut = asyncio.Future[rtc.TrackPublication]()

def kind_match(k: rtc.TrackKind.ValueType) -> bool:
if kind is None:
Expand All @@ -150,15 +202,33 @@ def _on_track_published(
if (identity is None or participant.identity == identity) and kind_match(publication.kind):
fut.set_result(publication)

def _on_local_track_published(
publication: rtc.LocalTrackPublication, _track: rtc.Track
) -> None:
if fut.done():
return

local = room.local_participant
if (identity is None or local.identity == identity) and kind_match(publication.kind):
fut.set_result(publication)

def _on_connection_state_changed(state: int) -> None:
if state == rtc.ConnectionState.CONN_DISCONNECTED and not fut.done():
fut.set_exception(RuntimeError("room disconnected while waiting for track publication"))

# room.on("track_subscribed", _on_track_subscribed)
room.on("track_published", _on_track_published)
if include_local:
room.on("local_track_published", _on_local_track_published)
room.on("connection_state_changed", _on_connection_state_changed)

try:
if include_local:
local = room.local_participant
if identity is None or local.identity == identity:
for local_publication in local.track_publications.values():
if kind_match(local_publication.kind):
return local_publication

for p in room.remote_participants.values():
for publication in p.track_publications.values():
_on_track_published(publication, p)
Expand All @@ -168,4 +238,6 @@ def _on_connection_state_changed(state: int) -> None:
return await fut
finally:
room.off("track_published", _on_track_published)
if include_local:
room.off("local_track_published", _on_local_track_published)
room.off("connection_state_changed", _on_connection_state_changed)
101 changes: 98 additions & 3 deletions livekit-agents/livekit/agents/voice/avatar/_types.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
from __future__ import annotations

import asyncio
import time
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator, Coroutine
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING, Generic, Literal, TypeVar

from livekit import rtc

from ... import utils
from ...job import get_job_context
from ...log import logger
from ...metrics.base import AvatarMetrics, Metadata
from ..events import ConversationItemAddedEvent, MetricsCollectedEvent

if TYPE_CHECKING:
from ..agent_session import AgentSession
Expand Down Expand Up @@ -55,9 +60,29 @@ def __aiter__(
"""Continuously stream out video and audio frames, or AudioSegmentEnd when the audio segment ends""" # noqa: E501


class AvatarSession:
TEvent = TypeVar("TEvent")


class AvatarSession(ABC, rtc.EventEmitter[Literal["metrics_collected"] | TEvent], Generic[TEvent]):
"""Base class for avatar plugin sessions."""

def __init__(self) -> None:
super().__init__()
self._wait_avatar_join_task: asyncio.Task[None] | None = None
self._room: rtc.Room | None = None
self._agent_session: AgentSession | None = None

@property
@abstractmethod
def avatar_identity(self) -> str:
"""The participant identifier of the avatar"""
...

@property
def provider(self) -> str:
"""The provider of the avatar"""
return "unknown"

async def start(self, agent_session: AgentSession, room: rtc.Room) -> None:
job_ctx = get_job_context(required=False)
if job_ctx is not None:
Expand All @@ -78,5 +103,75 @@ async def start(self, agent_session: AgentSession, room: rtc.Room) -> None:
extra={"audio_output": audio_output.label},
)

self._room = room
self._agent_session = agent_session
self._agent_session.on("conversation_item_added", self._on_conversation_item_added)

if self._room.isconnected():
self._wait_avatar_join_task = asyncio.create_task(self._wait_avatar_join())
else:
self._room.on("connection_state_changed", self._on_connection_state_changed)

async def aclose(self) -> None:
pass
if self._agent_session:
self._agent_session.off("conversation_item_added", self._on_conversation_item_added)
self._agent_session = None

if self._room:
self._room.off("connection_state_changed", self._on_connection_state_changed)
self._room = None

if self._wait_avatar_join_task:
await utils.aio.cancel_and_wait(self._wait_avatar_join_task)
self._wait_avatar_join_task = None

async def _wait_avatar_join(self) -> None:
assert self._room is not None

started_time = time.time()
await utils.wait_for_participant(
room=self._room, identity=self.avatar_identity, include_local=True
)
await utils.wait_for_track_publication(
room=self._room,
identity=self.avatar_identity,
kind=rtc.TrackKind.KIND_VIDEO,
include_local=True,
)
joined_time = time.time()
self._emit_metrics(
AvatarMetrics(
timestamp=joined_time,
session_started_time=started_time,
avatar_joined_time=joined_time,
metadata=Metadata(
model_provider=self.provider,
),
)
)

def _on_conversation_item_added(self, ev: ConversationItemAddedEvent) -> None:
if ev.item.type == "message" and ev.item.role == "assistant":
playback_latency = ev.item.metrics.get("playback_latency")
if playback_latency is not None:
self._emit_metrics(
AvatarMetrics(
timestamp=ev.created_at,
playback_latency=playback_latency,
metadata=Metadata(
model_provider=self.provider,
),
)
)

def _on_connection_state_changed(self, state: rtc.ConnectionState.ValueType) -> None:
assert self._room is not None

if state == rtc.ConnectionState.CONN_CONNECTED and not self._wait_avatar_join_task:
self._wait_avatar_join_task = asyncio.create_task(self._wait_avatar_join())
Comment thread
longcw marked this conversation as resolved.

def _emit_metrics(self, metrics: AvatarMetrics) -> None:
assert self._agent_session is not None

self.emit("metrics_collected", metrics)
self._agent_session.emit("metrics_collected", MetricsCollectedEvent(metrics=metrics))
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
avatar_participant_name: NotGivenOr[str] = NOT_GIVEN,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
) -> None:
super().__init__()
self._http_session: aiohttp.ClientSession | None = None
self._conn_options = conn_options
self.session_id: str | None = None
Expand All @@ -58,6 +59,14 @@ def __init__(
self._api_url = api_url_val
self._api_key = api_key_val

@property
def avatar_identity(self) -> str:
return self._avatar_participant_identity

@property
def provider(self) -> str:
return "anam"

def _ensure_http_session(self) -> aiohttp.ClientSession:
if self._http_session is None:
self._http_session = utils.http_context.http_session()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(
room. Defaults to "avatario-avatar-agent"
conn_options: Connection options for the aiohttp session.
"""
super().__init__()
self._http_session: aiohttp.ClientSession | None = None
self._conn_options = conn_options
video_info = video_info if utils.is_given(video_info) else self.VideoInfo()
Expand Down Expand Up @@ -91,6 +92,14 @@ def __init__(
else _AVATAR_AGENT_NAME
)

@property
def avatar_identity(self) -> str:
return self._avatar_participant_identity

@property
def provider(self) -> str:
return "avatario"

def _ensure_http_session(self) -> aiohttp.ClientSession:
if self._http_session is None:
self._http_session = utils.http_context.http_session()
Expand Down
Loading
Loading