From 9bcff24112474456a460fe468a0fd43950a033be Mon Sep 17 00:00:00 2001 From: Yordis Prieto Date: Mon, 30 Mar 2026 22:54:44 -0400 Subject: [PATCH] feat(acp-nats, acp-nats-agent): wire JetStream into bridge and agent Signed-off-by: Yordis Prieto --- .../crates/acp-nats-agent/src/connection.rs | 1077 ++++++++++++++++- .../crates/acp-nats-agent/src/constants.rs | 1 + .../crates/acp-nats/src/agent/bridge.rs | 68 +- .../acp-nats/src/agent/close_session.rs | 36 +- .../crates/acp-nats/src/agent/fork_session.rs | 36 +- .../crates/acp-nats/src/agent/js_request.rs | 388 ++++++ .../crates/acp-nats/src/agent/load_session.rs | 36 +- rsworkspace/crates/acp-nats/src/agent/mod.rs | 1 + .../crates/acp-nats/src/agent/prompt.rs | 1060 +++++++++++++++- .../acp-nats/src/agent/resume_session.rs | 36 +- .../src/agent/set_session_config_option.rs | 36 +- .../acp-nats/src/agent/set_session_mode.rs | 36 +- .../acp-nats/src/agent/set_session_model.rs | 36 +- .../crates/acp-nats/src/agent/test_support.rs | 45 + .../src/client/ext_session_prompt_response.rs | 3 +- rsworkspace/crates/acp-nats/src/client/mod.rs | 13 +- rsworkspace/crates/acp-nats/src/config.rs | 1 + rsworkspace/crates/acp-nats/src/constants.rs | 1 + .../acp-nats/src/jetstream/consumers.rs | 33 + .../crates/acp-nats/src/jetstream/streams.rs | 31 + rsworkspace/crates/acp-nats/src/lib.rs | 3 + .../crates/acp-nats/src/nats/subjects.rs | 19 + .../trogon-nats/src/jetstream/client.rs | 56 +- .../trogon-nats/src/jetstream/message.rs | 94 +- .../crates/trogon-nats/src/jetstream/mocks.rs | 286 ++++- .../crates/trogon-nats/src/jetstream/mod.rs | 14 +- .../trogon-nats/src/jetstream/traits.rs | 191 ++- 27 files changed, 3338 insertions(+), 299 deletions(-) create mode 100644 rsworkspace/crates/acp-nats/src/agent/js_request.rs diff --git a/rsworkspace/crates/acp-nats-agent/src/connection.rs b/rsworkspace/crates/acp-nats-agent/src/connection.rs index 35757d9e9..ebbd7712a 100644 --- a/rsworkspace/crates/acp-nats-agent/src/connection.rs +++ b/rsworkspace/crates/acp-nats-agent/src/connection.rs @@ -11,6 +11,7 @@ use agent_client_protocol::{ SetSessionModeRequest, SetSessionModelRequest, }; use async_nats::Message; +use async_nats::jetstream::AckKind; #[cfg(test)] use bytes::Bytes; use futures::StreamExt; @@ -22,12 +23,14 @@ use trogon_nats::{FlushClient, PublishClient, RequestClient, SubscribeClient}; pub enum ConnectionError { Subscribe(Box), + JetStream(Box), } impl std::fmt::Debug for ConnectionError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::Subscribe(e) => f.debug_tuple("Subscribe").field(e).finish(), + Self::JetStream(e) => f.debug_tuple("JetStream").field(e).finish(), } } } @@ -36,6 +39,7 @@ impl std::fmt::Display for ConnectionError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::Subscribe(e) => write!(f, "failed to subscribe: {}", e), + Self::JetStream(e) => write!(f, "jetstream error: {}", e), } } } @@ -63,7 +67,7 @@ impl std::fmt::Display for DispatchError { } } -use crate::constants::DEFAULT_OPERATION_TIMEOUT; +use crate::constants::{DEFAULT_OPERATION_TIMEOUT, KEEPALIVE_INTERVAL}; pub struct AgentSideNatsConnection { nats: N, @@ -97,6 +101,45 @@ where (conn, io_task) } + pub fn with_jetstream( + agent: impl Agent + 'static, + nats: N, + js: J, + acp_prefix: AcpPrefix, + spawn: impl Fn(LocalBoxFuture<'static, ()>) + Copy + 'static, + ) -> ( + Self, + impl std::future::Future>, + ) + where + J: JetStreamConsumerFactory + 'static, + ::Message: JsDispatchMessage, + { + let nats_for_serve = nats.clone(); + let nats_for_js = nats.clone(); + let prefix = acp_prefix.as_str().to_string(); + let prefix_js = prefix.clone(); + + let io_task = async move { + let agent = Rc::new(agent); + + let core = serve_global(agent.clone(), nats_for_serve, &prefix, spawn); + let jetstream = serve_js(agent, nats_for_js, js, &prefix_js, spawn); + + tokio::select! { + result = core => result, + result = jetstream => result, + } + }; + + let conn = Self { + nats, + acp_prefix, + operation_timeout: DEFAULT_OPERATION_TIMEOUT, + }; + (conn, io_task) + } + pub fn client_for_session(&self, session_id: AcpSessionId) -> NatsClientProxy { NatsClientProxy::new( self.nats.clone(), @@ -153,6 +196,51 @@ where Ok(()) } +async fn serve_global( + agent: Rc, + nats: N, + prefix: &str, + spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static, +) -> Result<(), ConnectionError> +where + N: SubscribeClient + PublishClient + FlushClient + Clone + 'static, + A: Agent + 'static, +{ + let global_wildcard = acp_nats::nats::agent::wildcards::all(prefix); + let ext_wildcard = acp_nats::nats::session::wildcards::all_agent_ext(prefix); + + info!( + global = %global_wildcard, + ext = %ext_wildcard, + "Starting global + ext NATS connection (session commands via JetStream)" + ); + + let global_sub = nats + .subscribe(global_wildcard) + .await + .map_err(|e| ConnectionError::Subscribe(Box::new(e)))?; + + let ext_sub = nats + .subscribe(ext_wildcard) + .await + .map_err(|e| ConnectionError::Subscribe(Box::new(e)))?; + + let mut subscriber = futures::stream::select(global_sub, ext_sub); + + let nats = Rc::new(nats); + + while let Some(msg) = subscriber.next().await { + let agent = agent.clone(); + let nats = nats.clone(); + spawn(Box::pin(async move { + dispatch_message(msg, agent.as_ref(), nats.as_ref()).await; + })); + } + + info!("Global-only NATS connection ended"); + Ok(()) +} + async fn dispatch_message( msg: Message, agent: &A, @@ -333,6 +421,267 @@ where .map_err(DispatchError::NotificationHandler) } +use trogon_nats::jetstream::{ + JetStreamConsumer as _, JetStreamConsumerFactory, JsAckWith, JsDispatchMessage, +}; + +async fn handle_request_with_keepalive( + msg: &Message, + nats: &N, + js_msg: &M, + handler: impl FnOnce(ReqT) -> F, +) -> Result<(), DispatchError> +where + N: PublishClient + FlushClient, + ReqT: serde::de::DeserializeOwned, + F: std::future::Future>, + Resp: serde::Serialize, + M: JsAckWith, +{ + let reply_to = msg.reply.as_deref().ok_or(DispatchError::NoReplySubject)?; + + let request: ReqT = match serde_json::from_slice(&msg.payload) { + Ok(req) => req, + Err(e) => { + let error = agent_client_protocol::Error::new( + agent_client_protocol::ErrorCode::InvalidParams.into(), + format!("Failed to deserialize request: {}", e), + ); + let _ = reply(nats, reply_to, &error).await; + return Err(DispatchError::DeserializeRequest(e)); + } + }; + + let handler_fut = handler(request); + tokio::pin!(handler_fut); + + let mut keepalive = tokio::time::interval(KEEPALIVE_INTERVAL); + keepalive.tick().await; + + loop { + tokio::select! { + result = &mut handler_fut => { + return match result { + Ok(resp) => reply(nats, reply_to, &resp).await, + Err(err) => reply(nats, reply_to, &err).await, + }; + } + _ = keepalive.tick() => { + if let Err(e) = js_msg.ack_with(AckKind::Progress).await { + warn!(error = %e, "Failed to send in_progress keepalive"); + } + } + } + } +} + +async fn serve_js( + agent: Rc, + nats: N, + js: J, + prefix: &str, + spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static, +) -> Result<(), ConnectionError> +where + N: PublishClient + FlushClient + Clone + 'static, + J: JetStreamConsumerFactory + 'static, + ::Message: JsDispatchMessage, + A: Agent + 'static, +{ + let stream_name = acp_nats::jetstream::streams::commands_stream_name(prefix); + let config = acp_nats::jetstream::consumers::commands_observer(); + + info!(stream = %stream_name, "Starting JetStream consumer for COMMANDS stream"); + + let consumer = js + .create_consumer(&stream_name, config) + .await + .map_err(|e| ConnectionError::JetStream(Box::new(e)))?; + + let mut messages = consumer + .messages() + .await + .map_err(|e| ConnectionError::JetStream(Box::new(e)))?; + + let nats = Rc::new(nats); + + let prefix = Rc::new(prefix.to_string()); + + while let Some(msg_result) = messages.next().await { + match msg_result { + Ok(js_msg) => { + let agent = agent.clone(); + let nats = nats.clone(); + let prefix = prefix.clone(); + spawn(Box::pin(async move { + dispatch_js_message(js_msg, agent.as_ref(), nats.as_ref(), &prefix).await; + })); + } + Err(e) => { + warn!(error = %e, "JetStream consumer error"); + } + } + } + + info!("JetStream COMMANDS consumer ended"); + Ok(()) +} + +async fn dispatch_js_message( + js_msg: M, + agent: &A, + nats: &N, + prefix: &str, +) { + let subject = js_msg.message().subject.to_string(); + + let (session_id, method) = match parse_agent_subject(&subject) { + Some(ParsedAgentSubject::Session { session_id, method }) => (session_id, method), + Some(ParsedAgentSubject::Global(_)) => { + warn!( + subject, + "Global method on JetStream path; handled by core NATS" + ); + if let Err(e) = js_msg.ack().await { + warn!(subject, error = %e, "Failed to ack misrouted global method"); + } + return; + } + None => { + if let Err(e) = js_msg.ack_with(AckKind::Term).await { + warn!(error = %e, subject, "Failed to term unknown subject"); + } + return; + } + }; + + let req_id = js_msg + .message() + .headers + .as_ref() + .and_then(|h| h.get(trogon_nats::REQ_ID_HEADER)) + .map(|v| v.as_str().to_string()); + + let reply_subject = match (&req_id, &method) { + (Some(rid), SessionAgentMethod::Prompt) => Some( + acp_nats::nats::session::agent::prompt_response(prefix, session_id.as_str(), rid), + ), + (_, SessionAgentMethod::Cancel) => None, + (Some(rid), _) => Some(acp_nats::nats::session::agent::response( + prefix, + session_id.as_str(), + rid, + )), + (None, _) => { + warn!(subject, "JetStream message missing X-Req-Id header"); + None + } + }; + + let inner = js_msg.message(); + let msg = Message { + subject: subject.as_str().into(), + reply: reply_subject.as_deref().map(|s| s.into()), + payload: inner.payload.clone(), + headers: inner.headers.clone(), + status: None, + description: None, + length: inner.payload.len(), + }; + let subject = msg.subject.as_str(); + + let result = match method { + SessionAgentMethod::Load => { + handle_request(&msg, nats, |req: LoadSessionRequest| { + agent.load_session(req) + }) + .await + } + SessionAgentMethod::Prompt => { + handle_request_with_keepalive(&msg, nats, &js_msg, |req: PromptRequest| { + agent.prompt(req) + }) + .await + } + SessionAgentMethod::Cancel => { + handle_notification(&msg, |req: CancelNotification| agent.cancel(req)).await + } + SessionAgentMethod::SetMode => { + handle_request(&msg, nats, |req: SetSessionModeRequest| { + agent.set_session_mode(req) + }) + .await + } + SessionAgentMethod::SetConfigOption => { + handle_request(&msg, nats, |req: SetSessionConfigOptionRequest| { + agent.set_session_config_option(req) + }) + .await + } + SessionAgentMethod::SetModel => { + handle_request(&msg, nats, |req: SetSessionModelRequest| { + agent.set_session_model(req) + }) + .await + } + SessionAgentMethod::Fork => { + handle_request(&msg, nats, |req: ForkSessionRequest| { + agent.fork_session(req) + }) + .await + } + SessionAgentMethod::Resume => { + handle_request(&msg, nats, |req: ResumeSessionRequest| { + agent.resume_session(req) + }) + .await + } + SessionAgentMethod::Close => { + handle_request(&msg, nats, |req: CloseSessionRequest| { + agent.close_session(req) + }) + .await + } + }; + + match &result { + Ok(()) => { + if let Err(e) = js_msg.ack().await { + warn!(subject, error = %e, "Failed to ack JetStream message"); + } + } + Err(DispatchError::DeserializeRequest(_) | DispatchError::DeserializeNotification(_)) => { + if let Err(e) = js_msg.ack_with(AckKind::Term).await { + warn!(subject, error = %e, "Failed to term bad payload"); + } + } + Err(DispatchError::NoReplySubject) => { + if let Err(e) = js_msg.ack_with(AckKind::Term).await { + warn!(subject, error = %e, "Failed to term missing reply subject"); + } + } + Err(DispatchError::Reply(_)) => { + if let Err(e) = js_msg.ack().await { + warn!(subject, error = %e, "Failed to ack after reply failure"); + } + } + Err(DispatchError::NotificationHandler(_)) => { + if let Err(e) = js_msg.ack().await { + warn!(subject, error = %e, "Failed to ack after notification handler error"); + } + } + } + + if let Err(e) = result { + warn!( + subject, + session_id = session_id.as_str(), + error = %e, + "Error handling JetStream request" + ); + } +} + #[cfg(test)] mod tests { use super::*; @@ -745,28 +1094,93 @@ mod tests { .await; } + use trogon_nats::jetstream::mocks::*; + + fn make_js_msg(subject: &str, payload: &[u8], reply: Option<&str>) -> MockJsMessage { + let mut headers = async_nats::HeaderMap::new(); + headers.insert(trogon_nats::REQ_ID_HEADER, "req-1"); + MockJsMessage::new(async_nats::Message { + subject: subject.into(), + reply: reply.map(|r| r.into()), + payload: Bytes::copy_from_slice(payload), + headers: Some(headers), + status: None, + description: None, + length: payload.len(), + }) + } + #[tokio::test] - async fn dispatch_error_logs_warning_with_subscriber() { - use tracing_subscriber::util::SubscriberInitExt; - let _guard = tracing_subscriber::fmt().with_test_writer().set_default(); + async fn with_jetstream_runs_both_loops() { + use trogon_nats::jetstream::MockJetStreamConsumerFactory; let nats = MockNatsClient::new(); let agent = MockAgent::new(); - let payload = serialize(&InitializeRequest::new( - agent_client_protocol::ProtocolVersion::V0, - )); - let msg = make_nats_message("acp.agent.initialize", &payload, None); + let factory = MockJetStreamConsumerFactory::new(); - dispatch_message(msg, &agent, &nats).await; + // Global + ext subs — drop immediately to end serve_global + let global_tx = nats.inject_messages(); + let ext_tx = nats.inject_messages(); + drop(global_tx); + drop(ext_tx); + + // serve_js will fail to create consumer (no mock consumer added) — that's OK + let local = tokio::task::LocalSet::new(); + local + .run_until(async { + let (conn, io_task) = AgentSideNatsConnection::with_jetstream( + agent, + nats, + factory, + AcpPrefix::new("acp").unwrap(), + |fut| { + tokio::task::spawn_local(fut); + }, + ); + + assert_eq!(conn.acp_prefix.as_str(), "acp"); + + let result = io_task.await; + // Either serve_global ends (Ok) or serve_js fails (Err) — both are fine + let _ = result; + }) + .await; } #[tokio::test] - async fn serve_subscribes_and_dispatches_messages() { + async fn serve_global_subscribes_to_global_and_ext() { let nats = MockNatsClient::new(); let agent = MockAgent::new(); let global_tx = nats.inject_messages(); - let session_tx = nats.inject_messages(); + let ext_tx = nats.inject_messages(); + drop(global_tx); + drop(ext_tx); + + let local = tokio::task::LocalSet::new(); + local + .run_until(async { + let _ = serve_global(Rc::new(agent), nats.clone(), "myprefix", |fut| { + tokio::task::spawn_local(fut); + }) + .await; + + let subjects = nats.subscribed_to(); + assert_eq!(subjects.len(), 2); + assert!(subjects.contains(&"myprefix.agent.>".to_string())); + assert!(subjects.contains(&"myprefix.session.*.agent.ext.>".to_string())); + assert!(!subjects.contains(&"myprefix.session.*.agent.>".to_string())); + }) + .await; + } + + #[tokio::test] + async fn serve_global_dispatches_message() { + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + + let global_tx = nats.inject_messages(); + let ext_tx = nats.inject_messages(); let local = tokio::task::LocalSet::new(); local @@ -783,29 +1197,367 @@ mod tests { description: None, length: 0, }; - global_tx.unbounded_send(msg).unwrap(); drop(global_tx); - drop(session_tx); + drop(ext_tx); - let result = serve(agent, nats.clone(), "acp", |fut| { + let _ = serve_global(Rc::new(agent), nats.clone(), "acp", |fut| { tokio::task::spawn_local(fut); }) .await; - assert!(result.is_ok()); - tokio::task::yield_now().await; tokio::task::yield_now().await; - assert_eq!(nats.published_messages().len(), 1); - assert_eq!(nats.published_messages()[0], "_INBOX.serve"); + assert!(!nats.published_messages().is_empty()); }) .await; } #[tokio::test] - async fn serve_returns_ok_when_subscription_ends() { + async fn serve_js_dispatches_message() { + use trogon_nats::jetstream::{ + MockJetStreamConsumer, MockJetStreamConsumerFactory, MockJsMessage, + }; + + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + let factory = MockJetStreamConsumerFactory::new(); + + let (consumer, tx) = MockJetStreamConsumer::new(); + factory.add_consumer(consumer); + + let payload = serialize(&LoadSessionRequest::new("s1", "/tmp")); + let mut headers = async_nats::HeaderMap::new(); + headers.insert(trogon_nats::REQ_ID_HEADER, "req-1"); + let js_msg = MockJsMessage::new(async_nats::Message { + subject: "acp.session.s1.agent.load".into(), + reply: None, + payload: Bytes::copy_from_slice(&payload), + headers: Some(headers), + status: None, + description: None, + length: payload.len(), + }); + tx.unbounded_send(Ok(js_msg)).unwrap(); + drop(tx); + + let local = tokio::task::LocalSet::new(); + local + .run_until(async { + let _ = serve_js(Rc::new(agent), nats.clone(), factory, "acp", |fut| { + tokio::task::spawn_local(fut); + }) + .await; + + tokio::task::yield_now().await; + tokio::task::yield_now().await; + + assert!(!nats.published_messages().is_empty()); + }) + .await; + } + + #[tokio::test] + async fn serve_js_handles_consumer_stream_error() { + use trogon_nats::jetstream::{MockJetStreamConsumer, MockJetStreamConsumerFactory}; + + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + let factory = MockJetStreamConsumerFactory::new(); + + let (consumer, tx) = MockJetStreamConsumer::new(); + factory.add_consumer(consumer); + + tx.unbounded_send(Err(trogon_nats::mocks::MockError("stream error".into()))) + .unwrap(); + drop(tx); + + let local = tokio::task::LocalSet::new(); + local + .run_until(async { + let _ = serve_js(Rc::new(agent), nats.clone(), factory, "acp", |fut| { + tokio::task::spawn_local(fut); + }) + .await; + }) + .await; + } + + #[tokio::test] + async fn serve_js_consumer_creation_failure() { + use trogon_nats::jetstream::MockJetStreamConsumerFactory; + + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + let factory = MockJetStreamConsumerFactory::new(); + + let local = tokio::task::LocalSet::new(); + local + .run_until(async { + let result = serve_js(Rc::new(agent), nats.clone(), factory, "acp", |fut| { + tokio::task::spawn_local(fut); + }) + .await; + assert!(result.is_err()); + }) + .await; + } + + #[tokio::test] + async fn dispatch_js_message_success_acks() { + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + let payload = serialize(&LoadSessionRequest::new("s1", "/tmp")); + let js_msg = make_js_msg("acp.session.s1.agent.load", &payload, None); + + dispatch_js_message(js_msg, &agent, &nats, "acp").await; + + assert!(!nats.published_messages().is_empty()); + } + + #[tokio::test] + async fn dispatch_js_message_unknown_subject_terms() { + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + let js_msg = make_js_msg("acp.unknown.something", b"{}", None); + + dispatch_js_message(js_msg, &agent, &nats, "acp").await; + + assert!(nats.published_messages().is_empty()); + } + + #[tokio::test] + async fn dispatch_js_message_bad_payload_terms() { + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + let js_msg = make_js_msg("acp.session.s1.agent.load", b"not json", None); + + dispatch_js_message(js_msg, &agent, &nats, "acp").await; + + let payloads = nats.published_payloads(); + assert_eq!(payloads.len(), 1); + let error: agent_client_protocol::Error = serde_json::from_slice(&payloads[0]).unwrap(); + assert_eq!(error.code, ErrorCode::InvalidParams); + } + + #[tokio::test] + async fn dispatch_js_message_missing_reply_terms() { + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + let js_msg = make_js_msg("acp.agent.initialize", &serialize(&init_request()), None); + + dispatch_js_message(js_msg, &agent, &nats, "acp").await; + + assert!(nats.published_messages().is_empty()); + } + + #[tokio::test] + async fn dispatch_js_message_missing_req_id_header() { + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + let payload = serialize(&LoadSessionRequest::new("s1", "/tmp")); + // Create message without X-Req-Id header + let js_msg = MockJsMessage::new(async_nats::Message { + subject: "acp.session.s1.agent.load".into(), + reply: None, + payload: Bytes::copy_from_slice(&payload), + headers: None, + status: None, + description: None, + length: payload.len(), + }); + dispatch_js_message(js_msg, &agent, &nats, "acp").await; + // No reply published because no req_id → no reply subject + assert!(nats.published_messages().is_empty()); + } + + #[tokio::test] + async fn dispatch_js_message_global_method_returns_early() { + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + let payload = serialize(&InitializeRequest::new( + agent_client_protocol::ProtocolVersion::V0, + )); + let js_msg = make_js_msg("acp.agent.initialize", &payload, Some("_INBOX.1")); + dispatch_js_message(js_msg, &agent, &nats, "acp").await; + // Global methods return early — no dispatch, no reply + assert!(nats.published_messages().is_empty()); + } + + #[tokio::test] + async fn dispatch_js_message_global_method_ack_failure() { + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + let payload = serialize(&InitializeRequest::new( + agent_client_protocol::ProtocolVersion::V0, + )); + let js_msg = make_failing_js_msg("acp.agent.initialize", &payload); + dispatch_js_message(js_msg, &agent, &nats, "acp").await; + } + + fn make_js_msg_no_headers(subject: &str, payload: &[u8]) -> MockJsMessage { + MockJsMessage::new(async_nats::Message { + subject: subject.into(), + reply: None, + payload: Bytes::copy_from_slice(payload), + headers: None, + status: None, + description: None, + length: payload.len(), + }) + } + + #[tokio::test] + async fn dispatch_js_message_ext_notification_handler_error() { + use tracing_subscriber::util::SubscriberInitExt; + let _guard = tracing_subscriber::fmt().with_test_writer().set_default(); + + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + let raw = std::sync::Arc::from( + serde_json::value::RawValue::from_string("{}".to_string()).unwrap(), + ); + let payload = serialize(&agent_client_protocol::ExtNotification::new("my_tool", raw)); + // No X-Req-Id → ext notification path (reply_subject is None → msg.reply is None) + let js_msg = make_js_msg_no_headers("acp.session.s1.agent.ext.my_tool", &payload); + dispatch_js_message(js_msg, &agent, &nats, "acp").await; + } + + #[tokio::test] + async fn dispatch_js_message_ext_notification_handler_error_ack_failure() { + use tracing_subscriber::util::SubscriberInitExt; + let _guard = tracing_subscriber::fmt().with_test_writer().set_default(); + + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + let raw = std::sync::Arc::from( + serde_json::value::RawValue::from_string("{}".to_string()).unwrap(), + ); + let payload = serialize(&agent_client_protocol::ExtNotification::new("my_tool", raw)); + let js_msg = MockJsMessage::with_failing_signals(async_nats::Message { + subject: "acp.session.s1.agent.ext.my_tool".into(), + reply: None, + payload: Bytes::copy_from_slice(&payload), + headers: None, + status: None, + description: None, + length: payload.len(), + }); + dispatch_js_message(js_msg, &agent, &nats, "acp").await; + } + + #[tokio::test] + async fn dispatch_js_message_global_ext_no_session_id() { + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + let raw = std::sync::Arc::from( + serde_json::value::RawValue::from_string("{}".to_string()).unwrap(), + ); + let payload = serialize(&agent_client_protocol::ExtNotification::new("my_tool", raw)); + let js_msg = make_js_msg_no_headers("acp.agent.ext.my_tool", &payload); + dispatch_js_message(js_msg, &agent, &nats, "acp").await; + } + + #[tokio::test] + async fn dispatch_js_message_prompt_uses_prompt_response_subject() { + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + let payload = serialize(&PromptRequest::new("s1", vec![])); + let js_msg = make_js_msg("acp.session.s1.agent.prompt", &payload, None); + + dispatch_js_message(js_msg, &agent, &nats, "acp").await; + + let subjects = nats.published_messages(); + assert!( + subjects + .iter() + .any(|s| s.starts_with("acp.session.s1.agent.prompt.response.")), + "expected prompt.response subject, got: {:?}", + subjects + ); + } + + #[tokio::test] + async fn dispatch_js_message_non_prompt_session_uses_response_subject() { + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + let payload = serialize(&LoadSessionRequest::new("s1", "/tmp")); + let js_msg = make_js_msg("acp.session.s1.agent.load", &payload, None); + + dispatch_js_message(js_msg, &agent, &nats, "acp").await; + + let subjects = nats.published_messages(); + assert!( + subjects + .iter() + .any(|s| s.starts_with("acp.session.s1.agent.response.")), + "expected response subject, got: {:?}", + subjects + ); + } + + #[tokio::test] + async fn dispatch_error_logs_warning_with_subscriber() { + use tracing_subscriber::util::SubscriberInitExt; + let _guard = tracing_subscriber::fmt().with_test_writer().set_default(); + + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + let payload = serialize(&InitializeRequest::new( + agent_client_protocol::ProtocolVersion::V0, + )); + let msg = make_nats_message("acp.agent.initialize", &payload, None); + + dispatch_message(msg, &agent, &nats).await; + } + + #[tokio::test] + async fn serve_subscribes_and_dispatches_messages() { + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + + let global_tx = nats.inject_messages(); + let session_tx = nats.inject_messages(); + + let local = tokio::task::LocalSet::new(); + local + .run_until(async { + let payload = serialize(&InitializeRequest::new( + agent_client_protocol::ProtocolVersion::V0, + )); + let msg = Message { + subject: "acp.agent.initialize".into(), + reply: Some("_INBOX.serve".into()), + payload: Bytes::copy_from_slice(&payload), + headers: None, + status: None, + description: None, + length: 0, + }; + + global_tx.unbounded_send(msg).unwrap(); + drop(global_tx); + drop(session_tx); + + let result = serve(agent, nats.clone(), "acp", |fut| { + tokio::task::spawn_local(fut); + }) + .await; + + assert!(result.is_ok()); + + tokio::task::yield_now().await; + tokio::task::yield_now().await; + + assert_eq!(nats.published_messages().len(), 1); + assert_eq!(nats.published_messages()[0], "_INBOX.serve"); + }) + .await; + } + + #[tokio::test] + async fn serve_returns_ok_when_subscription_ends() { let nats = MockNatsClient::new(); let agent = MockAgent::new(); @@ -853,4 +1605,291 @@ mod tests { }) .await; } + + #[test] + fn connection_error_jetstream_display() { + let err = ConnectionError::JetStream(Box::new(std::io::Error::other("js err"))); + assert!(err.to_string().contains("js err")); + let debug = format!("{:?}", err); + assert!(debug.contains("JetStream")); + } + + #[tokio::test] + async fn dispatch_js_message_cancel_notification() { + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + let payload = serialize(&CancelNotification::new("s1")); + let js_msg = make_js_msg("acp.session.s1.agent.cancel", &payload, None); + + dispatch_js_message(js_msg, &agent, &nats, "acp").await; + + assert_eq!(agent.cancelled.borrow().len(), 1); + } + + #[tokio::test] + async fn dispatch_js_message_set_mode() { + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + let payload = serialize(&SetSessionModeRequest::new("s1", "code")); + let js_msg = make_js_msg("acp.session.s1.agent.set_mode", &payload, Some("_INBOX.r")); + + dispatch_js_message(js_msg, &agent, &nats, "acp").await; + + assert!(!nats.published_messages().is_empty()); + } + + #[tokio::test] + async fn dispatch_js_message_close_session() { + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + let payload = serialize(&CloseSessionRequest::new("s1")); + let js_msg = make_js_msg("acp.session.s1.agent.close", &payload, Some("_INBOX.r")); + + dispatch_js_message(js_msg, &agent, &nats, "acp").await; + + assert!(!nats.published_messages().is_empty()); + } + + #[tokio::test] + async fn dispatch_js_message_fork_session() { + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + let payload = serialize(&ForkSessionRequest::new("s1", "/tmp")); + let js_msg = make_js_msg("acp.session.s1.agent.fork", &payload, Some("_INBOX.r")); + + dispatch_js_message(js_msg, &agent, &nats, "acp").await; + + assert!(!nats.published_messages().is_empty()); + } + + #[tokio::test] + async fn dispatch_js_message_set_config_option() { + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + let payload = serialize(&SetSessionConfigOptionRequest::new("s1", "key", "val")); + let js_msg = make_js_msg( + "acp.session.s1.agent.set_config_option", + &payload, + Some("_INBOX.r"), + ); + dispatch_js_message(js_msg, &agent, &nats, "acp").await; + assert!(!nats.published_messages().is_empty()); + } + + #[tokio::test] + async fn dispatch_js_message_set_model() { + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + let payload = serialize(&SetSessionModelRequest::new("s1", "gpt-4")); + let js_msg = make_js_msg("acp.session.s1.agent.set_model", &payload, Some("_INBOX.r")); + dispatch_js_message(js_msg, &agent, &nats, "acp").await; + assert!(!nats.published_messages().is_empty()); + } + + #[tokio::test] + async fn dispatch_js_message_resume_session() { + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + let payload = serialize(&ResumeSessionRequest::new("s1", "/tmp")); + let js_msg = make_js_msg("acp.session.s1.agent.resume", &payload, Some("_INBOX.r")); + dispatch_js_message(js_msg, &agent, &nats, "acp").await; + assert!(!nats.published_messages().is_empty()); + } + + #[tokio::test] + async fn dispatch_js_message_prompt() { + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + let payload = serialize(&PromptRequest::new("s1", vec![])); + let js_msg = make_js_msg("acp.session.s1.agent.prompt", &payload, Some("_INBOX.r")); + dispatch_js_message(js_msg, &agent, &nats, "acp").await; + assert!(!nats.published_messages().is_empty()); + } + + #[tokio::test] + async fn dispatch_js_message_reply_failure_acks() { + let nats = trogon_nats::AdvancedMockNatsClient::new(); + nats.fail_next_publish(); + let agent = MockAgent::new(); + let payload = serialize(&LoadSessionRequest::new("s1", "/tmp")); + let js_msg = make_js_msg("acp.session.s1.agent.load", &payload, Some("_INBOX.r")); + + dispatch_js_message(js_msg, &agent, &nats, "acp").await; + } + + fn make_failing_js_msg(subject: &str, payload: &[u8]) -> MockJsMessage { + let mut headers = async_nats::HeaderMap::new(); + headers.insert(trogon_nats::REQ_ID_HEADER, "req-1"); + MockJsMessage::with_failing_signals(async_nats::Message { + subject: subject.into(), + reply: None, + payload: Bytes::copy_from_slice(payload), + headers: Some(headers), + status: None, + description: None, + length: payload.len(), + }) + } + + #[tokio::test] + async fn dispatch_js_message_ack_failure_logs_warning() { + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + let payload = serialize(&LoadSessionRequest::new("s1", "/tmp")); + let js_msg = make_failing_js_msg("acp.session.s1.agent.load", &payload); + dispatch_js_message(js_msg, &agent, &nats, "acp").await; + } + + #[tokio::test] + async fn dispatch_js_message_term_failure_logs_warning() { + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + let js_msg = make_failing_js_msg("unknown.subject", b"{}"); + dispatch_js_message(js_msg, &agent, &nats, "acp").await; + } + + #[tokio::test] + async fn dispatch_js_message_term_bad_payload_failure_logs_warning() { + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + let js_msg = make_failing_js_msg("acp.session.s1.agent.load", b"not json"); + dispatch_js_message(js_msg, &agent, &nats, "acp").await; + } + + #[tokio::test] + async fn dispatch_js_message_no_reply_term_failure() { + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + let payload = serialize(&LoadSessionRequest::new("s1", "/tmp")); + // Session message without X-Req-Id → NoReplySubject → term → term fails + let js_msg = MockJsMessage::with_failing_signals(async_nats::Message { + subject: "acp.session.s1.agent.load".into(), + reply: None, + payload: Bytes::copy_from_slice(&payload), + headers: None, + status: None, + description: None, + length: payload.len(), + }); + dispatch_js_message(js_msg, &agent, &nats, "acp").await; + } + + #[tokio::test] + async fn dispatch_js_message_reply_failure_ack_failure() { + let nats = trogon_nats::AdvancedMockNatsClient::new(); + nats.fail_next_publish(); + let agent = MockAgent::new(); + let payload = serialize(&LoadSessionRequest::new("s1", "/tmp")); + let js_msg = make_failing_js_msg("acp.session.s1.agent.load", &payload); + dispatch_js_message(js_msg, &agent, &nats, "acp").await; + } + + #[tokio::test] + async fn dispatch_js_message_cancel_notification_ack_failure() { + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + let payload = serialize(&CancelNotification::new("s1")); + let js_msg = make_failing_js_msg("acp.session.s1.agent.cancel", &payload); + dispatch_js_message(js_msg, &agent, &nats, "acp").await; + } + + #[tokio::test] + async fn handle_request_with_keepalive_completes_fast() { + let nats = MockNatsClient::new(); + let payload = serialize(&InitializeRequest::new( + agent_client_protocol::ProtocolVersion::V0, + )); + let msg = make_nats_message("acp.agent.initialize", &payload, Some("_INBOX.1")); + let js_msg = make_js_msg("acp.agent.initialize", &payload, Some("_INBOX.1")); + + let agent = MockAgent::new(); + let result = + handle_request_with_keepalive(&msg, &nats, &js_msg, |req: InitializeRequest| { + agent.initialize(req) + }) + .await; + assert!(result.is_ok()); + assert!(!nats.published_messages().is_empty()); + } + + #[tokio::test] + async fn handle_request_with_keepalive_no_reply_subject() { + let nats = MockNatsClient::new(); + let payload = serialize(&InitializeRequest::new( + agent_client_protocol::ProtocolVersion::V0, + )); + let msg = make_nats_message("acp.agent.initialize", &payload, None); + let js_msg = make_js_msg("acp.agent.initialize", &payload, None); + + let result = + handle_request_with_keepalive(&msg, &nats, &js_msg, |_: InitializeRequest| async { + Err::(agent_client_protocol::Error::new(-1, "not called")) + }) + .await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn handle_request_with_keepalive_bad_payload() { + let nats = MockNatsClient::new(); + let msg = make_nats_message("acp.agent.initialize", b"not json", Some("_INBOX.1")); + let js_msg = make_js_msg("acp.agent.initialize", b"not json", Some("_INBOX.1")); + + let result = + handle_request_with_keepalive(&msg, &nats, &js_msg, |_: InitializeRequest| async { + Err::(agent_client_protocol::Error::new(-1, "not called")) + }) + .await; + assert!(result.is_err()); + } + + #[tokio::test(start_paused = true)] + async fn handle_request_with_keepalive_progress_ack_failure() { + use tracing_subscriber::util::SubscriberInitExt; + let _guard = tracing_subscriber::fmt().with_test_writer().set_default(); + + let nats = MockNatsClient::new(); + let payload = serialize(&InitializeRequest::new( + agent_client_protocol::ProtocolVersion::V0, + )); + let msg = make_nats_message("acp.agent.initialize", &payload, Some("_INBOX.1")); + + let mut headers = async_nats::HeaderMap::new(); + headers.insert(trogon_nats::REQ_ID_HEADER, "req-1"); + let js_msg = MockJsMessage::with_failing_signals(async_nats::Message { + subject: "acp.agent.initialize".into(), + reply: Some("_INBOX.1".into()), + payload: Bytes::copy_from_slice(&payload), + headers: Some(headers), + status: None, + description: None, + length: payload.len(), + }); + + let agent = MockAgent::new(); + let result = + handle_request_with_keepalive(&msg, &nats, &js_msg, |req: InitializeRequest| async { + tokio::time::sleep(Duration::from_secs(20)).await; + agent.initialize(req).await + }) + .await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn handle_request_with_keepalive_handler_error() { + let nats = MockNatsClient::new(); + let payload = serialize(&AuthenticateRequest::new("basic")); + let msg = make_nats_message("acp.agent.authenticate", &payload, Some("_INBOX.1")); + let js_msg = make_js_msg("acp.agent.authenticate", &payload, Some("_INBOX.1")); + + let agent = MockAgent::new(); + let result = + handle_request_with_keepalive(&msg, &nats, &js_msg, |req: AuthenticateRequest| { + agent.authenticate(req) + }) + .await; + assert!(result.is_ok()); + assert!(!nats.published_messages().is_empty()); + } } diff --git a/rsworkspace/crates/acp-nats-agent/src/constants.rs b/rsworkspace/crates/acp-nats-agent/src/constants.rs index 9b910b548..49cc02f40 100644 --- a/rsworkspace/crates/acp-nats-agent/src/constants.rs +++ b/rsworkspace/crates/acp-nats-agent/src/constants.rs @@ -1,3 +1,4 @@ use std::time::Duration; pub const DEFAULT_OPERATION_TIMEOUT: Duration = Duration::from_secs(30); +pub const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(15); diff --git a/rsworkspace/crates/acp-nats/src/agent/bridge.rs b/rsworkspace/crates/acp-nats/src/agent/bridge.rs index 360f72830..1aff140ab 100644 --- a/rsworkspace/crates/acp-nats/src/agent/bridge.rs +++ b/rsworkspace/crates/acp-nats/src/agent/bridge.rs @@ -21,22 +21,19 @@ use opentelemetry::metrics::Meter; use tokio::sync::mpsc; use tokio::task::JoinHandle; use tracing::{info, warn}; -#[cfg(not(coverage))] -#[allow(unused_imports)] -use trogon_nats::jetstream::{JetStreamConsumerFactory, JetStreamPublisher}; +use trogon_nats::jetstream::{JetStreamConsumerFactory, JetStreamPublisher, JsRequestMessage}; use trogon_std::time::GetElapsed; use super::{ authenticate, cancel, close_session, ext_method, ext_notification, fork_session, initialize, - list_sessions, load_session, new_session, prompt, resume_session, set_session_config_option, - set_session_mode, set_session_model, + js_request, list_sessions, load_session, new_session, prompt, resume_session, + set_session_config_option, set_session_mode, set_session_model, }; use crate::constants::SESSION_READY_DELAY; pub struct Bridge { pub(crate) nats: N, - #[allow(dead_code)] // Used in prompt.rs JetStream path pub(crate) js: Option, pub(crate) clock: C, pub(crate) config: Config, @@ -68,7 +65,6 @@ impl Bridge { } impl Bridge { - #[cfg(not(coverage))] pub fn with_jetstream( nats: N, js: J, @@ -93,8 +89,6 @@ impl Bridge { &self.nats } - #[cfg(not(coverage))] - #[allow(dead_code)] pub(crate) fn js(&self) -> Option<&J> { self.js.as_ref() } @@ -153,9 +147,61 @@ async fn publish_session_ready( } } +impl< + N: RequestClient + PublishClient + FlushClient, + C: GetElapsed, + J: JetStreamPublisher + JetStreamConsumerFactory, +> Bridge +where + ::Message: JsRequestMessage, +{ + pub(crate) async fn session_request( + &self, + subject: &str, + args: &Req, + session_id: &str, + ) -> Result + where + Req: serde::Serialize, + Res: serde::de::DeserializeOwned, + { + use crate::error::map_nats_error; + + match self.js() { + Some(js) => { + let req_id = uuid::Uuid::new_v4().to_string(); + js_request::js_request::( + js, + subject, + args, + &trogon_std::StdJsonSerialize, + self.config.acp_prefix(), + session_id, + &req_id, + self.config.operation_timeout, + ) + .await + } + None => nats::request_with_timeout::( + self.nats(), + subject, + args, + self.config.operation_timeout, + ) + .await + .map_err(map_nats_error), + } + } +} + #[async_trait::async_trait(?Send)] -impl Agent - for Bridge +impl< + N: RequestClient + PublishClient + SubscribeClient + FlushClient, + C: GetElapsed, + J: JetStreamPublisher + JetStreamConsumerFactory, +> Agent for Bridge +where + ::Message: JsRequestMessage, { async fn initialize(&self, args: InitializeRequest) -> Result { initialize::handle(self, args).await diff --git a/rsworkspace/crates/acp-nats/src/agent/close_session.rs b/rsworkspace/crates/acp-nats/src/agent/close_session.rs index 25c152e40..0465d3927 100644 --- a/rsworkspace/crates/acp-nats/src/agent/close_session.rs +++ b/rsworkspace/crates/acp-nats/src/agent/close_session.rs @@ -1,9 +1,9 @@ use super::Bridge; -use crate::error::map_nats_error; -use crate::nats::{self, RequestClient, session}; +use crate::nats::{FlushClient, PublishClient, RequestClient, session}; use crate::session_id::AcpSessionId; use agent_client_protocol::{CloseSessionRequest, CloseSessionResponse, Error, ErrorCode, Result}; use tracing::{info, instrument}; +use trogon_nats::jetstream::{JetStreamConsumerFactory, JetStreamPublisher, JsRequestMessage}; use trogon_std::time::GetElapsed; #[instrument( @@ -11,10 +11,17 @@ use trogon_std::time::GetElapsed; skip(bridge, args), fields(session_id = %args.session_id) )] -pub async fn handle( +pub async fn handle< + N: RequestClient + PublishClient + FlushClient, + C: GetElapsed, + J: JetStreamPublisher + JetStreamConsumerFactory, +>( bridge: &Bridge, args: CloseSessionRequest, -) -> Result { +) -> Result +where + ::Message: JsRequestMessage, +{ let start = bridge.clock.now(); info!(session_id = %args.session_id, "Close session request"); @@ -28,17 +35,16 @@ pub async fn handle( format!("Invalid session ID: {}", e), ) })?; - let nats = bridge.nats(); - let subject = session::agent::close(bridge.config.acp_prefix(), session_id.as_str()); - - let result = nats::request_with_timeout::( - nats, - &subject, - &args, - bridge.config.operation_timeout, - ) - .await - .map_err(map_nats_error); + let prefix = bridge.config.acp_prefix(); + let subject = session::agent::close(prefix, session_id.as_str()); + + let result = bridge + .session_request::( + &subject, + &args, + session_id.as_str(), + ) + .await; bridge.metrics.record_request( "close_session", diff --git a/rsworkspace/crates/acp-nats/src/agent/fork_session.rs b/rsworkspace/crates/acp-nats/src/agent/fork_session.rs index 2d681e4a4..6e7a54f9d 100644 --- a/rsworkspace/crates/acp-nats/src/agent/fork_session.rs +++ b/rsworkspace/crates/acp-nats/src/agent/fork_session.rs @@ -1,9 +1,9 @@ use super::Bridge; -use crate::error::map_nats_error; -use crate::nats::{self, FlushClient, PublishClient, RequestClient, session}; +use crate::nats::{FlushClient, PublishClient, RequestClient, session}; use crate::session_id::AcpSessionId; use agent_client_protocol::{Error, ErrorCode, ForkSessionRequest, ForkSessionResponse, Result}; use tracing::{Span, info, instrument}; +use trogon_nats::jetstream::{JetStreamConsumerFactory, JetStreamPublisher, JsRequestMessage}; use trogon_std::time::GetElapsed; #[instrument( @@ -11,10 +11,17 @@ use trogon_std::time::GetElapsed; skip(bridge, args), fields(session_id = %args.session_id, new_session_id = tracing::field::Empty) )] -pub async fn handle( +pub async fn handle< + N: RequestClient + PublishClient + FlushClient, + C: GetElapsed, + J: JetStreamPublisher + JetStreamConsumerFactory, +>( bridge: &Bridge, args: ForkSessionRequest, -) -> Result { +) -> Result +where + ::Message: JsRequestMessage, +{ let start = bridge.clock.now(); info!(session_id = %args.session_id, "Fork session request"); @@ -28,17 +35,16 @@ pub async fn handle( - nats, - &subject, - &args, - bridge.config.operation_timeout, - ) - .await - .map_err(map_nats_error); + let prefix = bridge.config.acp_prefix(); + let subject = session::agent::fork(prefix, session_id.as_str()); + + let result = bridge + .session_request::( + &subject, + &args, + session_id.as_str(), + ) + .await; if let Ok(ref response) = result { Span::current().record("new_session_id", response.session_id.to_string().as_str()); diff --git a/rsworkspace/crates/acp-nats/src/agent/js_request.rs b/rsworkspace/crates/acp-nats/src/agent/js_request.rs new file mode 100644 index 000000000..548e79be3 --- /dev/null +++ b/rsworkspace/crates/acp-nats/src/agent/js_request.rs @@ -0,0 +1,388 @@ +use agent_client_protocol::{Error, ErrorCode}; +use async_nats::jetstream::AckKind; +use bytes::Bytes; +use futures::StreamExt; +use serde::de::DeserializeOwned; +use std::time::Duration; +use tokio::time::timeout; +use trogon_nats::REQ_ID_HEADER; +use trogon_nats::jetstream::{ + JetStreamConsumer as _, JetStreamConsumerFactory, JetStreamPublisher, JsAck as _, + JsAckWith as _, JsMessageRef as _, JsRequestMessage, +}; +use trogon_std::JsonSerialize; + +use crate::constants::SESSION_ID_HEADER; +use crate::jetstream::{consumers, streams}; + +#[allow(clippy::too_many_arguments)] +pub async fn js_request( + js: &J, + subject: &str, + request: &Req, + serializer: &S, + prefix: &str, + session_id: &str, + req_id: &str, + operation_timeout: Duration, +) -> agent_client_protocol::Result +where + J: JetStreamPublisher + JetStreamConsumerFactory, + ::Message: JsRequestMessage, + Req: serde::Serialize, + Res: DeserializeOwned, + S: JsonSerialize, +{ + // Create consumer BEFORE publishing — prevents missing the response if the + // runner responds before we start consuming. DeliverAll replays from stream start. + let responses_stream = streams::responses_stream_name(prefix); + let resp_config = consumers::response_consumer(prefix, session_id, req_id); + let resp_consumer: J::Consumer = js + .create_consumer(&responses_stream, resp_config) + .await + .map_err(|e| { + Error::new( + ErrorCode::InternalError.into(), + format!("create response consumer: {e}"), + ) + })?; + let mut resp_messages: ::Messages = + resp_consumer.messages().await.map_err(|e| { + Error::new( + ErrorCode::InternalError.into(), + format!("response messages: {e}"), + ) + })?; + + let payload_bytes = serializer + .to_vec(request) + .map_err(|e| Error::new(ErrorCode::InternalError.into(), format!("serialize: {e}")))?; + + let mut headers = async_nats::HeaderMap::new(); + headers.insert(REQ_ID_HEADER, req_id); + headers.insert(SESSION_ID_HEADER, session_id); + + js.js_publish_with_headers(subject.to_string(), headers, Bytes::from(payload_bytes)) + .await + .map_err(|e| Error::new(ErrorCode::InternalError.into(), format!("js publish: {e}")))?; + + match timeout(operation_timeout, resp_messages.next()).await { + Ok(Some(Ok(js_msg))) => { + match serde_json::from_slice::(js_msg.message().payload.as_ref()) { + Ok(response) => { + let _ = js_msg.ack().await; + Ok(response) + } + Err(_) => { + if let Ok(agent_err) = + serde_json::from_slice::(js_msg.message().payload.as_ref()) + { + let _ = js_msg.ack().await; + Err(agent_err) + } else { + let _ = js_msg.ack_with(AckKind::Term).await; + Err(Error::new( + ErrorCode::InternalError.into(), + "bad response payload", + )) + } + } + } + } + Ok(Some(Err(e))) => Err(Error::new( + ErrorCode::InternalError.into(), + format!("response consumer: {e}"), + )), + Ok(None) => Err(Error::new( + ErrorCode::InternalError.into(), + "response stream closed unexpectedly", + )), + Err(_elapsed) => Err(Error::new( + ErrorCode::InternalError.into(), + "request timed out waiting for runner", + )), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use agent_client_protocol::PromptResponse; + use trogon_nats::jetstream::mocks::*; + + use crate::agent::test_support::MockJs; + + fn make_nats_msg(payload: &[u8]) -> async_nats::Message { + async_nats::Message { + subject: "test".into(), + reply: None, + payload: Bytes::from(payload.to_vec()), + headers: None, + status: None, + description: None, + length: payload.len(), + } + } + + #[tokio::test] + async fn js_request_success() { + let js = MockJs::new(); + let (consumer, tx) = MockJetStreamConsumer::new(); + js.consumer_factory.add_consumer(consumer); + + let response = PromptResponse::new(agent_client_protocol::StopReason::EndTurn); + let msg = MockJsMessage::new(make_nats_msg(&serde_json::to_vec(&response).unwrap())); + tx.unbounded_send(Ok(msg)).unwrap(); + + let result: agent_client_protocol::Result = js_request( + &js, + "acp.session.s1.agent.prompt", + &agent_client_protocol::PromptRequest::new("s1", vec![]), + &trogon_std::StdJsonSerialize, + "acp", + "s1", + "req-1", + Duration::from_secs(5), + ) + .await; + + assert!(result.is_ok()); + assert_eq!( + result.unwrap().stop_reason, + agent_client_protocol::StopReason::EndTurn + ); + } + + #[tokio::test] + async fn js_request_publish_failure() { + let js = MockJs::new(); + let (consumer, _tx) = trogon_nats::jetstream::MockJetStreamConsumer::new(); + js.consumer_factory.add_consumer(consumer); + js.publisher.fail_next_js_publish(); + + let result: agent_client_protocol::Result = js_request( + &js, + "acp.session.s1.agent.prompt", + &agent_client_protocol::PromptRequest::new("s1", vec![]), + &trogon_std::StdJsonSerialize, + "acp", + "s1", + "req-1", + Duration::from_secs(5), + ) + .await; + + assert!(result.is_err()); + assert!(result.unwrap_err().message.contains("js publish")); + } + + #[tokio::test] + async fn js_request_consumer_creation_failure() { + let js = MockJs::new(); + + let result: agent_client_protocol::Result = js_request( + &js, + "acp.session.s1.agent.prompt", + &agent_client_protocol::PromptRequest::new("s1", vec![]), + &trogon_std::StdJsonSerialize, + "acp", + "s1", + "req-1", + Duration::from_secs(5), + ) + .await; + + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .message + .contains("create response consumer") + ); + } + + #[tokio::test] + async fn js_request_messages_failure() { + let js = MockJs::new(); + let failing_consumer = MockJetStreamConsumer::failing(); + js.consumer_factory.add_consumer(failing_consumer); + + let result: agent_client_protocol::Result = js_request( + &js, + "acp.session.s1.agent.prompt", + &agent_client_protocol::PromptRequest::new("s1", vec![]), + &trogon_std::StdJsonSerialize, + "acp", + "s1", + "req-1", + Duration::from_secs(5), + ) + .await; + + assert!(result.is_err()); + assert!(result.unwrap_err().message.contains("response messages")); + } + + #[tokio::test] + async fn js_request_bad_response_payload() { + let js = MockJs::new(); + let (consumer, tx) = MockJetStreamConsumer::new(); + js.consumer_factory.add_consumer(consumer); + + let msg = MockJsMessage::new(make_nats_msg(b"not json")); + tx.unbounded_send(Ok(msg)).unwrap(); + + let result: agent_client_protocol::Result = js_request( + &js, + "acp.session.s1.agent.prompt", + &agent_client_protocol::PromptRequest::new("s1", vec![]), + &trogon_std::StdJsonSerialize, + "acp", + "s1", + "req-1", + Duration::from_secs(5), + ) + .await; + + assert!(result.is_err()); + assert!(result.unwrap_err().message.contains("bad response payload")); + } + + #[tokio::test] + async fn js_request_timeout() { + let js = MockJs::new(); + let (consumer, _tx) = MockJetStreamConsumer::new(); + js.consumer_factory.add_consumer(consumer); + + let result: agent_client_protocol::Result = js_request( + &js, + "acp.session.s1.agent.prompt", + &agent_client_protocol::PromptRequest::new("s1", vec![]), + &trogon_std::StdJsonSerialize, + "acp", + "s1", + "req-1", + Duration::from_millis(10), + ) + .await; + + assert!(result.is_err()); + assert!(result.unwrap_err().message.contains("timed out")); + } + + #[tokio::test] + async fn js_request_agent_error_response() { + let js = MockJs::new(); + let (consumer, tx) = MockJetStreamConsumer::new(); + js.consumer_factory.add_consumer(consumer); + + let agent_err = agent_client_protocol::Error::new( + agent_client_protocol::ErrorCode::InternalError.into(), + "agent failed", + ); + let msg = MockJsMessage::new(async_nats::Message { + subject: "test".into(), + reply: None, + payload: Bytes::from(serde_json::to_vec(&agent_err).unwrap()), + headers: None, + status: None, + description: None, + length: 0, + }); + tx.unbounded_send(Ok(msg)).unwrap(); + + let result: agent_client_protocol::Result = js_request( + &js, + "acp.session.s1.agent.prompt", + &agent_client_protocol::PromptRequest::new("s1", vec![]), + &trogon_std::StdJsonSerialize, + "acp", + "s1", + "req-1", + Duration::from_secs(5), + ) + .await; + + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!(err.message.contains("agent failed")); + } + + #[tokio::test] + async fn js_request_stream_closed() { + let js = MockJs::new(); + let (consumer, tx) = MockJetStreamConsumer::new(); + js.consumer_factory.add_consumer(consumer); + + drop(tx); + + let result: agent_client_protocol::Result = js_request( + &js, + "acp.session.s1.agent.prompt", + &agent_client_protocol::PromptRequest::new("s1", vec![]), + &trogon_std::StdJsonSerialize, + "acp", + "s1", + "req-1", + Duration::from_secs(5), + ) + .await; + + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .message + .contains("response stream closed") + ); + } + + #[tokio::test] + async fn js_request_consumer_stream_error() { + let js = MockJs::new(); + let (consumer, tx) = MockJetStreamConsumer::new(); + js.consumer_factory.add_consumer(consumer); + + tx.unbounded_send(Err(trogon_nats::mocks::MockError( + "stream error".to_string(), + ))) + .unwrap(); + + let result: agent_client_protocol::Result = js_request( + &js, + "acp.session.s1.agent.prompt", + &agent_client_protocol::PromptRequest::new("s1", vec![]), + &trogon_std::StdJsonSerialize, + "acp", + "s1", + "req-1", + Duration::from_secs(5), + ) + .await; + + assert!(result.is_err()); + assert!(result.unwrap_err().message.contains("response consumer")); + } + + #[tokio::test] + async fn js_request_serialize_failure() { + let js = MockJs::new(); + let (consumer, _tx) = trogon_nats::jetstream::MockJetStreamConsumer::new(); + js.consumer_factory.add_consumer(consumer); + + let result: agent_client_protocol::Result = js_request( + &js, + "acp.session.s1.agent.prompt", + &agent_client_protocol::PromptRequest::new("s1", vec![]), + &trogon_std::FailNextSerialize::new(1), + "acp", + "s1", + "req-1", + Duration::from_secs(5), + ) + .await; + + assert!(result.is_err()); + assert!(result.unwrap_err().message.contains("serialize")); + } +} diff --git a/rsworkspace/crates/acp-nats/src/agent/load_session.rs b/rsworkspace/crates/acp-nats/src/agent/load_session.rs index da8cde0f6..00a12ecc3 100644 --- a/rsworkspace/crates/acp-nats/src/agent/load_session.rs +++ b/rsworkspace/crates/acp-nats/src/agent/load_session.rs @@ -1,9 +1,9 @@ use super::Bridge; -use crate::error::map_nats_error; -use crate::nats::{self, FlushClient, PublishClient, RequestClient, session}; +use crate::nats::{FlushClient, PublishClient, RequestClient, session}; use crate::session_id::AcpSessionId; use agent_client_protocol::{Error, ErrorCode, LoadSessionRequest, LoadSessionResponse, Result}; use tracing::{info, instrument}; +use trogon_nats::jetstream::{JetStreamConsumerFactory, JetStreamPublisher, JsRequestMessage}; use trogon_std::time::GetElapsed; #[instrument( @@ -11,10 +11,17 @@ use trogon_std::time::GetElapsed; skip(bridge, args), fields(session_id = %args.session_id) )] -pub async fn handle( +pub async fn handle< + N: RequestClient + PublishClient + FlushClient, + C: GetElapsed, + J: JetStreamPublisher + JetStreamConsumerFactory, +>( bridge: &Bridge, args: LoadSessionRequest, -) -> Result { +) -> Result +where + ::Message: JsRequestMessage, +{ let start = bridge.clock.now(); info!(session_id = %args.session_id, "Load session request"); @@ -28,17 +35,16 @@ pub async fn handle( - nats, - &subject, - &args, - bridge.config.operation_timeout, - ) - .await - .map_err(map_nats_error); + let prefix = bridge.config.acp_prefix(); + let subject = session::agent::load(prefix, session_id.as_str()); + + let result = bridge + .session_request::( + &subject, + &args, + session_id.as_str(), + ) + .await; if result.is_ok() { bridge.schedule_session_ready(args.session_id.clone()); diff --git a/rsworkspace/crates/acp-nats/src/agent/mod.rs b/rsworkspace/crates/acp-nats/src/agent/mod.rs index 8f1cf4538..32c6ec600 100644 --- a/rsworkspace/crates/acp-nats/src/agent/mod.rs +++ b/rsworkspace/crates/acp-nats/src/agent/mod.rs @@ -6,6 +6,7 @@ mod ext_method; mod ext_notification; mod fork_session; mod initialize; +pub(crate) mod js_request; mod list_sessions; mod load_session; mod new_session; diff --git a/rsworkspace/crates/acp-nats/src/agent/prompt.rs b/rsworkspace/crates/acp-nats/src/agent/prompt.rs index aefad2ade..62b9a9061 100644 --- a/rsworkspace/crates/acp-nats/src/agent/prompt.rs +++ b/rsworkspace/crates/acp-nats/src/agent/prompt.rs @@ -1,13 +1,20 @@ use agent_client_protocol::{ Error, ErrorCode, PromptRequest, PromptResponse, SessionNotification, StopReason, }; +use async_nats::jetstream::AckKind; use bytes::Bytes; use futures::StreamExt; use tokio::time::timeout; use tracing::{instrument, warn}; +use trogon_nats::jetstream::{ + JetStreamConsumer as _, JetStreamConsumerFactory, JetStreamPublisher, JsAck as _, + JsAckWith as _, JsMessageRef as _, JsRequestMessage, +}; use trogon_std::JsonSerialize; use crate::agent::Bridge; +use crate::constants::SESSION_ID_HEADER; +use crate::jetstream::{consumers, streams}; use crate::nats::{FlushClient, PublishClient, RequestClient, SubscribeClient, session}; use crate::session_id::AcpSessionId; @@ -26,6 +33,8 @@ pub async fn handle( where N: RequestClient + PublishClient + SubscribeClient + FlushClient, C: trogon_std::time::GetElapsed, + J: JetStreamPublisher + JetStreamConsumerFactory, + ::Message: JsRequestMessage, S: JsonSerialize, { let start = bridge.clock.now(); @@ -39,16 +48,43 @@ where let sid = session_id.as_ref(); let prefix = bridge.config.acp_prefix(); + let result = match bridge.js() { + Some(js) => handle_js(bridge, js, &args, serializer, sid, prefix, &req_id).await, + None => handle_nats(bridge, &args, serializer, sid, prefix, &req_id).await, + }; + + bridge.metrics.record_request( + "prompt", + bridge.clock.elapsed(start).as_secs_f64(), + result.is_ok(), + ); + + result +} + +async fn handle_nats( + bridge: &Bridge, + args: &PromptRequest, + serializer: &S, + sid: &str, + prefix: &str, + req_id: &str, +) -> agent_client_protocol::Result +where + N: PublishClient + SubscribeClient + FlushClient, + C: trogon_std::time::GetElapsed, + S: JsonSerialize, +{ // Subscribe BEFORE publishing — prevents losing the first event if the runner responds instantly. let mut notifications_sub = bridge .nats - .subscribe(session::agent::update(prefix, sid, &req_id)) + .subscribe(session::agent::update(prefix, sid, req_id)) .await .map_err(|e| Error::new(ErrorCode::InternalError.into(), format!("subscribe: {e}")))?; let mut response_sub = bridge .nats - .subscribe(session::agent::prompt_response(prefix, sid, &req_id)) + .subscribe(session::agent::prompt_response(prefix, sid, req_id)) .await .map_err(|e| Error::new(ErrorCode::InternalError.into(), format!("subscribe: {e}")))?; @@ -64,11 +100,12 @@ where })?; let payload_bytes = serializer - .to_vec(&args) + .to_vec(args) .map_err(|e| Error::new(ErrorCode::InternalError.into(), format!("serialize: {e}")))?; let mut headers = async_nats::HeaderMap::new(); - headers.insert(REQ_ID_HEADER, req_id.as_str()); + headers.insert(REQ_ID_HEADER, req_id); + headers.insert(SESSION_ID_HEADER, sid); let prompt_subject = session::agent::prompt(prefix, sid); bridge @@ -85,7 +122,7 @@ where let op_timeout = bridge.config.prompt_timeout(); - let result = loop { + loop { tokio::select! { notif = notifications_sub.next() => { let Some(msg) = notif else { @@ -111,11 +148,14 @@ where Ok(Some(msg)) => { match serde_json::from_slice::(&msg.payload) { Ok(response) => break Ok(response), - Err(e) => { + Err(_) => { + if let Ok(agent_err) = serde_json::from_slice::(&msg.payload) { + break Err(agent_err); + } bridge.metrics.record_error("prompt", "bad_response_payload"); break Err(Error::new( ErrorCode::InternalError.into(), - format!("bad response payload: {e}"), + "bad response payload", )); } } @@ -140,15 +180,178 @@ where break Ok(PromptResponse::new(StopReason::Cancelled)); } } - }; + } +} - bridge.metrics.record_request( - "prompt", - bridge.clock.elapsed(start).as_secs_f64(), - result.is_ok(), - ); +async fn handle_js( + bridge: &Bridge, + js: &J, + args: &PromptRequest, + serializer: &S, + sid: &str, + prefix: &str, + req_id: &str, +) -> agent_client_protocol::Result +where + N: SubscribeClient, + C: trogon_std::time::GetElapsed, + J: JetStreamPublisher + JetStreamConsumerFactory, + ::Message: JsRequestMessage, + S: JsonSerialize, +{ + // Create consumers BEFORE publishing — same principle as subscribe-before-publish. + // JetStream consumers with DeliverAll replay from stream start, so they'll see the + // response even if the runner responds before we start consuming. + let notifications_stream = streams::notifications_stream_name(prefix); + let notif_config = consumers::prompt_notifications_consumer(prefix, sid, req_id); + let notif_consumer: J::Consumer = js + .create_consumer(¬ifications_stream, notif_config) + .await + .map_err(|e| { + Error::new( + ErrorCode::InternalError.into(), + format!("create notification consumer: {e}"), + ) + })?; + let mut notif_messages: ::Messages = + notif_consumer.messages().await.map_err(|e| { + Error::new( + ErrorCode::InternalError.into(), + format!("notification messages: {e}"), + ) + })?; - result + let responses_stream = streams::responses_stream_name(prefix); + let resp_config = consumers::prompt_response_consumer(prefix, sid, req_id); + let resp_consumer: J::Consumer = js + .create_consumer(&responses_stream, resp_config) + .await + .map_err(|e| { + Error::new( + ErrorCode::InternalError.into(), + format!("create response consumer: {e}"), + ) + })?; + let mut resp_messages: ::Messages = + resp_consumer.messages().await.map_err(|e| { + Error::new( + ErrorCode::InternalError.into(), + format!("response messages: {e}"), + ) + })?; + + // Cancel still uses core NATS — it's a fire-and-forget signal, not persisted. + let mut cancel_sub = bridge + .nats + .subscribe(session::agent::cancelled(prefix, sid)) + .await + .map_err(|e| { + Error::new( + ErrorCode::InternalError.into(), + format!("subscribe cancelled: {e}"), + ) + })?; + + // Now publish — consumers are ready, no race condition. + let payload_bytes = serializer + .to_vec(args) + .map_err(|e| Error::new(ErrorCode::InternalError.into(), format!("serialize: {e}")))?; + + let mut headers = async_nats::HeaderMap::new(); + headers.insert(REQ_ID_HEADER, req_id); + headers.insert(SESSION_ID_HEADER, sid); + + let prompt_subject = session::agent::prompt(prefix, sid); + js.js_publish_with_headers(prompt_subject, headers, Bytes::from(payload_bytes)) + .await + .map_err(|e| Error::new(ErrorCode::InternalError.into(), format!("js publish: {e}")))?; + + let op_timeout = bridge.config.prompt_timeout(); + + loop { + tokio::select! { + notif = notif_messages.next() => { + match notif { + None => { + bridge.metrics.record_error("prompt", "notification_stream_closed"); + break Err(Error::new( + ErrorCode::InternalError.into(), + "notification stream closed unexpectedly", + )); + } + Some(Err(e)) => { + bridge.metrics.record_error("prompt", "notification_consumer_error"); + break Err(Error::new( + ErrorCode::InternalError.into(), + format!("notification consumer: {e}"), + )); + } + Some(Ok(js_msg)) => { + let notification: SessionNotification = match serde_json::from_slice(js_msg.message().payload.as_ref()) { + Ok(n) => n, + Err(e) => { + warn!(error = %e, "bad notification payload; skipping"); + let _ = js_msg.ack().await; + continue; + } + }; + let _ = js_msg.ack().await; + if bridge.notification_sender.send(notification).await.is_err() { + warn!("notification receiver dropped; continuing prompt"); + } + } + } + } + resp = timeout(op_timeout, resp_messages.next()) => { + match resp { + Ok(Some(Ok(js_msg))) => { + match serde_json::from_slice::(js_msg.message().payload.as_ref()) { + Ok(response) => { + let _ = js_msg.ack().await; + break Ok(response); + } + Err(_) => { + if let Ok(agent_err) = serde_json::from_slice::(js_msg.message().payload.as_ref()) { + let _ = js_msg.ack().await; + break Err(agent_err); + } + let _ = js_msg.ack_with(AckKind::Term).await; + bridge.metrics.record_error("prompt", "bad_response_payload"); + break Err(Error::new( + ErrorCode::InternalError.into(), + "bad response payload", + )); + } + } + } + Ok(Some(Err(e))) => { + bridge.metrics.record_error("prompt", "response_consumer_error"); + break Err(Error::new( + ErrorCode::InternalError.into(), + format!("response consumer: {e}"), + )); + } + Ok(None) => { + bridge.metrics.record_error("prompt", "response_stream_closed"); + break Err(Error::new( + ErrorCode::InternalError.into(), + "response stream closed unexpectedly", + )); + } + Err(_elapsed) => { + bridge.metrics.record_error("prompt", "prompt_timeout"); + break Err(Error::new( + ErrorCode::InternalError.into(), + "prompt timed out waiting for runner", + )); + } + } + } + _ = cancel_sub.next() => { + break Ok(PromptResponse::new(StopReason::Cancelled)); + } + } + } } #[cfg(test)] @@ -180,7 +383,7 @@ mod tests { mock.clone(), trogon_std::time::SystemClock, &opentelemetry::global::meter("prompt-test"), - Config::for_test("acp"), + Config::for_test("acp").with_prompt_timeout(std::time::Duration::from_secs(5)), notification_tx, ); (mock, bridge) @@ -374,6 +577,833 @@ mod tests { assert!(result.is_err()); } + use crate::agent::test_support::MockJs; + + fn mock_bridge_with_js() -> ( + AdvancedMockNatsClient, + MockJs, + Bridge, + ) { + let mock = AdvancedMockNatsClient::new(); + let js = MockJs::new(); + let (notification_tx, _notification_rx) = + tokio::sync::mpsc::channel::(64); + let bridge = Bridge::with_jetstream( + mock.clone(), + js.clone(), + trogon_std::time::SystemClock, + &opentelemetry::global::meter("prompt-js-test"), + Config::for_test("acp").with_prompt_timeout(std::time::Duration::from_secs(5)), + notification_tx, + ); + (mock, js, bridge) + } + + #[tokio::test] + async fn prompt_js_success() { + let (mock, js, bridge) = mock_bridge_with_js(); + + // cancel sub for core NATS + let _cancel_tx = mock.inject_messages(); + + // notification consumer + let (notif_consumer, notif_tx) = trogon_nats::jetstream::MockJetStreamConsumer::new(); + js.consumer_factory.add_consumer(notif_consumer); + + // response consumer + let (resp_consumer, resp_tx) = trogon_nats::jetstream::MockJetStreamConsumer::new(); + js.consumer_factory.add_consumer(resp_consumer); + + let response = PromptResponse::new(StopReason::EndTurn); + let msg = trogon_nats::jetstream::MockJsMessage::new(make_nats_msg( + &serde_json::to_vec(&response).unwrap(), + )); + resp_tx.unbounded_send(Ok(msg)).unwrap(); + + let result = handle( + &bridge, + PromptRequest::new("s1", vec![]), + &trogon_std::StdJsonSerialize, + ) + .await; + + drop(notif_tx); + let response = result.expect("expected Ok prompt response"); + assert_eq!(response.stop_reason, StopReason::EndTurn); + } + + #[tokio::test] + async fn prompt_js_cancel() { + let (mock, js, bridge) = mock_bridge_with_js(); + + let cancel_tx = mock.inject_messages(); + + let (notif_consumer, _notif_tx) = trogon_nats::jetstream::MockJetStreamConsumer::new(); + js.consumer_factory.add_consumer(notif_consumer); + + let (resp_consumer, _resp_tx) = trogon_nats::jetstream::MockJetStreamConsumer::new(); + js.consumer_factory.add_consumer(resp_consumer); + + cancel_tx.unbounded_send(make_nats_msg(b"")).unwrap(); + + let result = handle( + &bridge, + PromptRequest::new("s1", vec![]), + &trogon_std::StdJsonSerialize, + ) + .await; + assert!(result.is_ok()); + assert_eq!(result.unwrap().stop_reason, StopReason::Cancelled); + } + + #[tokio::test] + async fn prompt_js_timeout() { + let mock = AdvancedMockNatsClient::new(); + let js = MockJs::new(); + let (notification_tx, _notification_rx) = + tokio::sync::mpsc::channel::(64); + let bridge = Bridge::with_jetstream( + mock.clone(), + js.clone(), + trogon_std::time::SystemClock, + &opentelemetry::global::meter("prompt-js-timeout-test"), + Config::for_test("acp").with_prompt_timeout(std::time::Duration::from_millis(50)), + notification_tx, + ); + + let _cancel_tx = mock.inject_messages(); + + let (notif_consumer, _notif_tx) = trogon_nats::jetstream::MockJetStreamConsumer::new(); + js.consumer_factory.add_consumer(notif_consumer); + + let (resp_consumer, _resp_tx) = trogon_nats::jetstream::MockJetStreamConsumer::new(); + js.consumer_factory.add_consumer(resp_consumer); + + let result = handle( + &bridge, + PromptRequest::new("s1", vec![]), + &trogon_std::StdJsonSerialize, + ) + .await; + assert!(result.is_err()); + assert!(result.unwrap_err().message.contains("timed out")); + } + + #[tokio::test] + async fn prompt_js_notification_forwarding() { + let mock = AdvancedMockNatsClient::new(); + let js = MockJs::new(); + let (notification_tx, _notification_rx) = + tokio::sync::mpsc::channel::(64); + let bridge = Bridge::with_jetstream( + mock.clone(), + js.clone(), + trogon_std::time::SystemClock, + &opentelemetry::global::meter("prompt-js-notif-test"), + Config::for_test("acp").with_prompt_timeout(std::time::Duration::from_secs(5)), + notification_tx, + ); + + let _cancel_tx = mock.inject_messages(); + + let (notif_consumer, notif_tx) = trogon_nats::jetstream::MockJetStreamConsumer::new(); + js.consumer_factory.add_consumer(notif_consumer); + + let (resp_consumer, resp_tx) = trogon_nats::jetstream::MockJetStreamConsumer::new(); + js.consumer_factory.add_consumer(resp_consumer); + + let notification = SessionNotification::new( + "s1", + agent_client_protocol::SessionUpdate::AgentThoughtChunk( + agent_client_protocol::ContentChunk::new( + agent_client_protocol::ContentBlock::Text( + agent_client_protocol::TextContent::new("thinking..."), + ), + ), + ), + ); + let notif_msg = trogon_nats::jetstream::MockJsMessage::new(make_nats_msg( + &serde_json::to_vec(¬ification).unwrap(), + )); + notif_tx.unbounded_send(Ok(notif_msg)).unwrap(); + + let response = PromptResponse::new(StopReason::EndTurn); + let resp_msg = trogon_nats::jetstream::MockJsMessage::new(make_nats_msg( + &serde_json::to_vec(&response).unwrap(), + )); + resp_tx.unbounded_send(Ok(resp_msg)).unwrap(); + let _notif_keeper = notif_tx; + + let result = handle( + &bridge, + PromptRequest::new("s1", vec![]), + &trogon_std::StdJsonSerialize, + ) + .await; + + let response = result.expect("expected Ok prompt response"); + assert_eq!(response.stop_reason, StopReason::EndTurn); + } + + #[tokio::test] + async fn prompt_js_publish_failure() { + let (mock, js, bridge) = mock_bridge_with_js(); + let _cancel_tx = mock.inject_messages(); + + let (notif_consumer, _notif_tx) = trogon_nats::jetstream::MockJetStreamConsumer::new(); + js.consumer_factory.add_consumer(notif_consumer); + let (resp_consumer, _resp_tx) = trogon_nats::jetstream::MockJetStreamConsumer::new(); + js.consumer_factory.add_consumer(resp_consumer); + + js.publisher.fail_next_js_publish(); + + let result = handle( + &bridge, + PromptRequest::new("s1", vec![]), + &trogon_std::StdJsonSerialize, + ) + .await; + assert!(result.is_err()); + assert!(result.unwrap_err().message.contains("js publish")); + } + + #[tokio::test] + async fn prompt_js_bad_response_payload() { + let (mock, js, bridge) = mock_bridge_with_js(); + let _cancel_tx = mock.inject_messages(); + + let (notif_consumer, _notif_tx) = trogon_nats::jetstream::MockJetStreamConsumer::new(); + js.consumer_factory.add_consumer(notif_consumer); + + let (resp_consumer, resp_tx) = trogon_nats::jetstream::MockJetStreamConsumer::new(); + js.consumer_factory.add_consumer(resp_consumer); + + let msg = trogon_nats::jetstream::MockJsMessage::new(make_nats_msg(b"not json")); + resp_tx.unbounded_send(Ok(msg)).unwrap(); + + let result = handle( + &bridge, + PromptRequest::new("s1", vec![]), + &trogon_std::StdJsonSerialize, + ) + .await; + assert!(result.is_err()); + assert!(result.unwrap_err().message.contains("bad response payload")); + } + + #[tokio::test] + async fn prompt_js_agent_error_response() { + let (mock, js, bridge) = mock_bridge_with_js(); + let _cancel_tx = mock.inject_messages(); + + let (notif_consumer, _notif_tx) = trogon_nats::jetstream::MockJetStreamConsumer::new(); + js.consumer_factory.add_consumer(notif_consumer); + + let (resp_consumer, resp_tx) = trogon_nats::jetstream::MockJetStreamConsumer::new(); + js.consumer_factory.add_consumer(resp_consumer); + + let agent_err = Error::new(ErrorCode::InternalError.into(), "agent blew up"); + let msg = trogon_nats::jetstream::MockJsMessage::new(make_nats_msg( + &serde_json::to_vec(&agent_err).unwrap(), + )); + resp_tx.unbounded_send(Ok(msg)).unwrap(); + + let result = handle( + &bridge, + PromptRequest::new("s1", vec![]), + &trogon_std::StdJsonSerialize, + ) + .await; + assert!(result.is_err()); + let err = result.unwrap_err(); + assert_eq!(err.code, ErrorCode::InternalError); + assert!(err.message.contains("agent blew up")); + } + + #[tokio::test] + async fn prompt_js_response_stream_closed() { + let (mock, js, bridge) = mock_bridge_with_js(); + let _cancel_tx = mock.inject_messages(); + + let (notif_consumer, _notif_tx) = trogon_nats::jetstream::MockJetStreamConsumer::new(); + js.consumer_factory.add_consumer(notif_consumer); + + let (resp_consumer, resp_tx) = trogon_nats::jetstream::MockJetStreamConsumer::new(); + js.consumer_factory.add_consumer(resp_consumer); + + drop(resp_tx); + + let result = handle( + &bridge, + PromptRequest::new("s1", vec![]), + &trogon_std::StdJsonSerialize, + ) + .await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn prompt_js_notif_consumer_creation_failure() { + let (mock, _js, bridge) = mock_bridge_with_js(); + let _cancel_tx = mock.inject_messages(); + // Don't add any consumers — first create_consumer call will fail + let result = handle( + &bridge, + PromptRequest::new("s1", vec![]), + &trogon_std::StdJsonSerialize, + ) + .await; + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .message + .contains("create notification consumer") + ); + } + + #[tokio::test] + async fn prompt_js_resp_consumer_creation_failure() { + let (mock, js, bridge) = mock_bridge_with_js(); + let _cancel_tx = mock.inject_messages(); + // Add notif consumer but not response consumer + let (notif_consumer, _notif_tx) = trogon_nats::jetstream::MockJetStreamConsumer::new(); + js.consumer_factory.add_consumer(notif_consumer); + let result = handle( + &bridge, + PromptRequest::new("s1", vec![]), + &trogon_std::StdJsonSerialize, + ) + .await; + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .message + .contains("create response consumer") + ); + } + + #[tokio::test] + async fn prompt_js_cancel_subscribe_failure() { + let (_mock, js, bridge) = mock_bridge_with_js(); + // Don't inject cancel_tx — subscribe will fail (no streams in mock) + let (notif_consumer, _notif_tx) = trogon_nats::jetstream::MockJetStreamConsumer::new(); + js.consumer_factory.add_consumer(notif_consumer); + let (resp_consumer, _resp_tx) = trogon_nats::jetstream::MockJetStreamConsumer::new(); + js.consumer_factory.add_consumer(resp_consumer); + let result = handle( + &bridge, + PromptRequest::new("s1", vec![]), + &trogon_std::StdJsonSerialize, + ) + .await; + assert!(result.is_err()); + assert!(result.unwrap_err().message.contains("subscribe cancelled")); + } + + #[tokio::test] + async fn prompt_js_notif_messages_failure() { + let (mock, js, bridge) = mock_bridge_with_js(); + let _cancel_tx = mock.inject_messages(); + + let failing_consumer = trogon_nats::jetstream::MockJetStreamConsumer::failing(); + js.consumer_factory.add_consumer(failing_consumer); + + let result = handle( + &bridge, + PromptRequest::new("s1", vec![]), + &trogon_std::StdJsonSerialize, + ) + .await; + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .message + .contains("notification messages") + ); + } + + #[tokio::test] + async fn prompt_js_resp_messages_failure() { + let (mock, js, bridge) = mock_bridge_with_js(); + let _cancel_tx = mock.inject_messages(); + + let (notif_consumer, _notif_tx) = trogon_nats::jetstream::MockJetStreamConsumer::new(); + js.consumer_factory.add_consumer(notif_consumer); + + let failing_consumer = trogon_nats::jetstream::MockJetStreamConsumer::failing(); + js.consumer_factory.add_consumer(failing_consumer); + + let result = handle( + &bridge, + PromptRequest::new("s1", vec![]), + &trogon_std::StdJsonSerialize, + ) + .await; + assert!(result.is_err()); + assert!(result.unwrap_err().message.contains("response messages")); + } + + #[tokio::test] + async fn prompt_nats_notification_forwarding() { + let (mock, bridge) = mock_bridge_no_js(); + + let notif_tx = mock.inject_messages(); + let resp_tx = mock.inject_messages(); + let _cancel_tx = mock.inject_messages(); + + let notification = SessionNotification::new( + "s1", + agent_client_protocol::SessionUpdate::AgentThoughtChunk( + agent_client_protocol::ContentChunk::new( + agent_client_protocol::ContentBlock::Text( + agent_client_protocol::TextContent::new("thinking..."), + ), + ), + ), + ); + notif_tx + .unbounded_send(make_nats_msg(&serde_json::to_vec(¬ification).unwrap())) + .unwrap(); + + let response = PromptResponse::new(StopReason::EndTurn); + resp_tx + .unbounded_send(make_nats_msg(&serde_json::to_vec(&response).unwrap())) + .unwrap(); + + let result = handle( + &bridge, + PromptRequest::new("s1", vec![]), + &trogon_std::StdJsonSerialize, + ) + .await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn prompt_nats_notification_rx_dropped_unit_type() { + let mock = AdvancedMockNatsClient::new(); + let (notification_tx, notification_rx) = + tokio::sync::mpsc::channel::(64); + drop(notification_rx); + let bridge = Bridge::new( + mock.clone(), + trogon_std::time::SystemClock, + &opentelemetry::global::meter("prompt-rx-dropped-unit-test"), + Config::for_test("acp").with_prompt_timeout(std::time::Duration::from_millis(100)), + notification_tx, + ); + + let notif_tx = mock.inject_messages(); + let _resp_tx = mock.inject_messages(); + let _cancel_tx = mock.inject_messages(); + + let notification = SessionNotification::new( + "s1", + agent_client_protocol::SessionUpdate::AgentThoughtChunk( + agent_client_protocol::ContentChunk::new( + agent_client_protocol::ContentBlock::Text( + agent_client_protocol::TextContent::new("thinking..."), + ), + ), + ), + ); + notif_tx + .unbounded_send(make_nats_msg(&serde_json::to_vec(¬ification).unwrap())) + .unwrap(); + + let result = handle( + &bridge, + PromptRequest::new("s1", vec![]), + &trogon_std::StdJsonSerialize, + ) + .await; + assert!(result.is_err()); + assert!(result.unwrap_err().message.contains("timed out")); + } + + #[tokio::test] + async fn prompt_nats_notification_stream_closed() { + let (mock, bridge) = mock_bridge_no_js(); + + let notif_tx = mock.inject_messages(); + let _resp_tx = mock.inject_messages(); + let _cancel_tx = mock.inject_messages(); + + // Close notification stream immediately + drop(notif_tx); + + let result = handle( + &bridge, + PromptRequest::new("s1", vec![]), + &trogon_std::StdJsonSerialize, + ) + .await; + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .message + .contains("notification stream closed") + ); + } + + fn mock_bridge_no_js() -> ( + AdvancedMockNatsClient, + Bridge, + ) { + let mock = AdvancedMockNatsClient::new(); + let (notification_tx, _notification_rx) = + tokio::sync::mpsc::channel::(64); + let bridge: Bridge = + Bridge { + nats: mock.clone(), + js: None, + clock: trogon_std::time::SystemClock, + config: Config::for_test("acp") + .with_prompt_timeout(std::time::Duration::from_secs(5)), + metrics: crate::telemetry::metrics::Metrics::new(&opentelemetry::global::meter( + "prompt-no-js-test", + )), + notification_sender: notification_tx, + pending_session_prompt_responses: + crate::pending_prompt_waiters::PendingSessionPromptResponseWaiters::new(), + background_tasks: std::cell::RefCell::new(Vec::new()), + }; + (mock, bridge) + } + + #[tokio::test] + async fn prompt_nats_bad_notification_payload_skipped() { + let mock = AdvancedMockNatsClient::new(); + let (notification_tx, _notification_rx) = + tokio::sync::mpsc::channel::(64); + let bridge: Bridge = + Bridge { + nats: mock.clone(), + js: None, + clock: trogon_std::time::SystemClock, + config: Config::for_test("acp") + .with_prompt_timeout(std::time::Duration::from_millis(100)), + metrics: crate::telemetry::metrics::Metrics::new(&opentelemetry::global::meter( + "prompt-nats-bad-notif-test", + )), + notification_sender: notification_tx, + pending_session_prompt_responses: + crate::pending_prompt_waiters::PendingSessionPromptResponseWaiters::new(), + background_tasks: std::cell::RefCell::new(Vec::new()), + }; + + let notif_tx = mock.inject_messages(); + let _resp_tx = mock.inject_messages(); + let _cancel_tx = mock.inject_messages(); + + // Send only bad notification, no response. select! must pick notification. + notif_tx.unbounded_send(make_nats_msg(b"not json")).unwrap(); + + let result = handle( + &bridge, + PromptRequest::new("s1", vec![]), + &trogon_std::StdJsonSerialize, + ) + .await; + // Bad notification processed (warn, continue), then times out + assert!(result.is_err()); + assert!(result.unwrap_err().message.contains("timed out")); + } + + #[tokio::test] + async fn prompt_nats_notification_receiver_dropped() { + use tracing_subscriber::util::SubscriberInitExt; + let _guard = tracing_subscriber::fmt().with_test_writer().set_default(); + + let mock = AdvancedMockNatsClient::new(); + let (notification_tx, notification_rx) = + tokio::sync::mpsc::channel::(64); + drop(notification_rx); + let bridge: Bridge = + Bridge { + nats: mock.clone(), + js: None, + clock: trogon_std::time::SystemClock, + config: Config::for_test("acp") + .with_prompt_timeout(std::time::Duration::from_millis(100)), + metrics: crate::telemetry::metrics::Metrics::new(&opentelemetry::global::meter( + "prompt-rx-dropped-test", + )), + notification_sender: notification_tx, + pending_session_prompt_responses: + crate::pending_prompt_waiters::PendingSessionPromptResponseWaiters::new(), + background_tasks: std::cell::RefCell::new(Vec::new()), + }; + + let notif_tx = mock.inject_messages(); + let _resp_tx = mock.inject_messages(); + let _cancel_tx = mock.inject_messages(); + + // Send only notification, no response. Handler processes the notification + // (send fails because rx dropped, warn, continue), then times out. + let notification = SessionNotification::new( + "s1", + agent_client_protocol::SessionUpdate::AgentThoughtChunk( + agent_client_protocol::ContentChunk::new( + agent_client_protocol::ContentBlock::Text( + agent_client_protocol::TextContent::new("thinking..."), + ), + ), + ), + ); + notif_tx + .unbounded_send(make_nats_msg(&serde_json::to_vec(¬ification).unwrap())) + .unwrap(); + + let result = handle( + &bridge, + PromptRequest::new("s1", vec![]), + &trogon_std::StdJsonSerialize, + ) + .await; + assert!(result.is_err()); + assert!(result.unwrap_err().message.contains("timed out")); + } + + #[tokio::test] + async fn prompt_nats_agent_error_response() { + let (mock, bridge) = mock_bridge_no_js(); + + let _notif_tx = mock.inject_messages(); + let resp_tx = mock.inject_messages(); + let _cancel_tx = mock.inject_messages(); + + let agent_err = Error::new(ErrorCode::InternalError.into(), "agent failed"); + resp_tx + .unbounded_send(make_nats_msg(&serde_json::to_vec(&agent_err).unwrap())) + .unwrap(); + + let result = handle( + &bridge, + PromptRequest::new("s1", vec![]), + &trogon_std::StdJsonSerialize, + ) + .await; + assert!(result.is_err()); + assert!(result.unwrap_err().message.contains("agent failed")); + } + + #[tokio::test] + async fn prompt_nats_timeout() { + let mock = AdvancedMockNatsClient::new(); + let (notification_tx, _notification_rx) = + tokio::sync::mpsc::channel::(64); + let bridge: Bridge = + Bridge { + nats: mock.clone(), + js: None, + clock: trogon_std::time::SystemClock, + config: Config::for_test("acp") + .with_prompt_timeout(std::time::Duration::from_secs(5)), + metrics: crate::telemetry::metrics::Metrics::new(&opentelemetry::global::meter( + "prompt-nats-timeout-test", + )), + notification_sender: notification_tx, + pending_session_prompt_responses: + crate::pending_prompt_waiters::PendingSessionPromptResponseWaiters::new(), + background_tasks: std::cell::RefCell::new(Vec::new()), + }; + + let _notif_tx = mock.inject_messages(); + let _resp_tx = mock.inject_messages(); + let _cancel_tx = mock.inject_messages(); + + let result = handle( + &bridge, + PromptRequest::new("s1", vec![]), + &trogon_std::StdJsonSerialize, + ) + .await; + assert!(result.is_err()); + assert!(result.unwrap_err().message.contains("timed out")); + } + + #[tokio::test] + async fn prompt_js_notification_consumer_error() { + let (mock, js, bridge) = mock_bridge_with_js(); + let _cancel_tx = mock.inject_messages(); + + let (notif_consumer, notif_tx) = trogon_nats::jetstream::MockJetStreamConsumer::new(); + js.consumer_factory.add_consumer(notif_consumer); + + let (resp_consumer, _resp_tx) = trogon_nats::jetstream::MockJetStreamConsumer::new(); + js.consumer_factory.add_consumer(resp_consumer); + + notif_tx + .unbounded_send(Err(trogon_nats::mocks::MockError( + "consumer error".to_string(), + ))) + .unwrap(); + + let result = handle( + &bridge, + PromptRequest::new("s1", vec![]), + &trogon_std::StdJsonSerialize, + ) + .await; + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .message + .contains("notification consumer") + ); + } + + #[tokio::test] + async fn prompt_js_response_consumer_error() { + let (mock, js, bridge) = mock_bridge_with_js(); + let _cancel_tx = mock.inject_messages(); + + let (notif_consumer, _notif_tx) = trogon_nats::jetstream::MockJetStreamConsumer::new(); + js.consumer_factory.add_consumer(notif_consumer); + + let (resp_consumer, resp_tx) = trogon_nats::jetstream::MockJetStreamConsumer::new(); + js.consumer_factory.add_consumer(resp_consumer); + + resp_tx + .unbounded_send(Err(trogon_nats::mocks::MockError( + "consumer error".to_string(), + ))) + .unwrap(); + + let result = handle( + &bridge, + PromptRequest::new("s1", vec![]), + &trogon_std::StdJsonSerialize, + ) + .await; + assert!(result.is_err()); + assert!(result.unwrap_err().message.contains("response consumer")); + } + + #[tokio::test] + async fn prompt_js_bad_notification_payload_skipped() { + let mock = AdvancedMockNatsClient::new(); + let js = MockJs::new(); + let (notification_tx, _notification_rx) = + tokio::sync::mpsc::channel::(64); + let bridge = Bridge::with_jetstream( + mock.clone(), + js.clone(), + trogon_std::time::SystemClock, + &opentelemetry::global::meter("prompt-js-bad-notif-test"), + Config::for_test("acp").with_prompt_timeout(std::time::Duration::from_millis(100)), + notification_tx, + ); + let _cancel_tx = mock.inject_messages(); + + let (notif_consumer, notif_tx) = trogon_nats::jetstream::MockJetStreamConsumer::new(); + js.consumer_factory.add_consumer(notif_consumer); + + let (resp_consumer, _resp_tx) = trogon_nats::jetstream::MockJetStreamConsumer::new(); + js.consumer_factory.add_consumer(resp_consumer); + + // Send only bad notification, no response. Handler processes bad notif + // (warn, ack, continue), then times out waiting for response. + let bad_notif = trogon_nats::jetstream::MockJsMessage::new(make_nats_msg(b"not json")); + notif_tx.unbounded_send(Ok(bad_notif)).unwrap(); + let _notif_keeper = notif_tx; + + let result = handle( + &bridge, + PromptRequest::new("s1", vec![]), + &trogon_std::StdJsonSerialize, + ) + .await; + // Times out after processing bad notification + assert!(result.is_err()); + assert!(result.unwrap_err().message.contains("timed out")); + } + + #[tokio::test] + async fn prompt_js_notification_receiver_dropped() { + use tracing_subscriber::util::SubscriberInitExt; + let _guard = tracing_subscriber::fmt().with_test_writer().set_default(); + + let mock = AdvancedMockNatsClient::new(); + let js = MockJs::new(); + let (notification_tx, notification_rx) = + tokio::sync::mpsc::channel::(64); + drop(notification_rx); + let bridge = Bridge::with_jetstream( + mock.clone(), + js.clone(), + trogon_std::time::SystemClock, + &opentelemetry::global::meter("prompt-js-rx-dropped-test"), + Config::for_test("acp").with_prompt_timeout(std::time::Duration::from_millis(100)), + notification_tx, + ); + + let _cancel_tx = mock.inject_messages(); + + let (notif_consumer, notif_tx) = trogon_nats::jetstream::MockJetStreamConsumer::new(); + js.consumer_factory.add_consumer(notif_consumer); + let (resp_consumer, _resp_tx) = trogon_nats::jetstream::MockJetStreamConsumer::new(); + js.consumer_factory.add_consumer(resp_consumer); + + let notification = SessionNotification::new( + "s1", + agent_client_protocol::SessionUpdate::AgentThoughtChunk( + agent_client_protocol::ContentChunk::new( + agent_client_protocol::ContentBlock::Text( + agent_client_protocol::TextContent::new("thinking..."), + ), + ), + ), + ); + let notif_msg = trogon_nats::jetstream::MockJsMessage::new(make_nats_msg( + &serde_json::to_vec(¬ification).unwrap(), + )); + notif_tx.unbounded_send(Ok(notif_msg)).unwrap(); + let _notif_keeper = notif_tx; + + let result = handle( + &bridge, + PromptRequest::new("s1", vec![]), + &trogon_std::StdJsonSerialize, + ) + .await; + assert!(result.is_err()); + assert!(result.unwrap_err().message.contains("timed out")); + } + + #[tokio::test] + async fn prompt_js_notification_stream_closed() { + let (mock, js, bridge) = mock_bridge_with_js(); + let _cancel_tx = mock.inject_messages(); + + let (notif_consumer, notif_tx) = trogon_nats::jetstream::MockJetStreamConsumer::new(); + js.consumer_factory.add_consumer(notif_consumer); + let (resp_consumer, _resp_tx) = trogon_nats::jetstream::MockJetStreamConsumer::new(); + js.consumer_factory.add_consumer(resp_consumer); + + drop(notif_tx); + + let result = handle( + &bridge, + PromptRequest::new("s1", vec![]), + &trogon_std::StdJsonSerialize, + ) + .await; + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .message + .contains("notification stream closed") + ); + } + #[tokio::test] async fn prompt_publishes_to_correct_subject() { let (mock, bridge) = mock_bridge(); diff --git a/rsworkspace/crates/acp-nats/src/agent/resume_session.rs b/rsworkspace/crates/acp-nats/src/agent/resume_session.rs index 809f2f02a..5f94a7da0 100644 --- a/rsworkspace/crates/acp-nats/src/agent/resume_session.rs +++ b/rsworkspace/crates/acp-nats/src/agent/resume_session.rs @@ -1,11 +1,11 @@ use super::Bridge; -use crate::error::map_nats_error; -use crate::nats::{self, FlushClient, PublishClient, RequestClient, session}; +use crate::nats::{FlushClient, PublishClient, RequestClient, session}; use crate::session_id::AcpSessionId; use agent_client_protocol::{ Error, ErrorCode, Result, ResumeSessionRequest, ResumeSessionResponse, }; use tracing::{info, instrument}; +use trogon_nats::jetstream::{JetStreamConsumerFactory, JetStreamPublisher, JsRequestMessage}; use trogon_std::time::GetElapsed; #[instrument( @@ -13,10 +13,17 @@ use trogon_std::time::GetElapsed; skip(bridge, args), fields(session_id = %args.session_id) )] -pub async fn handle( +pub async fn handle< + N: RequestClient + PublishClient + FlushClient, + C: GetElapsed, + J: JetStreamPublisher + JetStreamConsumerFactory, +>( bridge: &Bridge, args: ResumeSessionRequest, -) -> Result { +) -> Result +where + ::Message: JsRequestMessage, +{ let start = bridge.clock.now(); info!(session_id = %args.session_id, "Resume session request"); @@ -30,17 +37,16 @@ pub async fn handle( - nats, - &subject, - &args, - bridge.config.operation_timeout, - ) - .await - .map_err(map_nats_error); + let prefix = bridge.config.acp_prefix(); + let subject = session::agent::resume(prefix, session_id.as_str()); + + let result = bridge + .session_request::( + &subject, + &args, + session_id.as_str(), + ) + .await; if result.is_ok() { bridge.schedule_session_ready(args.session_id.clone()); diff --git a/rsworkspace/crates/acp-nats/src/agent/set_session_config_option.rs b/rsworkspace/crates/acp-nats/src/agent/set_session_config_option.rs index d2eaca8e7..b4307b4b9 100644 --- a/rsworkspace/crates/acp-nats/src/agent/set_session_config_option.rs +++ b/rsworkspace/crates/acp-nats/src/agent/set_session_config_option.rs @@ -1,11 +1,11 @@ use super::Bridge; -use crate::error::map_nats_error; -use crate::nats::{self, RequestClient, session}; +use crate::nats::{FlushClient, PublishClient, RequestClient, session}; use crate::session_id::AcpSessionId; use agent_client_protocol::{ Error, ErrorCode, Result, SetSessionConfigOptionRequest, SetSessionConfigOptionResponse, }; use tracing::{info, instrument}; +use trogon_nats::jetstream::{JetStreamConsumerFactory, JetStreamPublisher, JsRequestMessage}; use trogon_std::time::GetElapsed; #[instrument( @@ -13,10 +13,17 @@ use trogon_std::time::GetElapsed; skip(bridge, args), fields(session_id = %args.session_id, config_id = %args.config_id) )] -pub async fn handle( +pub async fn handle< + N: RequestClient + PublishClient + FlushClient, + C: GetElapsed, + J: JetStreamPublisher + JetStreamConsumerFactory, +>( bridge: &Bridge, args: SetSessionConfigOptionRequest, -) -> Result { +) -> Result +where + ::Message: JsRequestMessage, +{ let start = bridge.clock.now(); info!(session_id = %args.session_id, config_id = %args.config_id, "Set session config option request"); @@ -30,17 +37,16 @@ pub async fn handle( format!("Invalid session ID: {}", e), ) })?; - let nats = bridge.nats(); - let subject = - session::agent::set_config_option(bridge.config.acp_prefix(), session_id.as_str()); - - let result = nats::request_with_timeout::< - N, - SetSessionConfigOptionRequest, - SetSessionConfigOptionResponse, - >(nats, &subject, &args, bridge.config.operation_timeout) - .await - .map_err(map_nats_error); + let prefix = bridge.config.acp_prefix(); + let subject = session::agent::set_config_option(prefix, session_id.as_str()); + + let result = bridge + .session_request::( + &subject, + &args, + session_id.as_str(), + ) + .await; bridge.metrics.record_request( "set_session_config_option", diff --git a/rsworkspace/crates/acp-nats/src/agent/set_session_mode.rs b/rsworkspace/crates/acp-nats/src/agent/set_session_mode.rs index 955129cc0..2c2af96ed 100644 --- a/rsworkspace/crates/acp-nats/src/agent/set_session_mode.rs +++ b/rsworkspace/crates/acp-nats/src/agent/set_session_mode.rs @@ -1,11 +1,11 @@ use super::Bridge; -use crate::error::map_nats_error; -use crate::nats::{self, RequestClient, session}; +use crate::nats::{FlushClient, PublishClient, RequestClient, session}; use crate::session_id::AcpSessionId; use agent_client_protocol::{ Error, ErrorCode, Result, SetSessionModeRequest, SetSessionModeResponse, }; use tracing::{info, instrument}; +use trogon_nats::jetstream::{JetStreamConsumerFactory, JetStreamPublisher, JsRequestMessage}; use trogon_std::time::GetElapsed; #[instrument( @@ -13,10 +13,17 @@ use trogon_std::time::GetElapsed; skip(bridge, args), fields(session_id = %args.session_id, mode_id = %args.mode_id) )] -pub async fn handle( +pub async fn handle< + N: RequestClient + PublishClient + FlushClient, + C: GetElapsed, + J: JetStreamPublisher + JetStreamConsumerFactory, +>( bridge: &Bridge, args: SetSessionModeRequest, -) -> Result { +) -> Result +where + ::Message: JsRequestMessage, +{ let start = bridge.clock.now(); info!(session_id = %args.session_id, mode_id = %args.mode_id, "Set session mode request"); @@ -30,17 +37,16 @@ pub async fn handle( format!("Invalid session ID: {}", e), ) })?; - let nats = bridge.nats(); - let subject = session::agent::set_mode(bridge.config.acp_prefix(), session_id.as_str()); - - let result = nats::request_with_timeout::( - nats, - &subject, - &args, - bridge.config.operation_timeout, - ) - .await - .map_err(map_nats_error); + let prefix = bridge.config.acp_prefix(); + let subject = session::agent::set_mode(prefix, session_id.as_str()); + + let result = bridge + .session_request::( + &subject, + &args, + session_id.as_str(), + ) + .await; bridge.metrics.record_request( "set_session_mode", diff --git a/rsworkspace/crates/acp-nats/src/agent/set_session_model.rs b/rsworkspace/crates/acp-nats/src/agent/set_session_model.rs index f5c1b8a98..2fa75d32b 100644 --- a/rsworkspace/crates/acp-nats/src/agent/set_session_model.rs +++ b/rsworkspace/crates/acp-nats/src/agent/set_session_model.rs @@ -1,11 +1,11 @@ use super::Bridge; -use crate::error::map_nats_error; -use crate::nats::{self, RequestClient, session}; +use crate::nats::{FlushClient, PublishClient, RequestClient, session}; use crate::session_id::AcpSessionId; use agent_client_protocol::{ Error, ErrorCode, Result, SetSessionModelRequest, SetSessionModelResponse, }; use tracing::{info, instrument}; +use trogon_nats::jetstream::{JetStreamConsumerFactory, JetStreamPublisher, JsRequestMessage}; use trogon_std::time::GetElapsed; #[instrument( @@ -13,10 +13,17 @@ use trogon_std::time::GetElapsed; skip(bridge, args), fields(session_id = %args.session_id, model_id = %args.model_id) )] -pub async fn handle( +pub async fn handle< + N: RequestClient + PublishClient + FlushClient, + C: GetElapsed, + J: JetStreamPublisher + JetStreamConsumerFactory, +>( bridge: &Bridge, args: SetSessionModelRequest, -) -> Result { +) -> Result +where + ::Message: JsRequestMessage, +{ let start = bridge.clock.now(); info!(session_id = %args.session_id, model_id = %args.model_id, "Set session model request"); @@ -30,17 +37,16 @@ pub async fn handle( format!("Invalid session ID: {}", e), ) })?; - let nats = bridge.nats(); - let subject = session::agent::set_model(bridge.config.acp_prefix(), session_id.as_str()); - - let result = nats::request_with_timeout::( - nats, - &subject, - &args, - bridge.config.operation_timeout, - ) - .await - .map_err(map_nats_error); + let prefix = bridge.config.acp_prefix(); + let subject = session::agent::set_model(prefix, session_id.as_str()); + + let result = bridge + .session_request::( + &subject, + &args, + session_id.as_str(), + ) + .await; bridge.metrics.record_request( "set_session_model", diff --git a/rsworkspace/crates/acp-nats/src/agent/test_support.rs b/rsworkspace/crates/acp-nats/src/agent/test_support.rs index 988c7b2eb..6d3926d33 100644 --- a/rsworkspace/crates/acp-nats/src/agent/test_support.rs +++ b/rsworkspace/crates/acp-nats/src/agent/test_support.rs @@ -24,6 +24,51 @@ pub fn mock_bridge() -> ( (mock, bridge) } +#[derive(Clone)] +pub struct MockJs { + pub publisher: trogon_nats::jetstream::MockJetStreamPublisher, + pub consumer_factory: trogon_nats::jetstream::MockJetStreamConsumerFactory, +} + +impl MockJs { + pub fn new() -> Self { + Self { + publisher: trogon_nats::jetstream::MockJetStreamPublisher::new(), + consumer_factory: trogon_nats::jetstream::MockJetStreamConsumerFactory::new(), + } + } +} + +impl trogon_nats::jetstream::JetStreamPublisher for MockJs { + type PublishError = trogon_nats::mocks::MockError; + + async fn js_publish_with_headers( + &self, + subject: String, + headers: async_nats::HeaderMap, + payload: bytes::Bytes, + ) -> Result { + self.publisher + .js_publish_with_headers(subject, headers, payload) + .await + } +} + +impl trogon_nats::jetstream::JetStreamConsumerFactory for MockJs { + type Error = trogon_nats::mocks::MockError; + type Consumer = trogon_nats::jetstream::MockJetStreamConsumer; + + async fn create_consumer( + &self, + stream_name: &str, + config: async_nats::jetstream::consumer::pull::Config, + ) -> Result { + self.consumer_factory + .create_consumer(stream_name, config) + .await + } +} + pub fn mock_bridge_with_metrics() -> ( AdvancedMockNatsClient, Bridge, diff --git a/rsworkspace/crates/acp-nats/src/client/ext_session_prompt_response.rs b/rsworkspace/crates/acp-nats/src/client/ext_session_prompt_response.rs index 89bfb9c7e..c39cb3fcb 100644 --- a/rsworkspace/crates/acp-nats/src/client/ext_session_prompt_response.rs +++ b/rsworkspace/crates/acp-nats/src/client/ext_session_prompt_response.rs @@ -14,11 +14,12 @@ use trogon_std::time::GetElapsed; pub async fn handle< N: RequestClient + PublishClient + FlushClient + SubscribeClient, C: GetElapsed, + J, >( session_id: &str, payload: &[u8], reply: Option<&str>, - bridge: &Bridge, + bridge: &Bridge, ) { if reply.is_some() { warn!( diff --git a/rsworkspace/crates/acp-nats/src/client/mod.rs b/rsworkspace/crates/acp-nats/src/client/mod.rs index 34ccc2330..5b26a53d0 100644 --- a/rsworkspace/crates/acp-nats/src/client/mod.rs +++ b/rsworkspace/crates/acp-nats/src/client/mod.rs @@ -61,11 +61,12 @@ pub async fn run< N: SubscribeClient + RequestClient + PublishClient + FlushClient, Cl: Client + 'static, C: GetElapsed + 'static, + J: 'static, S: Clone + JsonSerialize + 'static, >( nats: N, client: Rc, - bridge: Rc>, + bridge: Rc>, serializer: S, ) { let wildcard = crate::nats::session::wildcards::all_client(bridge.config.acp_prefix()); @@ -102,12 +103,13 @@ async fn process_message< N: SubscribeClient + RequestClient + PublishClient + FlushClient, Cl: Client + 'static, C: GetElapsed + 'static, + J: 'static, S: Clone + JsonSerialize + 'static, >( msg: Message, nats: &N, client: Rc, - bridge: Rc>, + bridge: Rc>, in_flight: &Rc>, max_concurrent: usize, serializer: &S, @@ -159,14 +161,14 @@ async fn process_message< }); } -struct DispatchContext<'a, N, Cl, C, S> +struct DispatchContext<'a, N, Cl, C, J, S> where N: SubscribeClient + RequestClient + PublishClient + FlushClient, C: GetElapsed + 'static, { nats: &'a N, client: &'a Cl, - bridge: &'a Bridge, + bridge: &'a Bridge, serializer: &'a S, } @@ -175,13 +177,14 @@ async fn dispatch_client_method< N: SubscribeClient + RequestClient + PublishClient + FlushClient, Cl: Client, C: GetElapsed + 'static, + J: 'static, S: JsonSerialize, >( subject: &str, parsed: crate::nats::ParsedClientSubject, payload: Bytes, reply: Option, - ctx: &DispatchContext<'_, N, Cl, C, S>, + ctx: &DispatchContext<'_, N, Cl, C, J, S>, ) { Span::current().record("session_id", parsed.session_id.as_str()); diff --git a/rsworkspace/crates/acp-nats/src/config.rs b/rsworkspace/crates/acp-nats/src/config.rs index 4a74f9ad9..0f1dd429f 100644 --- a/rsworkspace/crates/acp-nats/src/config.rs +++ b/rsworkspace/crates/acp-nats/src/config.rs @@ -82,6 +82,7 @@ impl Config { auth: trogon_nats::NatsAuth::None, }; Self::new(AcpPrefix::new(acp_prefix.to_string()).unwrap(), nats) + .with_prompt_timeout(crate::constants::TEST_PROMPT_TIMEOUT) } } diff --git a/rsworkspace/crates/acp-nats/src/constants.rs b/rsworkspace/crates/acp-nats/src/constants.rs index 3762750a7..ea4a6d5e3 100644 --- a/rsworkspace/crates/acp-nats/src/constants.rs +++ b/rsworkspace/crates/acp-nats/src/constants.rs @@ -15,6 +15,7 @@ pub const MIN_TIMEOUT_SECS: u64 = 1; pub const SESSION_READY_DELAY: Duration = Duration::from_millis(100); pub const PROMPT_TIMEOUT_WARNING_SUPPRESSION_WINDOW: Duration = Duration::from_secs(5); +pub const TEST_PROMPT_TIMEOUT: Duration = Duration::from_secs(5); pub const MAX_PREFIX_LENGTH: usize = 128; pub const MAX_SESSION_ID_LENGTH: usize = 128; diff --git a/rsworkspace/crates/acp-nats/src/jetstream/consumers.rs b/rsworkspace/crates/acp-nats/src/jetstream/consumers.rs index 471617c36..9ae8298f1 100644 --- a/rsworkspace/crates/acp-nats/src/jetstream/consumers.rs +++ b/rsworkspace/crates/acp-nats/src/jetstream/consumers.rs @@ -21,6 +21,16 @@ pub fn prompt_response_consumer(prefix: &str, session_id: &str, req_id: &str) -> } } +pub fn response_consumer(prefix: &str, session_id: &str, req_id: &str) -> Config { + Config { + filter_subject: format!("{prefix}.session.{session_id}.agent.response.{req_id}"), + deliver_policy: DeliverPolicy::All, + ack_policy: AckPolicy::Explicit, + replay_policy: ReplayPolicy::Instant, + ..Default::default() + } +} + /// Observer consumer for the COMMANDS stream. /// /// Acks messages for audit persistence. No filter needed — the stream-level @@ -77,6 +87,29 @@ mod tests { assert_eq!(config.filter_subject, String::new()); } + #[test] + fn response_consumer_filter() { + let config = response_consumer("acp", "sess-1", "req-abc"); + assert_eq!( + config.filter_subject, + "acp.session.sess-1.agent.response.req-abc" + ); + } + + #[test] + fn response_consumer_delivers_all() { + let config = response_consumer("acp", "s1", "r1"); + assert_eq!(config.deliver_policy, DeliverPolicy::All); + assert_eq!(config.ack_policy, AckPolicy::Explicit); + assert_eq!(config.replay_policy, ReplayPolicy::Instant); + } + + #[test] + fn response_consumer_custom_prefix() { + let config = response_consumer("myapp", "s1", "r1"); + assert_eq!(config.filter_subject, "myapp.session.s1.agent.response.r1"); + } + #[test] fn custom_prefix_in_consumers() { let config = prompt_response_consumer("myapp", "s1", "r1"); diff --git a/rsworkspace/crates/acp-nats/src/jetstream/streams.rs b/rsworkspace/crates/acp-nats/src/jetstream/streams.rs index 996f21d90..b93cd771a 100644 --- a/rsworkspace/crates/acp-nats/src/jetstream/streams.rs +++ b/rsworkspace/crates/acp-nats/src/jetstream/streams.rs @@ -9,6 +9,18 @@ fn stream_name(prefix: &str, suffix: &str) -> String { format!("{}_{}", prefix.to_uppercase(), suffix) } +pub fn notifications_stream_name(prefix: &str) -> String { + stream_name(prefix, "NOTIFICATIONS") +} + +pub fn responses_stream_name(prefix: &str) -> String { + stream_name(prefix, "RESPONSES") +} + +pub fn commands_stream_name(prefix: &str) -> String { + stream_name(prefix, "COMMANDS") +} + pub fn commands_config(prefix: &str) -> Config { Config { name: stream_name(prefix, "COMMANDS"), @@ -36,6 +48,7 @@ pub fn responses_config(prefix: &str) -> Config { name: stream_name(prefix, "RESPONSES"), subjects: vec![ format!("{prefix}.session.*.agent.prompt.response.>"), + format!("{prefix}.session.*.agent.response.>"), format!("{prefix}.session.*.agent.ext.ready"), format!("{prefix}.session.*.agent.cancelled"), ], @@ -199,6 +212,24 @@ mod tests { } } + #[test] + fn notifications_stream_name_formats_correctly() { + assert_eq!(notifications_stream_name("acp"), "ACP_NOTIFICATIONS"); + assert_eq!(notifications_stream_name("myapp"), "MYAPP_NOTIFICATIONS"); + } + + #[test] + fn responses_stream_name_formats_correctly() { + assert_eq!(responses_stream_name("acp"), "ACP_RESPONSES"); + assert_eq!(responses_stream_name("myapp"), "MYAPP_RESPONSES"); + } + + #[test] + fn commands_stream_name_formats_correctly() { + assert_eq!(commands_stream_name("acp"), "ACP_COMMANDS"); + assert_eq!(commands_stream_name("myapp"), "MYAPP_COMMANDS"); + } + #[test] fn all_configs_returns_four_streams() { assert_eq!(all_configs("acp").len(), 4); diff --git a/rsworkspace/crates/acp-nats/src/lib.rs b/rsworkspace/crates/acp-nats/src/lib.rs index 5edcacd13..d861ca904 100644 --- a/rsworkspace/crates/acp-nats/src/lib.rs +++ b/rsworkspace/crates/acp-nats/src/lib.rs @@ -25,6 +25,9 @@ pub use config::{ pub use error::AGENT_UNAVAILABLE; pub use nats::{FlushClient, PublishClient, RequestClient, SubscribeClient}; pub use session_id::AcpSessionId; +#[cfg(not(coverage))] +pub use trogon_nats::jetstream::NatsJetStreamClient; +pub use trogon_nats::jetstream::{JetStreamConsumerFactory, JetStreamPublisher}; pub use trogon_nats::{NatsAuth, NatsConfig}; pub use trogon_std::StdJsonSerialize; diff --git a/rsworkspace/crates/acp-nats/src/nats/subjects.rs b/rsworkspace/crates/acp-nats/src/nats/subjects.rs index f05fbde2e..0bb3a2aa2 100644 --- a/rsworkspace/crates/acp-nats/src/nats/subjects.rs +++ b/rsworkspace/crates/acp-nats/src/nats/subjects.rs @@ -86,6 +86,13 @@ pub mod session { prefix, session_id, req_id ) } + + pub fn response(prefix: &str, session_id: &str, req_id: &str) -> String { + format!( + "{}.session.{}.agent.response.{}", + prefix, session_id, req_id + ) + } } pub mod client { @@ -152,6 +159,10 @@ pub mod session { format!("{}.session.*.client.>", prefix) } + pub fn all_agent_ext(prefix: &str) -> String { + format!("{}.session.*.agent.ext.>", prefix) + } + pub fn one_agent(prefix: &str, session_id: &str) -> String { format!("{}.session.{}.agent.>", prefix, session_id) } @@ -313,6 +324,14 @@ mod tests { ); } + #[test] + fn session_agent_response() { + assert_eq!( + session::agent::response("acp", "s1", "req-abc"), + "acp.session.s1.agent.response.req-abc" + ); + } + #[test] fn session_client_fs_read() { assert_eq!( diff --git a/rsworkspace/crates/trogon-nats/src/jetstream/client.rs b/rsworkspace/crates/trogon-nats/src/jetstream/client.rs index d8aa08507..8158130b5 100644 --- a/rsworkspace/crates/trogon-nats/src/jetstream/client.rs +++ b/rsworkspace/crates/trogon-nats/src/jetstream/client.rs @@ -1,12 +1,13 @@ use async_nats::HeaderMap; use async_nats::jetstream; +use async_nats::jetstream::AckKind; use async_nats::jetstream::consumer::pull; use async_nats::jetstream::publish::PublishAck; use async_nats::jetstream::stream; use bytes::Bytes; use futures::StreamExt; -use super::message::JsMessage; +use super::message::{JsAck, JsAckWith, JsDoubleAck, JsDoubleAckWith, JsMessageRef}; use super::traits::{ JetStreamConsumer, JetStreamConsumerFactory, JetStreamContext, JetStreamPublisher, }; @@ -67,6 +68,54 @@ impl JetStreamPublisher for NatsJetStreamClient { } } +pub struct NatsJsMessage { + inner: jetstream::Message, +} + +impl NatsJsMessage { + pub fn new(inner: jetstream::Message) -> Self { + Self { inner } + } +} + +impl JsMessageRef for NatsJsMessage { + fn message(&self) -> &async_nats::Message { + &self.inner.message + } +} + +impl JsAck for NatsJsMessage { + type Error = async_nats::Error; + + async fn ack(&self) -> Result<(), async_nats::Error> { + self.inner.ack().await + } +} + +impl JsAckWith for NatsJsMessage { + type Error = async_nats::Error; + + async fn ack_with(&self, kind: AckKind) -> Result<(), async_nats::Error> { + self.inner.ack_with(kind).await + } +} + +impl JsDoubleAck for NatsJsMessage { + type Error = async_nats::Error; + + async fn double_ack(&self) -> Result<(), async_nats::Error> { + self.inner.double_ack().await + } +} + +impl JsDoubleAckWith for NatsJsMessage { + type Error = async_nats::Error; + + async fn double_ack_with(&self, kind: AckKind) -> Result<(), async_nats::Error> { + self.inner.double_ack_with(kind).await + } +} + pub struct NatsJetStreamConsumer { inner: jetstream::consumer::Consumer, } @@ -103,7 +152,8 @@ impl JetStreamConsumerFactory for NatsJetStreamClient { impl JetStreamConsumer for NatsJetStreamConsumer { type Error = JetStreamError; - type Messages = futures::stream::BoxStream<'static, Result>; + type Message = NatsJsMessage; + type Messages = futures::stream::BoxStream<'static, Result>; async fn messages(&self) -> Result { let messages = self @@ -115,7 +165,7 @@ impl JetStreamConsumer for NatsJetStreamConsumer { Ok(messages .map(|result| { result - .map(JsMessage::new) + .map(NatsJsMessage::new) .map_err(|e| JetStreamError(e.to_string())) }) .boxed()) diff --git a/rsworkspace/crates/trogon-nats/src/jetstream/message.rs b/rsworkspace/crates/trogon-nats/src/jetstream/message.rs index a63227013..4a4f0c8b0 100644 --- a/rsworkspace/crates/trogon-nats/src/jetstream/message.rs +++ b/rsworkspace/crates/trogon-nats/src/jetstream/message.rs @@ -1,79 +1,41 @@ -use std::time::Duration; +use std::fmt; +use std::future::Future; -use async_nats::jetstream; +use async_nats::jetstream::AckKind; -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum JsSignal { - Ack, - DoubleAck, - Nak, - NakWithDelay(Duration), - Progress, - Term, +pub trait JsMessageRef: Send + 'static { + fn message(&self) -> &async_nats::Message; } -pub struct JsMessage { - #[cfg_attr(coverage, allow(dead_code))] - inner: jetstream::Message, -} - -#[cfg(not(coverage))] -impl JsMessage { - pub fn new(inner: jetstream::Message) -> Self { - Self { inner } - } - - pub fn into_inner(self) -> jetstream::Message { - self.inner - } - - pub fn message(&self) -> &async_nats::Message { - &self.inner.message - } - - pub fn payload(&self) -> &bytes::Bytes { - &self.inner.payload - } - - pub fn subject(&self) -> &str { - self.inner.subject.as_str() - } +pub trait JsAck: Send + 'static { + type Error: fmt::Display + fmt::Debug + Send + Sync; - pub fn headers(&self) -> Option<&async_nats::HeaderMap> { - self.inner.headers.as_ref() - } + fn ack(&self) -> impl Future> + Send; +} - pub fn reply(&self) -> Option<&async_nats::Subject> { - self.inner.reply.as_ref() - } +pub trait JsAckWith: Send + 'static { + type Error: fmt::Display + fmt::Debug + Send + Sync; - pub async fn ack(&self) -> Result<(), async_nats::Error> { - self.inner.ack().await - } + fn ack_with(&self, kind: AckKind) -> impl Future> + Send; +} - pub async fn double_ack(&self) -> Result<(), async_nats::Error> { - self.inner.double_ack().await - } +pub trait JsDoubleAck: Send + 'static { + type Error: fmt::Display + fmt::Debug + Send + Sync; - pub async fn nak(&self) -> Result<(), async_nats::Error> { - self.inner.ack_with(jetstream::AckKind::Nak(None)).await - } + fn double_ack(&self) -> impl Future> + Send; +} - pub async fn nak_with_delay(&self, delay: Duration) -> Result<(), async_nats::Error> { - self.inner - .ack_with(jetstream::AckKind::Nak(Some(delay))) - .await - } +pub trait JsDoubleAckWith: Send + 'static { + type Error: fmt::Display + fmt::Debug + Send + Sync; - pub async fn term(&self) -> Result<(), async_nats::Error> { - self.inner.ack_with(jetstream::AckKind::Term).await - } + fn double_ack_with( + &self, + kind: AckKind, + ) -> impl Future> + Send; +} - pub async fn in_progress(&self) -> Result<(), async_nats::Error> { - self.inner.ack_with(jetstream::AckKind::Progress).await - } +pub trait JsRequestMessage: JsMessageRef + JsAck + JsAckWith {} +impl JsRequestMessage for T {} - pub fn info(&self) -> Result, async_nats::Error> { - self.inner.info() - } -} +pub trait JsDispatchMessage: JsMessageRef + JsAck + JsAckWith {} +impl JsDispatchMessage for T {} diff --git a/rsworkspace/crates/trogon-nats/src/jetstream/mocks.rs b/rsworkspace/crates/trogon-nats/src/jetstream/mocks.rs index f76dc8e19..b2a6c37d2 100644 --- a/rsworkspace/crates/trogon-nats/src/jetstream/mocks.rs +++ b/rsworkspace/crates/trogon-nats/src/jetstream/mocks.rs @@ -1,8 +1,8 @@ use std::collections::VecDeque; use std::sync::{Arc, Mutex}; -use std::time::Duration; use async_nats::HeaderMap; +use async_nats::jetstream::AckKind; use async_nats::jetstream::consumer::pull; use async_nats::jetstream::publish::PublishAck; use async_nats::jetstream::stream; @@ -10,7 +10,7 @@ use bytes::Bytes; use futures::channel::mpsc; use futures::stream::BoxStream; -use super::message::{JsMessage, JsSignal}; +use super::message::{JsAck, JsAckWith, JsDoubleAck, JsDoubleAckWith, JsMessageRef}; use super::traits::{ JetStreamConsumer, JetStreamConsumerFactory, JetStreamContext, JetStreamPublisher, }; @@ -19,48 +19,121 @@ use crate::mocks::MockError; // --- MockJsMessage --- pub struct MockJsMessage { - pub message: async_nats::Message, - signals: Arc>>, + inner: async_nats::Message, + signals: Arc>>, + fail_signals: bool, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum AckKindSnapshot { + Ack, + AckWith(AckKindValue), + DoubleAck, + DoubleAckWith(AckKindValue), +} + +#[derive(Debug, Clone, PartialEq)] +pub enum AckKindValue { + Ack, + Nak(Option), + Progress, + Next, + Term, +} + +impl From for AckKindValue { + fn from(kind: AckKind) -> Self { + match kind { + AckKind::Ack => AckKindValue::Ack, + AckKind::Nak(d) => AckKindValue::Nak(d), + AckKind::Progress => AckKindValue::Progress, + AckKind::Next => AckKindValue::Next, + AckKind::Term => AckKindValue::Term, + } + } } impl MockJsMessage { pub fn new(message: async_nats::Message) -> Self { Self { - message, + inner: message, + signals: Arc::new(Mutex::new(Vec::new())), + fail_signals: false, + } + } + + pub fn with_failing_signals(message: async_nats::Message) -> Self { + Self { + inner: message, signals: Arc::new(Mutex::new(Vec::new())), + fail_signals: true, } } - pub fn signals(&self) -> Vec { + pub fn signals(&self) -> Vec { self.signals.lock().unwrap().clone() } - fn record(&self, signal: JsSignal) { + fn record(&self, signal: AckKindSnapshot) { self.signals.lock().unwrap().push(signal); } +} - pub fn ack(&self) { - self.record(JsSignal::Ack); +impl JsMessageRef for MockJsMessage { + fn message(&self) -> &async_nats::Message { + &self.inner } +} - pub fn double_ack(&self) { - self.record(JsSignal::DoubleAck); - } +impl JsAck for MockJsMessage { + type Error = MockError; - pub fn nak(&self) { - self.record(JsSignal::Nak); + async fn ack(&self) -> Result<(), MockError> { + self.record(AckKindSnapshot::Ack); + if self.fail_signals { + Err(MockError("ack failed".into())) + } else { + Ok(()) + } } +} - pub fn nak_with_delay(&self, delay: Duration) { - self.record(JsSignal::NakWithDelay(delay)); +impl JsAckWith for MockJsMessage { + type Error = MockError; + + async fn ack_with(&self, kind: AckKind) -> Result<(), MockError> { + self.record(AckKindSnapshot::AckWith(kind.into())); + if self.fail_signals { + Err(MockError("ack_with failed".into())) + } else { + Ok(()) + } } +} + +impl JsDoubleAck for MockJsMessage { + type Error = MockError; - pub fn term(&self) { - self.record(JsSignal::Term); + async fn double_ack(&self) -> Result<(), MockError> { + self.record(AckKindSnapshot::DoubleAck); + if self.fail_signals { + Err(MockError("double_ack failed".into())) + } else { + Ok(()) + } } +} - pub fn in_progress(&self) { - self.record(JsSignal::Progress); +impl JsDoubleAckWith for MockJsMessage { + type Error = MockError; + + async fn double_ack_with(&self, kind: AckKind) -> Result<(), MockError> { + self.record(AckKindSnapshot::DoubleAckWith(kind.into())); + if self.fail_signals { + Err(MockError("double_ack_with failed".into())) + } else { + Ok(()) + } } } @@ -269,11 +342,14 @@ impl JetStreamConsumerFactory for MockJetStreamConsumerFactory { } pub struct MockJetStreamConsumer { - rx: Mutex>>, + rx: Mutex>>>, } impl MockJetStreamConsumer { - pub fn new() -> (Self, mpsc::UnboundedSender) { + pub fn new() -> ( + Self, + mpsc::UnboundedSender>, + ) { let (tx, rx) = mpsc::unbounded(); ( Self { @@ -282,33 +358,28 @@ impl MockJetStreamConsumer { tx, ) } + + pub fn failing() -> Self { + Self { + rx: Mutex::new(None), + } + } } impl JetStreamConsumer for MockJetStreamConsumer { type Error = MockError; - type Messages = BoxStream<'static, Result>; + type Message = MockJsMessage; + type Messages = BoxStream<'static, Result>; async fn messages(&self) -> Result { - // MockJsMessage cannot be converted to JsMessage without a real jetstream::Message. - // Tests that need signal assertions should use raw_messages() instead. - Err(MockError( - "MockJetStreamConsumer.messages() not supported; use raw_messages() instead" - .to_string(), - )) - } -} - -impl MockJetStreamConsumer { - pub fn raw_messages( - self, - ) -> Result + Unpin + Send + 'static, MockError> - { + use futures::StreamExt; let rx = self .rx - .into_inner() + .lock() .unwrap() - .ok_or_else(|| MockError("raw_messages() already called".to_string()))?; - Ok(rx) + .take() + .ok_or_else(|| MockError("messages() already called".to_string()))?; + Ok(rx.boxed()) } } @@ -329,33 +400,51 @@ mod tests { } } - #[test] - fn mock_js_message_records_signals() { + #[tokio::test] + async fn mock_js_message_records_signals() { let msg = MockJsMessage::new(make_nats_msg("test", b"payload")); - msg.ack(); - msg.in_progress(); - msg.term(); + msg.ack().await.unwrap(); + msg.ack_with(AckKind::Progress).await.unwrap(); + msg.ack_with(AckKind::Term).await.unwrap(); assert_eq!( msg.signals(), - vec![JsSignal::Ack, JsSignal::Progress, JsSignal::Term] + vec![ + AckKindSnapshot::Ack, + AckKindSnapshot::AckWith(AckKindValue::Progress), + AckKindSnapshot::AckWith(AckKindValue::Term), + ] ); } - #[test] - fn mock_js_message_records_nak_with_delay() { + #[tokio::test] + async fn mock_js_message_records_nak_with_delay() { let msg = MockJsMessage::new(make_nats_msg("test", b"")); - msg.nak_with_delay(Duration::from_secs(5)); + msg.ack_with(AckKind::Nak(Some(std::time::Duration::from_secs(5)))) + .await + .unwrap(); assert_eq!( msg.signals(), - vec![JsSignal::NakWithDelay(Duration::from_secs(5))] + vec![AckKindSnapshot::AckWith(AckKindValue::Nak(Some( + std::time::Duration::from_secs(5) + )))] ); } - #[test] - fn mock_js_message_records_double_ack() { + #[tokio::test] + async fn mock_js_message_records_double_ack() { let msg = MockJsMessage::new(make_nats_msg("test", b"")); - msg.double_ack(); - assert_eq!(msg.signals(), vec![JsSignal::DoubleAck]); + msg.double_ack().await.unwrap(); + assert_eq!(msg.signals(), vec![AckKindSnapshot::DoubleAck]); + } + + #[tokio::test] + async fn mock_js_message_records_double_ack_with() { + let msg = MockJsMessage::new(make_nats_msg("test", b"")); + msg.double_ack_with(AckKind::Ack).await.unwrap(); + assert_eq!( + msg.signals(), + vec![AckKindSnapshot::DoubleAckWith(AckKindValue::Ack)] + ); } #[tokio::test] @@ -449,16 +538,16 @@ mod tests { } #[tokio::test] - async fn mock_consumer_raw_messages_streams() { + async fn mock_consumer_messages_streams() { let (consumer, tx) = MockJetStreamConsumer::new(); let msg = MockJsMessage::new(make_nats_msg("test.subject", b"data")); - tx.unbounded_send(msg).unwrap(); + tx.unbounded_send(Ok(msg)).unwrap(); drop(tx); - let mut stream = consumer.raw_messages().unwrap(); - let received = stream.next().await.unwrap(); - assert_eq!(received.message.subject.as_str(), "test.subject"); - assert_eq!(received.message.payload.as_ref(), b"data"); + let mut stream = JetStreamConsumer::messages(&consumer).await.unwrap(); + let received = stream.next().await.unwrap().unwrap(); + assert_eq!(received.message().subject.as_str(), "test.subject"); + assert_eq!(received.message().payload.as_ref(), b"data"); } #[test] @@ -478,11 +567,14 @@ mod tests { let _factory = MockJetStreamConsumerFactory::default(); } - #[test] - fn mock_js_message_records_nak() { + #[tokio::test] + async fn mock_js_message_records_nak() { let msg = MockJsMessage::new(make_nats_msg("test", b"")); - msg.nak(); - assert_eq!(msg.signals(), vec![JsSignal::Nak]); + msg.ack_with(AckKind::Nak(None)).await.unwrap(); + assert_eq!( + msg.signals(), + vec![AckKindSnapshot::AckWith(AckKindValue::Nak(None))] + ); } #[tokio::test] @@ -509,6 +601,41 @@ mod tests { ); } + #[test] + fn mock_js_message_payload_subject_headers_reply() { + let mut headers = async_nats::HeaderMap::new(); + headers.insert("X-Test", "value"); + let msg = async_nats::Message { + subject: "test.subject".into(), + reply: Some("_INBOX.reply".into()), + payload: Bytes::from("hello"), + headers: Some(headers), + status: None, + description: None, + length: 5, + }; + let mock = MockJsMessage::new(msg); + + let inner = mock.message(); + assert_eq!(inner.payload.as_ref(), b"hello"); + assert_eq!(inner.subject.as_str(), "test.subject"); + assert!(inner.headers.is_some()); + assert_eq!( + inner.reply.as_ref().map(|s| s.as_str()), + Some("_INBOX.reply") + ); + } + + #[test] + fn mock_js_message_no_headers_no_reply() { + let msg = make_nats_msg("sub", b"data"); + let mock = MockJsMessage::new(msg); + + let inner = mock.message(); + assert!(inner.headers.is_none()); + assert!(inner.reply.is_none()); + } + #[test] fn mock_consumer_factory_clone() { let factory = MockJetStreamConsumerFactory::new(); @@ -519,9 +646,36 @@ mod tests { } #[tokio::test] - async fn mock_consumer_messages_returns_error() { + async fn mock_js_message_failing_signals() { + let msg = MockJsMessage::with_failing_signals(make_nats_msg("test", b"")); + assert!(msg.ack().await.is_err()); + assert!(msg.ack_with(AckKind::Term).await.is_err()); + assert!(msg.double_ack().await.is_err()); + assert!(msg.double_ack_with(AckKind::Ack).await.is_err()); + } + + #[tokio::test] + async fn mock_consumer_messages_returns_stream() { let (consumer, _tx) = MockJetStreamConsumer::new(); let result = JetStreamConsumer::messages(&consumer).await; - assert!(result.is_err()); + assert!(result.is_ok()); + } + + #[tokio::test] + async fn mock_consumer_messages_called_twice_returns_error() { + let (consumer, _tx) = MockJetStreamConsumer::new(); + let _first = JetStreamConsumer::messages(&consumer).await.unwrap(); + let second = JetStreamConsumer::messages(&consumer).await; + assert!(second.is_err()); + } + + #[tokio::test] + async fn mock_js_message_ack_with_next() { + let msg = MockJsMessage::new(make_nats_msg("test", b"")); + msg.ack_with(AckKind::Next).await.unwrap(); + assert_eq!( + msg.signals(), + vec![AckKindSnapshot::AckWith(AckKindValue::Next)] + ); } } diff --git a/rsworkspace/crates/trogon-nats/src/jetstream/mod.rs b/rsworkspace/crates/trogon-nats/src/jetstream/mod.rs index 93db4c59d..00deb048c 100644 --- a/rsworkspace/crates/trogon-nats/src/jetstream/mod.rs +++ b/rsworkspace/crates/trogon-nats/src/jetstream/mod.rs @@ -7,14 +7,18 @@ pub mod traits; pub mod mocks; #[cfg(not(coverage))] -pub use client::{JetStreamError, NatsJetStreamClient, NatsJetStreamConsumer}; -pub use message::{JsMessage, JsSignal}; +pub use client::{JetStreamError, NatsJetStreamClient, NatsJetStreamConsumer, NatsJsMessage}; +pub use message::{ + JsAck, JsAckWith, JsDispatchMessage, JsDoubleAck, JsDoubleAckWith, JsMessageRef, + JsRequestMessage, +}; pub use traits::{ - JetStreamConsumer, JetStreamConsumerFactory, JetStreamContext, JetStreamPublisher, + JetStreamConsumer, JetStreamConsumerFactory, JetStreamContext, JetStreamPublisher, NoJetStream, + NoOpConsumer, NoOpMessage, }; #[cfg(feature = "test-support")] pub use mocks::{ - MockJetStreamConsumer, MockJetStreamConsumerFactory, MockJetStreamContext, - MockJetStreamPublisher, MockJsMessage, + AckKindSnapshot, AckKindValue, MockJetStreamConsumer, MockJetStreamConsumerFactory, + MockJetStreamContext, MockJetStreamPublisher, MockJsMessage, }; diff --git a/rsworkspace/crates/trogon-nats/src/jetstream/traits.rs b/rsworkspace/crates/trogon-nats/src/jetstream/traits.rs index b93e7f910..e56a5adfa 100644 --- a/rsworkspace/crates/trogon-nats/src/jetstream/traits.rs +++ b/rsworkspace/crates/trogon-nats/src/jetstream/traits.rs @@ -8,8 +8,6 @@ use async_nats::jetstream::stream; use bytes::Bytes; use futures::Stream; -use super::message::JsMessage; - pub trait JetStreamContext: Send + Sync + Clone + 'static { type Error: Error + Send + Sync; @@ -43,7 +41,194 @@ pub trait JetStreamConsumerFactory: Send + Sync + Clone + 'static { pub trait JetStreamConsumer: Send + Sync + 'static { type Error: Error + Send + Sync; - type Messages: Stream> + Unpin + Send + 'static; + type Message: Send + 'static; + type Messages: Stream> + Unpin + Send + 'static; fn messages(&self) -> impl Future> + Send; } + +/// No-op error for the `()` JetStream impls. These impls exist only to satisfy +/// trait bounds when `Bridge` is used without JetStream — the methods +/// are never called because `bridge.js()` returns `None`. +#[derive(Debug)] +pub struct NoJetStream; + +impl std::fmt::Display for NoJetStream { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "JetStream not configured") + } +} + +impl Error for NoJetStream {} + +impl JetStreamPublisher for () { + type PublishError = NoJetStream; + + async fn js_publish_with_headers( + &self, + _subject: String, + _headers: HeaderMap, + _payload: Bytes, + ) -> Result { + Err(NoJetStream) + } +} + +impl JetStreamConsumerFactory for () { + type Error = NoJetStream; + type Consumer = NoOpConsumer; + + async fn create_consumer( + &self, + _stream_name: &str, + _config: pull::Config, + ) -> Result { + Err(NoJetStream) + } +} + +/// Stub consumer for `J = ()`. Never produced at runtime because +/// `()::create_consumer()` always returns `Err`. +pub struct NoOpConsumer; + +impl JetStreamConsumer for NoOpConsumer { + type Error = NoJetStream; + type Message = NoOpMessage; + type Messages = futures::stream::Empty>; + + async fn messages(&self) -> Result { + Err(NoJetStream) + } +} + +/// Stub message for [`NoOpConsumer`]. Never produced at runtime because +/// `NoOpConsumer::messages()` always returns `Err`. Exists only to satisfy +/// generic bounds when `J = ()`. +#[derive(Debug)] +pub struct NoOpMessage { + inner: async_nats::Message, +} + +impl NoOpMessage { + #[cfg(test)] + fn stub() -> Self { + Self { + inner: async_nats::Message { + subject: "".into(), + reply: None, + payload: bytes::Bytes::new(), + headers: None, + status: None, + description: None, + length: 0, + }, + } + } +} + +impl super::message::JsMessageRef for NoOpMessage { + fn message(&self) -> &async_nats::Message { + &self.inner + } +} + +impl super::message::JsAck for NoOpMessage { + type Error = NoJetStream; + + async fn ack(&self) -> Result<(), NoJetStream> { + Err(NoJetStream) + } +} + +impl super::message::JsAckWith for NoOpMessage { + type Error = NoJetStream; + + async fn ack_with(&self, _kind: async_nats::jetstream::AckKind) -> Result<(), NoJetStream> { + Err(NoJetStream) + } +} + +impl super::message::JsDoubleAck for NoOpMessage { + type Error = NoJetStream; + + async fn double_ack(&self) -> Result<(), NoJetStream> { + Err(NoJetStream) + } +} + +impl super::message::JsDoubleAckWith for NoOpMessage { + type Error = NoJetStream; + + async fn double_ack_with( + &self, + _kind: async_nats::jetstream::AckKind, + ) -> Result<(), NoJetStream> { + Err(NoJetStream) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn no_jetstream_display() { + let err = NoJetStream; + assert_eq!(err.to_string(), "JetStream not configured"); + } + + #[test] + fn no_jetstream_is_error() { + let err: &dyn Error = &NoJetStream; + assert!(err.source().is_none()); + } + + #[tokio::test] + async fn no_op_consumer_messages_returns_err() { + let consumer = NoOpConsumer; + let result = JetStreamConsumer::messages(&consumer).await; + assert!(result.is_err()); + assert_eq!(result.unwrap_err().to_string(), "JetStream not configured"); + } + + #[tokio::test] + async fn unit_publisher_returns_err() { + let result = ().js_publish_with_headers("s".into(), HeaderMap::new(), Bytes::new()).await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn unit_consumer_factory_returns_err() { + let result = ().create_consumer("s", pull::Config::default()).await; + assert!(result.is_err()); + } + + #[test] + fn no_op_message_ref() { + use super::super::message::JsMessageRef; + let msg = NoOpMessage::stub(); + let inner = msg.message(); + assert!(inner.payload.is_empty()); + assert_eq!(inner.subject.as_str(), ""); + assert!(inner.headers.is_none()); + assert!(inner.reply.is_none()); + } + + #[tokio::test] + async fn no_op_message_signals() { + use super::super::message::{JsAck, JsAckWith, JsDoubleAck, JsDoubleAckWith}; + let msg = NoOpMessage::stub(); + assert!(msg.ack().await.is_err()); + assert!( + msg.ack_with(async_nats::jetstream::AckKind::Ack) + .await + .is_err() + ); + assert!(msg.double_ack().await.is_err()); + assert!( + msg.double_ack_with(async_nats::jetstream::AckKind::Ack) + .await + .is_err() + ); + } +}