From d17fa0202bd7f194cdfa3b7eaa04d5cd5d84d418 Mon Sep 17 00:00:00 2001 From: cnderrauber Date: Tue, 28 Apr 2026 11:54:29 +0800 Subject: [PATCH] Support publishing track(s) when join room Save 2 rtt (add track signaling & sdp negotiation) if client wants to publish track immediately after join room. --- livekit-api/src/signal_client/mod.rs | 78 ++++++++++++------ livekit/src/prelude.rs | 1 + livekit/src/room/mod.rs | 114 +++++++++++++++++++++++++- livekit/src/rtc_engine/mod.rs | 47 +++++++++-- livekit/src/rtc_engine/rtc_session.rs | 83 ++++++++++++++++--- 5 files changed, 278 insertions(+), 45 deletions(-) diff --git a/livekit-api/src/signal_client/mod.rs b/livekit-api/src/signal_client/mod.rs index 59f0c4494..e4d735841 100644 --- a/livekit-api/src/signal_client/mod.rs +++ b/livekit-api/src/signal_client/mod.rs @@ -160,6 +160,7 @@ impl SignalClient { token: &str, options: SignalOptions, publisher_offer: Option, + add_track_requests: Vec, ) -> SignalResult<(Self, proto::JoinResponse, SignalEvents)> { let handle_success = |inner: Arc, join_response, stream_events| { let (emitter, events) = mpsc::unbounded_channel(); @@ -169,7 +170,15 @@ impl SignalClient { (Self { inner, emitter, handle: Mutex::new(Some(signal_task)) }, join_response, events) }; - match SignalInner::connect(url, token, options.clone(), publisher_offer.clone()).await { + match SignalInner::connect( + url, + token, + options.clone(), + publisher_offer.clone(), + add_track_requests.clone(), + ) + .await + { Ok((inner, join_response, stream_events)) => { Ok(handle_success(inner, join_response, stream_events)) } @@ -183,8 +192,14 @@ impl SignalClient { for url in urls.iter() { log::info!("fallback connection to: {}", url); - match SignalInner::connect(url, token, options.clone(), publisher_offer.clone()) - .await + match SignalInner::connect( + url, + token, + options.clone(), + publisher_offer.clone(), + add_track_requests.clone(), + ) + .await { Ok((inner, join_response, stream_events)) => { return Ok(handle_success(inner, join_response, stream_events)) @@ -269,6 +284,7 @@ impl SignalInner { token: &str, options: SignalOptions, publisher_offer: Option, + add_track_requests: Vec, ) -> SignalResult<( Arc, proto::JoinResponse, @@ -277,8 +293,16 @@ impl SignalInner { // Try v1 path first if single_peer_connection is enabled let use_v1_path = options.single_peer_connection; // For initial connection: reconnect=false, reconnect_reason=None, participant_sid="" - let lk_url = - get_livekit_url(url, &options, use_v1_path, false, None, "", publisher_offer.as_ref())?; + let lk_url = get_livekit_url( + url, + &options, + use_v1_path, + false, + None, + "", + publisher_offer.as_ref(), + &add_track_requests, + )?; // Try to connect to the SignalClient let (stream, mut events, single_pc_mode_active) = match SignalStream::connect(lk_url.clone(), token, options.connect_timeout).await { @@ -309,7 +333,7 @@ impl SignalInner { if use_v1_path && is_not_found { let lk_url_v0 = - get_livekit_url(url, &options, false, false, None, "", None)?; + get_livekit_url(url, &options, false, false, None, "", None, &[])?; log::warn!("v1 path not found (404), falling back to v0 path"); match SignalStream::connect( lk_url_v0.clone(), @@ -414,6 +438,7 @@ impl SignalInner { None, sid, None, + &[], ) .unwrap(); @@ -574,6 +599,7 @@ fn create_join_request_param( os_version: String, device_model: String, publisher_offer: Option<&proto::SessionDescription>, + add_track_requests: &[proto::AddTrackRequest], ) -> String { let connection_settings = proto::ConnectionSettings { auto_subscribe: options.auto_subscribe, @@ -596,6 +622,7 @@ fn create_join_request_param( connection_settings: Some(connection_settings), reconnect, publisher_offer: publisher_offer.cloned(), + add_track_requests: add_track_requests.to_vec(), ..Default::default() }; @@ -612,21 +639,22 @@ fn create_join_request_param( // Serialize JoinRequest to bytes let join_request_bytes = join_request.encode_to_vec(); - // Use gzip compression when publisher offer is included (SDP makes payload large) - let (compressed_bytes, compression) = if publisher_offer.is_some() { - let mut encoder = GzEncoder::new(Vec::new(), Compression::default()); - if encoder.write_all(&join_request_bytes).is_ok() { - if let Ok(compressed) = encoder.finish() { - (compressed, proto::wrapped_join_request::Compression::Gzip as i32) + // Use gzip compression when publisher offer or add_track_requests are included + let (compressed_bytes, compression) = + if publisher_offer.is_some() || !add_track_requests.is_empty() { + let mut encoder = GzEncoder::new(Vec::new(), Compression::default()); + if encoder.write_all(&join_request_bytes).is_ok() { + if let Ok(compressed) = encoder.finish() { + (compressed, proto::wrapped_join_request::Compression::Gzip as i32) + } else { + (join_request_bytes, proto::wrapped_join_request::Compression::None as i32) + } } else { (join_request_bytes, proto::wrapped_join_request::Compression::None as i32) } } else { (join_request_bytes, proto::wrapped_join_request::Compression::None as i32) - } - } else { - (join_request_bytes, proto::wrapped_join_request::Compression::None as i32) - }; + }; let wrapped_join_request = proto::WrappedJoinRequest { join_request: compressed_bytes, compression }; @@ -654,6 +682,7 @@ fn get_livekit_url( reconnect_reason: Option, participant_sid: &str, publisher_offer: Option<&proto::SessionDescription>, + add_track_requests: &[proto::AddTrackRequest], ) -> SignalResult { let mut lk_url = url::Url::parse(url).map_err(|err| SignalError::UrlParse(err.to_string()))?; @@ -692,6 +721,7 @@ fn get_livekit_url( os_info.version().to_string(), device_model.to_string(), publisher_offer, + add_track_requests, ); lk_url.query_pairs_mut().append_pair("join_request", &join_request_param); } else { @@ -796,39 +826,41 @@ mod tests { fn livekit_url_test() { let io = SignalOptions::default(); - assert!(get_livekit_url("localhost:7880", &io, false, false, None, "", None).is_err()); + assert!(get_livekit_url("localhost:7880", &io, false, false, None, "", None, &[]).is_err()); assert_eq!( - get_livekit_url("https://localhost:7880", &io, false, false, None, "", None) + get_livekit_url("https://localhost:7880", &io, false, false, None, "", None, &[]) .unwrap() .scheme(), "wss" ); assert_eq!( - get_livekit_url("http://localhost:7880", &io, false, false, None, "", None) + get_livekit_url("http://localhost:7880", &io, false, false, None, "", None, &[]) .unwrap() .scheme(), "ws" ); assert_eq!( - get_livekit_url("wss://localhost:7880", &io, false, false, None, "", None) + get_livekit_url("wss://localhost:7880", &io, false, false, None, "", None, &[]) .unwrap() .scheme(), "wss" ); assert_eq!( - get_livekit_url("ws://localhost:7880", &io, false, false, None, "", None) + get_livekit_url("ws://localhost:7880", &io, false, false, None, "", None, &[]) .unwrap() .scheme(), "ws" ); - assert!(get_livekit_url("ftp://localhost:7880", &io, false, false, None, "", None).is_err()); + assert!(get_livekit_url("ftp://localhost:7880", &io, false, false, None, "", None, &[]) + .is_err()); } #[test] fn validate_url_test() { let io = SignalOptions::default(); let lk_url = - get_livekit_url("wss://localhost:7880", &io, false, false, None, "", None).unwrap(); + get_livekit_url("wss://localhost:7880", &io, false, false, None, "", None, &[]) + .unwrap(); let validate_url = get_validate_url(lk_url); // Should be /rtc/validate, not /rtc/rtc/validate diff --git a/livekit/src/prelude.rs b/livekit/src/prelude.rs index e1e71f1c7..02703edc6 100644 --- a/livekit/src/prelude.rs +++ b/livekit/src/prelude.rs @@ -21,6 +21,7 @@ pub use crate::{ PushFrameError, PushFrameErrorReason, RemoteDataTrack, }, id::*, + options::TrackPublishOptions, participant::{ ConnectionQuality, DisconnectReason, LocalParticipant, Participant, PerformRpcData, RemoteParticipant, RpcError, RpcErrorCode, RpcInvocationData, diff --git a/livekit/src/room/mod.rs b/livekit/src/room/mod.rs index 89ae9c205..2832349e6 100644 --- a/livekit/src/room/mod.rs +++ b/livekit/src/room/mod.rs @@ -12,8 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +use crate::e2ee::EncryptionType; use bmrng::unbounded::UnboundedRequestReceiver; use futures_util::{Stream, StreamExt}; +use libwebrtc::prelude::RtpEncodingParameters; use libwebrtc::{ native::frame_cryptor::EncryptionState, prelude::{ @@ -31,6 +33,7 @@ use livekit_datatrack::{ use livekit_protocol::observer::Dispatcher; use livekit_protocol::{self as proto, encryption}; use livekit_runtime::JoinHandle; +use options::TrackPublishOptions; use parking_lot::RwLock; pub use proto::DisconnectReason; use proto::{promise::Promise, SignalTarget}; @@ -41,6 +44,7 @@ use tokio::sync::{ mpsc::{self, UnboundedReceiver}, oneshot, Mutex as AsyncMutex, }; +use track::LocalTrack; pub use utils::take_cell::TakeCell; pub use self::{ @@ -55,8 +59,8 @@ use crate::{ prelude::*, registered_audio_filter_plugins, rtc_engine::{ - EngineError, EngineEvent, EngineEvents, EngineOptions, EngineResult, RtcEngine, - SessionStats, INITIAL_BUFFERED_AMOUNT_LOW_THRESHOLD, + EngineError, EngineEvent, EngineEvents, EngineOptions, EngineResult, PrePublishTrack, + RtcEngine, SessionStats, INITIAL_BUFFERED_AMOUNT_LOW_THRESHOLD, }, }; @@ -394,6 +398,8 @@ pub struct RoomOptions { pub single_peer_connection: bool, /// Timeout for each individual signal connection attempt pub connect_timeout: Duration, + /// Tracks to publish immediately upon joining. Only effective when `single_peer_connection` is true. + pub publish_tracks: Vec<(LocalTrack, TrackPublishOptions)>, } impl Default for RoomOptions { @@ -416,6 +422,7 @@ impl Default for RoomOptions { sdk_options: RoomSdkOptions::default(), single_peer_connection: false, connect_timeout: SIGNAL_CONNECT_TIMEOUT, + publish_tracks: Vec::new(), } } } @@ -519,7 +526,22 @@ impl Room { signal_options.adaptive_stream = options.adaptive_stream; signal_options.single_peer_connection = options.single_peer_connection; signal_options.connect_timeout = options.connect_timeout; - let (rtc_engine, join_response, engine_events) = RtcEngine::connect( + + if !options.publish_tracks.is_empty() && !options.single_peer_connection { + return Err(RoomError::Internal( + "publish_tracks requires single_peer_connection to be enabled".into(), + )); + } + let encryption_type = e2ee_manager.encryption_type(); + let pre_publish_tracks: Vec = options + .publish_tracks + .iter() + .map(|(track, opts)| { + Self::build_pre_publish_track(track.clone(), opts.clone(), encryption_type) + }) + .collect(); + + let (rtc_engine, join_response, engine_events, pre_publish_receivers) = RtcEngine::connect( url, token, EngineOptions { @@ -527,6 +549,7 @@ impl Room { signal_options, join_retries: options.join_retries, single_peer_connection: options.single_peer_connection, + publish_tracks: pre_publish_tracks.clone(), }, Some(e2ee_manager.clone()), ) @@ -788,9 +811,94 @@ impl Room { }; inner.handle.lock().await.replace(handle); + let mut receiver_map: HashMap> = + pre_publish_receivers.into_iter().collect(); + for pt in pre_publish_tracks { + let Some(rx) = receiver_map.remove(&pt.request.cid) else { + log::warn!("no receiver for pre-published track {}", pt.track.name()); + continue; + }; + match rtc_engine.wait_track_published_by_cid(pt.request.cid.clone(), rx).await { + Ok(track_info) => { + let publication = + LocalTrackPublication::new(track_info.clone(), pt.track.clone()); + pt.track.update_info(track_info); + publication.set_track(Some(pt.track.clone().into())); + publication.update_publish_options(pt.options); + inner.local_participant.add_publication(TrackPublication::Local(publication)); + pt.track.enable(); + log::debug!("pre-published track completed: {}", pt.track.name()); + } + Err(err) => { + log::warn!( + "failed to complete pre-published track {}: {:?}", + pt.track.name(), + err + ); + } + } + } + Ok((Self { inner }, events)) } + fn build_pre_publish_track( + track: LocalTrack, + opts: TrackPublishOptions, + encryption_type: EncryptionType, + ) -> PrePublishTrack { + let disable_red = encryption_type != EncryptionType::None || !opts.red; + + let mut req = proto::AddTrackRequest { + cid: track.rtc_track().id(), + name: track.name(), + r#type: proto::TrackType::from(track.kind()) as i32, + muted: track.is_muted(), + source: proto::TrackSource::from(opts.source) as i32, + disable_dtx: !opts.dtx, + disable_red, + encryption: proto::encryption::Type::from(encryption_type) as i32, + stream: opts.stream.clone(), + ..Default::default() + }; + + if opts.preconnect_buffer { + req.audio_features.push(proto::AudioTrackFeature::TfPreconnectBuffer as i32); + } + + let encodings = match &track { + LocalTrack::Video(video_track) => { + let resolution = video_track.rtc_source().video_resolution(); + req.width = resolution.width; + req.height = resolution.height; + + let encodings = options::compute_video_encodings(req.width, req.height, &opts); + req.layers = + options::video_layers_from_encodings(req.width, req.height, &encodings); + + if opts.simulcast && encodings.len() > 1 { + req.simulcast_codecs = vec![proto::SimulcastCodec { + codec: opts.video_codec.as_str().to_string(), + cid: track.rtc_track().id(), + layers: req.layers.clone(), + ..Default::default() + }]; + } + encodings + } + LocalTrack::Audio(_) => { + let audio_encoding = + opts.audio_encoding.as_ref().unwrap_or(&options::audio::MUSIC.encoding); + vec![RtpEncodingParameters { + max_bitrate: Some(audio_encoding.max_bitrate), + ..Default::default() + }] + } + }; + + PrePublishTrack { track, options: opts, encodings, request: req } + } + pub async fn close(&self) -> RoomResult<()> { self.inner.close(DisconnectReason::ClientInitiated).await } diff --git a/livekit/src/rtc_engine/mod.rs b/livekit/src/rtc_engine/mod.rs index 1ad21f7de..bb2837feb 100644 --- a/livekit/src/rtc_engine/mod.rs +++ b/livekit/src/rtc_engine/mod.rs @@ -76,6 +76,15 @@ pub enum EngineError { Internal(Cow<'static, str>), // Unexpected error, generally we can't recover } +/// A track to be published at join time, bundled with its options and encodings. +#[derive(Debug, Clone)] +pub struct PrePublishTrack { + pub track: LocalTrack, + pub options: TrackPublishOptions, + pub encodings: Vec, + pub request: proto::AddTrackRequest, +} + #[derive(Default, Debug, Clone)] pub struct EngineOptions { pub rtc_config: RtcConfiguration, @@ -83,6 +92,8 @@ pub struct EngineOptions { pub join_retries: u32, /// Enable single peer connection mode pub single_peer_connection: bool, + /// Tracks to publish at join time (pre-publish optimization) + pub publish_tracks: Vec, } #[derive(Debug)] @@ -245,10 +256,15 @@ impl RtcEngine { token: &str, options: EngineOptions, e2ee_manager: Option, - ) -> EngineResult<(Self, proto::JoinResponse, EngineEvents)> { - let (inner, join_response, engine_events) = + ) -> EngineResult<( + Self, + proto::JoinResponse, + EngineEvents, + Vec<(String, oneshot::Receiver)>, + )> { + let (inner, join_response, engine_events, pre_publish_receivers) = EngineInner::connect(url, token, options, e2ee_manager).await?; - Ok((Self { inner }, join_response, engine_events)) + Ok((Self { inner }, join_response, engine_events, pre_publish_receivers)) } pub async fn close(&self, reason: DisconnectReason) { @@ -308,6 +324,18 @@ impl RtcEngine { session.add_track(req).await } + pub async fn wait_track_published_by_cid( + &self, + cid: String, + rx: oneshot::Receiver, + ) -> EngineResult { + let (session, _r_lock) = { + let (handle, _r_lock) = self.inner.wait_reconnection().await?; + (handle.session.clone(), _r_lock) + }; + session.wait_track_published_by_cid(cid, rx).await + } + pub fn remove_track(&self, sender: RtpSender) -> EngineResult<()> { // We don't need to wait for the reconnection let session = self.inner.running_handle.read().session.clone(); @@ -378,7 +406,12 @@ impl EngineInner { token: &str, options: EngineOptions, e2ee_manager: Option, - ) -> EngineResult<(Arc, proto::JoinResponse, EngineEvents)> { + ) -> EngineResult<( + Arc, + proto::JoinResponse, + EngineEvents, + Vec<(String, oneshot::Receiver)>, + )> { let lk_runtime = LkRuntime::instance(); let max_retries = options.join_retries; @@ -388,7 +421,7 @@ impl EngineInner { let lk_runtime = lk_runtime.clone(); let e2ee_manager = e2ee_manager.clone(); async move { - let (session, join_response, session_events) = + let (session, join_response, session_events, pre_publish_receivers) = RtcSession::connect(url, token, options.clone(), e2ee_manager).await?; session.wait_pc_connection().await?; @@ -419,7 +452,7 @@ impl EngineInner { )); inner.running_handle.write().engine_task = Some((session_task, close_tx)); - Ok((inner, join_response, engine_rx)) + Ok((inner, join_response, engine_rx, pre_publish_receivers)) } } }; @@ -855,7 +888,7 @@ impl EngineInner { let _ = engine_task.await; } - let (new_session, join_response, session_events) = + let (new_session, join_response, session_events, _) = RtcSession::connect(url, token, options, e2ee_manager).await?; // On SignalRestarted, the room will try to unpublish the local tracks diff --git a/livekit/src/rtc_engine/rtc_session.rs b/livekit/src/rtc_engine/rtc_session.rs index 72abc36c0..b3ab74834 100644 --- a/livekit/src/rtc_engine/rtc_session.rs +++ b/livekit/src/rtc_engine/rtc_session.rs @@ -449,13 +449,19 @@ impl RtcSession { token: &str, options: EngineOptions, e2ee_manager: Option, - ) -> EngineResult<(Self, proto::JoinResponse, SessionEvents)> { + ) -> EngineResult<( + Self, + proto::JoinResponse, + SessionEvents, + Vec<(String, oneshot::Receiver)>, + )> { let (emitter, session_events) = mpsc::unbounded_channel(); let lk_runtime = LkRuntime::instance(); let use_single_pc = options.signal_options.single_peer_connection; let mut publisher_offer = None; + let mut add_track_requests = Vec::new(); let early_publisher_pc = if use_single_pc { let publisher_pc = PeerTransport::new( lk_runtime.pc_factory().create_peer_connection(options.rtc_config.clone())?, @@ -466,6 +472,18 @@ impl RtcSession { let dcs = Self::create_data_channels(&publisher_pc, &emitter)?; Self::add_recv_media_sections(&publisher_pc.peer_connection(), 3, 3)?; + // Add SendOnly transceivers for pre-publish tracks so the initial offer + // includes their media sections, and include AddTrackRequests in JoinRequest. + for pt in &options.publish_tracks { + let init = RtpTransceiverInit { + direction: RtpTransceiverDirection::SendOnly, + stream_ids: Default::default(), + send_encodings: pt.encodings.clone(), + }; + publisher_pc.peer_connection().add_transceiver(pt.track.rtc_track(), init)?; + add_track_requests.push(pt.request.clone()); + } + match publisher_pc.create_initial_offer().await { Ok(Some(offer)) => { publisher_offer = Some(proto::SessionDescription { @@ -491,6 +509,7 @@ impl RtcSession { token, options.signal_options.clone(), publisher_offer.clone(), + add_track_requests, ) .await?; let signal_client = Arc::new(signal_client); @@ -556,6 +575,12 @@ impl RtcSession { let (close_tx, close_rx) = watch::channel(false); + let pre_publish_cids: Vec = options + .publish_tracks + .iter() + .map(|pt| pt.request.cid.clone()) + .collect(); + let dt_sender_options = DataChannelSenderOptions { low_buffer_threshold: DATA_TRACK_BUFFERED_AMOUNT_LOW_THRESHOLD, dc: data_track_dc.clone(), @@ -600,6 +625,14 @@ impl RtcSession { pc_state_notify: Notify::new(), }); + let pre_publish_receivers: Vec<(String, oneshot::Receiver)> = + pre_publish_cids + .iter() + .filter_map(|cid| { + inner.register_pending_track(cid).ok().map(|rx| (cid.clone(), rx)) + }) + .collect(); + // Start session tasks let signal_task = livekit_runtime::spawn(inner.clone().signal_task(signal_events, close_rx.clone())); @@ -626,7 +659,7 @@ impl RtcSession { inner.publisher_negotiation_needed(); } - Ok((Self { inner, handle }, join_response, session_events)) + Ok((Self { inner, handle }, join_response, session_events, pre_publish_receivers)) } fn create_data_channels( @@ -699,6 +732,14 @@ impl RtcSession { self.inner.add_track(req).await } + pub async fn wait_track_published_by_cid( + &self, + cid: String, + rx: oneshot::Receiver, + ) -> EngineResult { + self.inner.wait_track_published_by_cid(cid, rx).await + } + pub async fn mute_track(&self, req: proto::MuteTrackRequest) -> EngineResult<()> { self.inner.mute_track(req).await } @@ -1620,20 +1661,38 @@ impl SessionInner { } async fn add_track(&self, req: proto::AddTrackRequest) -> EngineResult { - let (tx, rx) = oneshot::channel(); let cid = req.cid.clone(); - { - let mut pendings_tracks = self.pending_tracks.lock(); - if pendings_tracks.contains_key(&req.cid) { - Err(EngineError::Internal("track already published".into()))?; - } + let rx = self.register_pending_track(&cid)?; + self.signal_client.send(proto::signal_request::Message::AddTrack(req)).await; + self.wait_track_published(cid, rx).await + } - pendings_tracks.insert(cid.clone(), tx); - } + async fn wait_track_published_by_cid( + &self, + cid: String, + rx: oneshot::Receiver, + ) -> EngineResult { + self.wait_track_published(cid, rx).await + } - self.signal_client.send(proto::signal_request::Message::AddTrack(req)).await; + fn register_pending_track( + &self, + cid: &str, + ) -> EngineResult> { + let (tx, rx) = oneshot::channel(); + let mut pending_tracks = self.pending_tracks.lock(); + if pending_tracks.contains_key(cid) { + Err(EngineError::Internal("track already published".into()))?; + } + pending_tracks.insert(cid.to_string(), tx); + Ok(rx) + } - // Wait the result from the server (TrackInfo) + async fn wait_track_published( + &self, + cid: String, + rx: oneshot::Receiver, + ) -> EngineResult { tokio::select! { Ok(info) = rx => Ok(info), _ = sleep(TRACK_PUBLISH_TIMEOUT) => {