diff --git a/rust/README.md b/rust/README.md index f4d80fefd..ce232bde3 100644 --- a/rust/README.md +++ b/rust/README.md @@ -82,7 +82,25 @@ With the default `CliProgram::Resolve`, `Client::start()` resolves the CLI in th ### Session -Created via `Client::create_session` or `Client::resume_session`. Owns an internal event loop that dispatches CLI callbacks to the focused handler traits you install on `SessionConfig`, and broadcasts session events through `subscribe()`. +Created via `Client::create_session`, `Client::create_cloud_session`, or `Client::resume_session`. Owns an internal event loop that dispatches CLI callbacks to the focused handler traits you install on `SessionConfig`, and broadcasts session events through `subscribe()`. + +#### Cloud sessions + +`Client::create_cloud_session` creates a Mission Control–backed cloud session. The runtime owns the session ID: do **not** set `session_id` or `provider` on the config (the SDK rejects both with `Error::InvalidConfig`). Build the config with `SessionConfig::with_cloud(...)`; `Client::create_session` will reject any config that has `cloud` set. + +```rust,ignore +use github_copilot_sdk::types::{CloudSessionOptions, CloudSessionRepository, SessionConfig}; + +let cloud = CloudSessionOptions::with_repository( + CloudSessionRepository::new("github", "copilot-sdk").with_branch("main"), +); +let session = client + .create_cloud_session(SessionConfig::default().with_cloud(cloud)) + .await?; +println!("cloud session id: {}", session.id()); +``` + +The SDK buffers any `session.event` notifications or inbound JSON-RPC requests that arrive before the `session.create` response (bounded, drop-oldest) and replays them once the runtime-assigned session ID is registered. ```rust,ignore use github_copilot_sdk::MessageOptions; diff --git a/rust/src/jsonrpc.rs b/rust/src/jsonrpc.rs index 88a9670cd..0b92783a3 100644 --- a/rust/src/jsonrpc.rs +++ b/rust/src/jsonrpc.rs @@ -63,7 +63,6 @@ pub mod error_codes { /// Invalid method parameters (-32602). pub const INVALID_PARAMS: i32 = -32602; /// Internal server error (-32603). - #[allow(dead_code, reason = "standard JSON-RPC code, reserved for future use")] pub const INTERNAL_ERROR: i32 = -32603; } @@ -490,6 +489,59 @@ impl JsonRpcClient { ))), } } + + /// Clone a sync handle onto the outbound writer for fire-and-forget + /// frames. Use only for paths that cannot `.await` (currently the + /// session router, which holds a `parking_lot::Mutex` while deciding + /// to discard a buffered request). + pub(crate) fn writer_handle(&self) -> WriterHandle { + WriterHandle { + write_tx: self.write_tx.clone(), + } + } +} + +/// Sync, fire-and-forget handle onto the JSON-RPC writer actor. Cloned +/// from [`JsonRpcClient::writer_handle`]; serializes the message on the +/// caller's thread and enqueues it without awaiting an ack. Loss of the +/// ack means we'll never observe a write error here, which is acceptable +/// for the one current caller (error responses to dropped pending +/// requests): if the wire is broken, the runtime will time out the +/// request on its own. +pub(crate) struct WriterHandle { + write_tx: mpsc::UnboundedSender, +} + +impl Clone for WriterHandle { + fn clone(&self) -> Self { + Self { + write_tx: self.write_tx.clone(), + } + } +} + +impl WriterHandle { + /// Serialize and enqueue a JSON-RPC message without waiting for the + /// writer actor to flush it. Drops silently if serialization fails or + /// the writer actor has shut down — both indicate the transport is + /// already unusable. + pub(crate) fn send_fire_and_forget(&self, message: &T) { + let body = match serde_json::to_vec(message) { + Ok(body) => body, + Err(e) => { + warn!(error = %e, "WriterHandle failed to serialize fire-and-forget message"); + return; + } + }; + let mut frame = Vec::with_capacity(CONTENT_LENGTH_HEADER.len() + 16 + body.len() + 4); + frame.extend_from_slice(CONTENT_LENGTH_HEADER.as_bytes()); + frame.extend_from_slice(body.len().to_string().as_bytes()); + frame.extend_from_slice(b"\r\n\r\n"); + frame.extend_from_slice(&body); + + let (ack_tx, _ack_rx) = oneshot::channel(); + let _ = self.write_tx.send(WriteCommand { frame, ack: ack_tx }); + } } /// RAII guard that removes a pending-request entry from the map if the diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 787697e2e..29c599a1a 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -1203,6 +1203,8 @@ impl Client { let pid = child.as_ref().and_then(|c| c.id()); info!(pid = ?pid, "copilot CLI client ready"); + let client_rpc_writer_handle = rpc.writer_handle(); + let client = Self { inner: Arc::new(ClientInner { child: parking_lot::Mutex::new(child), @@ -1210,7 +1212,7 @@ impl Client { cwd, request_rx: parking_lot::Mutex::new(Some(request_rx)), notification_tx: notification_broadcast_tx, - router: router::SessionRouter::new(), + router: router::SessionRouter::with_writer(client_rpc_writer_handle), negotiated_protocol_version: OnceLock::new(), state: parking_lot::Mutex::new(ConnectionState::Connected), lifecycle_tx: broadcast::channel(256).0, @@ -1223,6 +1225,10 @@ impl Client { }), }; client.spawn_lifecycle_dispatcher(); + client + .inner + .router + .start(&client.inner.notification_tx, &client.inner.request_rx); debug!( elapsed_ms = setup_start.elapsed().as_millis(), pid = ?pid, @@ -1579,12 +1585,18 @@ impl Client { &self, session_id: &SessionId, ) -> crate::router::SessionChannels { - self.inner - .router - .ensure_started(&self.inner.notification_tx, &self.inner.request_rx); self.inner.router.register(session_id) } + /// Enter pending-routing mode on the router. While the returned guard is + /// alive, notifications and requests addressed to session ids that are + /// not yet registered are buffered instead of being dropped. Used by + /// [`Client::create_cloud_session`] so the SDK can receive events that + /// the runtime emits between `session.create` and the response. + pub(crate) fn begin_pending_session_routing(&self) -> crate::router::PendingSessionRouting { + self.inner.router.begin_pending_session_routing() + } + /// Unregister a session, dropping its per-session channels. pub(crate) fn unregister_session(&self, session_id: &SessionId) { self.inner.router.unregister(session_id); diff --git a/rust/src/router.rs b/rust/src/router.rs index e14630e03..e726318eb 100644 --- a/rust/src/router.rs +++ b/rust/src/router.rs @@ -1,13 +1,23 @@ -use std::collections::HashMap; +use std::collections::{HashMap, VecDeque}; use std::sync::Arc; use parking_lot::Mutex; use tokio::sync::{broadcast, mpsc}; use tracing::warn; -use crate::jsonrpc::{JsonRpcNotification, JsonRpcRequest}; +use crate::jsonrpc::{ + JsonRpcError, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, WriterHandle, error_codes, +}; use crate::types::{SessionEventNotification, SessionId}; +/// Upper bound on buffered notifications/requests per pending session id. +/// +/// Holds traffic that arrives between `session.create` being sent and the +/// SDK learning the runtime-assigned session id from the response (cloud +/// path). Drop-oldest behavior is acceptable: cloud handshakes are short, +/// and 128 entries is well above realistic init/replay bursts. +const PENDING_SESSION_BUFFER_LIMIT: usize = 128; + /// Per-session channels created by the router during session registration. pub(crate) struct SessionChannels { /// Filtered `session.event` notifications for this session. @@ -21,19 +31,263 @@ struct SessionSenders { requests: mpsc::UnboundedSender, } +#[derive(Default)] +struct PendingSessionMessages { + items: VecDeque, +} + +enum PendingItem { + Notification(SessionEventNotification), + Request(JsonRpcRequest), +} + +#[derive(Default)] +struct SessionRouterState { + sessions: HashMap, + pending: HashMap, + pending_registration_count: usize, + /// Outbound writer used to synthesize JSON-RPC error responses when + /// the pending buffer overflows. `None` in tests that exercise the + /// router in isolation; production construction goes through + /// [`SessionRouter::new`] which threads a real handle in. + writer: Option, +} + +impl SessionRouterState { + fn register(&mut self, session_id: &SessionId, senders: SessionSenders) { + if let Some(pending) = self.pending.remove(session_id.as_str()) { + for item in pending.items { + match item { + PendingItem::Notification(n) => { + let _ = senders.notifications.send(n); + } + PendingItem::Request(r) => { + let _ = senders.requests.send(r); + } + } + } + } + self.sessions.insert(session_id.clone(), senders); + } + + fn route_notification(&mut self, session_id: &str, notification: SessionEventNotification) { + if let Some(sender) = self.sessions.get(session_id) { + let _ = sender.notifications.send(notification); + return; + } + if self.pending_registration_count == 0 { + return; + } + + let session_id = SessionId::from(session_id); + push_pending( + self.pending.entry(session_id.clone()).or_default(), + &session_id, + PendingItem::Notification(notification), + self.writer.as_ref(), + ); + } + + fn route_request(&mut self, request: JsonRpcRequest) { + let Some(session_id) = request + .params + .as_ref() + .and_then(|p| p.get("sessionId")) + .and_then(|v| v.as_str()) + else { + warn!(method = %request.method, "request missing sessionId"); + return; + }; + if let Some(sender) = self.sessions.get(session_id) { + let _ = sender.requests.send(request); + return; + } + if self.pending_registration_count == 0 { + warn!( + session_id = session_id, + method = %request.method, + "request for unregistered session" + ); + return; + } + + let session_id = SessionId::from(session_id); + push_pending( + self.pending.entry(session_id.clone()).or_default(), + &session_id, + PendingItem::Request(request), + self.writer.as_ref(), + ); + } +} + +/// Push an item into a session's pending buffer, evicting the oldest entry +/// (regardless of type) when the per-session limit is reached. A single +/// FIFO across notifications and requests keeps the eviction policy fair +/// across both types and avoids the previous behavior where flushing +/// drained all buffered notifications before any buffered request, +/// artificially batching one type ahead of the other. +/// +/// Note: this does not give the consumer a strict cross-type total order. +/// After `register`, notifications and requests still arrive on two +/// separate per-session mpsc channels and are consumed via `select!`, so +/// the observed order across types is implementation-defined. Strict +/// ordering would require unifying the per-session channels — tracked +/// for a follow-up. +/// +/// When the evicted entry is a request, we synthesize a JSON-RPC error +/// response back to the runtime so it doesn't block waiting for a reply +/// that will never arrive. Notifications are fire-and-forget, so dropping +/// one only emits a warning. +fn push_pending( + buf: &mut PendingSessionMessages, + session_id: &SessionId, + item: PendingItem, + writer: Option<&WriterHandle>, +) { + if buf.items.len() >= PENDING_SESSION_BUFFER_LIMIT { + match buf.items.pop_front() { + Some(PendingItem::Request(dropped)) => { + warn!( + session_id = %session_id, + method = %dropped.method, + request_id = dropped.id, + limit = PENDING_SESSION_BUFFER_LIMIT, + "pending session buffer full; dropping oldest request and responding with error" + ); + if let Some(writer) = writer { + writer.send_fire_and_forget(&pending_overflow_response(dropped.id)); + } + } + Some(PendingItem::Notification(_)) => { + warn!( + session_id = %session_id, + limit = PENDING_SESSION_BUFFER_LIMIT, + "pending session buffer full; dropping oldest notification" + ); + } + None => {} + } + } + buf.items.push_back(item); +} + +/// Build a JSON-RPC error response for a request the SDK had to discard +/// because the pending-session buffer overflowed before the runtime +/// returned `session.create`. +fn pending_overflow_response(id: u64) -> JsonRpcResponse { + JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id, + result: None, + error: Some(JsonRpcError { + code: error_codes::INTERNAL_ERROR, + message: "request dropped: pending session buffer overflow before session.create \ + response" + .to_string(), + data: None, + }), + } +} + +/// Build a JSON-RPC error response for a request the SDK buffered while +/// awaiting `session.create` but had to discard because the pending +/// routing guard dropped without a matching `register` (e.g. cloud +/// session creation failed end-to-end). +fn pending_unregistered_response(id: u64) -> JsonRpcResponse { + JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id, + result: None, + error: Some(JsonRpcError { + code: error_codes::INTERNAL_ERROR, + message: "request dropped: pending session routing ended before session was registered" + .to_string(), + data: None, + }), + } +} + +/// Guard that keeps the router in "pending routing" mode for cloud +/// `session.create`: while any guard is alive, notifications/requests with +/// unknown session ids are buffered (up to [`PENDING_SESSION_BUFFER_LIMIT`]) +/// instead of dropped. On `register`, buffered messages flush in arrival +/// order into the freshly-created per-session channels. +/// +/// When the last guard drops without a matching `register` (e.g. cloud +/// `session.create` failed), any still-pending buffers are drained and +/// each pending request gets a JSON-RPC error response so the runtime +/// isn't left waiting on a reply that will never come. Notifications are +/// fire-and-forget and just get logged. +pub(crate) struct PendingSessionRouting { + state: Arc>, +} + +impl Drop for PendingSessionRouting { + fn drop(&mut self) { + let mut state = self.state.lock(); + state.pending_registration_count = state.pending_registration_count.saturating_sub(1); + if state.pending_registration_count != 0 { + return; + } + let pending = std::mem::take(&mut state.pending); + let writer = state.writer.clone(); + drop(state); + for (session_id, buf) in pending { + for item in buf.items { + match item { + PendingItem::Request(req) => { + warn!( + session_id = %session_id, + method = %req.method, + request_id = req.id, + "pending session routing ended without registration; \ + responding to buffered request with error" + ); + if let Some(writer) = writer.as_ref() { + writer.send_fire_and_forget(&pending_unregistered_response(req.id)); + } + } + PendingItem::Notification(_) => { + warn!( + session_id = %session_id, + "pending session routing ended without registration; \ + dropping buffered notification" + ); + } + } + } + } + } +} + /// Routes notifications and requests by sessionId to per-session channels. /// /// Internal to the SDK — consumers interact via `Client::register_session()`. pub(crate) struct SessionRouter { - sessions: Arc>>, - started: Mutex, + state: Arc>, } impl SessionRouter { + /// Test-only constructor. Production callers must use + /// [`SessionRouter::with_writer`] so dropped requests get error + /// responses. Tests that don't exercise the writer can use this. + #[cfg(test)] pub(crate) fn new() -> Self { Self { - sessions: Arc::new(Mutex::new(HashMap::new())), - started: Mutex::new(false), + state: Arc::new(Mutex::new(SessionRouterState::default())), + } + } + + /// Construct a router with a handle onto the JSON-RPC outbound writer, + /// used to synthesize error responses when pending-buffer overflow + /// forces us to discard an inbound request. + pub(crate) fn with_writer(writer: WriterHandle) -> Self { + Self { + state: Arc::new(Mutex::new(SessionRouterState { + writer: Some(writer), + ..SessionRouterState::default() + })), } } @@ -41,8 +295,8 @@ impl SessionRouter { pub(crate) fn register(&self, session_id: &SessionId) -> SessionChannels { let (notif_tx, notif_rx) = mpsc::unbounded_channel(); let (req_tx, req_rx) = mpsc::unbounded_channel(); - self.sessions.lock().insert( - session_id.clone(), + self.state.lock().register( + session_id, SessionSenders { notifications: notif_tx, requests: req_tx, @@ -54,9 +308,21 @@ impl SessionRouter { } } - /// Unregister a session, dropping its channels. + /// Enter pending-routing mode. While the returned guard is alive, + /// notifications and requests addressed to session ids that are not + /// yet registered are buffered instead of being dropped. + pub(crate) fn begin_pending_session_routing(&self) -> PendingSessionRouting { + self.state.lock().pending_registration_count += 1; + PendingSessionRouting { + state: self.state.clone(), + } + } + + /// Unregister a session, dropping its channels and any pending buffer. pub(crate) fn unregister(&self, session_id: &SessionId) { - self.sessions.lock().remove(session_id.as_str()); + let mut state = self.state.lock(); + state.sessions.remove(session_id.as_str()); + state.pending.remove(session_id.as_str()); } /// Snapshot every currently-registered session ID. @@ -65,35 +331,32 @@ impl SessionRouter { /// sessions for cooperative shutdown without holding the router lock /// across `.await`. pub(crate) fn session_ids(&self) -> Vec { - self.sessions.lock().keys().cloned().collect() + self.state.lock().sessions.keys().cloned().collect() } - /// Drop all registered session channels. + /// Drop all registered session channels and pending buffers. /// /// Used by [`Client::force_stop`](crate::Client::force_stop) to release /// per-session state without waiting for graceful unregistration. pub(crate) fn clear(&self) { - self.sessions.lock().clear(); + let mut state = self.state.lock(); + state.sessions.clear(); + state.pending.clear(); } - /// Start the router tasks if not already running. + /// Spawn the notification and request routing tasks. /// - /// Takes the notification broadcast and request channel from the Client. - /// If `request_rx` is `None` (already taken by `take_request_rx()`), - /// only notification routing is available. - pub(crate) fn ensure_started( + /// Called exactly once during [`Client::from_streams`]. Takes the + /// notification broadcast and request channel from the Client. If + /// `request_rx` is `None` (already taken by `take_request_rx()`), only + /// notification routing is available. + pub(crate) fn start( &self, notification_tx: &broadcast::Sender, request_rx: &Mutex>>, ) { - let mut started = self.started.lock(); - if *started { - return; - } - *started = true; - // Notification routing task - let sessions = self.sessions.clone(); + let state = self.state.clone(); let mut notif_rx = notification_tx.subscribe(); tokio::spawn(async move { loop { @@ -110,27 +373,20 @@ impl SessionRouter { continue; }; - let sender = { - let guard = sessions.lock(); - guard.get(session_id).map(|s| s.notifications.clone()) - }; - if let Some(sender) = sender { - match serde_json::from_value::(params.clone()) - { - Ok(event_notification) => { - let _ = sender.send(event_notification); - } - Err(e) => { - warn!( - error = %e, - session_id = session_id, - "failed to deserialize session event notification" - ); - } + match serde_json::from_value::(params.clone()) { + Ok(event_notification) => { + state + .lock() + .route_notification(session_id, event_notification); + } + Err(e) => { + warn!( + error = %e, + session_id = session_id, + "failed to deserialize session event notification" + ); } } - // Unknown session IDs are silently dropped — the session - // may have been unregistered between dispatch and delivery. } Err(broadcast::error::RecvError::Lagged(n)) => { warn!(missed = n, "notification router lagged"); @@ -142,37 +398,298 @@ impl SessionRouter { // Request routing task (if request_rx is available) if let Some(mut rx) = request_rx.lock().take() { - let sessions = self.sessions.clone(); + let state = self.state.clone(); tokio::spawn(async move { while let Some(request) = rx.recv().await { - let session_id = request - .params - .as_ref() - .and_then(|p| p.get("sessionId")) - .and_then(|v| v.as_str()); - - if let Some(sid) = session_id { - let sender = { - let guard = sessions.lock(); - guard.get(sid).map(|s| s.requests.clone()) - }; - if let Some(sender) = sender { - let _ = sender.send(request); - } else { - warn!( - session_id = sid, - method = %request.method, - "request for unregistered session" - ); - } - } else { - warn!( - method = %request.method, - "request missing sessionId" - ); - } + state.lock().route_request(request); } }); } } } + +#[cfg(test)] +mod tests { + use serde_json::json; + + use super::*; + use crate::jsonrpc::JsonRpcRequest; + + fn make_notification(session_id: &str, kind: &str) -> SessionEventNotification { + let value = json!({ + "sessionId": session_id, + "event": { + "id": "evt-id", + "timestamp": "1970-01-01T00:00:00Z", + "parentId": null, + "type": kind, + "data": {}, + }, + }); + serde_json::from_value(value).expect("valid session event notification") + } + + fn make_request(id: u64, session_id: &str, method: &str) -> JsonRpcRequest { + JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id, + method: method.to_string(), + params: Some(json!({ "sessionId": session_id })), + } + } + + #[test] + fn pending_buffer_off_drops_unknown_session() { + let router = SessionRouter::new(); + router + .state + .lock() + .route_notification("ghost", make_notification("ghost", "session.start")); + assert!(router.state.lock().pending.is_empty()); + } + + #[test] + fn pending_buffer_on_buffers_and_flushes_in_order() { + let router = SessionRouter::new(); + let guard = router.begin_pending_session_routing(); + + for i in 0..3 { + router + .state + .lock() + .route_notification("remote", make_notification("remote", &format!("evt-{i}"))); + } + for i in 0..2 { + router + .state + .lock() + .route_request(make_request(100 + i, "remote", "userInput.request")); + } + + let sid = SessionId::from("remote"); + let mut channels = router.register(&sid); + drop(guard); + + let mut got_notifications = 0; + while channels.notifications.try_recv().is_ok() { + got_notifications += 1; + } + assert_eq!(got_notifications, 3, "all buffered notifications flushed"); + + let mut got_requests = 0; + while channels.requests.try_recv().is_ok() { + got_requests += 1; + } + assert_eq!(got_requests, 2, "all buffered requests flushed"); + } + + #[test] + fn pending_buffer_drops_oldest_at_limit() { + let router = SessionRouter::new(); + let _guard = router.begin_pending_session_routing(); + + for i in 0..(PENDING_SESSION_BUFFER_LIMIT + 5) { + router + .state + .lock() + .route_notification("remote", make_notification("remote", &format!("evt-{i}"))); + } + + let state = router.state.lock(); + let pending = state.pending.get("remote").expect("pending bucket exists"); + assert_eq!(pending.items.len(), PENDING_SESSION_BUFFER_LIMIT); + } + + #[test] + fn pending_buffer_flush_interleaves_types_in_arrival_order() { + // The pending FIFO accepts notifications and requests interleaved, + // and `register` drains them in arrival order to their respective + // per-session channels. This test asserts the FIFO order is + // preserved through the flush, not that the downstream consumer + // observes a strict cross-type total order — after register the + // two channels are consumed via `select!`, so observed cross-type + // order is implementation-defined. + let router = SessionRouter::new(); + let guard = router.begin_pending_session_routing(); + + { + let mut state = router.state.lock(); + state.route_notification("remote", make_notification("remote", "evt-0")); + state.route_request(make_request(1, "remote", "userInput.request")); + state.route_notification("remote", make_notification("remote", "evt-1")); + } + + let sid = SessionId::from("remote"); + let mut channels = router.register(&sid); + drop(guard); + + // Notifications drain in arrival order to the notif channel. + let n0 = channels.notifications.try_recv().expect("first notif"); + assert_eq!(n0.event.event_type, "evt-0"); + let n1 = channels.notifications.try_recv().expect("trailing notif"); + assert_eq!(n1.event.event_type, "evt-1"); + // The buffered request drains to the request channel. + let r = channels.requests.try_recv().expect("request"); + assert_eq!(r.id, 1); + } + + /// Read one Content-Length-framed JSON-RPC response off the duplex + /// reader. Times out after 1s; CI has a comfortable margin for one + /// short frame. + async fn read_one_framed_response( + mut reader: tokio::io::DuplexStream, + ) -> crate::jsonrpc::JsonRpcResponse { + use tokio::io::AsyncReadExt; + let mut buf = Vec::with_capacity(1024); + let range = tokio::time::timeout(std::time::Duration::from_secs(1), async { + loop { + if let Some(pos) = buf.windows(4).position(|w| w == b"\r\n\r\n") { + let header = std::str::from_utf8(&buf[..pos]).expect("header utf-8"); + let len: usize = header + .strip_prefix("Content-Length: ") + .expect("Content-Length header") + .trim() + .parse() + .expect("numeric length"); + let body_start = pos + 4; + if buf.len() >= body_start + len { + return body_start..body_start + len; + } + } + let mut chunk = [0u8; 256]; + let n = reader.read(&mut chunk).await.expect("read"); + if n == 0 { + panic!("eof before frame complete"); + } + buf.extend_from_slice(&chunk[..n]); + } + }) + .await + .expect("frame within timeout"); + serde_json::from_slice(&buf[range]).expect("parse JsonRpcResponse") + } + + fn stand_up_router_with_capture() -> ( + SessionRouter, + tokio::io::DuplexStream, + crate::jsonrpc::JsonRpcClient, + ) { + use tokio::sync::{broadcast, mpsc}; + + use crate::jsonrpc::JsonRpcClient; + let (server_read, client_write) = tokio::io::duplex(64 * 1024); + let (client_read, _server_write) = tokio::io::duplex(64); + let (notif_tx, _) = broadcast::channel(16); + let (req_tx, _req_rx) = mpsc::unbounded_channel(); + let rpc = JsonRpcClient::new(client_write, client_read, notif_tx, req_tx); + let router = SessionRouter::with_writer(rpc.writer_handle()); + (router, server_read, rpc) + } + + #[tokio::test] + async fn pending_request_overflow_emits_jsonrpc_error_response() { + use crate::jsonrpc::error_codes; + + let (router, server_read, _rpc) = stand_up_router_with_capture(); + let _guard = router.begin_pending_session_routing(); + + // First buffered request is the one we expect to evict. + let evicted_id = 7777; + router + .state + .lock() + .route_request(make_request(evicted_id, "remote", "userInput.request")); + for i in 0..PENDING_SESSION_BUFFER_LIMIT { + router.state.lock().route_request(make_request( + i as u64, + "remote", + "userInput.request", + )); + } + + let response = read_one_framed_response(server_read).await; + assert_eq!(response.id, evicted_id); + let err = response.error.expect("error payload"); + assert_eq!(err.code, error_codes::INTERNAL_ERROR); + assert!(err.message.contains("pending session buffer overflow")); + } + + #[tokio::test] + async fn last_guard_drop_without_register_responds_to_buffered_requests() { + use crate::jsonrpc::error_codes; + + let (router, server_read, _rpc) = stand_up_router_with_capture(); + let guard = router.begin_pending_session_routing(); + + let pending_id = 4242; + router + .state + .lock() + .route_request(make_request(pending_id, "remote", "userInput.request")); + // A buffered notification just gets logged on guard drop. + router + .state + .lock() + .route_notification("remote", make_notification("remote", "evt")); + + // Cloud session.create failed; the guard drops without anyone + // registering "remote". Buffered request must be responded to so + // the runtime doesn't hang. + drop(guard); + + let response = read_one_framed_response(server_read).await; + assert_eq!(response.id, pending_id); + let err = response.error.expect("error payload"); + assert_eq!(err.code, error_codes::INTERNAL_ERROR); + assert!( + err.message + .contains("pending session routing ended before session was registered") + ); + + assert!(router.state.lock().pending.is_empty()); + } + + #[test] + fn last_guard_drop_clears_pending_buffers() { + let router = SessionRouter::new(); + let g1 = router.begin_pending_session_routing(); + let g2 = router.begin_pending_session_routing(); + + router + .state + .lock() + .route_notification("a", make_notification("a", "evt")); + router + .state + .lock() + .route_notification("b", make_notification("b", "evt")); + + drop(g1); + assert_eq!(router.state.lock().pending.len(), 2, "still buffering"); + drop(g2); + assert!( + router.state.lock().pending.is_empty(), + "last guard drop clears pending" + ); + } + + #[test] + fn unregister_clears_pending_for_session() { + let router = SessionRouter::new(); + let _guard = router.begin_pending_session_routing(); + router + .state + .lock() + .route_notification("doomed", make_notification("doomed", "evt")); + router + .state + .lock() + .route_notification("kept", make_notification("kept", "evt")); + + router.unregister(&SessionId::from("doomed")); + + let state = router.state.lock(); + assert!(!state.pending.contains_key("doomed")); + assert!(state.pending.contains_key("kept")); + } +} diff --git a/rust/src/session.rs b/rust/src/session.rs index f216b866b..2235ae816 100644 --- a/rust/src/session.rs +++ b/rust/src/session.rs @@ -29,12 +29,14 @@ use crate::types::{ CommandContext, CommandDefinition, CommandHandler, CreateSessionResult, ElicitationRequest, ElicitationResult, ExitPlanModeData, GetMessagesResponse, MessageOptions, PermissionRequestData, RequestId, ResumeSessionConfig, ResumeSessionResult, SectionOverride, - SessionCapabilities, SessionConfig, SessionEvent, SessionId, SetModelOptions, - SystemMessageConfig, ToolInvocation, ToolResult, ToolResultExpanded, TraceContext, - UiInputOptions, ensure_attachment_display_names, + SessionCapabilities, SessionConfig, SessionConfigRuntime, SessionEvent, SessionId, + SetModelOptions, SystemMessageConfig, ToolInvocation, ToolResult, ToolResultExpanded, + TraceContext, UiInputOptions, ensure_attachment_display_names, }; use crate::{Client, Error, JsonRpcResponse, SessionError, SessionEventNotification, error_codes}; +type CommandHandlerMap = HashMap>; + /// Bundle of the per-session callbacks the SDK dispatches to. Built from a /// [`SessionConfig`] / [`ResumeSessionConfig`] at /// [`Client::create_session`] / [`Client::resume_session`] time. Each @@ -52,6 +54,20 @@ pub(crate) struct SessionHandlers { pub tools: Arc>>, } +/// Bundle of everything `create_session` / `create_cloud_session` / +/// `resume_session` need to spawn the per-session event loop, extracted +/// from a `SessionConfigRuntime`. Built by [`prepare_session_runtime`]. +struct PreparedSessionRuntime { + handlers: SessionHandlers, + hooks: Option>, + transforms: Option>, + command_handlers: Arc, + canvas_handler: Option>, + session_fs_provider: Option>, + commands_count: usize, + has_hooks: bool, +} + /// Shared state between a [`Session`] and its event loop, used by [`Session::send_and_wait`]. struct IdleWaiter { tx: oneshot::Sender, Error>>, @@ -788,6 +804,13 @@ impl Client { /// broadcast (and silently skips dispatch if one arrives anyway). pub async fn create_session(&self, mut config: SessionConfig) -> Result { let total_start = Instant::now(); + if config.cloud.is_some() { + return Err(Error::InvalidConfig( + "Client::create_session does not support cloud sessions; \ + use Client::create_cloud_session instead" + .to_string(), + )); + } let session_id = config .session_id .clone() @@ -799,41 +822,19 @@ impl Client { if let Some(transforms) = config.system_message_transform.clone() { inject_transform_sections(&mut config, transforms.as_ref()); } - let (wire, mut runtime) = config.into_wire(session_id.clone())?; - - let permission_handler = crate::permission::resolve_handler( - runtime.permission_handler.take(), - runtime.permission_policy.take(), - ); - let handlers = SessionHandlers { - permission: permission_handler, - elicitation: runtime.elicitation_handler.take(), - user_input: runtime.user_input_handler.take(), - exit_plan_mode: runtime.exit_plan_mode_handler.take(), - auto_mode_switch: runtime.auto_mode_switch_handler.take(), - tools: Arc::new(std::mem::take(&mut runtime.tool_handlers)), - }; - let hooks = runtime.hooks_handler.take(); - let transforms = runtime.system_message_transform.take(); + let (wire, runtime) = config.into_wire(session_id.clone())?; let tools_count = wire.tools.as_ref().map_or(0, Vec::len); - let commands_count = runtime.commands.as_ref().map_or(0, Vec::len); - let has_hooks = hooks.is_some(); - let command_handlers = build_command_handler_map(runtime.commands.as_deref()); - let canvas_handler = runtime.canvas_handler.take(); - let session_fs_provider = runtime.session_fs_provider.take(); - if self.inner.session_fs_configured && session_fs_provider.is_none() { - return Err(Error::Session(SessionError::SessionFsProviderRequired)); - } - if self.inner.session_fs_sqlite_declared - && let Some(ref provider) = session_fs_provider - && provider.sqlite().is_none() - { - return Err(Error::InvalidConfig( - "SessionFs capabilities declare SQLite support but the provider \ - does not implement SessionFsSqliteProvider" - .to_string(), - )); - } + + let PreparedSessionRuntime { + handlers, + hooks, + transforms, + command_handlers, + canvas_handler, + session_fs_provider, + commands_count, + has_hooks, + } = prepare_session_runtime(self, runtime)?; let mut params = serde_json::to_value(&wire)?; let trace_ctx = self.resolve_trace_context().await; @@ -920,6 +921,165 @@ impl Client { }) } + /// Create a cloud (Mission Control) session. + /// + /// Unlike [`Self::create_session`], the runtime owns the session ID: + /// the SDK does **not** pre-assign one, and the caller must not set + /// `session_id` or `provider` on the config. Send a cloud config built + /// with [`SessionConfig::with_cloud`]. The returned [`Session`] is + /// keyed by the runtime-assigned Mission Control id. + /// + /// Routing for the runtime-chosen id is buffered (bounded, drop-oldest) + /// from before `session.create` is sent until the response arrives, so + /// `session.event` notifications and inbound JSON-RPC requests that + /// arrive early are delivered after registration. + pub async fn create_cloud_session(&self, mut config: SessionConfig) -> Result { + let total_start = Instant::now(); + if config.cloud.is_none() { + return Err(Error::InvalidConfig( + "Client::create_cloud_session requires a cloud config; \ + build the config with SessionConfig::with_cloud" + .to_string(), + )); + } + if config.session_id.is_some() { + return Err(Error::InvalidConfig( + "Client::create_cloud_session does not accept a caller-provided \ + session_id; the runtime assigns the session id" + .to_string(), + )); + } + if config.provider.is_some() { + return Err(Error::InvalidConfig( + "Client::create_cloud_session does not accept a caller-provided \ + provider; the runtime selects the provider" + .to_string(), + )); + } + if config.hooks_handler.is_some() && config.hooks.is_none() { + config.hooks = Some(true); + } + if let Some(transforms) = config.system_message_transform.clone() { + inject_transform_sections(&mut config, transforms.as_ref()); + } + let (wire, runtime) = config.into_cloud_wire()?; + let tools_count = wire.tools.as_ref().map_or(0, Vec::len); + + let PreparedSessionRuntime { + handlers, + hooks, + transforms, + command_handlers, + canvas_handler, + session_fs_provider, + commands_count, + has_hooks, + } = prepare_session_runtime(self, runtime)?; + + let mut params = serde_json::to_value(&wire)?; + let trace_ctx = self.resolve_trace_context().await; + inject_trace_context(&mut params, &trace_ctx); + + let setup_start = Instant::now(); + let pending_guard = self.begin_pending_session_routing(); + tracing::debug!( + elapsed_ms = setup_start.elapsed().as_millis(), + tools_count, + commands_count, + has_hooks, + "Client::create_cloud_session local setup complete" + ); + + let rpc_start = Instant::now(); + let result = self.call("session.create", Some(params)).await?; + tracing::debug!( + elapsed_ms = rpc_start.elapsed().as_millis(), + "Client::create_cloud_session session creation request completed successfully" + ); + // Pre-extract the runtime-assigned session id from the raw response so + // we can `session.destroy` it on decode failure without cloning the + // whole response. On success we still consume `result` to decode. + let recovered_session_id = result + .get("sessionId") + .and_then(|value| value.as_str()) + .map(SessionId::from); + let create_result: CreateSessionResult = match serde_json::from_value(result) { + Ok(result) => result, + Err(error) => { + // Keep the pending guard alive across the destroy so any + // straggler events for the runtime-assigned id are still + // routed (and then dropped on guard release). + if let Some(recovered_id) = recovered_session_id { + if let Err(destroy_err) = self + .call( + "session.destroy", + Some(serde_json::json!({ "sessionId": recovered_id })), + ) + .await + { + tracing::warn!( + session_id = %recovered_id, + error = %destroy_err, + "failed to destroy cloud session after create response decode failed" + ); + } + } else { + tracing::warn!( + "Client::create_cloud_session: decode failure with no recoverable session id; \ + skipping session.destroy (runtime session may leak)" + ); + } + drop(pending_guard); + return Err(error.into()); + } + }; + let session_id = create_result.session_id.clone(); + + let capabilities = Arc::new(parking_lot::RwLock::new( + create_result.capabilities.unwrap_or_default(), + )); + let channels = self.register_session(&session_id); + drop(pending_guard); + + let idle_waiter = Arc::new(ParkingLotMutex::new(None)); + let shutdown = CancellationToken::new(); + let (event_tx, _) = tokio::sync::broadcast::channel(512); + let event_loop = spawn_event_loop( + session_id.clone(), + self.clone(), + handlers, + hooks, + transforms, + command_handlers, + canvas_handler, + session_fs_provider, + channels, + idle_waiter.clone(), + capabilities.clone(), + event_tx.clone(), + shutdown.clone(), + ); + + tracing::debug!( + elapsed_ms = total_start.elapsed().as_millis(), + session_id = %session_id, + "Client::create_cloud_session complete" + ); + Ok(Session { + id: session_id, + cwd: self.cwd().clone(), + workspace_path: create_result.workspace_path, + remote_url: create_result.remote_url, + client: self.clone(), + event_loop: ParkingLotMutex::new(Some(event_loop)), + shutdown, + idle_waiter, + capabilities, + event_tx, + open_canvases: Arc::new(parking_lot::RwLock::new(Vec::new())), + }) + } + /// Resume an existing session on the CLI. /// /// Sends `session.resume` and `session.skills.reload`, registers the @@ -939,41 +1099,19 @@ impl Client { if let Some(transforms) = config.system_message_transform.clone() { inject_transform_sections_resume(&mut config, transforms.as_ref()); } - let (wire, mut runtime) = config.into_wire()?; - - let permission_handler = crate::permission::resolve_handler( - runtime.permission_handler.take(), - runtime.permission_policy.take(), - ); - let handlers = SessionHandlers { - permission: permission_handler, - elicitation: runtime.elicitation_handler.take(), - user_input: runtime.user_input_handler.take(), - exit_plan_mode: runtime.exit_plan_mode_handler.take(), - auto_mode_switch: runtime.auto_mode_switch_handler.take(), - tools: Arc::new(std::mem::take(&mut runtime.tool_handlers)), - }; - let hooks = runtime.hooks_handler.take(); - let transforms = runtime.system_message_transform.take(); + let (wire, runtime) = config.into_wire()?; let tools_count = wire.tools.as_ref().map_or(0, Vec::len); - let commands_count = runtime.commands.as_ref().map_or(0, Vec::len); - let has_hooks = hooks.is_some(); - let command_handlers = build_command_handler_map(runtime.commands.as_deref()); - let canvas_handler = runtime.canvas_handler.take(); - let session_fs_provider = runtime.session_fs_provider.take(); - if self.inner.session_fs_configured && session_fs_provider.is_none() { - return Err(Error::Session(SessionError::SessionFsProviderRequired)); - } - if self.inner.session_fs_sqlite_declared - && let Some(ref provider) = session_fs_provider - && provider.sqlite().is_none() - { - return Err(Error::InvalidConfig( - "SessionFs capabilities declare SQLite support but the provider \ - does not implement SessionFsSqliteProvider" - .to_string(), - )); - } + + let PreparedSessionRuntime { + handlers, + hooks, + transforms, + command_handlers, + canvas_handler, + session_fs_provider, + commands_count, + has_hooks, + } = prepare_session_runtime(self, runtime)?; let mut params = serde_json::to_value(&wire)?; let trace_ctx = self.resolve_trace_context().await; @@ -1094,8 +1232,6 @@ impl Client { } } -type CommandHandlerMap = HashMap>; - fn build_command_handler_map(commands: Option<&[CommandDefinition]>) -> Arc { let map = match commands { Some(commands) => commands @@ -1108,6 +1244,62 @@ fn build_command_handler_map(commands: Option<&[CommandDefinition]>) -> Arc Result { + let SessionConfigRuntime { + permission_handler, + permission_policy, + elicitation_handler, + user_input_handler, + exit_plan_mode_handler, + auto_mode_switch_handler, + hooks_handler, + system_message_transform, + tool_handlers, + canvas_handler, + session_fs_provider, + commands, + } = runtime; + let handlers = SessionHandlers { + permission: crate::permission::resolve_handler(permission_handler, permission_policy), + elicitation: elicitation_handler, + user_input: user_input_handler, + exit_plan_mode: exit_plan_mode_handler, + auto_mode_switch: auto_mode_switch_handler, + tools: Arc::new(tool_handlers), + }; + let commands_count = commands.as_ref().map_or(0, Vec::len); + let has_hooks = hooks_handler.is_some(); + let command_handlers = build_command_handler_map(commands.as_deref()); + + if client.inner.session_fs_configured && session_fs_provider.is_none() { + return Err(Error::Session(SessionError::SessionFsProviderRequired)); + } + if client.inner.session_fs_sqlite_declared + && let Some(ref provider) = session_fs_provider + && provider.sqlite().is_none() + { + return Err(Error::InvalidConfig( + "SessionFs capabilities declare SQLite support but the provider \ + does not implement SessionFsSqliteProvider" + .to_string(), + )); + } + + Ok(PreparedSessionRuntime { + handlers, + hooks: hooks_handler, + transforms: system_message_transform, + command_handlers, + canvas_handler, + session_fs_provider, + commands_count, + has_hooks, + }) +} + #[allow(clippy::too_many_arguments)] fn spawn_event_loop( session_id: SessionId, diff --git a/rust/src/types.rs b/rust/src/types.rs index d841096c5..00f01302d 100644 --- a/rust/src/types.rs +++ b/rust/src/types.rs @@ -1406,8 +1406,25 @@ impl SessionConfig { /// /// [`SessionCreateWire`]: crate::wire::SessionCreateWire pub(crate) fn into_wire( - mut self, + self, session_id: SessionId, + ) -> Result<(crate::wire::SessionCreateWire, SessionConfigRuntime), crate::Error> { + self.into_create_wire(Some(session_id)) + } + + /// Consume this config to produce the [`SessionCreateWire`] payload for + /// cloud `session.create`. Cloud create follows the runtime contract: + /// the caller does not provide a `sessionId`; the runtime returns the + /// Mission Control task/session ID. + pub(crate) fn into_cloud_wire( + self, + ) -> Result<(crate::wire::SessionCreateWire, SessionConfigRuntime), crate::Error> { + self.into_create_wire(None) + } + + fn into_create_wire( + mut self, + session_id: Option, ) -> Result<(crate::wire::SessionCreateWire, SessionConfigRuntime), crate::Error> { let permission_active = self.permission_handler.is_some() || self.permission_policy.is_some(); diff --git a/rust/src/wire.rs b/rust/src/wire.rs index b97aea261..a1d1ec094 100644 --- a/rust/src/wire.rs +++ b/rust/src/wire.rs @@ -42,7 +42,8 @@ pub(crate) struct CommandWireDefinition { #[derive(Debug, Serialize)] #[serde(rename_all = "camelCase")] pub(crate) struct SessionCreateWire { - pub session_id: SessionId, + #[serde(skip_serializing_if = "Option::is_none")] + pub session_id: Option, #[serde(skip_serializing_if = "Option::is_none")] pub model: Option, #[serde(skip_serializing_if = "Option::is_none")] diff --git a/rust/tests/session_test.rs b/rust/tests/session_test.rs index 050c5898d..ad13d98c3 100644 --- a/rust/tests/session_test.rs +++ b/rust/tests/session_test.rs @@ -16,9 +16,10 @@ use github_copilot_sdk::handler::{ ExitPlanModeHandler, ExitPlanModeResult, UserInputHandler, UserInputResponse, }; use github_copilot_sdk::types::{ - CommandContext, CommandDefinition, CommandHandler, DeliveryMode, ElicitationRequest, - ElicitationResult, ExitPlanModeData, ExtensionInfo, MessageOptions, RequestId, SessionConfig, - SessionId, Tool, ToolInvocation, ToolResult, + CloudSessionOptions, CloudSessionRepository, CommandContext, CommandDefinition, CommandHandler, + DeliveryMode, ElicitationRequest, ElicitationResult, ExitPlanModeData, ExtensionInfo, + MessageOptions, ProviderConfig, RequestId, SessionConfig, SessionId, Tool, ToolInvocation, + ToolResult, }; use github_copilot_sdk::{Client, tool}; use serde_json::Value; @@ -222,6 +223,22 @@ fn requested_session_id(request: &Value) -> &str { .expect("session request should include sessionId") } +fn cloud_session_config() -> SessionConfig { + SessionConfig::default().with_cloud(CloudSessionOptions::with_repository( + CloudSessionRepository::new("github", "copilot-sdk").with_branch("main"), + )) +} + +fn expect_sdk_error( + result: Result, + message: &str, +) -> github_copilot_sdk::Error { + match result { + Ok(_) => panic!("{message}"), + Err(error) => error, + } +} + #[tokio::test] async fn session_subscribe_yields_events_observe_only() { let (session, mut server) = create_session_pair().await; @@ -322,6 +339,21 @@ async fn create_session_sends_correct_rpc() { assert_eq!(session.workspace_path(), Some(Path::new("/ws"))); } +#[tokio::test] +async fn create_session_rejects_cloud_config() { + let (client, _server_read, _server_write) = make_client(); + + let error = expect_sdk_error( + client.create_session(cloud_session_config()).await, + "cloud config should use create_cloud_session", + ); + + assert!( + matches!(error, github_copilot_sdk::Error::InvalidConfig(ref message) if message.contains("create_cloud_session")), + "unexpected error: {error:?}", + ); +} + #[tokio::test] async fn create_session_sends_canvas_wire_fields() { let (client, mut server_read, mut server_write) = make_client(); @@ -369,6 +401,297 @@ async fn create_session_sends_canvas_wire_fields() { timeout(TIMEOUT, create_handle).await.unwrap().unwrap(); } +#[tokio::test] +async fn create_cloud_session_sends_cloud_create_without_session_id() { + let (client, mut server_read, mut server_write) = make_client(); + + let create_handle = tokio::spawn({ + let client = client.clone(); + async move { + client + .create_cloud_session(cloud_session_config()) + .await + .unwrap() + } + }); + + let request = read_framed(&mut server_read).await; + assert_eq!(request["method"], "session.create"); + assert!(request["params"].get("sessionId").is_none()); + assert_eq!(request["params"]["cloud"]["repository"]["owner"], "github"); + assert_eq!( + request["params"]["cloud"]["repository"]["name"], + "copilot-sdk" + ); + assert_eq!(request["params"]["cloud"]["repository"]["branch"], "main"); + assert!(request["params"].get("provider").is_none()); + + let id = request["id"].as_u64().unwrap(); + let response = serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "result": { + "sessionId": "remote-cloud-session", + "remoteUrl": "https://copilot.example.test/agents/remote-cloud-session", + "capabilities": { "ui": { "elicitation": true } } + }, + }); + write_framed(&mut server_write, &serde_json::to_vec(&response).unwrap()).await; + + let session = timeout(TIMEOUT, create_handle).await.unwrap().unwrap(); + assert_eq!(session.id(), "remote-cloud-session"); + assert_eq!( + session.remote_url(), + Some("https://copilot.example.test/agents/remote-cloud-session") + ); + assert_eq!( + session.capabilities().ui.and_then(|ui| ui.elicitation), + Some(true) + ); +} + +#[tokio::test] +async fn create_cloud_session_rejects_caller_session_id_and_provider() { + let (client, _server_read, _server_write) = make_client(); + + let error = expect_sdk_error( + client + .create_cloud_session(cloud_session_config().with_session_id("caller-id")) + .await, + "cloud create should reject caller session id", + ); + assert!( + matches!(error, github_copilot_sdk::Error::InvalidConfig(ref message) if message.contains("session_id")), + "unexpected error: {error:?}", + ); + + let mut config = cloud_session_config(); + config.provider = Some(ProviderConfig::new("https://api.example.test/v1")); + let error = expect_sdk_error( + client.create_cloud_session(config).await, + "cloud create should reject provider", + ); + assert!( + matches!(error, github_copilot_sdk::Error::InvalidConfig(ref message) if message.contains("provider")), + "unexpected error: {error:?}", + ); +} + +#[tokio::test] +async fn create_cloud_session_request_flags_follow_handlers() { + struct InputHandler; + #[async_trait] + impl UserInputHandler for InputHandler { + async fn handle( + &self, + _session_id: SessionId, + _question: String, + _choices: Option>, + _allow_freeform: Option, + ) -> Option { + None + } + } + + struct ExitHandler; + #[async_trait] + impl ExitPlanModeHandler for ExitHandler { + async fn handle( + &self, + _session_id: SessionId, + _data: ExitPlanModeData, + ) -> ExitPlanModeResult { + ExitPlanModeResult::default() + } + } + + struct AutoHandler; + #[async_trait] + impl AutoModeSwitchHandler for AutoHandler { + async fn handle( + &self, + _session_id: SessionId, + _error_code: Option, + _retry_after_seconds: Option, + ) -> AutoModeSwitchResponse { + AutoModeSwitchResponse::No + } + } + + struct ElicitHandler; + #[async_trait] + impl ElicitationHandler for ElicitHandler { + async fn handle( + &self, + _session_id: SessionId, + _request_id: RequestId, + _request: ElicitationRequest, + ) -> ElicitationResult { + ElicitationResult { + action: "cancel".to_string(), + content: None, + } + } + } + + let (client, mut server_read, mut server_write) = make_client(); + + let create_handle = tokio::spawn({ + let client = client.clone(); + async move { + client + .create_cloud_session( + cloud_session_config() + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_user_input_handler(Arc::new(InputHandler)) + .with_exit_plan_mode_handler(Arc::new(ExitHandler)) + .with_auto_mode_switch_handler(Arc::new(AutoHandler)) + .with_elicitation_handler(Arc::new(ElicitHandler)), + ) + .await + .unwrap() + } + }); + + let request = read_framed(&mut server_read).await; + assert_eq!(request["method"], "session.create"); + assert_eq!(request["params"]["requestPermission"], true); + assert_eq!(request["params"]["requestUserInput"], true); + assert_eq!(request["params"]["requestExitPlanMode"], true); + assert_eq!(request["params"]["requestAutoModeSwitch"], true); + assert_eq!(request["params"]["requestElicitation"], true); + + let id = request["id"].as_u64().unwrap(); + let response = serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "result": { "sessionId": "remote-cloud-session" }, + }); + write_framed(&mut server_write, &serde_json::to_vec(&response).unwrap()).await; + + timeout(TIMEOUT, create_handle).await.unwrap().unwrap(); +} + +#[tokio::test] +async fn create_cloud_session_buffers_early_notifications_until_session_id_is_registered() { + let (client, server_read, server_write) = make_client(); + let mut server = FakeServer { + read: server_read, + write: server_write, + session_id: "remote-cloud-session".to_string(), + }; + + let create_handle = tokio::spawn({ + let client = client.clone(); + async move { + client + .create_cloud_session(cloud_session_config()) + .await + .unwrap() + } + }); + + let request = server.read_request().await; + assert_eq!(request["method"], "session.create"); + server + .send_event( + "capabilities.changed", + serde_json::json!({ "ui": { "elicitation": true } }), + ) + .await; + server + .respond( + &request, + serde_json::json!({ "sessionId": server.session_id.clone() }), + ) + .await; + + let session = timeout(TIMEOUT, create_handle).await.unwrap().unwrap(); + for _ in 0..50 { + if session + .capabilities() + .ui + .and_then(|ui| ui.elicitation) + .unwrap_or(false) + { + return; + } + tokio::time::sleep(Duration::from_millis(20)).await; + } + assert_eq!( + session.capabilities().ui.and_then(|ui| ui.elicitation), + Some(true) + ); +} + +#[tokio::test] +async fn create_cloud_session_buffers_early_requests_until_session_id_is_registered() { + struct InputHandler; + #[async_trait] + impl UserInputHandler for InputHandler { + async fn handle( + &self, + _session_id: SessionId, + question: String, + _choices: Option>, + _allow_freeform: Option, + ) -> Option { + assert_eq!(question, "Pick a color"); + Some(UserInputResponse { + answer: "blue".to_string(), + was_freeform: true, + }) + } + } + + let (client, server_read, server_write) = make_client(); + let mut server = FakeServer { + read: server_read, + write: server_write, + session_id: "remote-cloud-session".to_string(), + }; + + let create_handle = tokio::spawn({ + let client = client.clone(); + async move { + client + .create_cloud_session( + cloud_session_config().with_user_input_handler(Arc::new(InputHandler)), + ) + .await + .unwrap() + } + }); + + let request = server.read_request().await; + assert_eq!(request["method"], "session.create"); + assert_eq!(request["params"]["requestUserInput"], true); + server + .send_request( + 301, + "userInput.request", + serde_json::json!({ + "sessionId": server.session_id.clone(), + "question": "Pick a color", + "choices": ["red", "blue"], + "allowFreeform": true, + }), + ) + .await; + server + .respond( + &request, + serde_json::json!({ "sessionId": server.session_id.clone() }), + ) + .await; + + timeout(TIMEOUT, create_handle).await.unwrap().unwrap(); + let response = timeout(TIMEOUT, server.read_response()).await.unwrap(); + assert_eq!(response["id"], 301); + assert_eq!(response["result"]["answer"], "blue"); + assert_eq!(response["result"]["wasFreeform"], true); +} + #[tokio::test] async fn provider_canvas_dispatch_routes_direct_canvas_action_requests() { let (session, mut server) = create_session_pair_with_config(|cfg| {