From 813c23bc511cb245a99c88bc45bf4582b8806fde Mon Sep 17 00:00:00 2001 From: Yordis Prieto Date: Wed, 25 Mar 2026 11:51:32 -0400 Subject: [PATCH] feat(acp-nats-agent): add server-side ACP agent framework Signed-off-by: Yordis Prieto --- rsworkspace/Cargo.lock | 20 +- rsworkspace/crates/acp-nats-agent/Cargo.toml | 37 + .../crates/acp-nats-agent/src/connection.rs | 879 ++++++++++++++++++ rsworkspace/crates/acp-nats-agent/src/lib.rs | 3 + .../crates/acp-nats/src/client_proxy.rs | 335 +++++++ rsworkspace/crates/acp-nats/src/lib.rs | 4 +- rsworkspace/crates/acp-nats/src/nats/mod.rs | 7 +- .../crates/acp-nats/src/nats/parsing.rs | 277 +++++- .../crates/acp-nats/src/nats/subjects.rs | 259 +++++- 9 files changed, 1810 insertions(+), 11 deletions(-) create mode 100644 rsworkspace/crates/acp-nats-agent/Cargo.toml create mode 100644 rsworkspace/crates/acp-nats-agent/src/connection.rs create mode 100644 rsworkspace/crates/acp-nats-agent/src/lib.rs create mode 100644 rsworkspace/crates/acp-nats/src/client_proxy.rs diff --git a/rsworkspace/Cargo.lock b/rsworkspace/Cargo.lock index 9c5a1da3a..3485dab79 100644 --- a/rsworkspace/Cargo.lock +++ b/rsworkspace/Cargo.lock @@ -23,9 +23,27 @@ dependencies = [ "uuid", ] +[[package]] +name = "acp-nats-agent" +version = "0.0.1" +dependencies = [ + "acp-nats", + "agent-client-protocol", + "async-nats", + "async-trait", + "bytes", + "futures", + "serde", + "serde_json", + "tokio", + "tracing", + "trogon-nats", + "trogon-std", +] + [[package]] name = "acp-nats-stdio" -version = "0.1.0" +version = "0.0.1" dependencies = [ "acp-nats", "acp-telemetry", diff --git a/rsworkspace/crates/acp-nats-agent/Cargo.toml b/rsworkspace/crates/acp-nats-agent/Cargo.toml new file mode 100644 index 000000000..f26c63418 --- /dev/null +++ b/rsworkspace/crates/acp-nats-agent/Cargo.toml @@ -0,0 +1,37 @@ +[package] +name = "acp-nats-agent" +version = "0.0.1" +edition = "2024" + +[lints] +workspace = true + +[dependencies] +acp-nats = { workspace = true } +agent-client-protocol = { workspace = true, features = [ + "unstable_auth_methods", + "unstable_boolean_config", + "unstable_cancel_request", + "unstable_message_id", + "unstable_session_close", + "unstable_session_fork", + "unstable_session_model", + "unstable_session_resume", + "unstable_session_usage", +] } +async-nats = { workspace = true } +async-trait = { workspace = true } +bytes = { workspace = true } +futures = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +tokio = { workspace = true, features = ["rt", "macros", "sync"] } +tracing = { workspace = true } +trogon-nats = { workspace = true } +trogon-std = { workspace = true } + +[dev-dependencies] +tokio = { workspace = true, features = ["test-util"] } +tracing-subscriber = { workspace = true, features = ["fmt"] } +trogon-nats = { workspace = true, features = ["test-support"] } +trogon-std = { workspace = true, features = ["test-support"] } diff --git a/rsworkspace/crates/acp-nats-agent/src/connection.rs b/rsworkspace/crates/acp-nats-agent/src/connection.rs new file mode 100644 index 000000000..4750546dc --- /dev/null +++ b/rsworkspace/crates/acp-nats-agent/src/connection.rs @@ -0,0 +1,879 @@ +use acp_nats::acp_prefix::AcpPrefix; +use acp_nats::client_proxy::NatsClientProxy; +use acp_nats::nats::{AgentMethod, parse_agent_subject}; +use acp_nats::session_id::AcpSessionId; +use agent_client_protocol::{ + Agent, AuthenticateRequest, CancelNotification, CloseSessionRequest, ExtNotification, + ExtRequest, ForkSessionRequest, InitializeRequest, ListSessionsRequest, LoadSessionRequest, + NewSessionRequest, PromptRequest, ResumeSessionRequest, SetSessionConfigOptionRequest, + SetSessionModeRequest, SetSessionModelRequest, +}; +use async_nats::Message; +#[cfg(test)] +use bytes::Bytes; +use futures::StreamExt; +use futures::future::LocalBoxFuture; +use std::rc::Rc; +use std::time::Duration; +use tracing::{info, warn}; +use trogon_nats::{FlushClient, PublishClient, RequestClient, SubscribeClient}; + +pub enum ConnectionError { + Subscribe(Box), +} + +impl std::fmt::Debug for ConnectionError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Subscribe(e) => f.debug_tuple("Subscribe").field(e).finish(), + } + } +} + +impl std::fmt::Display for ConnectionError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Subscribe(e) => write!(f, "failed to subscribe: {}", e), + } + } +} + +impl std::error::Error for ConnectionError {} + +#[derive(Debug)] +enum DispatchError { + NoReplySubject, + DeserializeRequest(serde_json::Error), + DeserializeNotification(serde_json::Error), + Reply(trogon_nats::NatsError), + NotificationHandler(agent_client_protocol::Error), +} + +impl std::fmt::Display for DispatchError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::NoReplySubject => write!(f, "no reply subject"), + Self::DeserializeRequest(e) => write!(f, "deserialize request: {}", e), + Self::DeserializeNotification(e) => write!(f, "deserialize notification: {}", e), + Self::Reply(e) => write!(f, "reply: {}", e), + Self::NotificationHandler(e) => write!(f, "notification handler: {}", e), + } + } +} + +const DEFAULT_OPERATION_TIMEOUT: Duration = Duration::from_secs(30); + +pub struct AgentSideNatsConnection { + nats: N, + acp_prefix: AcpPrefix, + operation_timeout: Duration, +} + +impl AgentSideNatsConnection +where + N: SubscribeClient + RequestClient + PublishClient + FlushClient + Clone + 'static, +{ + pub fn new( + agent: impl Agent + 'static, + nats: N, + acp_prefix: AcpPrefix, + spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static, + ) -> ( + Self, + impl std::future::Future>, + ) { + let nats_for_serve = nats.clone(); + let prefix = acp_prefix.as_str().to_string(); + + let io_task = async move { serve(agent, nats_for_serve, &prefix, spawn).await }; + + let conn = Self { + nats, + acp_prefix, + operation_timeout: DEFAULT_OPERATION_TIMEOUT, + }; + (conn, io_task) + } + + pub fn client_for_session(&self, session_id: AcpSessionId) -> NatsClientProxy { + NatsClientProxy::new( + self.nats.clone(), + session_id, + self.acp_prefix.clone(), + self.operation_timeout, + ) + } +} + +async fn serve( + agent: A, + nats: N, + prefix: &str, + spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static, +) -> Result<(), ConnectionError> +where + N: SubscribeClient + PublishClient + FlushClient + Clone + 'static, + A: Agent + 'static, +{ + // TODO: These two wildcards overlap when session_id == "agent", causing duplicate + // dispatch. A single {prefix}.> avoids duplicates but consumes client messages. + // Revisit the subject topology to eliminate both problems. + let global_wildcard = acp_nats::nats::agent::wildcards::all(prefix); + let session_wildcard = acp_nats::nats::agent::wildcards::all_sessions(prefix); + + info!( + global = %global_wildcard, + session = %session_wildcard, + "Starting agent-side NATS connection" + ); + + let global_sub = nats + .subscribe(global_wildcard) + .await + .map_err(|e| ConnectionError::Subscribe(Box::new(e)))?; + + let session_sub = nats + .subscribe(session_wildcard) + .await + .map_err(|e| ConnectionError::Subscribe(Box::new(e)))?; + + let mut subscriber = futures::stream::select(global_sub, session_sub); + + let agent = Rc::new(agent); + let nats = Rc::new(nats); + + while let Some(msg) = subscriber.next().await { + let agent = agent.clone(); + let nats = nats.clone(); + spawn(Box::pin(async move { + dispatch_message(msg, agent.as_ref(), nats.as_ref()).await; + })); + } + + info!("Agent-side NATS connection ended"); + Ok(()) +} + +async fn dispatch_message( + msg: Message, + agent: &A, + nats: &N, +) { + let subject = msg.subject.as_str(); + + let parsed = match parse_agent_subject(subject) { + Some(p) => p, + None => return, + }; + + let result = match parsed.method { + AgentMethod::Initialize => { + handle_request(&msg, nats, |req: InitializeRequest| agent.initialize(req)).await + } + AgentMethod::Authenticate => { + handle_request(&msg, nats, |req: AuthenticateRequest| { + agent.authenticate(req) + }) + .await + } + AgentMethod::SessionNew => { + handle_request(&msg, nats, |req: NewSessionRequest| agent.new_session(req)).await + } + AgentMethod::SessionList => { + handle_request(&msg, nats, |req: ListSessionsRequest| { + agent.list_sessions(req) + }) + .await + } + AgentMethod::SessionLoad => { + handle_request(&msg, nats, |req: LoadSessionRequest| { + agent.load_session(req) + }) + .await + } + AgentMethod::SessionPrompt => { + handle_request(&msg, nats, |req: PromptRequest| agent.prompt(req)).await + } + AgentMethod::SessionCancel => { + handle_notification(&msg, |req: CancelNotification| agent.cancel(req)).await + } + AgentMethod::SessionSetMode => { + handle_request(&msg, nats, |req: SetSessionModeRequest| { + agent.set_session_mode(req) + }) + .await + } + AgentMethod::SessionSetConfigOption => { + handle_request(&msg, nats, |req: SetSessionConfigOptionRequest| { + agent.set_session_config_option(req) + }) + .await + } + AgentMethod::SessionSetModel => { + handle_request(&msg, nats, |req: SetSessionModelRequest| { + agent.set_session_model(req) + }) + .await + } + AgentMethod::SessionFork => { + handle_request(&msg, nats, |req: ForkSessionRequest| { + agent.fork_session(req) + }) + .await + } + AgentMethod::SessionResume => { + handle_request(&msg, nats, |req: ResumeSessionRequest| { + agent.resume_session(req) + }) + .await + } + AgentMethod::SessionClose => { + handle_request(&msg, nats, |req: CloseSessionRequest| { + agent.close_session(req) + }) + .await + } + AgentMethod::Ext(_) => { + if msg.reply.is_some() { + handle_request(&msg, nats, |req: ExtRequest| agent.ext_method(req)).await + } else { + handle_notification(&msg, |req: ExtNotification| agent.ext_notification(req)).await + } + } + }; + + if let Err(e) = result { + let sid = parsed + .session_id + .as_ref() + .map(|s| s.as_str()) + .unwrap_or("-"); + warn!(subject, session_id = sid, error = %e, "Error handling agent request"); + } +} + +async fn reply( + nats: &N, + reply_to: &str, + value: &T, +) -> Result<(), DispatchError> { + trogon_nats::publish( + nats, + reply_to, + value, + trogon_nats::PublishOptions::builder() + .flush_policy(trogon_nats::FlushPolicy::standard()) + .build(), + ) + .await + .map_err(DispatchError::Reply) +} + +async fn handle_request( + msg: &Message, + nats: &N, + handler: impl FnOnce(ReqT) -> F, +) -> Result<(), DispatchError> +where + N: PublishClient + FlushClient, + ReqT: serde::de::DeserializeOwned, + F: std::future::Future>, + Resp: serde::Serialize, +{ + let reply_to = msg.reply.as_deref().ok_or(DispatchError::NoReplySubject)?; + + let request: ReqT = match serde_json::from_slice(&msg.payload) { + Ok(req) => req, + Err(e) => { + let error = agent_client_protocol::Error::new( + agent_client_protocol::ErrorCode::InvalidParams.into(), + format!("Failed to deserialize request: {}", e), + ); + let _ = reply(nats, reply_to, &error).await; + return Err(DispatchError::DeserializeRequest(e)); + } + }; + + match handler(request).await { + Ok(resp) => reply(nats, reply_to, &resp).await, + Err(err) => reply(nats, reply_to, &err).await, + } +} + +async fn handle_notification( + msg: &Message, + handler: impl FnOnce(ReqT) -> F, +) -> Result<(), DispatchError> +where + ReqT: serde::de::DeserializeOwned, + F: std::future::Future>, +{ + let request: ReqT = + serde_json::from_slice(&msg.payload).map_err(DispatchError::DeserializeNotification)?; + + handler(request) + .await + .map_err(DispatchError::NotificationHandler) +} + +#[cfg(test)] +mod tests { + use super::*; + use agent_client_protocol::{ + AuthenticateResponse, Error as AcpError, ErrorCode, InitializeResponse, PromptResponse, + StopReason, + }; + use std::cell::RefCell; + use trogon_nats::MockNatsClient; + + struct MockAgent { + initialized: RefCell, + cancelled: RefCell>, + } + + impl MockAgent { + fn new() -> Self { + Self { + initialized: RefCell::new(false), + cancelled: RefCell::new(Vec::new()), + } + } + } + + #[async_trait::async_trait(?Send)] + impl Agent for MockAgent { + async fn initialize( + &self, + _args: InitializeRequest, + ) -> agent_client_protocol::Result { + *self.initialized.borrow_mut() = true; + Ok(InitializeResponse::new( + agent_client_protocol::ProtocolVersion::V0, + )) + } + + async fn authenticate( + &self, + _args: AuthenticateRequest, + ) -> agent_client_protocol::Result { + Err(AcpError::method_not_found()) + } + + async fn new_session( + &self, + _args: NewSessionRequest, + ) -> agent_client_protocol::Result { + Ok(agent_client_protocol::NewSessionResponse::new("sess-1")) + } + + async fn prompt( + &self, + _args: PromptRequest, + ) -> agent_client_protocol::Result { + Ok(PromptResponse::new(StopReason::EndTurn)) + } + + async fn cancel(&self, args: CancelNotification) -> agent_client_protocol::Result<()> { + self.cancelled + .borrow_mut() + .push(args.session_id.to_string()); + Ok(()) + } + } + + fn make_nats_message(subject: &str, payload: &[u8], reply: Option<&str>) -> Message { + Message { + subject: subject.into(), + reply: reply.map(|r| r.into()), + payload: Bytes::copy_from_slice(payload), + headers: None, + status: None, + description: None, + length: 0, + } + } + + fn serialize(value: &T) -> Vec { + serde_json::to_vec(value).unwrap() + } + + #[tokio::test] + async fn dispatch_initialize_calls_agent_and_publishes_response() { + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + let payload = serialize(&InitializeRequest::new( + agent_client_protocol::ProtocolVersion::V0, + )); + let msg = make_nats_message("acp.agent.initialize", &payload, Some("_INBOX.1")); + + dispatch_message(msg, &agent, &nats).await; + + assert!(*agent.initialized.borrow()); + let published = nats.published_payloads(); + assert_eq!(published.len(), 1); + let response: InitializeResponse = serde_json::from_slice(&published[0]).unwrap(); + assert_eq!( + response.protocol_version, + agent_client_protocol::ProtocolVersion::V0 + ); + } + + #[tokio::test] + async fn dispatch_authenticate_error_publishes_acp_error() { + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + let payload = serialize(&AuthenticateRequest::new("basic")); + let msg = make_nats_message("acp.agent.authenticate", &payload, Some("_INBOX.2")); + + dispatch_message(msg, &agent, &nats).await; + + let published = nats.published_payloads(); + assert_eq!(published.len(), 1); + let error: AcpError = serde_json::from_slice(&published[0]).unwrap(); + assert_eq!(error.code, ErrorCode::MethodNotFound); + } + + #[tokio::test] + async fn dispatch_cancel_is_notification_no_reply_published() { + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + let payload = serialize(&CancelNotification::new("sess-1")); + let msg = make_nats_message("acp.s1.agent.session.cancel", &payload, None); + + dispatch_message(msg, &agent, &nats).await; + + assert_eq!(agent.cancelled.borrow().len(), 1); + assert!(nats.published_messages().is_empty()); + } + + #[tokio::test] + async fn dispatch_invalid_payload_publishes_error_reply() { + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + let msg = make_nats_message("acp.agent.initialize", b"not json", Some("_INBOX.err")); + + dispatch_message(msg, &agent, &nats).await; + + assert!(!*agent.initialized.borrow()); + let published = nats.published_payloads(); + assert_eq!(published.len(), 1); + let error: AcpError = serde_json::from_slice(&published[0]).unwrap(); + assert_eq!(error.code, ErrorCode::InvalidParams); + } + + #[tokio::test] + async fn dispatch_request_without_reply_subject_does_not_publish() { + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + let payload = serialize(&InitializeRequest::new( + agent_client_protocol::ProtocolVersion::V0, + )); + let msg = make_nats_message("acp.agent.initialize", &payload, None); + + dispatch_message(msg, &agent, &nats).await; + + assert!(nats.published_messages().is_empty()); + } + + #[tokio::test] + async fn dispatch_unknown_subject_is_silently_ignored() { + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + let msg = make_nats_message("acp.something.else", b"{}", Some("_INBOX.1")); + + dispatch_message(msg, &agent, &nats).await; + + assert!(nats.published_messages().is_empty()); + } + + #[tokio::test] + async fn dispatch_prompt_returns_stop_reason() { + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + let payload = serialize(&PromptRequest::new("sess-1", vec![])); + let msg = make_nats_message("acp.s1.agent.session.prompt", &payload, Some("_INBOX.3")); + + dispatch_message(msg, &agent, &nats).await; + + let published = nats.published_payloads(); + assert_eq!(published.len(), 1); + let response: PromptResponse = serde_json::from_slice(&published[0]).unwrap(); + assert_eq!(response.stop_reason, StopReason::EndTurn); + } + + #[tokio::test] + async fn dispatch_publishes_to_correct_reply_subject() { + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + let payload = serialize(&InitializeRequest::new( + agent_client_protocol::ProtocolVersion::V0, + )); + let msg = make_nats_message("acp.agent.initialize", &payload, Some("_INBOX.specific")); + + dispatch_message(msg, &agent, &nats).await; + + assert_eq!(nats.published_messages(), vec!["_INBOX.specific"]); + } + + #[test] + fn connection_error_display() { + let err = ConnectionError::Subscribe(Box::new(std::io::Error::other("test"))); + assert!(err.to_string().contains("failed to subscribe")); + assert!(err.to_string().contains("test")); + } + + #[test] + fn connection_error_debug() { + let err = ConnectionError::Subscribe(Box::new(std::io::Error::other("test"))); + let debug = format!("{:?}", err); + assert!(debug.contains("Subscribe")); + } + + #[test] + fn dispatch_error_display_variants() { + assert_eq!( + DispatchError::NoReplySubject.to_string(), + "no reply subject" + ); + + let json_err = serde_json::from_slice::<()>(b"bad").unwrap_err(); + assert!( + DispatchError::DeserializeRequest(json_err) + .to_string() + .contains("deserialize request") + ); + + let json_err = serde_json::from_slice::<()>(b"bad").unwrap_err(); + assert!( + DispatchError::DeserializeNotification(json_err) + .to_string() + .contains("deserialize notification") + ); + + let acp_err = agent_client_protocol::Error::internal_error(); + assert!( + DispatchError::NotificationHandler(acp_err) + .to_string() + .contains("notification handler") + ); + } + + #[tokio::test] + async fn dispatch_ext_with_reply_calls_ext_method() { + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + let payload = serialize(&agent_client_protocol::ExtRequest::new( + "my_tool", + std::sync::Arc::from( + serde_json::value::RawValue::from_string("{}".to_string()).unwrap(), + ), + )); + let msg = make_nats_message("acp.agent.ext.my_tool", &payload, Some("_INBOX.ext")); + + dispatch_message(msg, &agent, &nats).await; + + assert_eq!(nats.published_messages(), vec!["_INBOX.ext"]); + } + + #[tokio::test] + async fn dispatch_ext_without_reply_calls_ext_notification() { + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + let payload = serialize(&agent_client_protocol::ExtNotification::new( + "my_tool", + std::sync::Arc::from( + serde_json::value::RawValue::from_string("{}".to_string()).unwrap(), + ), + )); + let msg = make_nats_message("acp.agent.ext.my_tool", &payload, None); + + dispatch_message(msg, &agent, &nats).await; + + assert!(nats.published_messages().is_empty()); + } + + #[tokio::test] + async fn dispatch_new_session_publishes_response() { + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + let payload = serialize(&NewSessionRequest::new("/tmp")); + let msg = make_nats_message("acp.agent.session.new", &payload, Some("_INBOX.r")); + + dispatch_message(msg, &agent, &nats).await; + + assert_eq!(nats.published_messages(), vec!["_INBOX.r"]); + } + + #[tokio::test] + async fn dispatch_session_load_publishes_response() { + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + let payload = serialize(&LoadSessionRequest::new("sess-1", "/tmp")); + let msg = make_nats_message("acp.s1.agent.session.load", &payload, Some("_INBOX.r")); + + dispatch_message(msg, &agent, &nats).await; + + assert_eq!(nats.published_messages(), vec!["_INBOX.r"]); + } + + #[tokio::test] + async fn dispatch_list_sessions_publishes_response() { + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + let payload = serialize(&ListSessionsRequest::new()); + let msg = make_nats_message("acp.agent.session.list", &payload, Some("_INBOX.r")); + + dispatch_message(msg, &agent, &nats).await; + + assert_eq!(nats.published_messages(), vec!["_INBOX.r"]); + } + + #[tokio::test] + async fn dispatch_set_session_mode_publishes_response() { + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + let payload = serialize(&SetSessionModeRequest::new("sess-1", "code")); + let msg = make_nats_message("acp.s1.agent.session.set_mode", &payload, Some("_INBOX.r")); + + dispatch_message(msg, &agent, &nats).await; + + assert_eq!(nats.published_messages(), vec!["_INBOX.r"]); + } + + #[tokio::test] + async fn dispatch_set_session_config_option_publishes_response() { + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + let payload = serialize(&SetSessionConfigOptionRequest::new("sess-1", "key", "val")); + let msg = make_nats_message( + "acp.s1.agent.session.set_config_option", + &payload, + Some("_INBOX.r"), + ); + + dispatch_message(msg, &agent, &nats).await; + + assert_eq!(nats.published_messages(), vec!["_INBOX.r"]); + } + + #[tokio::test] + async fn dispatch_set_session_model_publishes_response() { + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + let payload = serialize(&SetSessionModelRequest::new("sess-1", "gpt-4")); + let msg = make_nats_message("acp.s1.agent.session.set_model", &payload, Some("_INBOX.r")); + + dispatch_message(msg, &agent, &nats).await; + + assert_eq!(nats.published_messages(), vec!["_INBOX.r"]); + } + + #[tokio::test] + async fn dispatch_fork_session_publishes_response() { + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + let payload = serialize(&ForkSessionRequest::new("sess-1", "/tmp")); + let msg = make_nats_message("acp.s1.agent.session.fork", &payload, Some("_INBOX.r")); + + dispatch_message(msg, &agent, &nats).await; + + assert_eq!(nats.published_messages(), vec!["_INBOX.r"]); + } + + #[tokio::test] + async fn dispatch_resume_session_publishes_response() { + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + let payload = serialize(&ResumeSessionRequest::new("sess-1", "/tmp")); + let msg = make_nats_message("acp.s1.agent.session.resume", &payload, Some("_INBOX.r")); + + dispatch_message(msg, &agent, &nats).await; + + assert_eq!(nats.published_messages(), vec!["_INBOX.r"]); + } + + #[tokio::test] + async fn dispatch_close_session_publishes_response() { + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + let payload = serialize(&CloseSessionRequest::new("sess-1")); + let msg = make_nats_message("acp.s1.agent.session.close", &payload, Some("_INBOX.r")); + + dispatch_message(msg, &agent, &nats).await; + + assert_eq!(nats.published_messages(), vec!["_INBOX.r"]); + } + + #[test] + fn dispatch_error_display_reply_variant() { + let err = DispatchError::Reply(trogon_nats::NatsError::Timeout { + subject: "test".to_string(), + }); + assert!(err.to_string().contains("reply")); + } + + #[tokio::test] + async fn new_runs_io_task_to_completion() { + let nats = MockNatsClient::new(); + let global_tx = nats.inject_messages(); + let session_tx = nats.inject_messages(); + drop(global_tx); + drop(session_tx); + + let local = tokio::task::LocalSet::new(); + local + .run_until(async { + let (conn, io_task) = AgentSideNatsConnection::new( + MockAgent::new(), + nats, + AcpPrefix::new("acp").unwrap(), + |fut| { + tokio::task::spawn_local(fut); + }, + ); + + assert_eq!(conn.acp_prefix.as_str(), "acp"); + + let result = io_task.await; + assert!(result.is_ok()); + }) + .await; + } + + #[tokio::test] + async fn client_for_session_returns_proxy() { + let nats = MockNatsClient::new(); + let global_tx = nats.inject_messages(); + let session_tx = nats.inject_messages(); + drop(global_tx); + drop(session_tx); + + let local = tokio::task::LocalSet::new(); + local + .run_until(async { + let (conn, io_task) = AgentSideNatsConnection::new( + MockAgent::new(), + nats, + AcpPrefix::new("acp").unwrap(), + |fut| { + tokio::task::spawn_local(fut); + }, + ); + + let _client = conn.client_for_session(AcpSessionId::new("sess-1").unwrap()); + + let result = io_task.await; + assert!(result.is_ok()); + }) + .await; + } + + #[tokio::test] + async fn dispatch_error_logs_warning_with_subscriber() { + use tracing_subscriber::util::SubscriberInitExt; + let _guard = tracing_subscriber::fmt().with_test_writer().set_default(); + + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + let payload = serialize(&InitializeRequest::new( + agent_client_protocol::ProtocolVersion::V0, + )); + let msg = make_nats_message("acp.agent.initialize", &payload, None); + + dispatch_message(msg, &agent, &nats).await; + } + + #[tokio::test] + async fn serve_subscribes_and_dispatches_messages() { + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + + let global_tx = nats.inject_messages(); + let session_tx = nats.inject_messages(); + + let local = tokio::task::LocalSet::new(); + local + .run_until(async { + let payload = serialize(&InitializeRequest::new( + agent_client_protocol::ProtocolVersion::V0, + )); + let msg = Message { + subject: "acp.agent.initialize".into(), + reply: Some("_INBOX.serve".into()), + payload: Bytes::copy_from_slice(&payload), + headers: None, + status: None, + description: None, + length: 0, + }; + + global_tx.unbounded_send(msg).unwrap(); + drop(global_tx); + drop(session_tx); + + let result = serve(agent, nats.clone(), "acp", |fut| { + tokio::task::spawn_local(fut); + }) + .await; + + assert!(result.is_ok()); + + tokio::task::yield_now().await; + tokio::task::yield_now().await; + + assert_eq!(nats.published_messages().len(), 1); + assert_eq!(nats.published_messages()[0], "_INBOX.serve"); + }) + .await; + } + + #[tokio::test] + async fn serve_returns_ok_when_subscription_ends() { + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + + let global_tx = nats.inject_messages(); + let session_tx = nats.inject_messages(); + + drop(global_tx); + drop(session_tx); + + let local = tokio::task::LocalSet::new(); + local + .run_until(async { + let result = serve(agent, nats, "acp", |fut| { + tokio::task::spawn_local(fut); + }) + .await; + + assert!(result.is_ok()); + }) + .await; + } + + #[tokio::test] + async fn serve_subscribes_to_correct_subjects() { + let nats = MockNatsClient::new(); + let agent = MockAgent::new(); + + let global_tx = nats.inject_messages(); + let session_tx = nats.inject_messages(); + + drop(global_tx); + drop(session_tx); + + let local = tokio::task::LocalSet::new(); + local + .run_until(async { + let _ = serve(agent, nats.clone(), "myprefix", |fut| { + tokio::task::spawn_local(fut); + }) + .await; + + let subjects = nats.subscribed_to(); + assert!(subjects.contains(&"myprefix.agent.>".to_string())); + assert!(subjects.contains(&"myprefix.*.agent.>".to_string())); + }) + .await; + } +} diff --git a/rsworkspace/crates/acp-nats-agent/src/lib.rs b/rsworkspace/crates/acp-nats-agent/src/lib.rs new file mode 100644 index 000000000..57282f296 --- /dev/null +++ b/rsworkspace/crates/acp-nats-agent/src/lib.rs @@ -0,0 +1,3 @@ +mod connection; + +pub use connection::{AgentSideNatsConnection, ConnectionError}; diff --git a/rsworkspace/crates/acp-nats/src/client_proxy.rs b/rsworkspace/crates/acp-nats/src/client_proxy.rs new file mode 100644 index 000000000..336f6d1ef --- /dev/null +++ b/rsworkspace/crates/acp-nats/src/client_proxy.rs @@ -0,0 +1,335 @@ +use crate::acp_prefix::AcpPrefix; +use crate::nats::client_subjects; +use crate::session_id::AcpSessionId; +use agent_client_protocol::{ + Client, CreateTerminalRequest, CreateTerminalResponse, Error, ErrorCode, KillTerminalRequest, + KillTerminalResponse, ReadTextFileRequest, ReadTextFileResponse, ReleaseTerminalRequest, + ReleaseTerminalResponse, RequestPermissionRequest, RequestPermissionResponse, Result, + SessionNotification, TerminalOutputRequest, TerminalOutputResponse, WaitForTerminalExitRequest, + WaitForTerminalExitResponse, WriteTextFileRequest, WriteTextFileResponse, +}; +use std::time::Duration; +use trogon_nats::{FlushClient, PublishClient, RequestClient, publish, request_with_timeout}; + +pub struct NatsClientProxy { + nats: N, + session_id: AcpSessionId, + prefix: AcpPrefix, + timeout: Duration, +} + +impl NatsClientProxy { + pub fn new(nats: N, session_id: AcpSessionId, prefix: AcpPrefix, timeout: Duration) -> Self { + Self { + nats, + session_id, + prefix, + timeout, + } + } +} + +fn to_acp_error(e: impl std::fmt::Display) -> Error { + Error::new(ErrorCode::InternalError.into(), e.to_string()) +} + +impl NatsClientProxy { + fn prefix(&self) -> &str { + self.prefix.as_str() + } + + fn session_id(&self) -> &str { + self.session_id.as_str() + } + + async fn request( + &self, + subject: &str, + args: &Req, + ) -> Result { + request_with_timeout(&self.nats, subject, args, self.timeout) + .await + .map_err(to_acp_error) + } + + async fn notify(&self, subject: &str, args: &Req) -> Result<()> { + publish( + &self.nats, + subject, + args, + trogon_nats::PublishOptions::default(), + ) + .await + .map_err(to_acp_error) + } +} + +#[async_trait::async_trait(?Send)] +impl Client for NatsClientProxy { + async fn request_permission( + &self, + args: RequestPermissionRequest, + ) -> Result { + let s = client_subjects::session_request_permission(self.prefix(), self.session_id()); + self.request(&s, &args).await + } + + async fn session_notification(&self, args: SessionNotification) -> Result<()> { + let s = client_subjects::session_update(self.prefix(), self.session_id()); + self.notify(&s, &args).await + } + + async fn read_text_file(&self, args: ReadTextFileRequest) -> Result { + let s = client_subjects::fs_read_text_file(self.prefix(), self.session_id()); + self.request(&s, &args).await + } + + async fn write_text_file(&self, args: WriteTextFileRequest) -> Result { + let s = client_subjects::fs_write_text_file(self.prefix(), self.session_id()); + self.request(&s, &args).await + } + + async fn create_terminal(&self, args: CreateTerminalRequest) -> Result { + let s = client_subjects::terminal_create(self.prefix(), self.session_id()); + self.request(&s, &args).await + } + + async fn terminal_output(&self, args: TerminalOutputRequest) -> Result { + let s = client_subjects::terminal_output(self.prefix(), self.session_id()); + self.request(&s, &args).await + } + + async fn release_terminal( + &self, + args: ReleaseTerminalRequest, + ) -> Result { + let s = client_subjects::terminal_release(self.prefix(), self.session_id()); + self.request(&s, &args).await + } + + async fn wait_for_terminal_exit( + &self, + args: WaitForTerminalExitRequest, + ) -> Result { + let s = client_subjects::terminal_wait_for_exit(self.prefix(), self.session_id()); + self.request(&s, &args).await + } + + async fn kill_terminal(&self, args: KillTerminalRequest) -> Result { + let s = client_subjects::terminal_kill(self.prefix(), self.session_id()); + self.request(&s, &args).await + } +} + +#[cfg(test)] +mod tests { + use super::*; + use agent_client_protocol::{ + Client, ContentBlock, ContentChunk, ReadTextFileResponse, RequestPermissionOutcome, + RequestPermissionResponse, SessionNotification, SessionUpdate, ToolCallUpdate, + ToolCallUpdateFields, + }; + use trogon_nats::AdvancedMockNatsClient; + + fn proxy(nats: AdvancedMockNatsClient) -> NatsClientProxy { + NatsClientProxy::new( + nats, + AcpSessionId::new("s1").unwrap(), + AcpPrefix::new("acp").unwrap(), + Duration::from_secs(5), + ) + } + + #[tokio::test] + async fn request_permission_publishes_to_correct_subject() { + let nats = AdvancedMockNatsClient::new(); + let response = RequestPermissionResponse::new(RequestPermissionOutcome::Cancelled); + nats.set_response( + "acp.s1.client.session.request_permission", + serde_json::to_vec(&response).unwrap().into(), + ); + + let p = proxy(nats.clone()); + let tool_call = ToolCallUpdate::new("tc-1", ToolCallUpdateFields::new()); + let result = p + .request_permission(RequestPermissionRequest::new("s1", tool_call, vec![])) + .await; + + assert!(result.is_ok()); + assert_eq!(result.unwrap().outcome, RequestPermissionOutcome::Cancelled); + } + + #[tokio::test] + async fn session_notification_publishes_to_correct_subject() { + let nats = AdvancedMockNatsClient::new(); + let p = proxy(nats.clone()); + + let notif = SessionNotification::new( + "s1", + SessionUpdate::AgentMessageChunk(ContentChunk::new(ContentBlock::from("hello"))), + ); + let result = p.session_notification(notif).await; + + assert!(result.is_ok()); + assert_eq!( + nats.published_messages(), + vec!["acp.s1.client.session.update"] + ); + } + + #[tokio::test] + async fn read_text_file_publishes_to_correct_subject() { + let nats = AdvancedMockNatsClient::new(); + let response = ReadTextFileResponse::new("file contents"); + nats.set_response( + "acp.s1.client.fs.read_text_file", + serde_json::to_vec(&response).unwrap().into(), + ); + + let p = proxy(nats.clone()); + let result = p + .read_text_file(ReadTextFileRequest::new("s1", "/test.txt")) + .await; + + assert!(result.is_ok()); + assert_eq!(result.unwrap().content, "file contents"); + } + + #[tokio::test] + async fn request_returns_error_when_nats_fails() { + let nats = AdvancedMockNatsClient::new(); + nats.fail_next_request(); + + let p = proxy(nats); + let result = p + .read_text_file(ReadTextFileRequest::new("s1", "/test.txt")) + .await; + + assert!(result.is_err()); + assert_eq!(result.unwrap_err().code, ErrorCode::InternalError); + } + + #[tokio::test] + async fn write_text_file_publishes_to_correct_subject() { + let nats = AdvancedMockNatsClient::new(); + let response = agent_client_protocol::WriteTextFileResponse::default(); + nats.set_response( + "acp.s1.client.fs.write_text_file", + serde_json::to_vec(&response).unwrap().into(), + ); + + let p = proxy(nats.clone()); + let result = p + .write_text_file(WriteTextFileRequest::new("s1", "/test.txt", "content")) + .await; + + assert!(result.is_ok()); + } + + #[tokio::test] + async fn create_terminal_publishes_to_correct_subject() { + let nats = AdvancedMockNatsClient::new(); + let response = CreateTerminalResponse::new("t1"); + nats.set_response( + "acp.s1.client.terminal.create", + serde_json::to_vec(&response).unwrap().into(), + ); + + let p = proxy(nats.clone()); + let result = p + .create_terminal(CreateTerminalRequest::new("s1", "echo")) + .await; + + assert!(result.is_ok()); + } + + #[tokio::test] + async fn terminal_output_publishes_to_correct_subject() { + let nats = AdvancedMockNatsClient::new(); + let response = TerminalOutputResponse::new("output", false); + nats.set_response( + "acp.s1.client.terminal.output", + serde_json::to_vec(&response).unwrap().into(), + ); + + let p = proxy(nats.clone()); + let result = p + .terminal_output(TerminalOutputRequest::new("s1", "t1")) + .await; + + assert!(result.is_ok()); + } + + #[tokio::test] + async fn release_terminal_publishes_to_correct_subject() { + let nats = AdvancedMockNatsClient::new(); + let response = ReleaseTerminalResponse::default(); + nats.set_response( + "acp.s1.client.terminal.release", + serde_json::to_vec(&response).unwrap().into(), + ); + + let p = proxy(nats.clone()); + let result = p + .release_terminal(ReleaseTerminalRequest::new("s1", "t1")) + .await; + + assert!(result.is_ok()); + } + + #[tokio::test] + async fn kill_terminal_publishes_to_correct_subject() { + let nats = AdvancedMockNatsClient::new(); + let response = KillTerminalResponse::default(); + nats.set_response( + "acp.s1.client.terminal.kill", + serde_json::to_vec(&response).unwrap().into(), + ); + + let p = proxy(nats.clone()); + let result = p.kill_terminal(KillTerminalRequest::new("s1", "t1")).await; + + assert!(result.is_ok()); + } + + #[tokio::test] + async fn wait_for_terminal_exit_publishes_to_correct_subject() { + let nats = AdvancedMockNatsClient::new(); + let response = + WaitForTerminalExitResponse::new(agent_client_protocol::TerminalExitStatus::new()); + nats.set_response( + "acp.s1.client.terminal.wait_for_exit", + serde_json::to_vec(&response).unwrap().into(), + ); + + let p = proxy(nats.clone()); + let result = p + .wait_for_terminal_exit(WaitForTerminalExitRequest::new("s1", "t1")) + .await; + + assert!(result.is_ok()); + } + + #[tokio::test] + async fn notification_returns_error_when_publish_fails() { + let nats = AdvancedMockNatsClient::new(); + nats.fail_next_publish(); + + let p = proxy(nats); + let notif = SessionNotification::new( + "s1", + SessionUpdate::AgentMessageChunk(ContentChunk::new(ContentBlock::from("hello"))), + ); + let result = p.session_notification(notif).await; + + assert!(result.is_err()); + assert_eq!(result.unwrap_err().code, ErrorCode::InternalError); + } + + #[test] + fn to_acp_error_preserves_message() { + let err = to_acp_error("something went wrong"); + assert_eq!(err.code, ErrorCode::InternalError); + assert_eq!(err.message, "something went wrong"); + } +} diff --git a/rsworkspace/crates/acp-nats/src/lib.rs b/rsworkspace/crates/acp-nats/src/lib.rs index 237fdef1a..d8f8f8924 100644 --- a/rsworkspace/crates/acp-nats/src/lib.rs +++ b/rsworkspace/crates/acp-nats/src/lib.rs @@ -1,9 +1,10 @@ pub mod acp_prefix; pub mod agent; pub mod client; +pub mod client_proxy; pub mod config; pub mod error; -pub(crate) mod ext_method_name; +pub mod ext_method_name; pub(crate) mod in_flight_slot_guard; pub(crate) mod jsonrpc; pub mod nats; @@ -15,6 +16,7 @@ pub(crate) mod telemetry; pub use acp_prefix::{AcpPrefix, AcpPrefixError}; pub use agent::Bridge; pub use agent::REQ_ID_HEADER; +pub use client_proxy::NatsClientProxy; pub use config::{ Config, DEFAULT_ACP_PREFIX, ENV_ACP_PREFIX, apply_timeout_overrides, nats_connect_timeout, }; diff --git a/rsworkspace/crates/acp-nats/src/nats/mod.rs b/rsworkspace/crates/acp-nats/src/nats/mod.rs index bf7f5a55f..b92d07954 100644 --- a/rsworkspace/crates/acp-nats/src/nats/mod.rs +++ b/rsworkspace/crates/acp-nats/src/nats/mod.rs @@ -4,8 +4,11 @@ mod subjects; pub(crate) mod token; pub use extensions::ExtSessionReady; -pub use parsing::{ClientMethod, ParsedClientSubject, parse_client_subject}; -pub use subjects::{agent, client as client_subjects}; +pub use parsing::{ + AgentMethod, ClientMethod, ParsedAgentSubject, ParsedClientSubject, parse_agent_subject, + parse_client_subject, +}; +pub use subjects::{agent, client as client_subjects, wildcards}; pub use trogon_nats::{ FlushClient, FlushPolicy, PublishClient, PublishOptions, RequestClient, RetryPolicy, SubscribeClient, client, connect, headers_with_trace_context, inject_trace_context, publish, diff --git a/rsworkspace/crates/acp-nats/src/nats/parsing.rs b/rsworkspace/crates/acp-nats/src/nats/parsing.rs index 348917e5b..d6934119d 100644 --- a/rsworkspace/crates/acp-nats/src/nats/parsing.rs +++ b/rsworkspace/crates/acp-nats/src/nats/parsing.rs @@ -1,6 +1,91 @@ use crate::ext_method_name::ExtMethodName; use crate::session_id::AcpSessionId; +const AGENT_MARKER: &str = ".agent."; +const AGENT_EXT_PREFIX: &str = "agent.ext."; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum AgentMethod { + Initialize, + Authenticate, + SessionNew, + SessionList, + SessionLoad, + SessionPrompt, + SessionCancel, + SessionSetMode, + SessionSetConfigOption, + SessionSetModel, + SessionFork, + SessionResume, + SessionClose, + Ext(ExtMethodName), +} + +impl AgentMethod { + pub fn is_session_scoped(&self) -> bool { + !matches!( + self, + Self::Initialize + | Self::Authenticate + | Self::SessionNew + | Self::SessionList + | Self::Ext(_) + ) + } + + fn from_suffix(suffix: &str) -> Option { + match suffix { + "agent.initialize" => Some(Self::Initialize), + "agent.authenticate" => Some(Self::Authenticate), + "agent.session.new" => Some(Self::SessionNew), + "agent.session.list" => Some(Self::SessionList), + "agent.session.load" => Some(Self::SessionLoad), + "agent.session.prompt" => Some(Self::SessionPrompt), + "agent.session.cancel" => Some(Self::SessionCancel), + "agent.session.set_mode" => Some(Self::SessionSetMode), + "agent.session.set_config_option" => Some(Self::SessionSetConfigOption), + "agent.session.set_model" => Some(Self::SessionSetModel), + "agent.session.fork" => Some(Self::SessionFork), + "agent.session.resume" => Some(Self::SessionResume), + "agent.session.close" => Some(Self::SessionClose), + other => { + let ext_name = other.strip_prefix(AGENT_EXT_PREFIX)?; + Some(Self::Ext(ExtMethodName::new(ext_name).ok()?)) + } + } + } +} + +#[derive(Debug)] +pub struct ParsedAgentSubject { + pub session_id: Option, + pub method: AgentMethod, +} + +pub fn parse_agent_subject(subject: &str) -> Option { + for (agent_byte_pos, _) in subject.match_indices(AGENT_MARKER) { + let suffix = &subject[agent_byte_pos + 1..]; + let method = match AgentMethod::from_suffix(suffix) { + Some(m) => m, + None => continue, + }; + + let session_id = if method.is_session_scoped() { + let before_agent = &subject[..agent_byte_pos]; + let session_dot = before_agent.rfind('.')?; + let raw = &before_agent[session_dot + 1..]; + Some(AcpSessionId::new(raw).ok()?) + } else { + None + }; + + return Some(ParsedAgentSubject { session_id, method }); + } + + None +} + /// NATS subject prefix for generic extension methods. /// `client.ext.{name}` — the `ext` token makes extensions explicit in subjects. /// `ExtSessionPromptResponse` is matched first as a specific ext, so it won't @@ -62,12 +147,6 @@ pub fn parse_client_subject(subject: &str) -> Option { let method = ClientMethod::from_subject_suffix(suffix)?; - if let ClientMethod::Ext(ref name) = method - && name.contains(".client.") - { - return None; - } - Some(ParsedClientSubject { session_id, method }) } @@ -356,4 +435,190 @@ mod tests { assert_eq!(ClientMethod::FsReadTextFile, ClientMethod::FsReadTextFile); assert_ne!(ClientMethod::FsReadTextFile, ClientMethod::FsWriteTextFile); } + + #[test] + fn test_agent_parse_initialize() { + let parsed = parse_agent_subject("acp.agent.initialize").unwrap(); + assert!(parsed.session_id.is_none()); + assert_eq!(parsed.method, AgentMethod::Initialize); + } + + #[test] + fn test_agent_parse_authenticate() { + let parsed = parse_agent_subject("acp.agent.authenticate").unwrap(); + assert!(parsed.session_id.is_none()); + assert_eq!(parsed.method, AgentMethod::Authenticate); + } + + #[test] + fn test_agent_parse_session_new() { + let parsed = parse_agent_subject("acp.agent.session.new").unwrap(); + assert!(parsed.session_id.is_none()); + assert_eq!(parsed.method, AgentMethod::SessionNew); + } + + #[test] + fn test_agent_parse_session_list() { + let parsed = parse_agent_subject("acp.agent.session.list").unwrap(); + assert!(parsed.session_id.is_none()); + assert_eq!(parsed.method, AgentMethod::SessionList); + } + + #[test] + fn test_agent_parse_session_load() { + let parsed = parse_agent_subject("acp.s1.agent.session.load").unwrap(); + assert_eq!(parsed.session_id.unwrap().as_str(), "s1"); + assert_eq!(parsed.method, AgentMethod::SessionLoad); + } + + #[test] + fn test_agent_parse_session_prompt() { + let parsed = parse_agent_subject("acp.s1.agent.session.prompt").unwrap(); + assert_eq!(parsed.session_id.unwrap().as_str(), "s1"); + assert_eq!(parsed.method, AgentMethod::SessionPrompt); + } + + #[test] + fn test_agent_parse_session_cancel() { + let parsed = parse_agent_subject("acp.s1.agent.session.cancel").unwrap(); + assert_eq!(parsed.session_id.unwrap().as_str(), "s1"); + assert_eq!(parsed.method, AgentMethod::SessionCancel); + } + + #[test] + fn test_agent_parse_session_set_mode() { + let parsed = parse_agent_subject("acp.s1.agent.session.set_mode").unwrap(); + assert_eq!(parsed.session_id.unwrap().as_str(), "s1"); + assert_eq!(parsed.method, AgentMethod::SessionSetMode); + } + + #[test] + fn test_agent_parse_session_set_config_option() { + let parsed = parse_agent_subject("acp.s1.agent.session.set_config_option").unwrap(); + assert_eq!(parsed.session_id.unwrap().as_str(), "s1"); + assert_eq!(parsed.method, AgentMethod::SessionSetConfigOption); + } + + #[test] + fn test_agent_parse_session_set_model() { + let parsed = parse_agent_subject("acp.s1.agent.session.set_model").unwrap(); + assert_eq!(parsed.session_id.unwrap().as_str(), "s1"); + assert_eq!(parsed.method, AgentMethod::SessionSetModel); + } + + #[test] + fn test_agent_parse_session_fork() { + let parsed = parse_agent_subject("acp.s1.agent.session.fork").unwrap(); + assert_eq!(parsed.session_id.unwrap().as_str(), "s1"); + assert_eq!(parsed.method, AgentMethod::SessionFork); + } + + #[test] + fn test_agent_parse_session_resume() { + let parsed = parse_agent_subject("acp.s1.agent.session.resume").unwrap(); + assert_eq!(parsed.session_id.unwrap().as_str(), "s1"); + assert_eq!(parsed.method, AgentMethod::SessionResume); + } + + #[test] + fn test_agent_parse_session_close() { + let parsed = parse_agent_subject("acp.s1.agent.session.close").unwrap(); + assert_eq!(parsed.session_id.unwrap().as_str(), "s1"); + assert_eq!(parsed.method, AgentMethod::SessionClose); + } + + #[test] + fn test_agent_parse_ext_method() { + let parsed = parse_agent_subject("acp.agent.ext.my_tool").unwrap(); + assert!(parsed.session_id.is_none()); + assert_eq!( + parsed.method, + AgentMethod::Ext(ExtMethodName::new("my_tool").unwrap()) + ); + } + + #[test] + fn test_agent_parse_ext_dotted_namespace() { + let parsed = parse_agent_subject("acp.agent.ext.vendor.operation").unwrap(); + assert!(parsed.session_id.is_none()); + assert_eq!( + parsed.method, + AgentMethod::Ext(ExtMethodName::new("vendor.operation").unwrap()) + ); + } + + #[test] + fn test_agent_parse_custom_prefix() { + let parsed = parse_agent_subject("myapp.agent.initialize").unwrap(); + assert!(parsed.session_id.is_none()); + assert_eq!(parsed.method, AgentMethod::Initialize); + } + + #[test] + fn test_agent_parse_multi_part_prefix() { + let parsed = parse_agent_subject("my.multi.prefix.s1.agent.session.load").unwrap(); + assert_eq!(parsed.session_id.unwrap().as_str(), "s1"); + assert_eq!(parsed.method, AgentMethod::SessionLoad); + } + + #[test] + fn test_agent_parse_empty_returns_none() { + assert!(parse_agent_subject("").is_none()); + } + + #[test] + fn test_agent_parse_no_agent_marker_returns_none() { + assert!(parse_agent_subject("acp.client.session.update").is_none()); + } + + #[test] + fn test_agent_parse_unknown_method_returns_none() { + assert!(parse_agent_subject("acp.agent.unknown.method").is_none()); + } + + #[test] + fn test_agent_parse_invalid_session_id_returns_none() { + assert!(parse_agent_subject("acp.sess*ion.agent.session.load").is_none()); + } + + #[test] + fn test_agent_parse_ext_empty_name_returns_none() { + assert!(parse_agent_subject("acp.agent.ext.").is_none()); + } + + #[test] + fn test_agent_parse_ext_wildcard_returns_none() { + assert!(parse_agent_subject("acp.agent.ext.*").is_none()); + } + + #[test] + fn test_agent_parse_multi_dot_prefix_global_method_has_no_session() { + let parsed = parse_agent_subject("my.multi.agent.initialize").unwrap(); + assert!(parsed.session_id.is_none()); + assert_eq!(parsed.method, AgentMethod::Initialize); + } + + #[test] + fn test_agent_parse_prefix_containing_agent_word() { + let parsed = parse_agent_subject("org.agent.app.agent.initialize").unwrap(); + assert!(parsed.session_id.is_none()); + assert_eq!(parsed.method, AgentMethod::Initialize); + } + + #[test] + fn test_agent_parse_ext_method_containing_agent_segment() { + let parsed = parse_agent_subject("acp.agent.ext.agent.foo").unwrap(); + assert!(parsed.session_id.is_none()); + assert_eq!( + parsed.method, + AgentMethod::Ext(ExtMethodName::new("agent.foo").unwrap()) + ); + } + + #[test] + fn test_agent_parse_multi_dot_prefix_session_scoped() { + let parsed = parse_agent_subject("my.multi.s1.agent.session.prompt").unwrap(); + assert_eq!(parsed.session_id.unwrap().as_str(), "s1"); + assert_eq!(parsed.method, AgentMethod::SessionPrompt); + } } diff --git a/rsworkspace/crates/acp-nats/src/nats/subjects.rs b/rsworkspace/crates/acp-nats/src/nats/subjects.rs index 39d1ad2ce..5c82aaf6b 100644 --- a/rsworkspace/crates/acp-nats/src/nats/subjects.rs +++ b/rsworkspace/crates/acp-nats/src/nats/subjects.rs @@ -77,9 +77,58 @@ pub mod agent { pub fn ext(prefix: &str, method: &str) -> String { format!("{}.agent.ext.{}", prefix, method) } + + pub mod wildcards { + pub fn all(prefix: &str) -> String { + format!("{}.agent.>", prefix) + } + + pub fn all_sessions(prefix: &str) -> String { + format!("{}.*.agent.>", prefix) + } + } } pub mod client { + pub fn fs_read_text_file(prefix: &str, session_id: &str) -> String { + format!("{}.{}.client.fs.read_text_file", prefix, session_id) + } + + pub fn fs_write_text_file(prefix: &str, session_id: &str) -> String { + format!("{}.{}.client.fs.write_text_file", prefix, session_id) + } + + pub fn session_request_permission(prefix: &str, session_id: &str) -> String { + format!( + "{}.{}.client.session.request_permission", + prefix, session_id + ) + } + + pub fn session_update(prefix: &str, session_id: &str) -> String { + format!("{}.{}.client.session.update", prefix, session_id) + } + + pub fn terminal_create(prefix: &str, session_id: &str) -> String { + format!("{}.{}.client.terminal.create", prefix, session_id) + } + + pub fn terminal_kill(prefix: &str, session_id: &str) -> String { + format!("{}.{}.client.terminal.kill", prefix, session_id) + } + + pub fn terminal_output(prefix: &str, session_id: &str) -> String { + format!("{}.{}.client.terminal.output", prefix, session_id) + } + + pub fn terminal_release(prefix: &str, session_id: &str) -> String { + format!("{}.{}.client.terminal.release", prefix, session_id) + } + + pub fn terminal_wait_for_exit(prefix: &str, session_id: &str) -> String { + format!("{}.{}.client.terminal.wait_for_exit", prefix, session_id) + } + pub mod wildcards { pub fn all(prefix: &str) -> String { format!("{}.*.client.>", prefix) @@ -87,9 +136,15 @@ pub mod client { } } +pub mod wildcards { + pub fn all(prefix: &str) -> String { + format!("{}.>", prefix) + } +} + #[cfg(test)] mod tests { - use super::{agent, client}; + use super::{agent, client, wildcards}; #[test] fn initialize_subject() { @@ -253,8 +308,210 @@ mod tests { assert!(agent::ext_session_ready(prefix, sid).starts_with(&expected_prefix)); } + #[test] + fn ext_subject() { + assert_eq!(agent::ext("acp", "my_tool"), "acp.agent.ext.my_tool"); + } + + #[test] + fn agent_wildcard_all_subject() { + assert_eq!(agent::wildcards::all("acp"), "acp.agent.>"); + } + + #[test] + fn agent_wildcard_all_sessions_subject() { + assert_eq!(agent::wildcards::all_sessions("acp"), "acp.*.agent.>"); + } + + #[test] + fn agent_wildcards_overlap_when_session_id_is_agent() { + let subject = "acp.agent.agent.session.load"; + let global = agent::wildcards::all("acp"); + let session = agent::wildcards::all_sessions("acp"); + + // Both wildcards match this subject in NATS: + // - "acp.agent.>" matches because "agent.session.load" falls under "acp.agent." + // - "acp.*.agent.>" matches because * = "agent", rest = "session.load" + // This is a known trade-off: using two subscriptions avoids consuming + // client messages, but causes duplicate delivery for session ID "agent". + assert!(nats_wildcard_matches(&global, subject)); + assert!(nats_wildcard_matches(&session, subject)); + } + + fn nats_wildcard_matches(pattern: &str, subject: &str) -> bool { + let pattern_parts: Vec<&str> = pattern.split('.').collect(); + let subject_parts: Vec<&str> = subject.split('.').collect(); + nats_match(&pattern_parts, &subject_parts) + } + + fn nats_match(pattern: &[&str], subject: &[&str]) -> bool { + match (pattern.first(), subject.first()) { + (Some(&">"), _) => true, + (Some(&"*"), Some(_)) => nats_match(&pattern[1..], &subject[1..]), + (Some(p), Some(s)) if p == s => nats_match(&pattern[1..], &subject[1..]), + (None, None) => true, + _ => false, + } + } + + #[test] + fn nats_match_exact_match() { + assert!(nats_wildcard_matches( + "acp.agent.initialize", + "acp.agent.initialize" + )); + } + + #[test] + fn nats_match_no_match() { + assert!(!nats_wildcard_matches( + "acp.agent.initialize", + "acp.agent.authenticate" + )); + } + + #[test] + fn nats_match_length_mismatch() { + assert!(!nats_wildcard_matches("acp.agent", "acp.agent.initialize")); + } + + #[test] + fn prefix_wildcard_all() { + assert_eq!(wildcards::all("acp"), "acp.>"); + } + + #[test] + fn prefix_wildcard_custom_prefix() { + assert_eq!(wildcards::all("myapp"), "myapp.>"); + } + + #[test] + fn agent_wildcard_custom_prefix() { + assert_eq!(agent::wildcards::all("myapp"), "myapp.agent.>"); + assert_eq!(agent::wildcards::all_sessions("myapp"), "myapp.*.agent.>"); + } + + #[test] + fn ext_subject_dotted_method() { + assert_eq!( + agent::ext("acp", "vendor.operation"), + "acp.agent.ext.vendor.operation" + ); + } + + #[test] + fn ext_subject_custom_prefix() { + assert_eq!(agent::ext("myapp", "my_tool"), "myapp.agent.ext.my_tool"); + } + + #[test] + fn client_fs_read_text_file_subject() { + assert_eq!( + client::fs_read_text_file("acp", "s1"), + "acp.s1.client.fs.read_text_file" + ); + } + + #[test] + fn client_fs_write_text_file_subject() { + assert_eq!( + client::fs_write_text_file("acp", "s1"), + "acp.s1.client.fs.write_text_file" + ); + } + + #[test] + fn client_session_request_permission_subject() { + assert_eq!( + client::session_request_permission("acp", "s1"), + "acp.s1.client.session.request_permission" + ); + } + + #[test] + fn client_session_update_subject() { + assert_eq!( + client::session_update("acp", "s1"), + "acp.s1.client.session.update" + ); + } + + #[test] + fn client_terminal_create_subject() { + assert_eq!( + client::terminal_create("acp", "s1"), + "acp.s1.client.terminal.create" + ); + } + + #[test] + fn client_terminal_kill_subject() { + assert_eq!( + client::terminal_kill("acp", "s1"), + "acp.s1.client.terminal.kill" + ); + } + + #[test] + fn client_terminal_output_subject() { + assert_eq!( + client::terminal_output("acp", "s1"), + "acp.s1.client.terminal.output" + ); + } + + #[test] + fn client_terminal_release_subject() { + assert_eq!( + client::terminal_release("acp", "s1"), + "acp.s1.client.terminal.release" + ); + } + + #[test] + fn client_terminal_wait_for_exit_subject() { + assert_eq!( + client::terminal_wait_for_exit("acp", "s1"), + "acp.s1.client.terminal.wait_for_exit" + ); + } + #[test] fn client_wildcard_all_subject() { assert_eq!(client::wildcards::all("acp"), "acp.*.client.>"); } + + #[test] + fn client_wildcard_custom_prefix() { + assert_eq!(client::wildcards::all("myapp"), "myapp.*.client.>"); + } + + #[test] + fn client_subjects_share_token_layout() { + let prefix = "acp"; + let sid = "abc"; + let expected_prefix = format!("{}.{}.client.", prefix, sid); + + assert!(client::fs_read_text_file(prefix, sid).starts_with(&expected_prefix)); + assert!(client::fs_write_text_file(prefix, sid).starts_with(&expected_prefix)); + assert!(client::session_request_permission(prefix, sid).starts_with(&expected_prefix)); + assert!(client::session_update(prefix, sid).starts_with(&expected_prefix)); + assert!(client::terminal_create(prefix, sid).starts_with(&expected_prefix)); + assert!(client::terminal_kill(prefix, sid).starts_with(&expected_prefix)); + assert!(client::terminal_output(prefix, sid).starts_with(&expected_prefix)); + assert!(client::terminal_release(prefix, sid).starts_with(&expected_prefix)); + assert!(client::terminal_wait_for_exit(prefix, sid).starts_with(&expected_prefix)); + } + + #[test] + fn client_subjects_custom_prefix() { + assert_eq!( + client::fs_read_text_file("myapp", "sess-42"), + "myapp.sess-42.client.fs.read_text_file" + ); + assert_eq!( + client::terminal_create("myapp", "sess-42"), + "myapp.sess-42.client.terminal.create" + ); + } }