diff --git a/.changeset/wip_support_for_large_rpc_messages_using_data_streams.md b/.changeset/wip_support_for_large_rpc_messages_using_data_streams.md new file mode 100644 index 000000000..64807aeee --- /dev/null +++ b/.changeset/wip_support_for_large_rpc_messages_using_data_streams.md @@ -0,0 +1,7 @@ +--- +livekit: patch +livekit-api: patch +livekit-ffi: patch +--- + +Support for large RPC messages using data streams - #1013 (@1egoman) diff --git a/livekit-api/src/signal_client/mod.rs b/livekit-api/src/signal_client/mod.rs index 689621103..024f33fba 100644 --- a/livekit-api/src/signal_client/mod.rs +++ b/livekit-api/src/signal_client/mod.rs @@ -54,6 +54,16 @@ const REGION_FETCH_TIMEOUT: Duration = Duration::from_secs(3); const VALIDATE_TIMEOUT: Duration = Duration::from_secs(3); pub const PROTOCOL_VERSION: u32 = 17; +/// Default value for `ClientInfo.client_protocol` when a participant has not +/// advertised one (treat as v1-only / no data-stream RPC support). +pub const CLIENT_PROTOCOL_DEFAULT: i32 = 0; +/// `ClientInfo.client_protocol` value indicating support for RPC v2 over data streams. +pub const CLIENT_PROTOCOL_DATA_STREAM_RPC: i32 = 1; + +/// The client protocol which is sent to other clients and indicates the set of apis that other +/// clients should assume this client supports. +const CLIENT_PROTOCOL_VERSION: i32 = CLIENT_PROTOCOL_DATA_STREAM_RPC; + #[derive(Error, Debug)] pub enum SignalError { #[error("ws failure: {0}")] @@ -571,6 +581,7 @@ fn create_join_request_param( os, os_version, device_model, + client_protocol: CLIENT_PROTOCOL_VERSION, ..Default::default() }; @@ -667,6 +678,7 @@ fn get_livekit_url( .append_pair("os_version", os_info.version().to_string().as_str()) .append_pair("device_model", device_model.to_string().as_str()) .append_pair("protocol", PROTOCOL_VERSION.to_string().as_str()) + .append_pair("client_protocol", CLIENT_PROTOCOL_VERSION.to_string().as_str()) .append_pair("auto_subscribe", if options.auto_subscribe { "1" } else { "0" }) .append_pair("adaptive_stream", if options.adaptive_stream { "1" } else { "0" }); diff --git a/livekit/src/room/data_stream/incoming.rs b/livekit/src/room/data_stream/incoming.rs index d9b9ecf3b..e0782f96a 100644 --- a/livekit/src/room/data_stream/incoming.rs +++ b/livekit/src/room/data_stream/incoming.rs @@ -148,6 +148,17 @@ impl Stream for ByteStreamReader { } } +#[cfg(test)] +impl TextStreamReader { + /// Create a TextStreamReader for testing purposes. + pub(crate) fn new_for_test( + info: TextStreamInfo, + chunk_rx: UnboundedReceiver>, + ) -> Self { + Self { info, chunk_rx } + } +} + impl StreamReader for TextStreamReader { type Output = String; type Info = TextStreamInfo; diff --git a/livekit/src/room/mod.rs b/livekit/src/room/mod.rs index 0035237b6..419e1c43f 100644 --- a/livekit/src/room/mod.rs +++ b/livekit/src/room/mod.rs @@ -23,7 +23,10 @@ use libwebrtc::{ rtp_transceiver::RtpTransceiver, RtcError, }; -use livekit_api::signal_client::{SignalOptions, SignalSdkOptions, SIGNAL_CONNECT_TIMEOUT}; +use livekit_api::signal_client::{ + SignalOptions, SignalSdkOptions, CLIENT_PROTOCOL_DATA_STREAM_RPC, CLIENT_PROTOCOL_DEFAULT, + SIGNAL_CONNECT_TIMEOUT, +}; use livekit_datatrack::{ api::{DataTrackSid, RemoteDataTrack}, backend as dt, @@ -67,6 +70,7 @@ pub mod id; pub mod options; pub mod participant; pub mod publication; +pub mod rpc; pub mod track; pub(crate) mod utils; @@ -329,30 +333,6 @@ pub struct ChatMessage { pub generated: Option, } -#[derive(Debug, Clone)] -pub struct RpcRequest { - pub destination_identity: String, - pub id: String, - pub method: String, - pub payload: String, - pub response_timeout: Duration, - pub version: u32, -} - -#[derive(Debug, Clone)] -pub struct RpcResponse { - destination_identity: String, - request_id: String, - payload: Option, - error: Option, -} - -#[derive(Debug, Clone)] -pub struct RpcAck { - destination_identity: String, - request_id: String, -} - #[derive(Debug, Clone)] #[non_exhaustive] pub struct RoomSdkOptions { @@ -473,9 +453,11 @@ pub(crate) struct RoomSession { remote_participants: RwLock>, e2ee_manager: E2eeManager, incoming_stream_manager: IncomingStreamManager, - outgoing_stream_manager: OutgoingStreamManager, + pub(crate) outgoing_stream_manager: OutgoingStreamManager, local_dt_input: dt::local::ManagerInput, remote_dt_input: dt::remote::ManagerInput, + pub(crate) rpc_client: rpc::RpcClientManager, + pub(crate) rpc_server: rpc::RpcServerManager, handle: AsyncMutex>, } @@ -554,6 +536,7 @@ impl Room { pi.joined_at_ms, e2ee_manager.encryption_type(), pi.permission, + pi.client_protocol, ); let dispatcher = Dispatcher::::default(); @@ -688,6 +671,8 @@ impl Room { outgoing_stream_manager, local_dt_input, remote_dt_input, + rpc_client: rpc::RpcClientManager::new(), + rpc_server: rpc::RpcServerManager::new(), handle: Default::default(), }); inner.local_participant.set_session(Arc::downgrade(&inner)); @@ -733,6 +718,7 @@ impl Room { pi.attributes, pi.joined_at_ms, pi.permission, + pi.client_protocol, ) }; participant.update_info(pi.clone()); @@ -757,6 +743,7 @@ impl Room { open_rx, dispatcher.clone(), close_rx.resubscribe(), + inner.clone(), )); let outgoing_stream_handle = livekit_runtime::spawn(outgoing_data_stream_task( packet_rx, @@ -985,25 +972,31 @@ impl RoomSession { log::warn!("Received RPC request with null caller identity"); return Ok(()); } - let local_participant = self.local_participant.clone(); + let session = self.clone(); + let caller = caller_identity.unwrap(); livekit_runtime::spawn(async move { - local_participant - .handle_incoming_rpc_request( - caller_identity.unwrap(), - request_id, - method, - payload, - response_timeout, - version, + let transport = rpc::SessionTransport(session.clone()); + session + .rpc_server + .handle_request( + rpc::HandleRequestOptions { + caller_identity: caller, + request_id, + method, + payload, + response_timeout, + version, + }, + &transport, ) .await; }); } EngineEvent::RpcResponse { request_id, payload, error } => { - self.local_participant.handle_incoming_rpc_response(request_id, payload, error); + self.rpc_client.handle_response(request_id, payload, error); } EngineEvent::RpcAck { request_id } => { - self.local_participant.handle_incoming_rpc_ack(request_id); + self.rpc_client.handle_ack(request_id); } EngineEvent::SpeakersChanged { speakers } => self.handle_speakers_changed(speakers), EngineEvent::ConnectionQuality { updates } => { @@ -1143,6 +1136,7 @@ impl RoomSession { pi.attributes, pi.joined_at_ms, pi.permission, + pi.client_protocol, ) }; @@ -1828,6 +1822,7 @@ impl RoomSession { attributes: HashMap, joined_at: i64, permission: Option, + client_protocol: i32, ) -> RemoteParticipant { let participant = RemoteParticipant::new( self.rtc_engine.clone(), @@ -1842,6 +1837,7 @@ impl RoomSession { joined_at, self.options.auto_subscribe, permission, + client_protocol, ); participant.on_track_published({ @@ -1984,6 +1980,14 @@ impl RoomSession { self.remote_participants.read().get(identity).cloned() } + pub(crate) fn get_remote_client_protocol(&self, identity: &ParticipantIdentity) -> i32 { + self.remote_participants + .read() + .get(identity) + .map(|p| p.client_protocol()) + .unwrap_or(CLIENT_PROTOCOL_DEFAULT) + } + fn get_local_or_remote_participant( &self, identity: &ParticipantIdentity, @@ -2053,10 +2057,14 @@ impl RoomSession { } /// Receives stream readers for newly-opened streams and dispatches room events. +/// +/// Intercepts text streams on RPC topics (`lk.rpc_request`, `lk.rpc_response`) +/// and routes them to the RPC managers instead of emitting them as room events. async fn incoming_data_stream_task( mut open_rx: UnboundedReceiver<(AnyStreamReader, String)>, dispatcher: Dispatcher, mut close_rx: broadcast::Receiver<()>, + session: Arc, ) { loop { tokio::select! { @@ -2067,11 +2075,36 @@ async fn incoming_data_stream_task( reader: TakeCell::new(reader), participant_identity: ParticipantIdentity(identity) }), - AnyStreamReader::Text(reader) => dispatcher.dispatch(&RoomEvent::TextStreamOpened { - topic: reader.info().topic.clone(), - reader: TakeCell::new(reader), - participant_identity: ParticipantIdentity(identity) - }), + AnyStreamReader::Text(reader) => { + let topic = reader.info().topic.clone(); + match topic.as_str() { + rpc::RPC_REQUEST_TOPIC => { + let caller_identity = ParticipantIdentity(identity); + let session = session.clone(); + livekit_runtime::spawn(async move { + let transport = rpc::SessionTransport(session.clone()); + session.rpc_server.handle_request_stream( + reader, + caller_identity, + &transport, + ).await; + }); + } + rpc::RPC_RESPONSE_TOPIC => { + let session = session.clone(); + livekit_runtime::spawn(async move { + session.rpc_client.handle_response_stream(reader).await; + }); + } + _ => { + dispatcher.dispatch(&RoomEvent::TextStreamOpened { + topic, + reader: TakeCell::new(reader), + participant_identity: ParticipantIdentity(identity) + }); + } + } + } } }, _ = close_rx.recv() => { diff --git a/livekit/src/room/participant/local_participant.rs b/livekit/src/room/participant/local_participant.rs index 1053abde5..28e21acf0 100644 --- a/livekit/src/room/participant/local_participant.rs +++ b/livekit/src/room/participant/local_participant.rs @@ -35,10 +35,10 @@ use crate::{ e2ee::EncryptionType, options::{self, compute_video_encodings, video_layers_from_encodings, TrackPublishOptions}, prelude::*, - room::participant::rpc::{RpcError, RpcErrorCode, RpcInvocationData, MAX_PAYLOAD_BYTES}, + room::rpc::{RpcError, RpcErrorCode, RpcInvocationData}, rtc_engine::lk_runtime::LkRuntime, rtc_engine::{EngineError, RtcEngine}, - ChatMessage, DataPacket, RoomSession, RpcAck, RpcRequest, RpcResponse, SipDTMF, Transcription, + ChatMessage, DataPacket, RoomSession, SipDTMF, Transcription, }; use chrono::Utc; use libwebrtc::{ @@ -51,14 +51,6 @@ use livekit_protocol as proto; use livekit_runtime::timeout; use parking_lot::{Mutex, RwLock}; use proto::request_response::Reason; -use semver::Version; -use tokio::sync::oneshot; - -type RpcHandler = Arc< - dyn Fn(RpcInvocationData) -> Pin> + Send>> - + Send - + Sync, ->; const REQUEST_TIMEOUT: Duration = Duration::from_secs(5); @@ -71,25 +63,9 @@ struct LocalEvents { local_track_unpublished: Mutex>, } -struct RpcState { - pending_acks: HashMap>, - pending_responses: HashMap>>, - handlers: HashMap, -} - -impl RpcState { - fn new() -> Self { - Self { - pending_acks: HashMap::new(), - pending_responses: HashMap::new(), - handlers: HashMap::new(), - } - } -} struct LocalInfo { events: LocalEvents, encryption_type: EncryptionType, - rpc_state: Mutex, all_participants_allowed: Mutex, track_permissions: Mutex>, session: RwLock>>, @@ -126,6 +102,7 @@ impl LocalParticipant { joined_at: i64, encryption_type: EncryptionType, permission: Option, + client_protocol: i32, ) -> Self { Self { inner: super::new_inner( @@ -140,11 +117,11 @@ impl LocalParticipant { kind_details, joined_at, permission, + client_protocol, ), local: Arc::new(LocalInfo { events: LocalEvents::default(), encryption_type, - rpc_state: Mutex::new(RpcState::new()), all_participants_allowed: Mutex::new(true), track_permissions: Mutex::new(vec![]), session: Default::default(), @@ -682,76 +659,6 @@ impl LocalParticipant { .map_err(Into::into) } - async fn publish_rpc_request(&self, rpc_request: RpcRequest) -> RoomResult<()> { - let destination_identities = vec![rpc_request.destination_identity]; - let rpc_request_message = proto::RpcRequest { - id: rpc_request.id, - method: rpc_request.method, - payload: rpc_request.payload, - response_timeout_ms: rpc_request.response_timeout.as_millis() as u32, - version: rpc_request.version, - ..Default::default() - }; - - let data = proto::DataPacket { - value: Some(proto::data_packet::Value::RpcRequest(rpc_request_message)), - destination_identities, - ..Default::default() - }; - - self.inner - .rtc_engine - .publish_data(data, DataPacketKind::Reliable, false) - .await - .map_err(Into::into) - } - - async fn publish_rpc_response(&self, rpc_response: RpcResponse) -> RoomResult<()> { - let destination_identities = vec![rpc_response.destination_identity]; - let rpc_response_message = proto::RpcResponse { - request_id: rpc_response.request_id, - value: Some(match rpc_response.error { - Some(error) => proto::rpc_response::Value::Error(proto::RpcError { - code: error.code, - message: error.message, - data: error.data, - }), - None => proto::rpc_response::Value::Payload(rpc_response.payload.unwrap()), - }), - ..Default::default() - }; - - let data = proto::DataPacket { - value: Some(proto::data_packet::Value::RpcResponse(rpc_response_message)), - destination_identities: destination_identities.clone(), - ..Default::default() - }; - - self.inner - .rtc_engine - .publish_data(data, DataPacketKind::Reliable, false) - .await - .map_err(Into::into) - } - - async fn publish_rpc_ack(&self, rpc_ack: RpcAck) -> RoomResult<()> { - let destination_identities = vec![rpc_ack.destination_identity]; - let rpc_ack_message = - proto::RpcAck { request_id: rpc_ack.request_id, ..Default::default() }; - - let data = proto::DataPacket { - value: Some(proto::data_packet::Value::RpcAck(rpc_ack_message)), - destination_identities: destination_identities.clone(), - ..Default::default() - }; - - self.inner - .rtc_engine - .publish_data(data, DataPacketKind::Reliable, false) - .await - .map_err(Into::into) - } - pub(crate) async fn update_track_subscription_permissions(&self) { let all_participants_allowed = *self.local.all_participants_allowed.lock(); let track_permissions = self @@ -855,100 +762,16 @@ impl LocalParticipant { self.inner.info.read().permission.clone() } - pub async fn perform_rpc(&self, data: PerformRpcData) -> Result { - // Maximum amount of time it should ever take for an RPC request to reach the destination, and the ACK to come back - // This is set to 7 seconds to account for various relay timeouts and retries in LiveKit Cloud that occur in rare cases - - let max_round_trip_latency = Duration::from_millis(7000); - let min_effective_timeout = Duration::from_millis(1000); - - if data.payload.len() > MAX_PAYLOAD_BYTES { - return Err(RpcError::built_in(RpcErrorCode::RequestPayloadTooLarge, None)); - } - - if let Some(server_info) = - self.inner.rtc_engine.session().signal_client().join_response().server_info - { - if !server_info.version.is_empty() { - let server_version = Version::parse(&server_info.version).unwrap(); - let min_required_version = Version::parse("1.8.0").unwrap(); - if server_version < min_required_version { - return Err(RpcError::built_in(RpcErrorCode::UnsupportedServer, None)); - } - } - } - - let id = create_random_uuid(); - let (ack_tx, ack_rx) = oneshot::channel(); - let (response_tx, response_rx) = oneshot::channel(); - let effective_timeout = std::cmp::max( - data.response_timeout.saturating_sub(max_round_trip_latency), - min_effective_timeout, - ); - - // Register channels BEFORE sending the request to avoid race condition - // where the response arrives before we've registered the handlers - { - let mut rpc_state = self.local.rpc_state.lock(); - rpc_state.pending_acks.insert(id.clone(), ack_tx); - rpc_state.pending_responses.insert(id.clone(), response_tx); - } - - if let Err(e) = self - .publish_rpc_request(RpcRequest { - destination_identity: data.destination_identity.clone(), - id: id.clone(), - method: data.method.clone(), - payload: data.payload.clone(), - response_timeout: effective_timeout, - version: 1, - }) - .await - { - // Clean up on failure - let mut rpc_state = self.local.rpc_state.lock(); - rpc_state.pending_acks.remove(&id); - rpc_state.pending_responses.remove(&id); - log::error!("Failed to publish RPC request: {}", e); - return Err(RpcError::built_in(RpcErrorCode::SendFailed, Some(e.to_string()))); - } - - // Wait for ack timeout - match tokio::time::timeout(max_round_trip_latency, ack_rx).await { - Err(_) => { - let mut rpc_state = self.local.rpc_state.lock(); - rpc_state.pending_acks.remove(&id); - rpc_state.pending_responses.remove(&id); - return Err(RpcError::built_in(RpcErrorCode::ConnectionTimeout, None)); - } - Ok(_) => { - // Ack received, continue to wait for response - } - } - - // Wait for response timout - let response = match tokio::time::timeout(data.response_timeout, response_rx).await { - Err(_) => { - self.local.rpc_state.lock().pending_responses.remove(&id); - return Err(RpcError::built_in(RpcErrorCode::ResponseTimeout, None)); - } - Ok(result) => result, - }; + pub fn client_protocol(&self) -> i32 { + self.inner.info.read().client_protocol + } - match response { - Err(_) => { - // Something went wrong locally - Err(RpcError::built_in(RpcErrorCode::RecipientDisconnected, None)) - } - Ok(Err(e)) => { - // RPC error from remote, forward it - Err(e) - } - Ok(Ok(payload)) => { - // Successful response - Ok(payload) - } - } + pub async fn perform_rpc(&self, data: PerformRpcData) -> Result { + let session = self.session().ok_or_else(|| { + RpcError::built_in(RpcErrorCode::SendFailed, Some("Not connected".to_string())) + })?; + let transport = crate::room::rpc::SessionTransport(session.clone()); + session.rpc_client.perform_rpc(data, &transport).await } pub fn register_rpc_method( @@ -959,7 +782,9 @@ impl LocalParticipant { + Sync + 'static, ) { - self.local.rpc_state.lock().handlers.insert(method, Arc::new(handler)); + if let Some(session) = self.session() { + session.rpc_server.register_method(method, handler); + } // Pre-connect the publisher PC so ACKs can be sent immediately when requests arrive. // Without this, the first RPC request would trigger publisher negotiation, causing @@ -968,104 +793,8 @@ impl LocalParticipant { } pub fn unregister_rpc_method(&self, method: String) { - self.local.rpc_state.lock().handlers.remove(&method); - } - - pub(crate) fn handle_incoming_rpc_ack(&self, request_id: String) { - let mut rpc_state = self.local.rpc_state.lock(); - if let Some(tx) = rpc_state.pending_acks.remove(&request_id) { - let _ = tx.send(()); - } else { - log::error!("Ack received for unexpected RPC request: {}", request_id); - } - } - - pub(crate) fn handle_incoming_rpc_response( - &self, - request_id: String, - payload: Option, - error: Option, - ) { - let mut rpc_state = self.local.rpc_state.lock(); - if let Some(tx) = rpc_state.pending_responses.remove(&request_id) { - let _ = tx.send(match error { - Some(e) => Err(RpcError::from_proto(e)), - None => Ok(payload.unwrap_or_default()), - }); - } else { - log::error!("Response received for unexpected RPC request: {}", request_id); - } - } - - pub(crate) async fn handle_incoming_rpc_request( - &self, - caller_identity: ParticipantIdentity, - request_id: String, - method: String, - payload: String, - response_timeout: Duration, - version: u32, - ) { - if let Err(e) = self - .publish_rpc_ack(RpcAck { - destination_identity: caller_identity.to_string(), - request_id: request_id.clone(), - }) - .await - { - log::error!("Failed to publish RPC ACK: {:?}", e); - } - - let caller_identity_2 = caller_identity.clone(); - let request_id_2 = request_id.clone(); - - let response = if version != 1 { - Err(RpcError::built_in(RpcErrorCode::UnsupportedVersion, None)) - } else { - let handler = self.local.rpc_state.lock().handlers.get(&method).cloned(); - - match handler { - Some(handler) => { - match tokio::task::spawn(async move { - handler(RpcInvocationData { - request_id: request_id.clone(), - caller_identity: caller_identity.clone(), - payload: payload.clone(), - response_timeout, - }) - .await - }) - .await - { - Ok(result) => result, - Err(e) => { - log::error!("RPC method handler returned an error: {:?}", e); - Err(RpcError::built_in(RpcErrorCode::ApplicationError, None)) - } - } - } - None => Err(RpcError::built_in(RpcErrorCode::UnsupportedMethod, None)), - } - }; - - let (payload, error) = match response { - Ok(response_payload) if response_payload.len() <= MAX_PAYLOAD_BYTES => { - (Some(response_payload), None) - } - Ok(_) => (None, Some(RpcError::built_in(RpcErrorCode::ResponsePayloadTooLarge, None))), - Err(e) => (None, Some(e.into())), - }; - - if let Err(e) = self - .publish_rpc_response(RpcResponse { - destination_identity: caller_identity_2.to_string(), - request_id: request_id_2, - payload, - error: error.map(|e| e.to_proto()), - }) - .await - { - log::error!("Failed to publish RPC response: {:?}", e); + if let Some(session) = self.session() { + session.rpc_server.unregister_method(&method); } } diff --git a/livekit/src/room/participant/mod.rs b/livekit/src/room/participant/mod.rs index c3a52adfa..fea11090b 100644 --- a/livekit/src/room/participant/mod.rs +++ b/livekit/src/room/participant/mod.rs @@ -145,6 +145,7 @@ struct ParticipantInfo { pub disconnect_reason: DisconnectReason, pub joined_at: i64, pub permission: Option, + pub client_protocol: i32, } type TrackMutedHandler = Box; @@ -195,6 +196,7 @@ pub(super) fn new_inner( kind_details: Vec, joined_at: i64, permission: Option, + client_protocol: i32, ) -> Arc { Arc::new(ParticipantInner { rtc_engine, @@ -213,6 +215,7 @@ pub(super) fn new_inner( disconnect_reason: DisconnectReason::UnknownReason, joined_at, permission, + client_protocol, }), track_publications: Default::default(), events: Default::default(), @@ -264,6 +267,8 @@ pub(super) fn update_info( cb(participant.clone(), new_info.permission.clone()); } } + + info.client_protocol = new_info.client_protocol; } pub(super) fn set_speaking( diff --git a/livekit/src/room/participant/remote_participant.rs b/livekit/src/room/participant/remote_participant.rs index eea7e5672..b19882a57 100644 --- a/livekit/src/room/participant/remote_participant.rs +++ b/livekit/src/room/participant/remote_participant.rs @@ -89,6 +89,7 @@ impl RemoteParticipant { joined_at: i64, auto_subscribe: bool, permission: Option, + client_protocol: i32, ) -> Self { Self { inner: super::new_inner( @@ -103,6 +104,7 @@ impl RemoteParticipant { kind_details, joined_at, permission, + client_protocol, ), remote: Arc::new(RemoteInfo { events: Default::default(), auto_subscribe }), } @@ -575,6 +577,10 @@ impl RemoteParticipant { self.inner.info.read().permission.clone() } + pub fn client_protocol(&self) -> i32 { + self.inner.info.read().client_protocol + } + pub fn is_encrypted(&self) -> bool { *self.inner.is_encrypted.read() } diff --git a/livekit/src/room/participant/rpc.rs b/livekit/src/room/participant/rpc.rs index b04691dda..28c1d4463 100644 --- a/livekit/src/room/participant/rpc.rs +++ b/livekit/src/room/participant/rpc.rs @@ -12,157 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::room::participant::ParticipantIdentity; -use livekit_protocol::RpcError as RpcError_Proto; -use std::{error::Error, fmt::Display, time::Duration}; - -/// Parameters for performing an RPC call -#[derive(Debug, Clone)] -pub struct PerformRpcData { - pub destination_identity: String, - pub method: String, - pub payload: String, - pub response_timeout: Duration, -} - -impl Default for PerformRpcData { - fn default() -> Self { - Self { - destination_identity: Default::default(), - method: Default::default(), - payload: Default::default(), - response_timeout: Duration::from_secs(15), - } - } -} - -/// Data passed to method handler for incoming RPC invocations -/// -/// Attributes: -/// request_id (String): The unique request ID. Will match at both sides of the call, useful for debugging or logging. -/// caller_identity (ParticipantIdentity): The unique participant identity of the caller. -/// payload (String): The payload of the request. User-definable format, typically JSON. -/// response_timeout (Duration): The maximum time the caller will wait for a response. -#[derive(Debug, Clone)] -pub struct RpcInvocationData { - pub request_id: String, - pub caller_identity: ParticipantIdentity, - pub payload: String, - pub response_timeout: Duration, -} - -/// Specialized error handling for RPC methods. -/// -/// Instances of this type, when thrown in a method handler, will have their `message` -/// serialized and sent across the wire. The caller will receive an equivalent error on the other side. -/// -/// Build-in types are included but developers may use any string, with a max length of 256 bytes. -#[derive(Debug, Clone)] -pub struct RpcError { - pub code: u32, - pub message: String, - pub data: Option, -} - -impl RpcError { - pub const MAX_MESSAGE_BYTES: usize = 256; - pub const MAX_DATA_BYTES: usize = 15360; // 15 KB - - /// Creates an error object with the given code and message, plus an optional data payload. - /// - /// If thrown in an RPC method handler, the error will be sent back to the caller. - /// - /// Error codes 1001-1999 are reserved for built-in errors (see RpcErrorCode for their meanings). - pub fn new(code: u32, message: String, data: Option) -> Self { - Self { - code, - message: truncate_bytes(&message, Self::MAX_MESSAGE_BYTES), - data: data.map(|d| truncate_bytes(&d, Self::MAX_DATA_BYTES)), - } - } - - pub fn from_proto(proto: RpcError_Proto) -> Self { - Self::new(proto.code, proto.message, Some(proto.data)) - } - - pub fn to_proto(&self) -> RpcError_Proto { - RpcError_Proto { - code: self.code, - message: self.message.clone(), - data: self.data.clone().unwrap_or_default(), - } - } -} - -impl Display for RpcError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "RPC Error: {} ({})", self.message, self.code) - } -} -impl Error for RpcError {} - -#[derive(Debug, Clone, Copy)] -pub enum RpcErrorCode { - ApplicationError = 1500, - ConnectionTimeout = 1501, - ResponseTimeout = 1502, - RecipientDisconnected = 1503, - ResponsePayloadTooLarge = 1504, - SendFailed = 1505, - - UnsupportedMethod = 1400, - RecipientNotFound = 1401, - RequestPayloadTooLarge = 1402, - UnsupportedServer = 1403, - UnsupportedVersion = 1404, -} - -impl RpcErrorCode { - pub(crate) fn message(&self) -> &'static str { - match self { - Self::ApplicationError => "Application error in method handler", - Self::ConnectionTimeout => "Connection timeout", - Self::ResponseTimeout => "Response timeout", - Self::RecipientDisconnected => "Recipient disconnected", - Self::ResponsePayloadTooLarge => "Response payload too large", - Self::SendFailed => "Failed to send", - - Self::UnsupportedMethod => "Method not supported at destination", - Self::RecipientNotFound => "Recipient not found", - Self::RequestPayloadTooLarge => "Request payload too large", - Self::UnsupportedServer => "RPC not supported by server", - Self::UnsupportedVersion => "Unsupported RPC version", - } - } -} - -impl RpcError { - /// Creates an error object from the code, with an auto-populated message. - pub(crate) fn built_in(code: RpcErrorCode, data: Option) -> Self { - Self::new(code as u32, code.message().to_string(), data) - } -} - -/// Maximum payload size in bytes -pub const MAX_PAYLOAD_BYTES: usize = 15360; // 15 KB - -/// Calculate the byte length of a string -pub(crate) fn byte_length(s: &str) -> usize { - s.as_bytes().len() -} - -/// Truncate a string to a maximum number of bytes -pub(crate) fn truncate_bytes(s: &str, max_bytes: usize) -> String { - if byte_length(s) <= max_bytes { - return s.to_string(); - } - - let mut result = String::new(); - for c in s.chars() { - if byte_length(&(result.clone() + &c.to_string())) > max_bytes { - break; - } - result.push(c); - } - result -} +// Re-export all RPC types from the room::rpc module. +// This keeps existing imports from `room::participant::*` working. +pub use crate::room::rpc::*; diff --git a/livekit/src/room/rpc/client.rs b/livekit/src/room/rpc/client.rs new file mode 100644 index 000000000..4c35ca0f8 --- /dev/null +++ b/livekit/src/room/rpc/client.rs @@ -0,0 +1,308 @@ +// Copyright 2026 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use super::{ + PerformRpcData, RpcError, RpcErrorCode, RpcTransport, ATTR_METHOD, ATTR_REQUEST_ID, + ATTR_RESPONSE_TIMEOUT_MS, ATTR_VERSION, MAX_PAYLOAD_BYTES, RPC_REQUEST_TOPIC, RPC_VERSION_V1, + RPC_VERSION_V2, +}; +use crate::data_stream::{StreamReader, StreamTextOptions, TextStreamReader}; +use crate::room::id::ParticipantIdentity; +use libwebrtc::native::create_random_uuid; +use livekit_api::signal_client::CLIENT_PROTOCOL_DATA_STREAM_RPC; +use livekit_protocol as proto; +use parking_lot::Mutex; +use semver::Version; +use std::collections::HashMap; +use std::time::Duration; +use tokio::sync::oneshot; + +/// Manages outgoing RPC calls (caller/client side). +/// +/// Tracks pending ACKs and responses, handles v1 packet and v2 data stream +/// transport selection based on the remote participant's client protocol. +pub struct RpcClientManager { + pending_acks: Mutex>>, + pending_responses: Mutex>>>, +} + +impl RpcClientManager { + pub fn new() -> Self { + Self { + pending_acks: Mutex::new(HashMap::new()), + pending_responses: Mutex::new(HashMap::new()), + } + } + + /// Perform an RPC call to a remote participant. + /// + /// Selects v1 (data packet) or v2 (data stream) transport based on + /// the remote participant's client_protocol. + pub(crate) async fn perform_rpc( + &self, + data: PerformRpcData, + transport: &(impl RpcTransport + 'static), + ) -> Result { + let max_round_trip_latency = Duration::from_millis(7000); + let min_effective_timeout = Duration::from_millis(1000); + + if let Some(version_str) = transport.server_version() { + let server_version = Version::parse(&version_str).unwrap(); + let min_required_version = Version::parse("1.8.0").unwrap(); + if server_version < min_required_version { + return Err(RpcError::built_in(RpcErrorCode::UnsupportedServer, None)); + } + } + + // Determine transport version based on remote participant's client_protocol + let remote_protocol = transport + .remote_client_protocol(&ParticipantIdentity(data.destination_identity.clone())); + let use_v2 = remote_protocol >= CLIENT_PROTOCOL_DATA_STREAM_RPC; + + // Only enforce payload size limit for v1 transport + if !use_v2 && data.payload.len() > MAX_PAYLOAD_BYTES { + return Err(RpcError::built_in(RpcErrorCode::RequestPayloadTooLarge, None)); + } + + let id = create_random_uuid(); + let (ack_tx, ack_rx) = oneshot::channel(); + let (response_tx, response_rx) = oneshot::channel(); + let effective_timeout = std::cmp::max( + data.response_timeout.saturating_sub(max_round_trip_latency), + min_effective_timeout, + ); + + // Register channels BEFORE sending the request to avoid race condition + // where the response arrives before we've registered the handlers + { + let mut pending_acks = self.pending_acks.lock(); + let mut pending_responses = self.pending_responses.lock(); + pending_acks.insert(id.clone(), ack_tx); + pending_responses.insert(id.clone(), response_tx); + } + + let send_result = if use_v2 { + self.send_v2_request( + transport, + &data.destination_identity, + &id, + &data.method, + &data.payload, + effective_timeout, + ) + .await + } else { + self.send_v1_request( + transport, + &data.destination_identity, + &id, + &data.method, + &data.payload, + effective_timeout, + ) + .await + .map_err(|e| RpcError::built_in(RpcErrorCode::SendFailed, Some(e.to_string()))) + }; + + if let Err(e) = send_result { + // Clean up on failure + let mut pending_acks = self.pending_acks.lock(); + let mut pending_responses = self.pending_responses.lock(); + pending_acks.remove(&id); + pending_responses.remove(&id); + log::error!("Failed to publish RPC request: {}", e); + return Err(e); + } + + // Wait for ack timeout + match tokio::time::timeout(max_round_trip_latency, ack_rx).await { + Err(_) => { + let mut pending_acks = self.pending_acks.lock(); + let mut pending_responses = self.pending_responses.lock(); + pending_acks.remove(&id); + pending_responses.remove(&id); + return Err(RpcError::built_in(RpcErrorCode::ConnectionTimeout, None)); + } + Ok(_) => { + // Ack received, continue to wait for response + } + } + + // Wait for response timeout + let response = match tokio::time::timeout(data.response_timeout, response_rx).await { + Err(_) => { + self.pending_responses.lock().remove(&id); + return Err(RpcError::built_in(RpcErrorCode::ResponseTimeout, None)); + } + Ok(result) => result, + }; + + match response { + Err(_) => { + // Channel closed — sender dropped (e.g. disconnect) + Err(RpcError::built_in(RpcErrorCode::RecipientDisconnected, None)) + } + Ok(Err(e)) => { + // RPC error from remote, forward it + Err(e) + } + Ok(Ok(payload)) => { + // Successful response + Ok(payload) + } + } + } + + /// Publish a v1 RPC request data packet. + pub(crate) async fn send_v1_request( + &self, + transport: &impl RpcTransport, + destination_identity: &str, + id: &str, + method: &str, + payload: &str, + response_timeout: Duration, + ) -> Result<(), crate::room::RoomError> { + let rpc_request_message = proto::RpcRequest { + id: id.to_string(), + method: method.to_string(), + payload: payload.to_string(), + response_timeout_ms: response_timeout.as_millis() as u32, + version: RPC_VERSION_V1, + ..Default::default() + }; + + let data = proto::DataPacket { + value: Some(proto::data_packet::Value::RpcRequest(rpc_request_message)), + destination_identities: vec![destination_identity.to_string()], + ..Default::default() + }; + + transport.publish_data(data).await + } + + /// Send an RPC request as a v2 text data stream. + async fn send_v2_request( + &self, + transport: &impl RpcTransport, + destination_identity: &str, + id: &str, + method: &str, + payload: &str, + response_timeout: Duration, + ) -> Result<(), RpcError> { + let mut attributes = HashMap::new(); + attributes.insert(ATTR_REQUEST_ID.to_string(), id.to_string()); + attributes.insert(ATTR_METHOD.to_string(), method.to_string()); + attributes + .insert(ATTR_RESPONSE_TIMEOUT_MS.to_string(), response_timeout.as_millis().to_string()); + attributes.insert(ATTR_VERSION.to_string(), RPC_VERSION_V2.to_string()); + + let options = StreamTextOptions { + topic: RPC_REQUEST_TOPIC.to_string(), + attributes, + destination_identities: vec![ParticipantIdentity(destination_identity.to_string())], + ..Default::default() + }; + + transport + .send_text(payload, options) + .await + .map(|_| ()) + .map_err(|e| RpcError::built_in(RpcErrorCode::SendFailed, Some(e.to_string()))) + } + + /// Drop the pending response sender for a request, simulating a disconnect. + #[cfg(test)] + pub(crate) fn drop_pending_response(&self, request_id: &str) { + self.pending_responses.lock().remove(request_id); + } + + /// Register a pending response channel for testing. + #[cfg(test)] + pub(crate) fn insert_pending_response( + &self, + request_id: String, + tx: tokio::sync::oneshot::Sender>, + ) { + self.pending_responses.lock().insert(request_id, tx); + } + + pub(crate) fn handle_ack(&self, request_id: String) { + let mut pending = self.pending_acks.lock(); + if let Some(tx) = pending.remove(&request_id) { + let _ = tx.send(()); + } else { + log::error!("Ack received for unexpected RPC request: {}", request_id); + } + } + + /// Handle a v1 RPC response packet. + /// + /// Also handles error responses for v2 calls, since error responses + /// always use v1 packets regardless of transport version. + pub(crate) fn handle_response( + &self, + request_id: String, + payload: Option, + error: Option, + ) { + let mut pending = self.pending_responses.lock(); + if let Some(tx) = pending.remove(&request_id) { + let _ = tx.send(match error { + Some(e) => Err(RpcError::from_proto(e)), + None => Ok(payload.unwrap_or_default()), + }); + } else { + log::error!("Response received for unexpected RPC request: {}", request_id); + } + } + + /// Handle a v2 RPC success response received as a data stream. + /// + /// Success responses between v2 clients arrive as text data streams + /// on the `lk.rpc_response` topic. Error responses always arrive + /// as v1 packets and are handled by `handle_response`. + pub(crate) async fn handle_response_stream(&self, reader: TextStreamReader) { + let request_id = reader.info().attributes.get(ATTR_REQUEST_ID).cloned().unwrap_or_default(); + + if request_id.is_empty() { + log::error!("RPC v2 response stream missing request_id attribute"); + return; + } + + let payload = match reader.read_all().await { + Ok(payload) => payload, + Err(e) => { + log::error!("Failed to read RPC v2 response stream: {:?}", e); + // Resolve with error so the caller doesn't hang + let mut pending = self.pending_responses.lock(); + if let Some(tx) = pending.remove(&request_id) { + let _ = tx.send(Err(RpcError::built_in( + RpcErrorCode::ApplicationError, + Some(format!("Failed to read response stream: {}", e)), + ))); + } + return; + } + }; + + let mut pending = self.pending_responses.lock(); + if let Some(tx) = pending.remove(&request_id) { + let _ = tx.send(Ok(payload)); + } else { + log::error!("Response stream received for unexpected RPC request: {}", request_id); + } + } +} diff --git a/livekit/src/room/rpc/mod.rs b/livekit/src/room/rpc/mod.rs new file mode 100644 index 000000000..4dc1aa95a --- /dev/null +++ b/livekit/src/room/rpc/mod.rs @@ -0,0 +1,256 @@ +// Copyright 2026 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +mod client; +mod server; + +#[cfg(test)] +mod tests; + +pub use client::RpcClientManager; +pub use server::{HandleRequestOptions, RpcServerManager}; + +use crate::data_stream::{StreamResult, StreamTextOptions, TextStreamInfo}; +use crate::room::id::ParticipantIdentity; +use livekit_protocol::RpcError as RpcError_Proto; +use std::{error::Error, fmt::Display, future::Future, time::Duration}; + +// RPC protocol version constants (distinct from client_protocol; this is the +// version field on RpcRequest / v2 stream attributes). +pub(crate) const RPC_VERSION_V1: u32 = 1; +pub(crate) const RPC_VERSION_V2: u32 = 2; + +// Data stream topic constants for RPC v2 +pub(crate) const RPC_REQUEST_TOPIC: &str = "lk.rpc_request"; +pub(crate) const RPC_RESPONSE_TOPIC: &str = "lk.rpc_response"; + +// Stream attribute keys for RPC v2 +pub(crate) const ATTR_REQUEST_ID: &str = "lk.rpc_request_id"; +pub(crate) const ATTR_METHOD: &str = "lk.rpc_request_method"; +pub(crate) const ATTR_RESPONSE_TIMEOUT_MS: &str = "lk.rpc_request_response_timeout_ms"; +pub(crate) const ATTR_VERSION: &str = "lk.rpc_request_version"; + +/// Transport abstraction for RPC operations. +/// +/// Decouples the RPC managers from concrete engine/session types, +/// enabling in-memory unit testing with a mock transport. +pub(crate) trait RpcTransport: Send + Sync { + /// Send a data packet (used for v1 RPC packets and ACKs). + fn publish_data( + &self, + data: livekit_protocol::DataPacket, + ) -> impl Future> + Send; + + /// Send text as a data stream (used for v2 RPC requests and responses). + fn send_text( + &self, + text: &str, + options: StreamTextOptions, + ) -> impl Future> + Send; + + /// Look up a remote participant's client_protocol value. + fn remote_client_protocol(&self, identity: &ParticipantIdentity) -> i32; + + /// Get the server version string, if available. + fn server_version(&self) -> Option; +} + +/// Production implementation of `RpcTransport` backed by a `RoomSession`. +pub(crate) struct SessionTransport(pub(crate) std::sync::Arc); + +impl RpcTransport for SessionTransport { + async fn publish_data( + &self, + data: livekit_protocol::DataPacket, + ) -> Result<(), crate::room::RoomError> { + self.0 + .rtc_engine + .publish_data(data, crate::DataPacketKind::Reliable, false) + .await + .map_err(Into::into) + } + + async fn send_text( + &self, + text: &str, + options: StreamTextOptions, + ) -> StreamResult { + self.0.outgoing_stream_manager.send_text(text, options).await + } + + fn remote_client_protocol(&self, identity: &ParticipantIdentity) -> i32 { + self.0.get_remote_client_protocol(identity) + } + + fn server_version(&self) -> Option { + self.0 + .rtc_engine + .session() + .signal_client() + .join_response() + .server_info + .and_then(|info| info.version.is_empty().then(|| info.version)) + } +} + +/// Parameters for performing an RPC call +#[derive(Debug, Clone)] +pub struct PerformRpcData { + pub destination_identity: String, + pub method: String, + pub payload: String, + pub response_timeout: Duration, +} + +impl Default for PerformRpcData { + fn default() -> Self { + Self { + destination_identity: Default::default(), + method: Default::default(), + payload: Default::default(), + response_timeout: Duration::from_secs(15), + } + } +} + +/// Data passed to method handler for incoming RPC invocations +/// +/// Attributes: +/// request_id (String): The unique request ID. Will match at both sides of the call, useful for debugging or logging. +/// caller_identity (ParticipantIdentity): The unique participant identity of the caller. +/// payload (String): The payload of the request. User-definable format, typically JSON. +/// response_timeout (Duration): The maximum time the caller will wait for a response. +#[derive(Debug, Clone)] +pub struct RpcInvocationData { + pub request_id: String, + pub caller_identity: ParticipantIdentity, + pub payload: String, + pub response_timeout: Duration, +} + +/// Specialized error handling for RPC methods. +/// +/// Instances of this type, when thrown in a method handler, will have their `message` +/// serialized and sent across the wire. The caller will receive an equivalent error on the other side. +/// +/// Build-in types are included but developers may use any string, with a max length of 256 bytes. +#[derive(Debug, Clone)] +pub struct RpcError { + pub code: u32, + pub message: String, + pub data: Option, +} + +impl RpcError { + pub const MAX_MESSAGE_BYTES: usize = 256; + pub const MAX_DATA_BYTES: usize = 15360; // 15 KB + + /// Creates an error object with the given code and message, plus an optional data payload. + /// + /// If thrown in an RPC method handler, the error will be sent back to the caller. + /// + /// Error codes 1001-1999 are reserved for built-in errors (see RpcErrorCode for their meanings). + pub fn new(code: u32, message: String, data: Option) -> Self { + Self { + code, + message: truncate_bytes(&message, Self::MAX_MESSAGE_BYTES), + data: data.map(|d| truncate_bytes(&d, Self::MAX_DATA_BYTES)), + } + } + + pub fn from_proto(proto: RpcError_Proto) -> Self { + Self::new(proto.code, proto.message, Some(proto.data)) + } + + pub fn to_proto(&self) -> RpcError_Proto { + RpcError_Proto { + code: self.code, + message: self.message.clone(), + data: self.data.clone().unwrap_or_default(), + } + } +} + +impl Display for RpcError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "RPC Error: {} ({})", self.message, self.code) + } +} +impl Error for RpcError {} + +#[derive(Debug, Clone, Copy)] +pub enum RpcErrorCode { + ApplicationError = 1500, + ConnectionTimeout = 1501, + ResponseTimeout = 1502, + RecipientDisconnected = 1503, + ResponsePayloadTooLarge = 1504, + SendFailed = 1505, + + UnsupportedMethod = 1400, + RecipientNotFound = 1401, + RequestPayloadTooLarge = 1402, + UnsupportedServer = 1403, + UnsupportedVersion = 1404, +} + +impl RpcErrorCode { + pub(crate) fn message(&self) -> &'static str { + match self { + Self::ApplicationError => "Application error in method handler", + Self::ConnectionTimeout => "Connection timeout", + Self::ResponseTimeout => "Response timeout", + Self::RecipientDisconnected => "Recipient disconnected", + Self::ResponsePayloadTooLarge => "Response payload too large", + Self::SendFailed => "Failed to send", + + Self::UnsupportedMethod => "Method not supported at destination", + Self::RecipientNotFound => "Recipient not found", + Self::RequestPayloadTooLarge => "Request payload too large", + Self::UnsupportedServer => "RPC not supported by server", + Self::UnsupportedVersion => "Unsupported RPC version", + } + } +} + +impl RpcError { + /// Creates an error object from the code, with an auto-populated message. + pub(crate) fn built_in(code: RpcErrorCode, data: Option) -> Self { + Self::new(code as u32, code.message().to_string(), data) + } +} + +/// Maximum payload size in bytes for RPC v1 +pub const MAX_PAYLOAD_BYTES: usize = 15360; // 15 KB + +/// Calculate the byte length of a string +pub(crate) fn byte_length(s: &str) -> usize { + s.as_bytes().len() +} + +/// Truncate a string to a maximum number of bytes +pub(crate) fn truncate_bytes(s: &str, max_bytes: usize) -> String { + if byte_length(s) <= max_bytes { + return s.to_string(); + } + + let mut result = String::new(); + for c in s.chars() { + if byte_length(&(result.clone() + &c.to_string())) > max_bytes { + break; + } + result.push(c); + } + result +} diff --git a/livekit/src/room/rpc/server.rs b/livekit/src/room/rpc/server.rs new file mode 100644 index 000000000..578d66b14 --- /dev/null +++ b/livekit/src/room/rpc/server.rs @@ -0,0 +1,325 @@ +// Copyright 2026 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use super::{ + RpcError, RpcErrorCode, RpcInvocationData, RpcTransport, ATTR_METHOD, ATTR_REQUEST_ID, + ATTR_RESPONSE_TIMEOUT_MS, ATTR_VERSION, MAX_PAYLOAD_BYTES, RPC_RESPONSE_TOPIC, RPC_VERSION_V1, + RPC_VERSION_V2, +}; +use crate::data_stream::{StreamReader, StreamTextOptions, TextStreamReader}; +use crate::room::id::ParticipantIdentity; +use livekit_protocol as proto; +use parking_lot::Mutex; +use std::{collections::HashMap, future::Future, pin::Pin, sync::Arc, time::Duration}; + +pub(crate) type RpcHandlerFn = Arc< + dyn Fn(RpcInvocationData) -> Pin> + Send>> + + Send + + Sync, +>; + +/// Parameters for [`RpcServerManager::handle_request`]. +pub struct HandleRequestOptions { + pub caller_identity: ParticipantIdentity, + pub request_id: String, + pub method: String, + pub payload: String, + pub response_timeout: Duration, + pub version: u32, +} + +/// Manages incoming RPC requests (handler/server side). +/// +/// Stores registered method handlers and dispatches incoming requests +/// to the appropriate handler. Handles both v1 packet and v2 data stream +/// request formats. +pub struct RpcServerManager { + handlers: Mutex>, +} + +impl RpcServerManager { + pub fn new() -> Self { + Self { handlers: Mutex::new(HashMap::new()) } + } + + pub fn register_method( + &self, + method: String, + handler: impl Fn(RpcInvocationData) -> Pin> + Send>> + + Send + + Sync + + 'static, + ) { + self.handlers.lock().insert(method, Arc::new(handler)); + } + + pub fn unregister_method(&self, method: &str) { + self.handlers.lock().remove(method); + } + + pub(crate) fn get_handler(&self, method: &str) -> Option { + self.handlers.lock().get(method).cloned() + } + + /// Handle an incoming v1 RPC request (received as a DataPacket). + /// + /// Sends ACK, invokes the registered handler, and sends the response + /// as a v1 RPC response packet. + pub(crate) async fn handle_request( + &self, + options: HandleRequestOptions, + transport: &(impl RpcTransport + 'static), + ) { + let HandleRequestOptions { + caller_identity, + request_id, + method, + payload, + response_timeout, + version, + } = options; + + // Send ACK immediately + if let Err(e) = self.publish_rpc_ack(transport, &caller_identity.0, &request_id).await { + log::error!("Failed to publish RPC ACK: {:?}", e); + } + + let response = if version != RPC_VERSION_V1 { + Err(RpcError::built_in(RpcErrorCode::UnsupportedVersion, None)) + } else { + self.invoke_handler(&caller_identity, &request_id, &method, &payload, response_timeout) + .await + }; + + let (resp_payload, error) = match response { + Ok(response_payload) if response_payload.len() <= MAX_PAYLOAD_BYTES => { + (Some(response_payload), None) + } + Ok(_) => ( + None, + Some(RpcError::built_in(RpcErrorCode::ResponsePayloadTooLarge, None).to_proto()), + ), + Err(e) => (None, Some(e.to_proto())), + }; + + if let Err(e) = self + .publish_rpc_response_packet( + transport, + &caller_identity.0, + &request_id, + resp_payload, + error, + ) + .await + { + log::error!("Failed to publish RPC response: {:?}", e); + } + } + + /// Handle an incoming v2 RPC request (received as a data stream). + /// + /// Parses request metadata from stream attributes, sends ACK, + /// invokes the handler, and sends the response. Success responses + /// use a v2 data stream; error responses always use v1 packets. + pub(crate) async fn handle_request_stream( + &self, + reader: TextStreamReader, + caller_identity: ParticipantIdentity, + transport: &(impl RpcTransport + 'static), + ) { + let attrs = &reader.info().attributes; + + let request_id = attrs.get(ATTR_REQUEST_ID).cloned().unwrap_or_default(); + let method = attrs.get(ATTR_METHOD).cloned().unwrap_or_default(); + let response_timeout_ms: u64 = + attrs.get(ATTR_RESPONSE_TIMEOUT_MS).and_then(|v| v.parse().ok()).unwrap_or(15000); + let version: u32 = attrs.get(ATTR_VERSION).and_then(|v| v.parse().ok()).unwrap_or(0); + + let response_timeout = Duration::from_millis(response_timeout_ms); + + // Send ACK immediately (always v1 packet) + if let Err(e) = self.publish_rpc_ack(transport, &caller_identity.0, &request_id).await { + log::error!("Failed to publish RPC ACK: {:?}", e); + } + + if version != RPC_VERSION_V2 { + let error = RpcError::built_in(RpcErrorCode::UnsupportedVersion, None); + let _ = self + .publish_rpc_response_packet( + transport, + &caller_identity.0, + &request_id, + None, + Some(error.to_proto()), + ) + .await; + return; + } + + // Read the full payload from the stream + let payload = match reader.read_all().await { + Ok(payload) => payload, + Err(e) => { + log::error!("Failed to read RPC v2 request stream: {:?}", e); + let error = RpcError::built_in( + RpcErrorCode::ApplicationError, + Some(format!("Failed to read request stream: {}", e)), + ); + let _ = self + .publish_rpc_response_packet( + transport, + &caller_identity.0, + &request_id, + None, + Some(error.to_proto()), + ) + .await; + return; + } + }; + + let response = self + .invoke_handler(&caller_identity, &request_id, &method, &payload, response_timeout) + .await; + + match response { + Ok(response_payload) => { + // Success: send response as v2 data stream + let mut attributes = HashMap::new(); + attributes.insert(ATTR_REQUEST_ID.to_string(), request_id.clone()); + + let options = StreamTextOptions { + topic: RPC_RESPONSE_TOPIC.to_string(), + attributes, + destination_identities: vec![caller_identity.clone()], + ..Default::default() + }; + + if let Err(e) = transport.send_text(&response_payload, options).await { + log::error!("Failed to send RPC v2 response stream: {:?}", e); + // Fall back to error via v1 packet + let error = RpcError::built_in(RpcErrorCode::SendFailed, Some(e.to_string())); + let _ = self + .publish_rpc_response_packet( + transport, + &caller_identity.0, + &request_id, + None, + Some(error.to_proto()), + ) + .await; + } + } + Err(e) => { + // Error: always send as v1 packet + if let Err(send_err) = self + .publish_rpc_response_packet( + transport, + &caller_identity.0, + &request_id, + None, + Some(e.to_proto()), + ) + .await + { + log::error!("Failed to publish RPC error response: {:?}", send_err); + } + } + } + } + + /// Invoke a registered handler for an RPC method, with error handling. + async fn invoke_handler( + &self, + caller_identity: &ParticipantIdentity, + request_id: &str, + method: &str, + payload: &str, + response_timeout: Duration, + ) -> Result { + let handler = self.get_handler(method); + + match handler { + Some(handler) => { + let caller_id = caller_identity.clone(); + let req_id = request_id.to_string(); + let req_payload = payload.to_string(); + match tokio::task::spawn(async move { + handler(RpcInvocationData { + request_id: req_id, + caller_identity: caller_id, + payload: req_payload, + response_timeout, + }) + .await + }) + .await + { + Ok(result) => result, + Err(e) => { + log::error!("RPC method handler returned an error: {:?}", e); + Err(RpcError::built_in(RpcErrorCode::ApplicationError, None)) + } + } + } + None => Err(RpcError::built_in(RpcErrorCode::UnsupportedMethod, None)), + } + } + + /// Publish a v1 RPC response data packet. + async fn publish_rpc_response_packet( + &self, + transport: &impl RpcTransport, + destination_identity: &str, + request_id: &str, + payload: Option, + error: Option, + ) -> Result<(), crate::room::RoomError> { + let rpc_response_message = proto::RpcResponse { + request_id: request_id.to_string(), + value: Some(match error { + Some(error) => proto::rpc_response::Value::Error(error), + None => proto::rpc_response::Value::Payload(payload.unwrap()), + }), + ..Default::default() + }; + + let data = proto::DataPacket { + value: Some(proto::data_packet::Value::RpcResponse(rpc_response_message)), + destination_identities: vec![destination_identity.to_string()], + ..Default::default() + }; + + transport.publish_data(data).await + } + + /// Publish a v1 RPC ack data packet. + async fn publish_rpc_ack( + &self, + transport: &impl RpcTransport, + destination_identity: &str, + request_id: &str, + ) -> Result<(), crate::room::RoomError> { + let rpc_ack_message = + proto::RpcAck { request_id: request_id.to_string(), ..Default::default() }; + + let data = proto::DataPacket { + value: Some(proto::data_packet::Value::RpcAck(rpc_ack_message)), + destination_identities: vec![destination_identity.to_string()], + ..Default::default() + }; + + transport.publish_data(data).await + } +} diff --git a/livekit/src/room/rpc/tests.rs b/livekit/src/room/rpc/tests.rs new file mode 100644 index 000000000..f366a0b57 --- /dev/null +++ b/livekit/src/room/rpc/tests.rs @@ -0,0 +1,794 @@ +// Copyright 2026 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use super::*; +use crate::data_stream::{ + OperationType, StreamResult, StreamTextOptions, TextStreamInfo, TextStreamReader, +}; +use crate::e2ee::EncryptionType; +use crate::room::id::ParticipantIdentity; +use crate::room::RoomError; +use bytes::Bytes; +use chrono::Utc; +use livekit_api::signal_client::{CLIENT_PROTOCOL_DATA_STREAM_RPC, CLIENT_PROTOCOL_DEFAULT}; +use livekit_protocol as proto; +use parking_lot::Mutex as ParkingMutex; +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::{mpsc, Notify}; + +// --------------------------------------------------------------------------- +// Mock transport +// --------------------------------------------------------------------------- + +/// Captures all outgoing packets and text streams for assertion. +struct MockTransport { + sent_packets: Arc>>, + sent_texts: Arc>>, + packet_sent: Arc, + text_sent: Arc, + remote_protocols: HashMap, + server_ver: Option, +} + +impl MockTransport { + fn new() -> Self { + Self { + sent_packets: Default::default(), + sent_texts: Default::default(), + packet_sent: Arc::new(Notify::new()), + text_sent: Arc::new(Notify::new()), + remote_protocols: HashMap::new(), + server_ver: Some("1.8.0".to_string()), + } + } + + fn with_remote_protocol(mut self, identity: &str, protocol: i32) -> Self { + self.remote_protocols.insert(identity.to_string(), protocol); + self + } + + /// Wait until at least one packet has been sent. + async fn wait_for_packet(&self) { + self.packet_sent.notified().await; + } + + /// Wait until at least one text stream has been sent. + async fn wait_for_text(&self) { + self.text_sent.notified().await; + } + + /// Return all sent packets. + fn packets(&self) -> Vec { + self.sent_packets.lock().clone() + } + + /// Return all sent text streams as (body, options). + fn texts(&self) -> Vec<(String, StreamTextOptions)> { + self.sent_texts.lock().clone() + } + + /// Count packets matching a predicate on their `value`. + fn count_packets bool>(&self, f: F) -> usize { + self.packets().iter().filter(|p| p.value.as_ref().map_or(false, &f)).count() + } + + /// Extract the request ID from the first RPC request packet or text stream. + fn extract_request_id(&self) -> String { + // Try v1 packets first + for p in self.packets() { + if let Some(proto::data_packet::Value::RpcRequest(req)) = &p.value { + return req.id.clone(); + } + } + // Try v2 text streams + for (_, opts) in self.texts() { + if opts.topic == RPC_REQUEST_TOPIC { + if let Some(id) = opts.attributes.get(ATTR_REQUEST_ID) { + return id.clone(); + } + } + } + panic!("No RPC request found in mock transport"); + } +} + +impl RpcTransport for MockTransport { + async fn publish_data(&self, data: proto::DataPacket) -> Result<(), RoomError> { + self.sent_packets.lock().push(data); + self.packet_sent.notify_waiters(); + Ok(()) + } + + async fn send_text( + &self, + text: &str, + options: StreamTextOptions, + ) -> StreamResult { + self.sent_texts.lock().push((text.to_string(), options.clone())); + self.text_sent.notify_waiters(); + Ok(TextStreamInfo { + id: "mock-stream-id".to_string(), + topic: options.topic, + timestamp: Utc::now(), + total_length: Some(text.len() as u64), + attributes: options.attributes, + mime_type: "text/plain".to_string(), + operation_type: OperationType::Create, + version: 0, + reply_to_stream_id: None, + attached_stream_ids: vec![], + generated: false, + encryption_type: EncryptionType::None, + }) + } + + fn remote_client_protocol(&self, identity: &ParticipantIdentity) -> i32 { + self.remote_protocols.get(&identity.0).copied().unwrap_or(CLIENT_PROTOCOL_DEFAULT) + } + + fn server_version(&self) -> Option { + self.server_ver.clone() + } +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +fn make_text_reader( + text: &str, + attributes: HashMap, + topic: &str, +) -> TextStreamReader { + let (tx, rx) = mpsc::unbounded_channel(); + tx.send(Ok(Bytes::from(text.to_string()))).unwrap(); + drop(tx); // close the stream + TextStreamReader::new_for_test( + TextStreamInfo { + id: "test-stream".to_string(), + topic: topic.to_string(), + timestamp: Utc::now(), + total_length: Some(text.len() as u64), + attributes, + mime_type: "text/plain".to_string(), + operation_type: OperationType::Create, + version: 0, + reply_to_stream_id: None, + attached_stream_ids: vec![], + generated: false, + encryption_type: EncryptionType::None, + }, + rx, + ) +} + +fn v2_request_attrs(request_id: &str, method: &str, timeout_ms: u64) -> HashMap { + let mut attrs = HashMap::new(); + attrs.insert(ATTR_REQUEST_ID.to_string(), request_id.to_string()); + attrs.insert(ATTR_METHOD.to_string(), method.to_string()); + attrs.insert(ATTR_RESPONSE_TIMEOUT_MS.to_string(), timeout_ms.to_string()); + attrs.insert(ATTR_VERSION.to_string(), "2".to_string()); + attrs +} + +fn v2_response_attrs(request_id: &str) -> HashMap { + let mut attrs = HashMap::new(); + attrs.insert(ATTR_REQUEST_ID.to_string(), request_id.to_string()); + attrs +} + +fn is_rpc_request_packet(v: &proto::data_packet::Value) -> bool { + matches!(v, proto::data_packet::Value::RpcRequest(_)) +} + +fn is_rpc_response_packet(v: &proto::data_packet::Value) -> bool { + matches!(v, proto::data_packet::Value::RpcResponse(_)) +} + +fn is_rpc_ack_packet(v: &proto::data_packet::Value) -> bool { + matches!(v, proto::data_packet::Value::RpcAck(_)) +} + +fn extract_response_error(transport: &MockTransport) -> Option { + for p in transport.packets() { + if let Some(proto::data_packet::Value::RpcResponse(resp)) = &p.value { + if let Some(proto::rpc_response::Value::Error(e)) = &resp.value { + return Some(e.clone()); + } + } + } + None +} + +/// Run `perform_rpc` in a background task and return a handle. +/// +/// Uses `Arc` to share the client and transport safely across the spawn boundary. +async fn spawn_perform_rpc( + client: Arc, + transport: Arc, + data: PerformRpcData, +) -> tokio::task::JoinHandle> { + tokio::spawn(async move { client.perform_rpc(data, &*transport).await }) +} + +// ========================================================================= +// v2 -> v2 tests (both sides support data streams) +// ========================================================================= + +/// Spec #1: Caller happy path (short payload) — v2 data stream used. +#[tokio::test] +async fn test_v2_v2_caller_happy_path_short() { + let client = Arc::new(RpcClientManager::new()); + let transport = Arc::new( + MockTransport::new().with_remote_protocol("dest", CLIENT_PROTOCOL_DATA_STREAM_RPC), + ); + + let handle = spawn_perform_rpc( + client.clone(), + transport.clone(), + PerformRpcData { + destination_identity: "dest".into(), + method: "greet".into(), + payload: "hello".into(), + response_timeout: Duration::from_secs(5), + }, + ) + .await; + + // Wait for the request to be sent + transport.wait_for_text().await; + + // Verify: sent as v2 data stream, NOT a v1 packet + assert_eq!(transport.count_packets(is_rpc_request_packet), 0); + assert_eq!(transport.texts().len(), 1); + let (body, opts) = &transport.texts()[0]; + assert_eq!(opts.topic, RPC_REQUEST_TOPIC); + assert_eq!(body, "hello"); + assert_eq!(opts.attributes.get(ATTR_VERSION).unwrap(), "2"); + + let request_id = transport.extract_request_id(); + + // Simulate ACK + response + client.handle_ack(request_id.clone()); + client.handle_response(request_id, Some("world".into()), None); + + let result = handle.await.unwrap(); + assert_eq!(result.unwrap(), "world"); +} + +/// Spec #2: Caller happy path (large payload > 15 KB) — no size error. +#[tokio::test] +async fn test_v2_v2_caller_happy_path_large_payload() { + let client = Arc::new(RpcClientManager::new()); + let transport = Arc::new( + MockTransport::new().with_remote_protocol("dest", CLIENT_PROTOCOL_DATA_STREAM_RPC), + ); + + let large_payload = "x".repeat(20_000); + let handle = spawn_perform_rpc( + client.clone(), + transport.clone(), + PerformRpcData { + destination_identity: "dest".into(), + method: "big".into(), + payload: large_payload, + response_timeout: Duration::from_secs(5), + }, + ) + .await; + + transport.wait_for_text().await; + + let (body, _) = &transport.texts()[0]; + assert_eq!(body.len(), 20_000); + + let request_id = transport.extract_request_id(); + client.handle_ack(request_id.clone()); + client.handle_response(request_id, Some("ok".into()), None); + + let result = handle.await.unwrap(); + assert_eq!(result.unwrap(), "ok"); +} + +/// Spec #3: Handler happy path — response sent via v2 data stream. +#[tokio::test] +async fn test_v2_v2_handler_happy_path() { + let server = RpcServerManager::new(); + let transport = MockTransport::new(); + + server.register_method("echo".to_string(), |data| Box::pin(async move { Ok(data.payload) })); + + let reader = make_text_reader( + "request-body", + v2_request_attrs("req-1", "echo", 5000), + RPC_REQUEST_TOPIC, + ); + + server.handle_request_stream(reader, ParticipantIdentity("caller".into()), &transport).await; + + // ACK should be sent as v1 packet + assert_eq!(transport.count_packets(is_rpc_ack_packet), 1); + + // Success response should be sent as v2 data stream, NOT a v1 packet + assert_eq!(transport.count_packets(is_rpc_response_packet), 0); + assert_eq!(transport.texts().len(), 1); + let (body, opts) = &transport.texts()[0]; + assert_eq!(opts.topic, RPC_RESPONSE_TOPIC); + assert_eq!(body, "request-body"); // echo + assert_eq!(opts.attributes.get(ATTR_REQUEST_ID).unwrap(), "req-1"); +} + +/// Spec #4: Unhandled error in handler — error sent via v1 packet. +#[tokio::test] +async fn test_v2_v2_handler_unhandled_error() { + let server = RpcServerManager::new(); + let transport = MockTransport::new(); + + server.register_method("crash".to_string(), |_data| { + Box::pin(async move { + panic!("handler panic"); + }) + }); + + let reader = + make_text_reader("payload", v2_request_attrs("req-2", "crash", 5000), RPC_REQUEST_TOPIC); + + server.handle_request_stream(reader, ParticipantIdentity("caller".into()), &transport).await; + + // Error responses always use v1 packets, even between v2 clients + assert_eq!(transport.count_packets(is_rpc_response_packet), 1); + assert_eq!(transport.texts().len(), 0); // no data stream response + + let err = extract_response_error(&transport).unwrap(); + assert_eq!(err.code, RpcErrorCode::ApplicationError as u32); +} + +/// Spec #5: RpcError passthrough in handler — custom error code preserved. +#[tokio::test] +async fn test_v2_v2_handler_rpc_error_passthrough() { + let server = RpcServerManager::new(); + let transport = MockTransport::new(); + + server.register_method("fail".to_string(), |_data| { + Box::pin(async move { Err(RpcError::new(101, "custom".into(), Some("data".into()))) }) + }); + + let reader = + make_text_reader("payload", v2_request_attrs("req-3", "fail", 5000), RPC_REQUEST_TOPIC); + + server.handle_request_stream(reader, ParticipantIdentity("caller".into()), &transport).await; + + // Error sent as v1 packet + let err = extract_response_error(&transport).unwrap(); + assert_eq!(err.code, 101); + assert_eq!(err.message, "custom"); +} + +/// Spec #6: Response timeout — caller gives up after timeout. +#[tokio::test] +async fn test_v2_v2_response_timeout() { + let client = RpcClientManager::new(); + let transport = + MockTransport::new().with_remote_protocol("dest", CLIENT_PROTOCOL_DATA_STREAM_RPC); + + // Very short timeout — no ack or response will arrive. + // The ack timeout (7s) is larger than response_timeout (50ms), + // so connection timeout fires. + let result = client + .perform_rpc( + PerformRpcData { + destination_identity: "dest".into(), + method: "slow".into(), + payload: "x".into(), + response_timeout: Duration::from_millis(50), + }, + &transport, + ) + .await; + + let err = result.unwrap_err(); + assert_eq!(err.code, RpcErrorCode::ConnectionTimeout as u32); +} + +/// Spec #7: Error response — v1 error packet received by v2 caller. +#[tokio::test] +async fn test_v2_v2_error_response() { + let client = Arc::new(RpcClientManager::new()); + let transport = Arc::new( + MockTransport::new().with_remote_protocol("dest", CLIENT_PROTOCOL_DATA_STREAM_RPC), + ); + + let handle = spawn_perform_rpc( + client.clone(), + transport.clone(), + PerformRpcData { + destination_identity: "dest".into(), + method: "err".into(), + payload: "x".into(), + response_timeout: Duration::from_secs(5), + }, + ) + .await; + + transport.wait_for_text().await; + let request_id = transport.extract_request_id(); + + client.handle_ack(request_id.clone()); + // Error response arrives as v1 packet (per spec) + client.handle_response( + request_id, + None, + Some(proto::RpcError { code: 101, message: "nope".into(), data: "details".into() }), + ); + + let result = handle.await.unwrap(); + let err = result.unwrap_err(); + assert_eq!(err.code, 101); + assert_eq!(err.message, "nope"); +} + +/// Spec #8: Participant disconnection — channel dropped before response. +#[tokio::test] +async fn test_v2_v2_participant_disconnection() { + let client = Arc::new(RpcClientManager::new()); + let transport = Arc::new( + MockTransport::new().with_remote_protocol("dest", CLIENT_PROTOCOL_DATA_STREAM_RPC), + ); + + let handle = spawn_perform_rpc( + client.clone(), + transport.clone(), + PerformRpcData { + destination_identity: "dest".into(), + method: "dc".into(), + payload: "x".into(), + response_timeout: Duration::from_secs(5), + }, + ) + .await; + + transport.wait_for_text().await; + let request_id = transport.extract_request_id(); + + // ACK arrives, then the responder disconnects (response channel dropped) + client.handle_ack(request_id.clone()); + // Simulate disconnect by dropping the pending response sender + client.drop_pending_response(&request_id); + + let result = handle.await.unwrap(); + let err = result.unwrap_err(); + assert_eq!(err.code, RpcErrorCode::RecipientDisconnected as u32); +} + +// ========================================================================= +// v2 -> v1 tests (v2 caller, v1 handler) +// ========================================================================= + +/// Spec #10: Caller falls back to v1 packet when remote is v1. +#[tokio::test] +async fn test_v2_v1_caller_request_fallback() { + let client = Arc::new(RpcClientManager::new()); + // Remote has client_protocol = 0 (v1 only) + let transport = + Arc::new(MockTransport::new().with_remote_protocol("dest", CLIENT_PROTOCOL_DEFAULT)); + + let handle = spawn_perform_rpc( + client.clone(), + transport.clone(), + PerformRpcData { + destination_identity: "dest".into(), + method: "greet".into(), + payload: "hi".into(), + response_timeout: Duration::from_secs(5), + }, + ) + .await; + + transport.wait_for_packet().await; + + // Verify: sent as v1 packet, NOT a data stream + assert_eq!(transport.count_packets(is_rpc_request_packet), 1); + assert_eq!(transport.texts().iter().filter(|(_, o)| o.topic == RPC_REQUEST_TOPIC).count(), 0); + + let request_id = transport.extract_request_id(); + client.handle_ack(request_id.clone()); + client.handle_response(request_id, Some("yo".into()), None); + + let result = handle.await.unwrap(); + assert_eq!(result.unwrap(), "yo"); +} + +/// Spec #11: v1 handler receives v1 request and responds with v1 packet. +#[tokio::test] +async fn test_v2_v1_handler_v1_request() { + let server = RpcServerManager::new(); + let transport = MockTransport::new(); + + server.register_method("echo".to_string(), |data| Box::pin(async move { Ok(data.payload) })); + + server + .handle_request( + HandleRequestOptions { + caller_identity: ParticipantIdentity("caller".into()), + request_id: "req-v1".into(), + method: "echo".into(), + payload: "v1-body".into(), + response_timeout: Duration::from_secs(5), + version: RPC_VERSION_V1, + }, + &transport, + ) + .await; + + // ACK sent + assert_eq!(transport.count_packets(is_rpc_ack_packet), 1); + // Response sent as v1 packet (not data stream) + assert_eq!(transport.count_packets(is_rpc_response_packet), 1); + assert_eq!(transport.texts().len(), 0); + + // Verify response payload + for p in transport.packets() { + if let Some(proto::data_packet::Value::RpcResponse(resp)) = &p.value { + if let Some(proto::rpc_response::Value::Payload(payload)) = &resp.value { + assert_eq!(payload, "v1-body"); + } + } + } +} + +/// Spec #12: Payload too large rejected for v1 remote. +#[tokio::test] +async fn test_v2_v1_payload_too_large() { + let client = RpcClientManager::new(); + let transport = MockTransport::new().with_remote_protocol("dest", CLIENT_PROTOCOL_DEFAULT); + + let large_payload = "x".repeat(MAX_PAYLOAD_BYTES + 1); + let result = client + .perform_rpc( + PerformRpcData { + destination_identity: "dest".into(), + method: "big".into(), + payload: large_payload, + response_timeout: Duration::from_secs(5), + }, + &transport, + ) + .await; + + let err = result.unwrap_err(); + assert_eq!(err.code, RpcErrorCode::RequestPayloadTooLarge as u32); +} + +/// Spec #13: Response timeout with v1 remote. +#[tokio::test] +async fn test_v2_v1_response_timeout() { + let client = RpcClientManager::new(); + let transport = MockTransport::new().with_remote_protocol("dest", CLIENT_PROTOCOL_DEFAULT); + + let result = client + .perform_rpc( + PerformRpcData { + destination_identity: "dest".into(), + method: "slow".into(), + payload: "x".into(), + response_timeout: Duration::from_millis(50), + }, + &transport, + ) + .await; + + let err = result.unwrap_err(); + assert_eq!(err.code, RpcErrorCode::ConnectionTimeout as u32); +} + +/// Spec #14: Error response from v1 handler. +#[tokio::test] +async fn test_v2_v1_error_response() { + let client = Arc::new(RpcClientManager::new()); + let transport = + Arc::new(MockTransport::new().with_remote_protocol("dest", CLIENT_PROTOCOL_DEFAULT)); + + let handle = spawn_perform_rpc( + client.clone(), + transport.clone(), + PerformRpcData { + destination_identity: "dest".into(), + method: "err".into(), + payload: "x".into(), + response_timeout: Duration::from_secs(5), + }, + ) + .await; + + transport.wait_for_packet().await; + let request_id = transport.extract_request_id(); + + client.handle_ack(request_id.clone()); + client.handle_response( + request_id, + None, + Some(proto::RpcError { code: 101, message: "v1-err".into(), data: String::new() }), + ); + + let result = handle.await.unwrap(); + let err = result.unwrap_err(); + assert_eq!(err.code, 101); + assert_eq!(err.message, "v1-err"); +} + +/// Spec #15: Participant disconnection with v1 remote. +#[tokio::test] +async fn test_v2_v1_participant_disconnection() { + let client = Arc::new(RpcClientManager::new()); + let transport = + Arc::new(MockTransport::new().with_remote_protocol("dest", CLIENT_PROTOCOL_DEFAULT)); + + let handle = spawn_perform_rpc( + client.clone(), + transport.clone(), + PerformRpcData { + destination_identity: "dest".into(), + method: "dc".into(), + payload: "x".into(), + response_timeout: Duration::from_secs(5), + }, + ) + .await; + + transport.wait_for_packet().await; + let request_id = transport.extract_request_id(); + + client.handle_ack(request_id.clone()); + // Simulate disconnect by dropping the pending response sender + client.drop_pending_response(&request_id); + + let result = handle.await.unwrap(); + let err = result.unwrap_err(); + assert_eq!(err.code, RpcErrorCode::RecipientDisconnected as u32); +} + +// ========================================================================= +// v1 -> v2 tests (v1 caller, v2 handler) +// ========================================================================= + +/// Spec #16: v2 handler responds with v1 packet when request was v1. +#[tokio::test] +async fn test_v1_v2_handler_response_fallback() { + let server = RpcServerManager::new(); + let transport = MockTransport::new(); + + server.register_method("echo".to_string(), |data| Box::pin(async move { Ok(data.payload) })); + + // v1 caller sends a v1 packet request to our v2 handler + server + .handle_request( + HandleRequestOptions { + caller_identity: ParticipantIdentity("v1-caller".into()), + request_id: "req-v1-to-v2".into(), + method: "echo".into(), + payload: "hello-from-v1".into(), + response_timeout: Duration::from_secs(5), + version: RPC_VERSION_V1, + }, + &transport, + ) + .await; + + // ACK via v1 packet + assert_eq!(transport.count_packets(is_rpc_ack_packet), 1); + // Response via v1 packet (not data stream), even though handler supports v2 + assert_eq!(transport.count_packets(is_rpc_response_packet), 1); + assert_eq!(transport.texts().len(), 0); +} + +/// Spec #17: Unhandled error in v2 handler for v1 caller — APPLICATION_ERROR. +#[tokio::test] +async fn test_v1_v2_handler_unhandled_error() { + let server = RpcServerManager::new(); + let transport = MockTransport::new(); + + server.register_method("crash".to_string(), |_data| { + Box::pin(async move { + panic!("boom"); + }) + }); + + server + .handle_request( + HandleRequestOptions { + caller_identity: ParticipantIdentity("v1-caller".into()), + request_id: "req-crash".into(), + method: "crash".into(), + payload: "x".into(), + response_timeout: Duration::from_secs(5), + version: RPC_VERSION_V1, + }, + &transport, + ) + .await; + + let err = extract_response_error(&transport).unwrap(); + assert_eq!(err.code, RpcErrorCode::ApplicationError as u32); +} + +/// Spec #18: RpcError passthrough in v2 handler for v1 caller. +#[tokio::test] +async fn test_v1_v2_handler_rpc_error_passthrough() { + let server = RpcServerManager::new(); + let transport = MockTransport::new(); + + server.register_method("fail".to_string(), |_data| { + Box::pin(async move { Err(RpcError::new(101, "custom-err".into(), Some("extra".into()))) }) + }); + + server + .handle_request( + HandleRequestOptions { + caller_identity: ParticipantIdentity("v1-caller".into()), + request_id: "req-fail".into(), + method: "fail".into(), + payload: "x".into(), + response_timeout: Duration::from_secs(5), + version: RPC_VERSION_V1, + }, + &transport, + ) + .await; + + let err = extract_response_error(&transport).unwrap(); + assert_eq!(err.code, 101); + assert_eq!(err.message, "custom-err"); +} + +// ========================================================================= +// Additional tests +// ========================================================================= + +/// Verify handle_response_stream resolves the pending caller correctly. +#[tokio::test] +async fn test_v2_response_stream_resolves_caller() { + let client = RpcClientManager::new(); + + // Manually register a pending response + let (tx, rx) = tokio::sync::oneshot::channel(); + client.insert_pending_response("req-stream".to_string(), tx); + + let reader = + make_text_reader("stream-result", v2_response_attrs("req-stream"), RPC_RESPONSE_TOPIC); + + client.handle_response_stream(reader).await; + + let result: Result = rx.await.unwrap(); + assert_eq!(result.unwrap(), "stream-result"); +} + +/// Verify unregistered method returns UNSUPPORTED_METHOD error via v2 path. +#[tokio::test] +async fn test_v2_handler_unsupported_method() { + let server = RpcServerManager::new(); + let transport = MockTransport::new(); + + let reader = make_text_reader( + "payload", + v2_request_attrs("req-unsup", "nonexistent", 5000), + RPC_REQUEST_TOPIC, + ); + + server.handle_request_stream(reader, ParticipantIdentity("caller".into()), &transport).await; + + let err = extract_response_error(&transport).unwrap(); + assert_eq!(err.code, RpcErrorCode::UnsupportedMethod as u32); +} diff --git a/livekit/tests/rpc_test.rs b/livekit/tests/rpc_test.rs index edee6063d..ba7fd88fa 100644 --- a/livekit/tests/rpc_test.rs +++ b/livekit/tests/rpc_test.rs @@ -77,6 +77,71 @@ pub async fn test_rpc_unregistered() -> Result<()> { Ok(()) } +#[cfg(feature = "__lk-e2e-test")] +#[test_log::test(tokio::test)] +pub async fn test_rpc_large_payload() -> Result<()> { + let mut rooms = test_rooms(2).await?; + let (caller_room, _) = rooms.pop().unwrap(); + let (callee_room, _) = rooms.pop().unwrap(); + let callee_identity = callee_room.local_participant().identity(); + + const METHOD_NAME: &str = "large-payload-method"; + // 20KB payload - exceeds 15KB v1 limit but works with v2 data streams + let large_payload: String = "x".repeat(20_000); + + callee_room.local_participant().register_rpc_method(METHOD_NAME.to_string(), |data| { + Box::pin(async move { Ok(data.payload.to_string()) }) + }); + + let perform_data = PerformRpcData { + method: METHOD_NAME.to_string(), + destination_identity: callee_identity.to_string(), + payload: large_payload.clone(), + response_timeout: Duration::from_secs(5), + ..Default::default() + }; + let return_payload = caller_room + .local_participant() + .perform_rpc(perform_data) + .await + .context("Large payload invocation failed")?; + assert_eq!(return_payload, large_payload, "Large payload mismatch"); + Ok(()) +} + +#[cfg(feature = "__lk-e2e-test")] +#[test_log::test(tokio::test)] +pub async fn test_rpc_error_response() -> Result<()> { + use livekit::prelude::{RpcError, RpcErrorCode}; + + let mut rooms = test_rooms(2).await?; + let (caller_room, _) = rooms.pop().unwrap(); + let (callee_room, _) = rooms.pop().unwrap(); + let callee_identity = callee_room.local_participant().identity(); + + const METHOD_NAME: &str = "error-method"; + + callee_room.local_participant().register_rpc_method(METHOD_NAME.to_string(), |_data| { + Box::pin(async move { + Err(RpcError::new(42, "custom error".to_string(), Some("error data".to_string()))) + }) + }); + + let perform_data = PerformRpcData { + method: METHOD_NAME.to_string(), + destination_identity: callee_identity.to_string(), + payload: "test".to_string(), + response_timeout: Duration::from_secs(5), + ..Default::default() + }; + let result = caller_room.local_participant().perform_rpc(perform_data).await; + assert!(result.is_err(), "Expected error response"); + let err = result.unwrap_err(); + assert_eq!(err.code, 42, "Error code mismatch"); + assert_eq!(err.message, "custom error", "Error message mismatch"); + Ok(()) +} + #[cfg(feature = "__lk-e2e-test")] #[test_log::test(tokio::test)] pub async fn test_rpc_unknown_destination() -> Result<()> {