diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index 57a0bbcd7..b946eb962 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -31,6 +31,7 @@ - Java single test: `cd java && mvn test -Dtest=CopilotClientTest` | single method: `mvn test -Dtest=ToolsTest#testToolInvocation` - Java format check only: `mvn spotless:check` | Build without tests: `mvn clean package -DskipTests` - **Java testing note:** Always use `mvn verify` without `-q` and without piping through `grep`. Never add `InternalsVisibleTo` equivalent — tests must only access public APIs. +- Use configured LSPs for supported operations like finding references instead of pattern matching, renaming symbols, etc. ## Testing & E2E tips ⚙️ diff --git a/.github/lsp.json b/.github/lsp.json index e58456ac4..753521284 100644 --- a/.github/lsp.json +++ b/.github/lsp.json @@ -21,6 +21,24 @@ ".go": "go" }, "rootUri": "go" + }, + "rust-analyzer": { + "command": "rust-analyzer", + "fileExtensions": { + ".rs": "rust" + }, + "initializationOptions": { + "cargo": { + "buildScripts": { + "enable": true + }, + "allFeatures": true + }, + "checkOnSave": true, + "check": { + "command": "clippy" + } + } } } } diff --git a/.github/skills/rust-coding-skill/SKILL.md b/.github/skills/rust-coding-skill/SKILL.md index 7e0342f06..b33cd2c43 100644 --- a/.github/skills/rust-coding-skill/SKILL.md +++ b/.github/skills/rust-coding-skill/SKILL.md @@ -13,11 +13,10 @@ Opinionated Rust rules for the Copilot Rust SDK (`rust/`). Priority order: ## Error handling -The SDK's public error type is `crate::Error` (`rust/src/error.rs`). Add new -variants there rather than introducing parallel error enums per module — every -public failure mode is part of the API contract and should be expressible in one -type. Internal modules can use `thiserror` enums when a richer local taxonomy -helps; convert at the boundary. +The SDK's public error type is `crate::Error` (`rust/src/errors.rs`). Add new +variants to `crate::ErrorKind` rather than introducing parallel error enums +per module — every public failure mode is part of the API contract and should +be expressible in one type. `anyhow` is reserved for binaries and example code. Library code never returns `anyhow::Result` — callers can't pattern-match on `anyhow::Error`, so it would @@ -42,7 +41,7 @@ it on shutdown. Fire-and-forget spawns silently swallow panics and outlive the session; don't. Blocking calls (filesystem, subprocess wait) belong in -`tokio::task::spawn_blocking`, *not* on the async runtime. The blocking pool is +`tokio::task::spawn_blocking`, _not_ on the async runtime. The blocking pool is bounded, so for genuinely long-lived workers (think: file watchers that run for the lifetime of a session) prefer `std::thread::spawn` with a channel back into async land. @@ -81,12 +80,12 @@ Trivial field re-shaping is best inlined. Closures should stay short (under ~10 **Channels, not callback closures, for event flow.** Closures fight `Send + Sync + 'static` and don't compose with `select!`. Channel choice by semantics: -| Use case | Primitive | -|---|---| -| One producer → one consumer with backpressure | `tokio::sync::mpsc` (cap 1) or `tokio::sync::oneshot` for single value | -| Many producers → one consumer | `tokio::sync::mpsc` | -| One producer → many consumers, every event delivered (pub/sub) | `tokio::sync::broadcast` | -| One producer → many consumers, only the latest value matters | `tokio::sync::watch` | +| Use case | Primitive | +| -------------------------------------------------------------- | ---------------------------------------------------------------------- | +| One producer → one consumer with backpressure | `tokio::sync::mpsc` (cap 1) or `tokio::sync::oneshot` for single value | +| Many producers → one consumer | `tokio::sync::mpsc` | +| One producer → many consumers, every event delivered (pub/sub) | `tokio::sync::broadcast` | +| One producer → many consumers, only the latest value matters | `tokio::sync::watch` | For the **public** API, prefer returning `impl Stream` (wrap a `broadcast::Receiver` in `tokio_stream::wrappers::BroadcastStream`). `Stream` composes with `select!`, `take`, `map`, `filter`, `timeout`. See `EventSubscription` and `LifecycleSubscription`. @@ -115,7 +114,7 @@ JSON: `#[serde(rename_all = "camelCase")]` at the type level, per-field `#[serde Banned via `clippy.toml`. Use manual spans with `error_span!`: - **Almost always use `error_span!`**, not `info_span!`. Span level controls - the *minimum* filter at which the span appears. An `info_span` disappears when + the _minimum_ filter at which the span appears. An `info_span` disappears when the filter is `warn` or `error` — taking all child events with it, even errors. `error_span!` ensures the span is always present. - **Spawned tasks lose parent context.** Attach a span with `.instrument()` or @@ -239,9 +238,9 @@ Match those exact commands locally before pushing. JSON-RPC and session-event types are generated from the Copilot CLI schema: -| Source | Output | -|---|---| -| `nodejs/node_modules/@github/copilot/schemas/api.schema.json` | `rust/src/generated/api_types.rs` | +| Source | Output | +| ------------------------------------------------------------------------ | -------------------------------------- | +| `nodejs/node_modules/@github/copilot/schemas/api.schema.json` | `rust/src/generated/api_types.rs` | | `nodejs/node_modules/@github/copilot/schemas/session-events.schema.json` | `rust/src/generated/session_events.rs` | Regenerate with: diff --git a/.vscode/settings.json b/.vscode/settings.json index 105fec4d7..d0d8465c3 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -14,6 +14,11 @@ "python.testing.pytestEnabled": true, "python.testing.unittestEnabled": false, "python.testing.pytestArgs": ["python"], + "rust-analyzer.cargo.features": "all", + "rust-analyzer.check.command": "clippy", + "[rust]": { + "editor.defaultFormatter": "rust-lang.rust-analyzer" + }, "[python]": { "editor.defaultFormatter": "charliermarsh.ruff" }, diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 56e658ad1..1676f2f91 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -354,7 +354,6 @@ dependencies = [ "sha2", "tar", "tempfile", - "thiserror 2.0.18", "tokio", "tokio-stream", "tokio-util", diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 4f16c7bf5..44c4b369e 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -43,7 +43,6 @@ async-trait = "0.1" schemars = { version = "1", optional = true } serde = { version = "1", features = ["derive"] } serde_json = "1" -thiserror = "2" tokio = { version = "1", features = ["io-util", "sync", "rt", "process", "net", "time", "macros"] } tokio-stream = { version = "0.1", features = ["sync"] } tokio-util = { version = "0.7", default-features = false } @@ -68,6 +67,18 @@ serial_test = "3" tempfile = "3" tokio = { version = "1", features = ["rt-multi-thread"] } +# Integration tests that call test-support-only Client methods (e.g. +# `from_streams_with_connection_token`, `from_streams_with_trace_provider`) +# require the `test-support` feature because `cfg(test)` is not set on the +# library when Cargo compiles it for integration tests. +[[test]] +name = "session_test" +required-features = ["test-support"] + +[[test]] +name = "protocol_version_test" +required-features = ["test-support"] + [build-dependencies] dirs = "5" flate2 = "1" diff --git a/rust/examples/manual_tool_resume.rs b/rust/examples/manual_tool_resume.rs index dfb2b6232..9ce9f0964 100644 --- a/rust/examples/manual_tool_resume.rs +++ b/rust/examples/manual_tool_resume.rs @@ -9,8 +9,9 @@ use github_copilot_sdk::generated::api_types::{ use github_copilot_sdk::generated::session_events::{ AssistantMessageData, ExternalToolRequestedData, PermissionRequestedData, SessionEventType, }; +use github_copilot_sdk::subscription::RecvError; use github_copilot_sdk::{ - Client, ClientOptions, EventSubscription, RecvError, ResumeSessionConfig, SessionConfig, + Client, ClientOptions, EventSubscription, ResumeSessionConfig, SessionConfig, }; use serde_json::json; diff --git a/rust/examples/session_fs.rs b/rust/examples/session_fs.rs index 924e6947f..ad31f6849 100644 --- a/rust/examples/session_fs.rs +++ b/rust/examples/session_fs.rs @@ -15,7 +15,7 @@ use std::sync::Arc; use async_trait::async_trait; use github_copilot_sdk::handler::ApproveAllHandler; use github_copilot_sdk::session_fs::{ - DirEntry, DirEntryKind, FileInfo, FsError, SessionFsConfig, SessionFsConventions, + DirEntry, DirEntryKind, FileInfo, FsError, FsErrorKind, SessionFsConfig, SessionFsConventions, SessionFsProvider, }; use github_copilot_sdk::types::{MessageOptions, SessionConfig}; @@ -46,7 +46,7 @@ impl SessionFsProvider for InMemoryProvider { .lock() .get(path) .cloned() - .ok_or_else(|| FsError::NotFound(path.to_string())) + .ok_or_else(|| FsError::from(FsErrorKind::NotFound(path.to_string()))) } async fn write_file( @@ -69,7 +69,7 @@ impl SessionFsProvider for InMemoryProvider { let files = self.files.lock(); let content = files .get(path) - .ok_or_else(|| FsError::NotFound(path.to_string()))?; + .ok_or_else(|| FsError::from(FsErrorKind::NotFound(path.to_string())))?; Ok(FileInfo::new( true, false, @@ -101,7 +101,7 @@ impl SessionFsProvider for InMemoryProvider { async fn rm(&self, path: &str, _recursive: bool, force: bool) -> Result<(), FsError> { if self.files.lock().remove(path).is_none() && !force { - return Err(FsError::NotFound(path.to_string())); + return Err(FsError::from(FsErrorKind::NotFound(path.to_string()))); } Ok(()) } diff --git a/rust/src/canvas.rs b/rust/src/canvas.rs index 2d5b5b035..0f9c0ecf4 100644 --- a/rust/src/canvas.rs +++ b/rust/src/canvas.rs @@ -3,7 +3,6 @@ use async_trait::async_trait; use serde::{Deserialize, Serialize}; use serde_json::Value; -use thiserror::Error; use crate::generated::api_types::CanvasAction; @@ -54,9 +53,8 @@ impl CanvasDeclaration { } /// Structured error returned from canvas handlers. -#[derive(Debug, Clone, Error, Serialize, Deserialize, PartialEq, Eq)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "camelCase")] -#[error("{code}: {message}")] pub struct CanvasError { /// Machine-readable error code. pub code: String, @@ -64,6 +62,14 @@ pub struct CanvasError { pub message: String, } +impl std::fmt::Display for CanvasError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}: {}", self.code, self.message) + } +} + +impl std::error::Error for CanvasError {} + impl CanvasError { /// Construct a new error envelope with the given code and message. pub fn new(code: impl Into, message: impl Into) -> Self { diff --git a/rust/src/embeddedcli.rs b/rust/src/embeddedcli.rs index 504edbf67..a92f37d46 100644 --- a/rust/src/embeddedcli.rs +++ b/rust/src/embeddedcli.rs @@ -17,7 +17,7 @@ use std::fs; #[cfg(all(has_bundled_cli, not(windows)))] use std::io::Read; #[cfg(has_bundled_cli)] -use std::io::{self, Write}; +use std::io::Write; use std::path::{Path, PathBuf}; use std::sync::OnceLock; @@ -132,7 +132,8 @@ fn default_install_dir(version: &str) -> PathBuf { fn install(install_dir: &Path, archive: &[u8]) -> Result { let verbose = std::env::var("COPILOT_CLI_INSTALL_VERBOSE").ok().as_deref() == Some("1"); - fs::create_dir_all(install_dir).map_err(EmbeddedCliError::CreateDir)?; + fs::create_dir_all(install_dir) + .map_err(|e| EmbeddedCliError::new(EmbeddedCliErrorKind::CreateDir, e))?; let final_path = install_dir.join(CLI_BINARY_NAME); @@ -164,35 +165,45 @@ fn install(install_dir: &Path, archive: &[u8]) -> Result Result, EmbeddedCliError> { let gz = flate2::read::GzDecoder::new(archive); let mut tar = tar::Archive::new(gz); - for entry in tar.entries().map_err(EmbeddedCliError::Archive)? { - let mut entry = entry.map_err(EmbeddedCliError::Archive)?; - let path = entry.path().map_err(EmbeddedCliError::Archive)?; + for entry in tar + .entries() + .map_err(|e| EmbeddedCliError::new(EmbeddedCliErrorKind::Archive, e))? + { + let mut entry = + entry.map_err(|e| EmbeddedCliError::new(EmbeddedCliErrorKind::Archive, e))?; + let path = entry + .path() + .map_err(|e| EmbeddedCliError::new(EmbeddedCliErrorKind::Archive, e))?; let name = path.to_string_lossy(); if name == binary_name || name.ends_with(&format!("/{binary_name}")) { let mut bytes = Vec::with_capacity(entry.size() as usize); entry .read_to_end(&mut bytes) - .map_err(EmbeddedCliError::Archive)?; + .map_err(|e| EmbeddedCliError::new(EmbeddedCliErrorKind::Archive, e))?; return Ok(bytes); } } - Err(EmbeddedCliError::BinaryNotFoundInArchive) + Err(EmbeddedCliErrorKind::BinaryNotFoundInArchive.into()) } #[cfg(all(has_bundled_cli, windows))] fn extract_binary(archive: &[u8], binary_name: &str) -> Result, EmbeddedCliError> { let cursor = std::io::Cursor::new(archive); - let mut zip = zip::ZipArchive::new(cursor).map_err(EmbeddedCliError::Zip)?; + let mut zip = zip::ZipArchive::new(cursor) + .map_err(|e| EmbeddedCliError::new(EmbeddedCliErrorKind::Zip, e))?; for i in 0..zip.len() { - let mut entry = zip.by_index(i).map_err(EmbeddedCliError::Zip)?; + let mut entry = zip + .by_index(i) + .map_err(|e| EmbeddedCliError::new(EmbeddedCliErrorKind::Zip, e))?; let name = entry.name().to_string(); if name == binary_name || name.ends_with(&format!("/{binary_name}")) { let mut bytes = Vec::with_capacity(entry.size() as usize); - std::io::copy(&mut entry, &mut bytes).map_err(EmbeddedCliError::Io)?; + std::io::copy(&mut entry, &mut bytes) + .map_err(|e| EmbeddedCliError::new(EmbeddedCliErrorKind::Io, e))?; return Ok(bytes); } } - Err(EmbeddedCliError::BinaryNotFoundInArchive) + Err(EmbeddedCliErrorKind::BinaryNotFoundInArchive.into()) } #[cfg(has_bundled_cli)] @@ -213,38 +224,107 @@ fn write_binary(path: &Path, data: &[u8]) -> Result<(), EmbeddedCliError> { .create(true) .truncate(true) .open(path) - .map_err(EmbeddedCliError::Io)?; + .map_err(|e| EmbeddedCliError::new(EmbeddedCliErrorKind::Io, e))?; - file.write_all(data).map_err(EmbeddedCliError::Io)?; + file.write_all(data) + .map_err(|e| EmbeddedCliError::new(EmbeddedCliErrorKind::Io, e))?; #[cfg(unix)] { use std::os::unix::fs::PermissionsExt; fs::set_permissions(path, fs::Permissions::from_mode(0o755)) - .map_err(EmbeddedCliError::Io)?; + .map_err(|e| EmbeddedCliError::new(EmbeddedCliErrorKind::Io, e))?; } Ok(()) } #[cfg(has_bundled_cli)] -#[derive(Debug, thiserror::Error)] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] #[allow(dead_code)] -enum EmbeddedCliError { - #[error("failed to create install directory: {0}")] - CreateDir(io::Error), - +enum EmbeddedCliErrorKind { + CreateDir, #[cfg(not(windows))] - #[error("failed to read archive entry: {0}")] - Archive(io::Error), - + Archive, #[cfg(windows)] - #[error("failed to read zip archive: {0}")] - Zip(zip::result::ZipError), - - #[error("CLI binary not found in embedded archive")] + Zip, BinaryNotFoundInArchive, + Io, +} + +#[cfg(has_bundled_cli)] +impl std::fmt::Display for EmbeddedCliErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + EmbeddedCliErrorKind::CreateDir => f.write_str("failed to create install directory"), + #[cfg(not(windows))] + EmbeddedCliErrorKind::Archive => f.write_str("failed to read archive entry"), + #[cfg(windows)] + EmbeddedCliErrorKind::Zip => f.write_str("failed to read zip archive"), + EmbeddedCliErrorKind::BinaryNotFoundInArchive => { + f.write_str("CLI binary not found in embedded archive") + } + EmbeddedCliErrorKind::Io => f.write_str("I/O error"), + } + } +} + +#[cfg(has_bundled_cli)] +#[allow(dead_code)] +struct EmbeddedCliError { + repr: crate::errors::Repr, +} - #[error("I/O error: {0}")] - Io(io::Error), +#[cfg(has_bundled_cli)] +impl EmbeddedCliError { + fn new(kind: EmbeddedCliErrorKind, error: E) -> Self + where + E: Into>, + { + Self { + repr: crate::errors::Repr::Custom(crate::errors::Custom { + kind, + error: error.into(), + }), + } + } +} + +#[cfg(has_bundled_cli)] +impl From for EmbeddedCliError { + fn from(kind: EmbeddedCliErrorKind) -> Self { + Self { + repr: crate::errors::Repr::Simple(kind), + } + } +} + +#[cfg(has_bundled_cli)] +impl std::fmt::Display for EmbeddedCliError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match &self.repr { + crate::errors::Repr::Simple(kind) => write!(f, "{kind}"), + crate::errors::Repr::SimpleMessage(_, msg) => write!(f, "{msg}"), + crate::errors::Repr::Custom(crate::errors::Custom { kind, error }) => { + write!(f, "{kind}: {error}") + } + } + } +} + +#[cfg(has_bundled_cli)] +impl std::fmt::Debug for EmbeddedCliError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "EmbeddedCliError({self})") + } +} + +#[cfg(has_bundled_cli)] +impl std::error::Error for EmbeddedCliError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match &self.repr { + crate::errors::Repr::Custom(crate::errors::Custom { error, .. }) => Some(&**error), + _ => None, + } + } } diff --git a/rust/src/errors.rs b/rust/src/errors.rs new file mode 100644 index 000000000..5690f6412 --- /dev/null +++ b/rust/src/errors.rs @@ -0,0 +1,434 @@ +//! Crate errors. + +use std::backtrace::{Backtrace, BacktraceStatus}; +use std::borrow::{Borrow, Cow}; +use std::fmt; +use std::time::Duration; + +use crate::types::SessionId; + +/// Crate-specific [`Result`](std::result::Result). +pub type Result = std::result::Result; + +// ── Repr / Custom ───────────────────────────────────────────────────────────── + +/// Internal representation shared by all SDK error structs. +/// +/// `T` is the `*Kind` enum specific to each error struct. Shared across +/// [`Error`], [`ProtocolError`], [`SessionError`], [`FsError`], +/// [`RecvError`], and the crate-internal `EmbeddedCliError`. +#[derive(Debug)] +pub(crate) enum Repr { + Simple(T), + SimpleMessage(T, Cow<'static, str>), + Custom(Custom), + // CustomMessage(Custom, Cow<'static, str>), +} + +/// Custom error representation: a kind tag plus a boxed source error. +#[derive(Debug)] +pub(crate) struct Custom { + pub(crate) kind: T, + pub(crate) error: Box, +} + +// ── ProtocolErrorKind ───────────────────────────────────────── + +/// Specific protocol-level error kind in the JSON-RPC transport or CLI lifecycle. +#[derive(Clone, Debug, PartialEq, Eq)] +#[non_exhaustive] +pub enum ProtocolErrorKind { + /// Missing `Content-Length` header in a JSON-RPC message. + MissingContentLength, + + /// Invalid `Content-Length` header value. + InvalidContentLength(String), + + /// A pending JSON-RPC request was cancelled (e.g. the response channel was dropped). + RequestCancelled, + + /// The CLI process did not report a listening port within the timeout. + CliStartupTimeout, + + /// The CLI process exited before reporting a listening port. + CliStartupFailed, + + /// The CLI server's protocol version is outside the SDK's supported range. + VersionMismatch { + /// Version reported by the server. + server: u32, + /// Minimum version supported by this SDK. + min: u32, + /// Maximum version supported by this SDK. + max: u32, + }, + + /// The CLI server's protocol version changed between calls. + VersionChanged { + /// Previously negotiated version. + previous: u32, + /// Newly reported version. + current: u32, + }, +} + +impl fmt::Display for ProtocolErrorKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ProtocolErrorKind::MissingContentLength => { + write!(f, "missing Content-Length header") + } + ProtocolErrorKind::InvalidContentLength(v) => { + write!(f, "invalid Content-Length value: \"{v}\"") + } + ProtocolErrorKind::RequestCancelled => write!(f, "request cancelled"), + ProtocolErrorKind::CliStartupTimeout => { + write!(f, "timed out waiting for CLI to report listening port") + } + ProtocolErrorKind::CliStartupFailed => { + write!(f, "CLI exited before reporting listening port") + } + ProtocolErrorKind::VersionMismatch { server, min, max } => { + write!( + f, + "version mismatch: server={server}, supported={min}\u{2013}{max}" + ) + } + ProtocolErrorKind::VersionChanged { previous, current } => { + write!(f, "version changed: was {previous}, now {current}") + } + } + } +} + +// ── SessionErrorKind ─────────────────────────────────────────── + +/// Session-scoped error kind. +#[derive(Clone, Debug, PartialEq, Eq)] +#[non_exhaustive] +pub enum SessionErrorKind { + /// The CLI could not find the requested session. + NotFound(SessionId), + + /// The CLI reported an error during agent execution (via `session.error` event). + AgentError, + + /// A `send_and_wait` call exceeded its timeout. + Timeout(Duration), + + /// `send` was called while a `send_and_wait` is in flight. + SendWhileWaiting, + + /// The session event loop exited before a pending `send_and_wait` completed. + EventLoopClosed, + + /// Elicitation is not supported by the host. + /// Check `session.capabilities().ui.elicitation` before calling UI methods. + ElicitationNotSupported, + + /// The client was started with [`crate::ClientOptions::session_fs`] but this + /// session was created without a [`crate::session_fs::SessionFsProvider`]. Set one via + /// [`crate::SessionConfig::with_session_fs_provider`] (or + /// [`crate::ResumeSessionConfig::with_session_fs_provider`]). + SessionFsProviderRequired, + + /// [`crate::ClientOptions::session_fs`] was provided with empty or invalid + /// fields. All of `initial_cwd` and `session_state_path` must be non-empty. + InvalidSessionFsConfig, + + /// The CLI returned a different session ID than the one the SDK registered. + SessionIdMismatch { + /// Session ID registered by the SDK before the RPC was sent. + requested: SessionId, + /// Session ID returned by the CLI. + returned: SessionId, + }, +} + +impl fmt::Display for SessionErrorKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + SessionErrorKind::NotFound(id) => write!(f, "session not found: {id}"), + SessionErrorKind::AgentError => write!(f, "agent error"), + SessionErrorKind::Timeout(d) => write!(f, "timed out after {d:?}"), + SessionErrorKind::SendWhileWaiting => { + write!(f, "cannot send while send_and_wait is in flight") + } + SessionErrorKind::EventLoopClosed => { + write!(f, "event loop closed before session reached idle") + } + SessionErrorKind::ElicitationNotSupported => write!( + f, + "elicitation not supported by host \ + \u{2014} check session.capabilities().ui.elicitation first" + ), + SessionErrorKind::SessionFsProviderRequired => write!( + f, + "session was created on a client with session_fs configured \ + but no SessionFsProvider was supplied" + ), + SessionErrorKind::InvalidSessionFsConfig => { + write!(f, "invalid SessionFsConfig") + } + SessionErrorKind::SessionIdMismatch { + requested, + returned, + } => write!( + f, + "CLI returned session ID {returned} after SDK registered {requested}" + ), + } + } +} + +// ── ErrorKind ───────────────────────────────────────────────────────────────── + +/// The kind of [`Error`]. +#[derive(Clone, Debug, PartialEq, Eq)] +#[non_exhaustive] +pub enum ErrorKind { + /// JSON-RPC transport or protocol violation. + Protocol(ProtocolErrorKind), + /// The CLI returned a JSON-RPC error response. + Rpc { + /// JSON-RPC error code. + code: i32, + }, + /// Session-scoped error (not found, agent error, timeout, etc.). + Session(SessionErrorKind), + /// I/O error on the stdio transport or during process spawn. + Io, + /// Failed to serialize or deserialize a JSON-RPC message. + Json, + /// A required binary was not found on the system. + BinaryNotFound { + /// Name of the binary. + name: String, + /// Optional hint for how to resolve the issue. + hint: Option, + }, + /// Invalid combination of options or configuration. + InvalidConfig, +} + +impl fmt::Display for ErrorKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ErrorKind::Protocol(k) => write!(f, "{k}"), + ErrorKind::Rpc { code } => write!(f, "RPC error {code}"), + ErrorKind::Session(k) => write!(f, "{k}"), + ErrorKind::Io => write!(f, "I/O error"), + ErrorKind::Json => write!(f, "JSON error"), + ErrorKind::BinaryNotFound { + name, + hint: Some(h), + } => { + write!(f, "binary not found: {name} ({h})") + } + ErrorKind::BinaryNotFound { name, hint: None } => { + write!(f, "binary not found: {name}") + } + ErrorKind::InvalidConfig => write!(f, "invalid configuration"), + } + } +} + +/// Errors returned by the SDK. +pub struct Error { + repr: Repr, + // Only `Some` when `RUST_BACKTRACE` is set; boxed so the `Some` variant + // doesn't inflate `Error` beyond `clippy::result_large_err` limits. + backtrace: Option>, +} + +impl Error { + /// Constructs a new `Error` boxing another [`std::error::Error`]. + pub(crate) fn new(kind: ErrorKind, error: E) -> Self + where + E: Into>, + { + Self { + repr: Repr::Custom(Custom { + kind, + error: error.into(), + }), + backtrace: capture_backtrace(), + } + } + + /// The [`ErrorKind`] of this `Error`. + pub fn kind(&self) -> &ErrorKind { + match &self.repr { + Repr::Simple(kind) + | Repr::SimpleMessage(kind, ..) + | Repr::Custom(Custom { kind, .. }) => kind, + } + } + + /// The message provided when this `Error` was constructed, or `None`. + pub fn message(&self) -> Option<&str> { + match &self.repr { + Repr::SimpleMessage(_, message) => Some(message.borrow()), + _ => None, + } + } + + /// Create an `Error` with a message. + #[must_use] + pub fn with_message(kind: ErrorKind, message: C) -> Self + where + C: Into>, + { + Self { + repr: Repr::SimpleMessage(kind, message.into()), + backtrace: capture_backtrace(), + } + } + + /// Returns `true` if this error indicates the transport is broken — the CLI + /// process exited, the connection was lost, or an I/O failure occurred. + /// Callers should discard the client and create a fresh one. + pub fn is_transport_failure(&self) -> bool { + matches!(self.kind(), ErrorKind::Io) + || matches!( + self.kind(), + ErrorKind::Protocol(ProtocolErrorKind::RequestCancelled) + ) + } + + /// Returns the JSON-RPC error code if this is an [`ErrorKind::Rpc`] error. + pub fn rpc_code(&self) -> Option { + match self.kind() { + ErrorKind::Rpc { code } => Some(*code), + _ => None, + } + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.repr { + Repr::Simple(kind) => write!(f, "{kind}"), + Repr::SimpleMessage(kind, message) if matches!(kind, ErrorKind::Rpc { code: _ }) => { + write!(f, "{kind}: {message}") + } + Repr::SimpleMessage(_, message) => write!(f, "{message}"), + Repr::Custom(Custom { kind, error }) if matches!(kind, ErrorKind::Rpc { code: _ }) => { + write!(f, "{kind}: {error}") + } + Repr::Custom(Custom { error, .. }) => write!(f, "{error}"), + } + } +} + +impl fmt::Debug for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut dbg = f.debug_struct("Error"); + dbg.field("context", &self.repr); + if let Some(backtrace) = &self.backtrace { + return dbg.field("backtrace", backtrace).finish(); + } + dbg.finish_non_exhaustive() + } +} + +impl std::error::Error for Error { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match &self.repr { + Repr::Custom(Custom { error, .. }) => Some(&**error), + _ => None, + } + } +} + +impl From for Error { + fn from(kind: ErrorKind) -> Self { + Self { + repr: Repr::Simple(kind), + backtrace: capture_backtrace(), + } + } +} + +impl From for Error { + fn from(kind: ProtocolErrorKind) -> Self { + Self::from(ErrorKind::Protocol(kind)) + } +} + +impl From for Error { + fn from(kind: SessionErrorKind) -> Self { + Self::from(ErrorKind::Session(kind)) + } +} + +impl From for Error { + fn from(error: std::io::Error) -> Self { + Self::new(ErrorKind::Io, error) + } +} + +impl From for Error { + fn from(error: serde_json::Error) -> Self { + Self::new(ErrorKind::Json, error) + } +} + +#[inline(always)] +fn capture_backtrace() -> Option> { + let backtrace = Backtrace::capture(); + if backtrace.status() == BacktraceStatus::Captured { + Some(Box::new(backtrace)) + } else { + None + } +} + +/// Aggregate of errors collected during [`crate::Client::stop`]. +/// +/// `Client::stop` performs cooperative shutdown across every active +/// session before killing the CLI child process. Errors from any +/// per-session `session.destroy` RPC and from the terminal child-kill +/// step are collected here rather than short-circuiting on the first +/// failure, so callers see the full picture of what went wrong during +/// teardown. +/// +/// Implements [`std::error::Error`] and forwards to `Display` for the +/// first error, with a count suffix when there are more. +#[derive(Debug)] +pub struct StopErrors(pub(crate) Vec); + +impl StopErrors { + /// Borrow the collected errors as a slice, in the order they + /// occurred (per-session destroys first, then child-kill last). + pub fn errors(&self) -> &[Error] { + &self.0 + } + + /// Consume the aggregate and return the underlying error vector. + pub fn into_errors(self) -> Vec { + self.0 + } +} + +impl fmt::Display for StopErrors { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.0.as_slice() { + [] => write!(f, "stop completed with no errors"), + [only] => write!(f, "stop failed: {only}"), + [first, rest @ ..] => write!( + f, + "stop failed with {n} errors; first: {first}", + n = 1 + rest.len(), + ), + } + } +} + +impl std::error::Error for StopErrors { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + self.0 + .first() + .map(|e| e as &(dyn std::error::Error + 'static)) + } +} diff --git a/rust/src/jsonrpc.rs b/rust/src/jsonrpc.rs index 88a9670cd..55375b09b 100644 --- a/rust/src/jsonrpc.rs +++ b/rust/src/jsonrpc.rs @@ -11,7 +11,7 @@ use tokio::sync::{broadcast, mpsc, oneshot}; use tokio::task::JoinHandle; use tracing::{Instrument, debug, error, warn}; -use crate::{Error, ProtocolError}; +use crate::{Error, ErrorKind, ProtocolErrorKind}; /// A JSON-RPC 2.0 request message. #[derive(Debug, Clone, Serialize, Deserialize)] @@ -352,15 +352,15 @@ impl JsonRpcClient { if let Some(value) = trimmed.strip_prefix(CONTENT_LENGTH_HEADER) { content_length = Some(value.trim().parse::().map_err(|_| { - Error::Protocol(ProtocolError::InvalidContentLength( - value.trim().to_string(), + Error::from(ErrorKind::Protocol( + ProtocolErrorKind::InvalidContentLength(value.trim().to_string()), )) })?); } } let Some(length) = content_length else { - return Err(Error::Protocol(ProtocolError::MissingContentLength)); + return Err(ErrorKind::Protocol(ProtocolErrorKind::MissingContentLength).into()); }; let mut body = vec![0u8; length]; @@ -420,7 +420,7 @@ impl JsonRpcClient { let response = match rx.await { Ok(response) => response, Err(_) => { - let error = Error::Protocol(ProtocolError::RequestCancelled); + let error = ErrorKind::Protocol(ProtocolErrorKind::RequestCancelled).into(); warn!( elapsed_ms = request_start.elapsed().as_millis(), method = %method, @@ -475,7 +475,7 @@ impl JsonRpcClient { self.write_tx .send(WriteCommand { frame, ack: ack_tx }) .map_err(|_| { - Error::Io(std::io::Error::new( + Error::from(std::io::Error::new( std::io::ErrorKind::BrokenPipe, "writer actor has shut down", )) @@ -483,8 +483,8 @@ impl JsonRpcClient { match ack_rx.await { Ok(Ok(())) => Ok(()), - Ok(Err(e)) => Err(Error::Io(e)), - Err(_) => Err(Error::Io(std::io::Error::new( + Ok(Err(e)) => Err(Error::from(e)), + Err(_) => Err(Error::from(std::io::Error::new( std::io::ErrorKind::BrokenPipe, "writer actor dropped ack without responding", ))), diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 0852d98a8..a7238a248 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -9,6 +9,8 @@ mod canvas_dispatch; /// Bundled CLI binary extraction and caching. #[cfg(feature = "bundled-cli")] pub(crate) mod embeddedcli; +mod errors; +pub use errors::*; /// Event handler traits for session lifecycle. pub mod handler; /// Lifecycle hook callbacks (pre/post tool use, prompt submission, session start/end). @@ -75,219 +77,11 @@ pub use types::*; mod sdk_protocol_version; pub use sdk_protocol_version::{SDK_PROTOCOL_VERSION, get_sdk_protocol_version}; -pub use subscription::{EventSubscription, Lagged, LifecycleSubscription, RecvError}; +pub use subscription::{EventSubscription, LifecycleSubscription}; /// Minimum protocol version this SDK can communicate with. const MIN_PROTOCOL_VERSION: u32 = 3; -/// Errors returned by the SDK. -#[derive(Debug, thiserror::Error)] -#[non_exhaustive] -pub enum Error { - /// JSON-RPC transport or protocol violation. - #[error("protocol error: {0}")] - Protocol(ProtocolError), - - /// The CLI returned a JSON-RPC error response. - #[error("RPC error {code}: {message}")] - Rpc { - /// JSON-RPC error code. - code: i32, - /// Human-readable error message. - message: String, - }, - - /// Session-scoped error (not found, agent error, timeout, etc.). - #[error("session error: {0}")] - Session(SessionError), - - /// I/O error on the stdio transport or during process spawn. - #[error(transparent)] - Io(#[from] std::io::Error), - - /// Failed to serialize or deserialize a JSON-RPC message. - #[error(transparent)] - Json(#[from] serde_json::Error), - - /// A required binary was not found on the system. - #[error("binary not found: {name} ({hint})")] - BinaryNotFound { - /// Binary name that was searched for. - name: &'static str, - /// Guidance on how to install or configure the binary. - hint: &'static str, - }, - - /// Invalid combination of [`ClientOptions`] supplied to [`Client::start`]. - /// Surfaces consumer-side configuration errors that would otherwise - /// produce confusing runtime failures (e.g. a connection token paired - /// with stdio transport). - #[error("invalid client configuration: {0}")] - InvalidConfig(String), -} - -impl Error { - /// Returns true if this error indicates the transport is broken — the CLI - /// process exited, the connection was lost, or an I/O failure occurred. - /// Callers should discard the client and create a fresh one. - pub fn is_transport_failure(&self) -> bool { - matches!( - self, - Error::Protocol(ProtocolError::RequestCancelled) | Error::Io(_) - ) - } -} - -/// Aggregate of errors collected during [`Client::stop`]. -/// -/// `Client::stop` performs cooperative shutdown across every active -/// session before killing the CLI child process. Errors from any -/// per-session `session.destroy` RPC and from the terminal child-kill -/// step are collected here rather than short-circuiting on the first -/// failure, so callers see the full picture of what went wrong during -/// teardown. -/// -/// Implements [`std::error::Error`] and forwards to `Display` for the -/// first error, with a count suffix when there are more. -#[derive(Debug)] -pub struct StopErrors(Vec); - -impl StopErrors { - /// Borrow the collected errors as a slice, in the order they - /// occurred (per-session destroys first, then child-kill last). - pub fn errors(&self) -> &[Error] { - &self.0 - } - - /// Consume the aggregate and return the underlying error vector. - pub fn into_errors(self) -> Vec { - self.0 - } -} - -impl std::fmt::Display for StopErrors { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self.0.as_slice() { - [] => write!(f, "stop completed with no errors"), - [only] => write!(f, "stop failed: {only}"), - [first, rest @ ..] => write!( - f, - "stop failed with {n} errors; first: {first}", - n = 1 + rest.len(), - ), - } - } -} - -impl std::error::Error for StopErrors { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - self.0 - .first() - .map(|e| e as &(dyn std::error::Error + 'static)) - } -} - -/// Specific protocol-level errors in the JSON-RPC transport or CLI lifecycle. -#[derive(Debug, thiserror::Error)] -#[non_exhaustive] -pub enum ProtocolError { - /// Missing `Content-Length` header in a JSON-RPC message. - #[error("missing Content-Length header")] - MissingContentLength, - - /// Invalid `Content-Length` header value. - #[error("invalid Content-Length value: \"{0}\"")] - InvalidContentLength(String), - - /// A pending JSON-RPC request was cancelled (e.g. the response channel was dropped). - #[error("request cancelled")] - RequestCancelled, - - /// The CLI process did not report a listening port within the timeout. - #[error("timed out waiting for CLI to report listening port")] - CliStartupTimeout, - - /// The CLI process exited before reporting a listening port. - #[error("CLI exited before reporting listening port")] - CliStartupFailed, - - /// The CLI server's protocol version is outside the SDK's supported range. - #[error("version mismatch: server={server}, supported={min}–{max}")] - VersionMismatch { - /// Version reported by the server. - server: u32, - /// Minimum version supported by this SDK. - min: u32, - /// Maximum version supported by this SDK. - max: u32, - }, - - /// The CLI server's protocol version changed between calls. - #[error("version changed: was {previous}, now {current}")] - VersionChanged { - /// Previously negotiated version. - previous: u32, - /// Newly reported version. - current: u32, - }, -} - -/// Session-scoped errors. -#[derive(Debug, thiserror::Error)] -#[non_exhaustive] -pub enum SessionError { - /// The CLI could not find the requested session. - #[error("session not found: {0}")] - NotFound(SessionId), - - /// The CLI reported an error during agent execution (via `session.error` event). - #[error("{0}")] - AgentError(String), - - /// A `send_and_wait` call exceeded its timeout. - #[error("timed out after {0:?}")] - Timeout(std::time::Duration), - - /// `send` was called while a `send_and_wait` is in flight. - #[error("cannot send while send_and_wait is in flight")] - SendWhileWaiting, - - /// The session event loop exited before a pending `send_and_wait` completed. - #[error("event loop closed before session reached idle")] - EventLoopClosed, - - /// Elicitation is not supported by the host. - /// Check `session.capabilities().ui.elicitation` before calling UI methods. - #[error( - "elicitation not supported by host — check session.capabilities().ui.elicitation first" - )] - ElicitationNotSupported, - - /// The client was started with [`ClientOptions::session_fs`] but this - /// session was created without a [`SessionFsProvider`]. Set one via - /// [`SessionConfig::with_session_fs_provider`] (or - /// [`ResumeSessionConfig::with_session_fs_provider`]). - #[error( - "session was created on a client with session_fs configured but no SessionFsProvider was supplied" - )] - SessionFsProviderRequired, - - /// [`ClientOptions::session_fs`] was provided with empty or invalid - /// fields. All of `initial_cwd` and `session_state_path` must be - /// non-empty. - #[error("invalid SessionFsConfig: {0}")] - InvalidSessionFsConfig(String), - - /// The CLI returned a different session ID than the one the SDK registered. - #[error("CLI returned session ID {returned} after SDK registered {requested}")] - SessionIdMismatch { - /// Session ID registered by the SDK before the RPC was sent. - requested: SessionId, - /// Session ID returned by the CLI. - returned: SessionId, - }, -} - /// How the SDK communicates with the CLI server. #[derive(Debug, Default)] #[non_exhaustive] @@ -491,7 +285,7 @@ impl std::fmt::Debug for ClientOptions { #[async_trait] pub trait ListModelsHandler: Send + Sync + 'static { /// Return the list of available models. - async fn list_models(&self) -> Result, Error>; + async fn list_models(&self) -> Result>; } /// Log verbosity for the CLI server (passed via `--log-level`). @@ -866,16 +660,18 @@ impl ClientOptions { } /// Validate a [`SessionFsConfig`] before sending `sessionFs.setProvider`. -fn validate_session_fs_config(cfg: &SessionFsConfig) -> Result<(), Error> { +fn validate_session_fs_config(cfg: &SessionFsConfig) -> Result<()> { if cfg.initial_cwd.trim().is_empty() { - return Err(Error::Session(SessionError::InvalidSessionFsConfig( - "initial_cwd must not be empty".to_string(), - ))); + return Err(Error::with_message( + ErrorKind::Session(SessionErrorKind::InvalidSessionFsConfig), + "invalid SessionFsConfig: initial_cwd must not be empty", + )); } if cfg.session_state_path.trim().is_empty() { - return Err(Error::Session(SessionError::InvalidSessionFsConfig( - "session_state_path must not be empty".to_string(), - ))); + return Err(Error::with_message( + ErrorKind::Session(SessionErrorKind::InvalidSessionFsConfig), + "invalid SessionFsConfig: session_state_path must not be empty", + )); } Ok(()) } @@ -954,16 +750,16 @@ impl Client { /// When [`ClientOptions::session_fs`] is set, also calls /// `sessionFs.setProvider` to register the SDK as the filesystem /// backend. - pub async fn start(options: ClientOptions) -> Result { + pub async fn start(options: ClientOptions) -> Result { let start_time = Instant::now(); if options.mode == ClientMode::Empty && options.base_directory.is_none() && options.session_fs.is_none() { - return Err(Error::InvalidConfig( + return Err(Error::with_message( + ErrorKind::InvalidConfig, "ClientMode::Empty requires either `base_directory` or \ - `session_fs` to be set (no implicit ~/.copilot fallback)." - .to_string(), + `session_fs` to be set (no implicit ~/.copilot fallback).", )); } if let Some(cfg) = &options.session_fs { @@ -973,17 +769,17 @@ impl Client { // external server, the server manages its own auth. if matches!(options.transport, Transport::External { .. }) { if options.github_token.is_some() { - return Err(Error::InvalidConfig( - "github_token cannot be used with Transport::External \ - (external server manages its own auth)" - .to_string(), + return Err(Error::with_message( + ErrorKind::InvalidConfig, + "invalid client configuration: github_token cannot be used with \ + Transport::External (external server manages its own auth)", )); } if options.use_logged_in_user == Some(true) { - return Err(Error::InvalidConfig( - "use_logged_in_user cannot be used with Transport::External \ - (external server manages its own auth)" - .to_string(), + return Err(Error::with_message( + ErrorKind::InvalidConfig, + "invalid client configuration: use_logged_in_user cannot be used with \ + Transport::External (external server manages its own auth)", )); } } @@ -999,8 +795,9 @@ impl Client { connection_token: Some(t), .. } if t.is_empty() => { - return Err(Error::InvalidConfig( - "connection_token must be a non-empty string".to_string(), + return Err(Error::with_message( + ErrorKind::InvalidConfig, + "invalid client configuration: connection_token must be a non-empty string", )); } _ => {} @@ -1173,7 +970,7 @@ impl Client { reader: impl AsyncRead + Unpin + Send + 'static, writer: impl AsyncWrite + Unpin + Send + 'static, cwd: PathBuf, - ) -> Result { + ) -> Result { Self::from_transport( reader, writer, @@ -1201,7 +998,7 @@ impl Client { writer: impl AsyncWrite + Unpin + Send + 'static, cwd: PathBuf, provider: Arc, - ) -> Result { + ) -> Result { Self::from_transport( reader, writer, @@ -1225,7 +1022,7 @@ impl Client { writer: impl AsyncWrite + Unpin + Send + 'static, cwd: PathBuf, token: Option, - ) -> Result { + ) -> Result { Self::from_transport( reader, writer, @@ -1262,7 +1059,7 @@ impl Client { on_get_trace_context: Option>, effective_connection_token: Option, mode: ClientMode, - ) -> Result { + ) -> Result { let setup_start = Instant::now(); let (request_tx, request_rx) = mpsc::unbounded_channel::(); let (notification_broadcast_tx, _) = broadcast::channel::(1024); @@ -1463,7 +1260,7 @@ impl Client { } } - fn spawn_stdio(program: &Path, options: &ClientOptions) -> Result { + fn spawn_stdio(program: &Path, options: &ClientOptions) -> Result { info!(cwd = ?options.working_directory, program = %program.display(), "spawning copilot CLI (stdio)"); let mut command = Self::build_command(program, options); command @@ -1483,11 +1280,7 @@ impl Client { Ok(child) } - async fn spawn_tcp( - program: &Path, - options: &ClientOptions, - port: u16, - ) -> Result<(Child, u16), Error> { + async fn spawn_tcp(program: &Path, options: &ClientOptions, port: u16) -> Result<(Child, u16)> { info!(cwd = ?options.working_directory, program = %program.display(), port = %port, "spawning copilot CLI (tcp)"); let mut command = Self::build_command(program, options); command @@ -1535,8 +1328,8 @@ impl Client { let port_wait_start = Instant::now(); let actual_port = tokio::time::timeout(std::time::Duration::from_secs(10), port_rx) .await - .map_err(|_| Error::Protocol(ProtocolError::CliStartupTimeout))? - .map_err(|_| Error::Protocol(ProtocolError::CliStartupFailed))?; + .map_err(|_| Error::from(ErrorKind::Protocol(ProtocolErrorKind::CliStartupTimeout)))? + .map_err(|_| Error::from(ErrorKind::Protocol(ProtocolErrorKind::CliStartupFailed)))?; debug!( elapsed_ms = port_wait_start.elapsed().as_millis(), @@ -1591,7 +1384,7 @@ impl Client { &self, method: &str, params: Option, - ) -> Result { + ) -> Result { self.inner.rpc.send_request(method, params).await } @@ -1618,7 +1411,7 @@ impl Client { &self, method: &str, params: Option, - ) -> Result { + ) -> Result { let session_id: Option = params .as_ref() .and_then(|p| p.get("sessionId")) @@ -1627,20 +1420,21 @@ impl Client { let response = self.send_request(method, params).await?; if let Some(err) = response.error { if err.message.contains("Session not found") { - return Err(Error::Session(SessionError::NotFound( + return Err(ErrorKind::Session(SessionErrorKind::NotFound( session_id.unwrap_or_else(|| "unknown".into()), - ))); + )) + .into()); } - return Err(Error::Rpc { - code: err.code, - message: err.message, - }); + return Err(Error::with_message( + ErrorKind::Rpc { code: err.code }, + err.message, + )); } Ok(response.result.unwrap_or(serde_json::Value::Null)) } /// Send a JSON-RPC response back to the CLI (e.g. for permission or tool call requests). - pub(crate) async fn send_response(&self, response: &JsonRpcResponse) -> Result<(), Error> { + pub(crate) async fn send_response(&self, response: &JsonRpcResponse) -> Result<()> { self.inner.rpc.write(response).await } @@ -1707,7 +1501,7 @@ impl Client { /// Returns an error if the negotiated `protocolVersion` is outside /// `MIN_PROTOCOL_VERSION`..=[`SDK_PROTOCOL_VERSION`]. If the server /// doesn't report a version, logs a warning and succeeds. - pub async fn verify_protocol_version(&self) -> Result<(), Error> { + pub async fn verify_protocol_version(&self) -> Result<()> { let handshake_start = Instant::now(); let mut used_fallback_ping = false; // Try the new `connect` handshake first (sends the connection @@ -1715,7 +1509,7 @@ impl Client { // that don't expose `connect` (-32601 MethodNotFound). let server_version = match self.connect_handshake().await { Ok(v) => v, - Err(Error::Rpc { code, .. }) if code == error_codes::METHOD_NOT_FOUND => { + Err(ref e) if e.rpc_code() == Some(error_codes::METHOD_NOT_FOUND) => { used_fallback_ping = true; self.ping(None).await?.protocol_version } @@ -1727,19 +1521,21 @@ impl Client { warn!("CLI server did not report protocolVersion; skipping version check"); } Some(v) if !(MIN_PROTOCOL_VERSION..=SDK_PROTOCOL_VERSION).contains(&v) => { - return Err(Error::Protocol(ProtocolError::VersionMismatch { + return Err(ErrorKind::Protocol(ProtocolErrorKind::VersionMismatch { server: v, min: MIN_PROTOCOL_VERSION, max: SDK_PROTOCOL_VERSION, - })); + }) + .into()); } Some(v) => { if let Some(&existing) = self.inner.negotiated_protocol_version.get() { if existing != v { - return Err(Error::Protocol(ProtocolError::VersionChanged { + return Err(ErrorKind::Protocol(ProtocolErrorKind::VersionChanged { previous: existing, current: v, - })); + }) + .into()); } } else { let _ = self.inner.negotiated_protocol_version.set(v); @@ -1762,7 +1558,7 @@ impl Client { /// auto-generated token for SDK-spawned TCP servers) as the `token` /// param. Server-side, the token is required when the server was /// started with `COPILOT_CONNECTION_TOKEN`. - async fn connect_handshake(&self) -> Result, Error> { + async fn connect_handshake(&self) -> Result> { let result = self .rpc() .connect(crate::generated::api_types::ConnectRequest { @@ -1779,7 +1575,7 @@ impl Client { /// the CLI reports one. /// /// [`PingResponse`]: crate::types::PingResponse - pub async fn ping(&self, message: Option<&str>) -> Result { + pub async fn ping(&self, message: Option<&str>) -> Result { let params = match message { Some(m) => serde_json::json!({ "message": m }), None => serde_json::json!({}), @@ -1795,7 +1591,7 @@ impl Client { pub async fn list_sessions( &self, filter: Option, - ) -> Result, Error> { + ) -> Result> { let params = match filter { Some(f) => serde_json::json!({ "filter": f }), None => serde_json::json!({}), @@ -1825,7 +1621,7 @@ impl Client { pub async fn get_session_metadata( &self, session_id: &SessionId, - ) -> Result, Error> { + ) -> Result> { let result = self .call( "session.getMetadata", @@ -1837,7 +1633,7 @@ impl Client { } /// Delete a persisted session by ID. - pub async fn delete_session(&self, session_id: &SessionId) -> Result<(), Error> { + pub async fn delete_session(&self, session_id: &SessionId) -> Result<()> { self.call( "session.delete", Some(serde_json::json!({ "sessionId": session_id })), @@ -1861,7 +1657,7 @@ impl Client { /// # Ok(()) /// # } /// ``` - pub async fn get_last_session_id(&self) -> Result, Error> { + pub async fn get_last_session_id(&self) -> Result> { let result = self .call("session.getLastId", Some(serde_json::json!({}))) .await?; @@ -1873,7 +1669,7 @@ impl Client { /// /// Only meaningful when connected to a server running in TUI+server mode /// (`--ui-server`). Returns `Ok(None)` if no foreground session is set. - pub async fn get_foreground_session_id(&self) -> Result, Error> { + pub async fn get_foreground_session_id(&self) -> Result> { let result = self .call("session.getForeground", Some(serde_json::json!({}))) .await?; @@ -1885,7 +1681,7 @@ impl Client { /// /// Only meaningful when connected to a server running in TUI+server mode /// (`--ui-server`). - pub async fn set_foreground_session_id(&self, session_id: &SessionId) -> Result<(), Error> { + pub async fn set_foreground_session_id(&self, session_id: &SessionId) -> Result<()> { self.call( "session.setForeground", Some(serde_json::json!({ "sessionId": session_id })), @@ -1895,13 +1691,13 @@ impl Client { } /// Get the CLI server status. - pub async fn get_status(&self) -> Result { + pub async fn get_status(&self) -> Result { let result = self.call("status.get", Some(serde_json::json!({}))).await?; Ok(serde_json::from_value(result)?) } /// Get authentication status. - pub async fn get_auth_status(&self) -> Result { + pub async fn get_auth_status(&self) -> Result { let result = self .call("auth.getStatus", Some(serde_json::json!({}))) .await?; @@ -1912,7 +1708,7 @@ impl Client { /// /// When [`ClientOptions::on_list_models`] is set, returns the handler's /// result without making a `models.list` RPC. Otherwise queries the CLI. - pub async fn list_models(&self) -> Result, Error> { + pub async fn list_models(&self) -> Result> { let cache = self.inner.models_cache.lock().clone(); let models = cache .get_or_try_init(|| async { @@ -1966,7 +1762,7 @@ impl Client { /// or call `stop()` again with a fresh future. The documented /// `tokio::time::timeout(..., client.stop())` pattern in the example /// below uses `force_stop` as the fallback for exactly this case. - pub async fn stop(&self) -> Result<(), StopErrors> { + pub async fn stop(&self) -> std::result::Result<(), StopErrors> { let pid = self.pid(); info!(pid = ?pid, "stopping CLI process"); let mut errors: Vec = Vec::new(); @@ -2000,7 +1796,7 @@ impl Client { if let Some(mut child) = child && let Err(e) = child.kill().await { - errors.push(Error::Io(e)); + errors.push(e.into()); } info!(pid = ?pid, errors = errors.len(), "CLI process stopped"); @@ -2069,7 +1865,8 @@ impl Client { /// /// Each subscriber maintains its own queue. If a consumer cannot keep /// up, the oldest events are dropped and `recv` returns - /// [`RecvError::Lagged`] with the count of skipped events; consumers + /// [`RecvErrorKind::Lagged`](crate::subscription::RecvErrorKind::Lagged) + /// with the count of skipped events; consumers /// should match on it and continue. Slow consumers do not block the /// producer. /// @@ -2113,28 +1910,25 @@ mod tests { #[test] fn is_transport_failure_matches_request_cancelled() { - let err = Error::Protocol(ProtocolError::RequestCancelled); + let err = Error::from(ErrorKind::Protocol(ProtocolErrorKind::RequestCancelled)); assert!(err.is_transport_failure()); } #[test] fn is_transport_failure_matches_io_error() { - let err = Error::Io(std::io::Error::new(std::io::ErrorKind::BrokenPipe, "gone")); + let err = Error::from(std::io::Error::new(std::io::ErrorKind::BrokenPipe, "gone")); assert!(err.is_transport_failure()); } #[test] fn is_transport_failure_rejects_rpc_error() { - let err = Error::Rpc { - code: -1, - message: "bad".into(), - }; + let err = Error::with_message(ErrorKind::Rpc { code: -1 }, "bad"); assert!(!err.is_transport_failure()); } #[test] fn is_transport_failure_rejects_session_error() { - let err = Error::Session(SessionError::NotFound("s1".into())); + let err = Error::from(ErrorKind::Session(SessionErrorKind::NotFound("s1".into()))); assert!(!err.is_transport_failure()); } @@ -2173,7 +1967,7 @@ mod tests { #[test] fn is_transport_failure_rejects_other_protocol_errors() { - let err = Error::Protocol(ProtocolError::CliStartupTimeout); + let err = Error::from(ErrorKind::Protocol(ProtocolErrorKind::CliStartupTimeout)); assert!(!err.is_transport_failure()); } @@ -2409,7 +2203,10 @@ mod tests { }) .with_program(CliProgram::Path(PathBuf::from("/bin/echo"))); let err = Client::start(opts).await.unwrap_err(); - assert!(matches!(err, Error::InvalidConfig(_)), "got {err:?}"); + assert!( + matches!(err.kind(), ErrorKind::InvalidConfig), + "got {err:?}" + ); } #[tokio::test] @@ -2422,7 +2219,10 @@ mod tests { }) .with_program(CliProgram::Path(PathBuf::from("/bin/echo"))); let err = Client::start(opts).await.unwrap_err(); - assert!(matches!(err, Error::InvalidConfig(_)), "got {err:?}"); + assert!( + matches!(err.kind(), ErrorKind::InvalidConfig), + "got {err:?}" + ); } #[test] @@ -2540,7 +2340,7 @@ mod tests { struct StubHandler; #[async_trait] impl ListModelsHandler for StubHandler { - async fn list_models(&self) -> Result, Error> { + async fn list_models(&self) -> Result> { Ok(vec![]) } } @@ -2565,7 +2365,7 @@ mod tests { } #[async_trait] impl ListModelsHandler for CountingHandler { - async fn list_models(&self) -> Result, Error> { + async fn list_models(&self) -> Result> { self.calls.fetch_add(1, Ordering::SeqCst); Ok(self.models.clone()) } @@ -2600,7 +2400,7 @@ mod tests { } #[async_trait] impl ListModelsHandler for SlowCountingHandler { - async fn list_models(&self) -> Result, Error> { + async fn list_models(&self) -> Result> { self.calls.fetch_add(1, Ordering::SeqCst); tokio::time::sleep(std::time::Duration::from_millis(25)).await; Ok(self.models.clone()) diff --git a/rust/src/mode.rs b/rust/src/mode.rs index 01d1038b1..c86b03071 100644 --- a/rust/src/mode.rs +++ b/rust/src/mode.rs @@ -47,10 +47,13 @@ fn validate_name(kind: &str, name: &str) -> Result<(), crate::Error> { return Ok(()); } if !is_valid_tool_name(name) { - return Err(crate::Error::InvalidConfig(format!( - "Invalid {kind} tool name '{name}': tool names must match \ + return Err(crate::Error::with_message( + crate::ErrorKind::InvalidConfig, + format!( + "Invalid {kind} tool name '{name}': tool names must match \ /^[a-zA-Z0-9_-]+$/ or be the wildcard '*'." - ))); + ), + )); } Ok(()) } @@ -185,11 +188,14 @@ pub(crate) fn validate_tool_filter_list( let Some(list) = list else { return Ok(()) }; for item in list { if item == "*" { - return Err(crate::Error::InvalidConfig(format!( - "{field} contains a bare '*' which matches no tool. Use \ + return Err(crate::Error::with_message( + crate::ErrorKind::InvalidConfig, + format!( + "{field} contains a bare '*' which matches no tool. Use \ source-qualified wildcards instead: \ ToolSet::new().add_builtin(\"*\").add_mcp(\"*\").add_custom(\"*\")." - ))); + ), + )); } } Ok(()) diff --git a/rust/src/resolve.rs b/rust/src/resolve.rs index cd813407a..1c88283a2 100644 --- a/rust/src/resolve.rs +++ b/rust/src/resolve.rs @@ -13,14 +13,14 @@ //! There is no PATH scanning and no walking of standard install locations. //! If none of the above resolves to a real file, //! [`Client::start`](crate::Client::start) returns -//! [`Error::BinaryNotFound`](crate::Error::BinaryNotFound). +//! an [`ErrorKind::BinaryNotFound`](crate::ErrorKind::BinaryNotFound) error. use std::env; use std::path::{Path, PathBuf}; use tracing::warn; -use crate::Error; +use crate::{Error, ErrorKind}; /// Resolve the CLI binary, optionally overriding the directory the bundled /// CLI is extracted to. Called by `Client::start` to thread @@ -66,13 +66,17 @@ pub(crate) fn copilot_binary_with_extract_dir( } } - Err(Error::BinaryNotFound { - name: "copilot", - hint: "the Copilot CLI is not bundled in this build of github-copilot-sdk and \ - COPILOT_CLI_PATH is not set. Either keep the default `bundled-cli` cargo \ - feature enabled, set COPILOT_CLI_PATH, or supply an explicit path via \ - `CliProgram::Path(...)` on `ClientOptions::program`.", - }) + Err(ErrorKind::BinaryNotFound { + name: "copilot".into(), + hint: Some( + "the Copilot CLI is not bundled in this build of github-copilot-sdk and \ + COPILOT_CLI_PATH is not set. Either keep the default `bundled-cli` cargo \ + feature enabled, set COPILOT_CLI_PATH, or supply an explicit path via \ + `CliProgram::Path(...)` on `ClientOptions::program`." + .into(), + ), + } + .into()) } /// Path to the CLI extracted into the per-user cache by `build.rs` when diff --git a/rust/src/session.rs b/rust/src/session.rs index 7527b6c8a..e735e405e 100644 --- a/rust/src/session.rs +++ b/rust/src/session.rs @@ -32,7 +32,10 @@ use crate::types::{ SystemMessageConfig, ToolInvocation, ToolResult, ToolResultExpanded, TraceContext, UiInputOptions, ensure_attachment_display_names, }; -use crate::{Client, Error, JsonRpcResponse, SessionError, SessionEventNotification, error_codes}; +use crate::{ + Client, Error, ErrorKind, JsonRpcResponse, SessionErrorKind, SessionEventNotification, + error_codes, +}; /// Bundle of the per-session callbacks the SDK dispatches to. Built from a /// [`SessionConfig`] / [`ResumeSessionConfig`] at @@ -68,7 +71,7 @@ struct IdleWaiter { /// Without this, an outer cancellation between "install waiter" and /// "drain channel" would leave the slot occupied, causing all subsequent /// `send` and `send_and_wait` calls on the session to return -/// [`SendWhileWaiting`](SessionError::SendWhileWaiting). Closes RFD-400 +/// [`SendWhileWaiting`](SessionErrorKind::SendWhileWaiting). Closes RFD-400 /// review finding #2. struct WaiterGuard { slot: Arc>>, @@ -257,7 +260,7 @@ impl Session { /// /// Each subscriber maintains its own queue. If a consumer cannot keep /// up, the oldest events are dropped and `recv` returns - /// [`RecvError::Lagged`](crate::subscription::RecvError::Lagged) + /// [`RecvErrorKind::Lagged`](crate::subscription::RecvErrorKind::Lagged) /// reporting the count of skipped events. Slow consumers do not block /// the session's event loop. /// @@ -311,9 +314,9 @@ impl Session { } // Fail any pending send_and_wait so it returns immediately. if let Some(waiter) = self.idle_waiter.lock().take() { - let _ = waiter - .tx - .send(Err(Error::Session(SessionError::EventLoopClosed))); + let _ = waiter.tx.send(Err( + ErrorKind::Session(SessionErrorKind::EventLoopClosed).into() + )); } } @@ -343,7 +346,7 @@ impl Session { /// message ID. pub async fn send(&self, opts: impl Into) -> Result { if self.idle_waiter.lock().is_some() { - return Err(Error::Session(SessionError::SendWhileWaiting)); + return Err(ErrorKind::Session(SessionErrorKind::SendWhileWaiting).into()); } self.send_inner(opts.into()).await } @@ -424,7 +427,7 @@ impl Session { { let mut guard = self.idle_waiter.lock(); if guard.is_some() { - return Err(Error::Session(SessionError::SendWhileWaiting)); + return Err(ErrorKind::Session(SessionErrorKind::SendWhileWaiting).into()); } *guard = Some(IdleWaiter { tx, @@ -446,7 +449,7 @@ impl Session { self.send_inner(opts).await?; match rx.await { Ok(result) => result, - Err(_) => Err(Error::Session(SessionError::EventLoopClosed)), + Err(_) => Err(ErrorKind::Session(SessionErrorKind::EventLoopClosed).into()), } }) .await; @@ -468,7 +471,7 @@ impl Session { completed_by = "timeout", "Session::send_and_wait failed" ); - Err(Error::Session(SessionError::Timeout(timeout_duration))) + Err(ErrorKind::Session(SessionErrorKind::Timeout(timeout_duration)).into()) } } } @@ -605,7 +608,7 @@ impl Session { .and_then(|u| u.elicitation) != Some(true) { - return Err(Error::Session(SessionError::ElicitationNotSupported)); + return Err(ErrorKind::Session(SessionErrorKind::ElicitationNotSupported).into()); } Ok(()) } @@ -803,11 +806,11 @@ impl Client { } let mode = self.inner.mode; if mode == crate::ClientMode::Empty && config.available_tools.is_none() { - return Err(Error::InvalidConfig( + return Err(Error::with_message( + ErrorKind::InvalidConfig, "ClientMode::Empty requires available_tools to be set on the session config. \ Use ToolSet to specify which tools the session may use (e.g. \ - ToolSet::new().add_builtin_many(BUILTIN_TOOLS_ISOLATED))." - .to_string(), + ToolSet::new().add_builtin_many(BUILTIN_TOOLS_ISOLATED)).", )); } crate::mode::validate_tool_filter_list( @@ -847,16 +850,16 @@ impl Client { let canvas_handler = runtime.canvas_handler.take(); let session_fs_provider = runtime.session_fs_provider.take(); if self.inner.session_fs_configured && session_fs_provider.is_none() { - return Err(Error::Session(SessionError::SessionFsProviderRequired)); + return Err(ErrorKind::Session(SessionErrorKind::SessionFsProviderRequired).into()); } if self.inner.session_fs_sqlite_declared && let Some(ref provider) = session_fs_provider && provider.sqlite().is_none() { - return Err(Error::InvalidConfig( + return Err(Error::with_message( + ErrorKind::InvalidConfig, "SessionFs capabilities declare SQLite support but the provider \ - does not implement SessionFsSqliteProvider" - .to_string(), + does not implement SessionFsSqliteProvider", )); } @@ -917,10 +920,11 @@ impl Client { }; if create_result.session_id != session_id { registration.cleanup(event_loop).await; - return Err(Error::Session(SessionError::SessionIdMismatch { + return Err(ErrorKind::Session(SessionErrorKind::SessionIdMismatch { requested: session_id, returned: create_result.session_id, - })); + }) + .into()); } *capabilities.write() = create_result.capabilities.unwrap_or_default(); @@ -976,11 +980,11 @@ impl Client { } let mode = self.inner.mode; if mode == crate::ClientMode::Empty && config.available_tools.is_none() { - return Err(Error::InvalidConfig( + return Err(Error::with_message( + ErrorKind::InvalidConfig, "ClientMode::Empty requires available_tools to be set on the session config. \ Use ToolSet to specify which tools the session may use (e.g. \ - ToolSet::new().add_builtin_many(BUILTIN_TOOLS_ISOLATED))." - .to_string(), + ToolSet::new().add_builtin_many(BUILTIN_TOOLS_ISOLATED)).", )); } crate::mode::validate_tool_filter_list( @@ -1020,16 +1024,16 @@ impl Client { let canvas_handler = runtime.canvas_handler.take(); let session_fs_provider = runtime.session_fs_provider.take(); if self.inner.session_fs_configured && session_fs_provider.is_none() { - return Err(Error::Session(SessionError::SessionFsProviderRequired)); + return Err(ErrorKind::Session(SessionErrorKind::SessionFsProviderRequired).into()); } if self.inner.session_fs_sqlite_declared && let Some(ref provider) = session_fs_provider && provider.sqlite().is_none() { - return Err(Error::InvalidConfig( + return Err(Error::with_message( + ErrorKind::InvalidConfig, "SessionFs capabilities declare SQLite support but the provider \ - does not implement SessionFsSqliteProvider" - .to_string(), + does not implement SessionFsSqliteProvider", )); } @@ -1096,10 +1100,11 @@ impl Client { .unwrap_or_else(|| session_id.clone()); if cli_session_id != session_id { registration.cleanup(event_loop).await; - return Err(Error::Session(SessionError::SessionIdMismatch { + return Err(ErrorKind::Session(SessionErrorKind::SessionIdMismatch { requested: session_id, returned: cli_session_id, - })); + }) + .into()); } // Reload skills after resume (best-effort). @@ -1284,7 +1289,7 @@ fn spawn_event_loop( if let Some(waiter) = idle_waiter.lock().take() { let _ = waiter .tx - .send(Err(Error::Session(SessionError::EventLoopClosed))); + .send(Err(ErrorKind::Session(SessionErrorKind::EventLoopClosed).into())); } } .instrument(span), @@ -1387,9 +1392,10 @@ async fn handle_notification( .map(|s| s.to_string()) }) .unwrap_or_else(|| "session error".to_string()); - let _ = waiter - .tx - .send(Err(Error::Session(SessionError::AgentError(error_msg)))); + let _ = waiter.tx.send(Err(Error::with_message( + ErrorKind::Session(SessionErrorKind::AgentError), + error_msg, + ))); } } } diff --git a/rust/src/session_fs.rs b/rust/src/session_fs.rs index 0e13be7d7..da4d3e3c9 100644 --- a/rust/src/session_fs.rs +++ b/rust/src/session_fs.rs @@ -17,8 +17,9 @@ //! //! Provider methods return [`Result`]. The SDK adapts these into //! the schema's `{ ..., error: Option }` payload, mapping -//! [`FsError::NotFound`] to the wire's `ENOENT` and everything else to -//! `UNKNOWN`. A [`From`] conversion is provided so handlers +//! [`FsErrorKind::NotFound`](crate::session_fs::FsErrorKind::NotFound) to +//! the wire's `ENOENT` and everything else to `UNKNOWN`. +//! A [`From`] conversion is provided so handlers //! backed by [`tokio::fs`](https://docs.rs/tokio/latest/tokio/fs/index.html) //! can propagate `io::Error` with `?`. //! @@ -40,7 +41,9 @@ //! } //! ``` +use std::borrow::{Borrow, Cow}; use std::collections::HashMap; +use std::fmt; use async_trait::async_trait; @@ -49,6 +52,7 @@ use crate::generated::api_types::{ SessionFsError, SessionFsErrorCode, SessionFsReaddirWithTypesEntry, SessionFsReaddirWithTypesEntryType, SessionFsSetProviderConventions, SessionFsStatResult, }; +use crate::{Custom, Repr}; /// Optional capabilities declared by a session filesystem provider. #[non_exhaustive] @@ -135,45 +139,126 @@ impl SessionFsConventions { } } -/// Error returned by a [`SessionFsProvider`] method. +/// Error kind returned by a [`SessionFsProvider`] method. /// -/// The SDK maps this onto the wire schema's [`SessionFsError`]: -/// [`FsError::NotFound`] becomes `ENOENT`, everything else becomes `UNKNOWN`. +/// The SDK maps this onto the wire schema's `SessionFsError`: +/// [`FsErrorKind::NotFound`] becomes `ENOENT`, everything else becomes `UNKNOWN`. +#[derive(Clone, Debug, PartialEq, Eq)] #[non_exhaustive] -#[derive(Debug, Clone, thiserror::Error)] -pub enum FsError { +pub enum FsErrorKind { /// File or directory does not exist. - #[error("not found: {0}")] NotFound(String), /// Any other filesystem error (permission denied, I/O error, etc.). - /// - /// The wire mapping always uses `UNKNOWN` as the code; the message is - /// preserved for diagnostics. - #[error("{0}")] - Other(String), + Other, +} + +impl fmt::Display for FsErrorKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + FsErrorKind::NotFound(path) => write!(f, "not found: {path}"), + FsErrorKind::Other => write!(f, "filesystem error"), + } + } +} + +/// Error returned by a [`crate::session_fs::SessionFsProvider`] method. +/// +/// The SDK maps this onto the wire schema's `SessionFsError`: +/// [`FsErrorKind::NotFound`] becomes `ENOENT`, everything else becomes `UNKNOWN`. +#[derive(Debug)] +pub struct FsError { + repr: Repr, } impl FsError { + /// Construct a `FsError` wrapping a source error. + pub fn new(kind: FsErrorKind, error: E) -> Self + where + E: Into>, + { + Self { + repr: Repr::Custom(Custom { + kind, + error: error.into(), + }), + } + } + + /// The [`FsErrorKind`] of this error. + pub fn kind(&self) -> &FsErrorKind { + match &self.repr { + Repr::Simple(k) | Repr::SimpleMessage(k, ..) | Repr::Custom(Custom { kind: k, .. }) => { + k + } + } + } + + /// The message provided when this error was constructed, or `None`. + pub fn message(&self) -> Option<&str> { + match &self.repr { + Repr::SimpleMessage(_, m) => Some(m.borrow()), + _ => None, + } + } + + /// Create a `FsError` with a custom message. + #[must_use] + pub fn with_message(kind: FsErrorKind, message: C) -> Self + where + C: Into>, + { + Self { + repr: Repr::SimpleMessage(kind, message.into()), + } + } + pub(crate) fn into_wire(self) -> SessionFsError { - match self { - Self::NotFound(message) => SessionFsError { + match self.kind() { + FsErrorKind::NotFound(message) => SessionFsError { code: SessionFsErrorCode::ENOENT, - message: Some(message), + message: Some(message.clone()), }, - Self::Other(message) => SessionFsError { + FsErrorKind::Other => SessionFsError { code: SessionFsErrorCode::UNKNOWN, - message: Some(message), + message: Some(self.to_string()), }, } } } +impl fmt::Display for FsError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.repr { + Repr::Simple(k) => write!(f, "{k}"), + Repr::SimpleMessage(_, m) => write!(f, "{m}"), + Repr::Custom(Custom { error, .. }) => write!(f, "{error}"), + } + } +} + +impl std::error::Error for FsError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match &self.repr { + Repr::Custom(Custom { error, .. }) => Some(&**error), + _ => None, + } + } +} + +impl From for FsError { + fn from(kind: FsErrorKind) -> Self { + Self { + repr: Repr::Simple(kind), + } + } +} + impl From for FsError { fn from(err: std::io::Error) -> Self { match err.kind() { - std::io::ErrorKind::NotFound => Self::NotFound(err.to_string()), - _ => Self::Other(err.to_string()), + std::io::ErrorKind::NotFound => Self::new(FsErrorKind::NotFound(err.to_string()), err), + _ => Self::new(FsErrorKind::Other, err), } } } @@ -296,7 +381,7 @@ impl DirEntry { /// # Forward compatibility /// /// Methods on this trait have default implementations that return -/// `Err(FsError::Other("operation not supported".into()))`. When the CLI +/// `Err(FsError::with_message(FsErrorKind::Other, "operation not supported"))`. When the CLI /// schema grows new `sessionFs.*` methods, the SDK adds them to this trait /// with default impls so existing implementations continue to compile. /// Override only the methods relevant to your backing store. @@ -305,7 +390,10 @@ pub trait SessionFsProvider: Send + Sync + 'static { /// Read the full contents of a file as UTF-8. async fn read_file(&self, path: &str) -> Result { let _ = path; - Err(FsError::Other("read_file not supported".to_string())) + Err(FsError::with_message( + FsErrorKind::Other, + "read_file not supported", + )) } /// Write content to a file, creating parent directories if needed. @@ -316,7 +404,10 @@ pub trait SessionFsProvider: Send + Sync + 'static { mode: Option, ) -> Result<(), FsError> { let _ = (path, content, mode); - Err(FsError::Other("write_file not supported".to_string())) + Err(FsError::with_message( + FsErrorKind::Other, + "write_file not supported", + )) } /// Append content to a file, creating parent directories if needed. @@ -327,40 +418,56 @@ pub trait SessionFsProvider: Send + Sync + 'static { mode: Option, ) -> Result<(), FsError> { let _ = (path, content, mode); - Err(FsError::Other("append_file not supported".to_string())) + Err(FsError::with_message( + FsErrorKind::Other, + "append_file not supported", + )) } /// Check whether a path exists. /// - /// Returns `Ok(false)` for non-existent paths, not [`FsError::NotFound`]. + /// Returns `Ok(false)` for non-existent paths, not [`FsErrorKind::NotFound`]. async fn exists(&self, path: &str) -> Result { let _ = path; - Err(FsError::Other("exists not supported".to_string())) + Err(FsError::with_message( + FsErrorKind::Other, + "exists not supported", + )) } /// Get metadata about a file or directory. async fn stat(&self, path: &str) -> Result { let _ = path; - Err(FsError::Other("stat not supported".to_string())) + Err(FsError::with_message( + FsErrorKind::Other, + "stat not supported", + )) } /// Create a directory. When `recursive`, missing parents are also created. async fn mkdir(&self, path: &str, recursive: bool, mode: Option) -> Result<(), FsError> { let _ = (path, recursive, mode); - Err(FsError::Other("mkdir not supported".to_string())) + Err(FsError::with_message( + FsErrorKind::Other, + "mkdir not supported", + )) } /// List entry names in a directory. async fn readdir(&self, path: &str) -> Result, FsError> { let _ = path; - Err(FsError::Other("readdir not supported".to_string())) + Err(FsError::with_message( + FsErrorKind::Other, + "readdir not supported", + )) } /// List directory entries with type information. async fn readdir_with_types(&self, path: &str) -> Result, FsError> { let _ = path; - Err(FsError::Other( - "readdir_with_types not supported".to_string(), + Err(FsError::with_message( + FsErrorKind::Other, + "readdir_with_types not supported", )) } @@ -368,13 +475,19 @@ pub trait SessionFsProvider: Send + Sync + 'static { /// error. When `recursive`, directory contents are removed as well. async fn rm(&self, path: &str, recursive: bool, force: bool) -> Result<(), FsError> { let _ = (path, recursive, force); - Err(FsError::Other("rm not supported".to_string())) + Err(FsError::with_message( + FsErrorKind::Other, + "rm not supported", + )) } /// Rename or move a file or directory. async fn rename(&self, src: &str, dest: &str) -> Result<(), FsError> { let _ = (src, dest); - Err(FsError::Other("rename not supported".to_string())) + Err(FsError::with_message( + FsErrorKind::Other, + "rename not supported", + )) } /// Return a reference to the SQLite provider, if this provider supports @@ -443,7 +556,9 @@ mod tests { fn fs_error_maps_io_not_found_to_enoent() { let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "missing.txt"); let fs_err: FsError = io_err.into(); - assert!(matches!(fs_err, FsError::NotFound(_))); + assert!( + matches!(fs_err.kind(), FsErrorKind::NotFound(message) if message == "missing.txt") + ); let wire = fs_err.into_wire(); assert_eq!(wire.code, SessionFsErrorCode::ENOENT); } @@ -452,7 +567,7 @@ mod tests { fn fs_error_maps_other_io_to_unknown() { let io_err = std::io::Error::other("disk full"); let fs_err: FsError = io_err.into(); - assert!(matches!(fs_err, FsError::Other(_))); + assert!(matches!(fs_err.kind(), FsErrorKind::Other)); let wire = fs_err.into_wire(); assert_eq!(wire.code, SessionFsErrorCode::UNKNOWN); assert!(wire.message.unwrap().contains("disk full")); @@ -478,6 +593,8 @@ mod tests { async fn default_impls_return_unsupported() { let p = DefaultProvider; let err = p.read_file("/x").await.unwrap_err(); - assert!(matches!(err, FsError::Other(ref m) if m.contains("not supported"))); + assert!( + matches!(err.kind(), FsErrorKind::Other) && err.to_string().contains("not supported") + ); } } diff --git a/rust/src/subscription.rs b/rust/src/subscription.rs index 69886a195..c3fc83b8b 100644 --- a/rust/src/subscription.rs +++ b/rust/src/subscription.rs @@ -23,9 +23,10 @@ //! //! Each subscriber maintains its own internal queue. If a consumer cannot //! keep up, the oldest events are dropped and the next call yields -//! [`Lagged`] reporting how many events were skipped. Slow subscribers do -//! not block the producer. +//! [`Lagged`](crate::subscription::Lagged) reporting how many events were skipped. +//! Slow subscribers do not block the producer. +use std::fmt; use std::pin::Pin; use std::task::{Context, Poll}; @@ -35,6 +36,7 @@ use tokio_stream::wrappers::errors::BroadcastStreamRecvError; use tokio_stream::{Stream, StreamExt as _}; use crate::types::{SessionEvent, SessionLifecycleEvent}; +use crate::{Custom, Repr}; /// The subscription fell behind the producer. /// @@ -43,9 +45,8 @@ use crate::types::{SessionEvent, SessionLifecycleEvent}; /// after this error, starting from the next live event — callers who care /// about lag should match on it and decide whether to resync, re-fetch, or /// log and continue. -#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)] -#[error("subscription lagged behind by {0} events")] -pub struct Lagged(u64); +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct Lagged(pub(crate) u64); impl Lagged { /// Number of events skipped before this consumer could read them. @@ -54,19 +55,84 @@ impl Lagged { } } -/// Error returned by [`EventSubscription::recv`] and -/// [`LifecycleSubscription::recv`]. -#[derive(Debug, thiserror::Error)] +impl fmt::Display for Lagged { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "subscription lagged behind by {} events", self.0) + } +} + +impl std::error::Error for Lagged {} + +/// Error kind for subscription receive operations. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] #[non_exhaustive] -pub enum RecvError { +pub enum RecvErrorKind { /// The producer is gone — the session has shut down or the client has /// stopped. No further events will be delivered. - #[error("subscription closed")] Closed, /// The subscriber fell behind. See [`Lagged`]. - #[error(transparent)] - Lagged(#[from] Lagged), + Lagged(Lagged), +} + +impl fmt::Display for RecvErrorKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + RecvErrorKind::Closed => write!(f, "subscription closed"), + RecvErrorKind::Lagged(l) => write!(f, "{l}"), + } + } +} + +/// Error returned by [`crate::subscription::EventSubscription::recv`] and +/// [`crate::subscription::LifecycleSubscription::recv`]. +#[derive(Debug)] +pub struct RecvError { + repr: Repr, +} + +impl RecvError { + /// The [`RecvErrorKind`] of this error. + pub fn kind(&self) -> &RecvErrorKind { + match &self.repr { + Repr::Simple(k) | Repr::SimpleMessage(k, ..) | Repr::Custom(Custom { kind: k, .. }) => { + k + } + } + } +} + +impl fmt::Display for RecvError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.repr { + Repr::Simple(k) => write!(f, "{k}"), + Repr::SimpleMessage(_, m) => write!(f, "{m}"), + Repr::Custom(Custom { error, .. }) => write!(f, "{error}"), + } + } +} + +impl std::error::Error for RecvError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match &self.repr { + Repr::Custom(Custom { error, .. }) => Some(&**error), + _ => None, + } + } +} + +impl From for RecvError { + fn from(kind: RecvErrorKind) -> Self { + Self { + repr: Repr::Simple(kind), + } + } +} + +impl From for RecvError { + fn from(lagged: Lagged) -> Self { + Self::from(RecvErrorKind::Lagged(lagged)) + } } macro_rules! define_subscription { @@ -92,9 +158,9 @@ macro_rules! define_subscription { /// Returns: /// /// - `Ok(event)` for the next delivered event. - /// - `Err(`[`RecvError::Lagged`]`)` if the subscriber fell behind; + /// - `Err(`[`RecvError`]`)` with [`RecvError::kind()`] [`RecvErrorKind::Lagged`] if the subscriber fell behind; /// call `recv` again to continue from the next live event. - /// - `Err(`[`RecvError::Closed`]`)` once the producer is gone. + /// - `Err(`[`RecvError`]`)` with [`RecvError::kind()`] [`RecvErrorKind::Closed`] once the producer is gone. /// /// # Cancel safety /// @@ -107,9 +173,9 @@ macro_rules! define_subscription { match self.inner.next().await { Some(Ok(event)) => Ok(event), Some(Err(BroadcastStreamRecvError::Lagged(n))) => { - Err(RecvError::Lagged(Lagged(n))) + Err(Lagged(n).into()) } - None => Err(RecvError::Closed), + None => Err(RecvErrorKind::Closed.into()), } } } @@ -184,7 +250,10 @@ mod tests { assert_eq!(sub.recv().await.unwrap().id, "a"); assert_eq!(sub.recv().await.unwrap().id, "b"); - assert!(matches!(sub.recv().await, Err(RecvError::Closed))); + assert!(matches!( + sub.recv().await.unwrap_err().kind(), + RecvErrorKind::Closed + )); } #[tokio::test] @@ -194,10 +263,11 @@ mod tests { for id in ["a", "b", "c", "d"] { tx.send(make_event(id)).unwrap(); } - match sub.recv().await { - Err(RecvError::Lagged(l)) => assert_eq!(l.skipped(), 2), - other => panic!("expected Lagged, got {other:?}"), - } + let err = sub.recv().await.expect_err("expected a Lagged error"); + let RecvErrorKind::Lagged(l) = err.kind() else { + panic!("expected Lagged, got {:?}", err.kind()); + }; + assert_eq!(l.skipped(), 2); // Subscription continues with the live tail. assert_eq!(sub.recv().await.unwrap().id, "c"); assert_eq!(sub.recv().await.unwrap().id, "d"); diff --git a/rust/src/tool.rs b/rust/src/tool.rs index b9b44bc0a..189bc6f21 100644 --- a/rust/src/tool.rs +++ b/rust/src/tool.rs @@ -621,7 +621,7 @@ mod tests { use serde::Deserialize; use super::super::*; - use crate::SessionId; + use crate::{ErrorKind, SessionId}; #[derive(Deserialize, schemars::JsonSchema)] struct GetWeatherParams { @@ -712,7 +712,7 @@ mod tests { }; let err = tool.call(inv).await.unwrap_err(); - assert!(matches!(err, Error::Json(_))); + assert!(matches!(err.kind(), ErrorKind::Json)); } #[tokio::test] diff --git a/rust/src/types.rs b/rust/src/types.rs index f9b29600d..dfa5ff22d 100644 --- a/rust/src/types.rs +++ b/rust/src/types.rs @@ -1443,10 +1443,10 @@ impl SessionConfig { if let Some(handler) = tool.handler.take() && tool_handlers.insert(tool.name.clone(), handler).is_some() { - return Err(crate::Error::InvalidConfig(format!( - "duplicate tool handler registered for name {:?}", - tool.name - ))); + return Err(crate::Error::with_message( + crate::ErrorKind::InvalidConfig, + format!("duplicate tool handler registered for name {:?}", tool.name), + )); } } } @@ -2128,10 +2128,10 @@ impl ResumeSessionConfig { if let Some(handler) = tool.handler.take() && tool_handlers.insert(tool.name.clone(), handler).is_some() { - return Err(crate::Error::InvalidConfig(format!( - "duplicate tool handler registered for name {:?}", - tool.name - ))); + return Err(crate::Error::with_message( + crate::ErrorKind::InvalidConfig, + format!("duplicate tool handler registered for name {:?}", tool.name), + )); } } } diff --git a/rust/tests/cli_resolution_test.rs b/rust/tests/cli_resolution_test.rs index 72bebdcb2..0abd0f94e 100644 --- a/rust/tests/cli_resolution_test.rs +++ b/rust/tests/cli_resolution_test.rs @@ -8,7 +8,7 @@ use std::path::PathBuf; -use github_copilot_sdk::{CliProgram, Client, ClientOptions, Error}; +use github_copilot_sdk::{CliProgram, Client, ClientOptions, ErrorKind}; use serial_test::serial; fn unset_env(key: &str) { @@ -87,7 +87,7 @@ async fn stale_env_override_falls_through() { // here would mean fallthrough is broken. if let Err(e) = &result { assert!( - !matches!(e, Error::BinaryNotFound { .. }), + !matches!(e.kind(), ErrorKind::BinaryNotFound { .. }), "stale COPILOT_CLI_PATH should fall through; got BinaryNotFound: {e}" ); } @@ -147,7 +147,7 @@ async fn unbundled_resolver_finds_extracted_binary() { let result = Client::start(opts).await; if let Err(e) = result { assert!( - !matches!(e, Error::BinaryNotFound { .. }), + !matches!(e.kind(), ErrorKind::BinaryNotFound { .. }), "resolver returned BinaryNotFound with `bundled-cli` off: {e}" ); } @@ -185,7 +185,7 @@ async fn extract_dir_runtime_override_is_honored() { if let Err(e) = result { assert!( - !matches!(e, Error::BinaryNotFound { .. }), + !matches!(e.kind(), ErrorKind::BinaryNotFound { .. }), "EXTRACT_DIR-redirected resolver returned BinaryNotFound: {e}" ); } diff --git a/rust/tests/e2e/session_fs_sqlite.rs b/rust/tests/e2e/session_fs_sqlite.rs index cd8758c31..0b99d951b 100644 --- a/rust/tests/e2e/session_fs_sqlite.rs +++ b/rust/tests/e2e/session_fs_sqlite.rs @@ -2,8 +2,9 @@ use std::collections::HashMap; use std::sync::{Arc, Mutex}; use async_trait::async_trait; +use github_copilot_sdk::session_fs::{FsError, FsErrorKind}; use github_copilot_sdk::{ - Client, DirEntry, DirEntryKind, FileInfo, FsError, SessionConfig, SessionFsCapabilities, + Client, DirEntry, DirEntryKind, FileInfo, SessionConfig, SessionFsCapabilities, SessionFsConfig, SessionFsConventions, SessionFsProvider, SessionFsSqliteProvider, SessionFsSqliteQueryResult, SessionFsSqliteQueryType, }; @@ -53,9 +54,10 @@ impl InMemorySqliteProvider { fn get_or_create_db(db: &mut Option) -> Result<&mut Connection, FsError> { if db.is_none() { - let conn = Connection::open_in_memory().map_err(|e| FsError::Other(e.to_string()))?; + let conn = + Connection::open_in_memory().map_err(|e| FsError::new(FsErrorKind::Other, e))?; conn.execute_batch("PRAGMA busy_timeout = 5000;") - .map_err(|e| FsError::Other(e.to_string()))?; + .map_err(|e| FsError::new(FsErrorKind::Other, e))?; *db = Some(conn); } Ok(db.as_mut().unwrap()) @@ -69,7 +71,7 @@ impl SessionFsProvider for InMemorySqliteProvider { files .get(path) .cloned() - .ok_or_else(|| FsError::NotFound(path.to_string())) + .ok_or_else(|| FsError::from(FsErrorKind::NotFound(path.to_string()))) } async fn write_file( @@ -114,7 +116,7 @@ impl SessionFsProvider for InMemorySqliteProvider { } else if let Some(content) = files.get(path) { Ok(FileInfo::new(true, false, content.len() as i64, now, now)) } else { - Err(FsError::NotFound(path.to_string())) + Err(FsError::from(FsErrorKind::NotFound(path.to_string()))) } } @@ -244,7 +246,7 @@ impl SessionFsSqliteProvider for InMemorySqliteProvider { match query_type { SessionFsSqliteQueryType::Exec => { db.execute_batch(trimmed) - .map_err(|e| FsError::Other(e.to_string()))?; + .map_err(|e| FsError::new(FsErrorKind::Other, e))?; Ok(Some(SessionFsSqliteQueryResult { columns: vec![], rows: vec![], @@ -255,21 +257,24 @@ impl SessionFsSqliteProvider for InMemorySqliteProvider { SessionFsSqliteQueryType::Query => { let mut stmt = db .prepare(trimmed) - .map_err(|e| FsError::Other(e.to_string()))?; + .map_err(|e| FsError::new(FsErrorKind::Other, e))?; let col_count = stmt.column_count(); let columns: Vec = (0..col_count) .map(|i| stmt.column_name(i).unwrap().to_string()) .collect(); let mut rows = vec![]; - let mut query_rows = stmt.query([]).map_err(|e| FsError::Other(e.to_string()))?; + let mut query_rows = stmt + .query([]) + .map_err(|e| FsError::new(FsErrorKind::Other, e))?; while let Some(row) = query_rows .next() - .map_err(|e| FsError::Other(e.to_string()))? + .map_err(|e| FsError::new(FsErrorKind::Other, e))? { let mut map = HashMap::new(); for (i, col) in columns.iter().enumerate() { - let val: rusqlite::types::Value = - row.get(i).map_err(|e| FsError::Other(e.to_string()))?; + let val: rusqlite::types::Value = row + .get(i) + .map_err(|e| FsError::new(FsErrorKind::Other, e))?; let json_val = match val { rusqlite::types::Value::Null => serde_json::Value::Null, rusqlite::types::Value::Integer(n) => { @@ -297,7 +302,7 @@ impl SessionFsSqliteProvider for InMemorySqliteProvider { SessionFsSqliteQueryType::Run => { let affected = db .execute(trimmed, []) - .map_err(|e| FsError::Other(e.to_string()))?; + .map_err(|e| FsError::new(FsErrorKind::Other, e))?; let last_id = db.last_insert_rowid(); Ok(Some(SessionFsSqliteQueryResult { columns: vec![], diff --git a/rust/tests/protocol_version_test.rs b/rust/tests/protocol_version_test.rs index fd4eecada..9d613d8d7 100644 --- a/rust/tests/protocol_version_test.rs +++ b/rust/tests/protocol_version_test.rs @@ -91,11 +91,10 @@ async fn rejected_when_version_out_of_range() { let (res, version) = verify_with_result(serde_json::json!({ "protocolVersion": 1 })).await; let err = res.unwrap_err(); assert!(matches!( - err, - github_copilot_sdk::Error::Protocol(github_copilot_sdk::ProtocolError::VersionMismatch { - server: 1, - .. - }) + err.kind(), + github_copilot_sdk::ErrorKind::Protocol( + github_copilot_sdk::ProtocolErrorKind::VersionMismatch { server: 1, .. } + ) )); assert_eq!(version, None); } diff --git a/rust/tests/session_test.rs b/rust/tests/session_test.rs index bb4e602e0..487f00bf1 100644 --- a/rust/tests/session_test.rs +++ b/rust/tests/session_test.rs @@ -1720,7 +1720,12 @@ async fn send_and_wait_returns_error_on_session_error() { .unwrap() .unwrap_err(); assert!( - matches!(err, github_copilot_sdk::Error::Session(github_copilot_sdk::SessionError::AgentError(ref msg)) if msg.contains("something went wrong")) + matches!( + err.kind(), + github_copilot_sdk::ErrorKind::Session( + github_copilot_sdk::SessionErrorKind::AgentError + ) + ) && err.to_string().contains("something went wrong") ); } @@ -1749,8 +1754,8 @@ async fn send_and_wait_times_out() { .unwrap() .unwrap_err(); assert!(matches!( - err, - github_copilot_sdk::Error::Session(github_copilot_sdk::SessionError::Timeout(_)) + err.kind(), + github_copilot_sdk::ErrorKind::Session(github_copilot_sdk::SessionErrorKind::Timeout(_)) )); } @@ -2594,17 +2599,17 @@ async fn elicitation_methods_fail_without_capability() { .await .unwrap_err(); assert!(matches!( - err, - github_copilot_sdk::Error::Session( - github_copilot_sdk::SessionError::ElicitationNotSupported + err.kind(), + github_copilot_sdk::ErrorKind::Session( + github_copilot_sdk::SessionErrorKind::ElicitationNotSupported ) )); let err = session.ui().confirm("ok?").await.unwrap_err(); assert!(matches!( - err, - github_copilot_sdk::Error::Session( - github_copilot_sdk::SessionError::ElicitationNotSupported + err.kind(), + github_copilot_sdk::ErrorKind::Session( + github_copilot_sdk::SessionErrorKind::ElicitationNotSupported ) )); } @@ -3081,8 +3086,11 @@ impl CommandHandler for CountingCommandHandler { async fn on_command(&self, ctx: CommandContext) -> Result<(), github_copilot_sdk::Error> { *self.last_ctx.lock() = Some(ctx); if let Some(message) = &self.error_to_return { - Err(github_copilot_sdk::Error::Session( - github_copilot_sdk::SessionError::AgentError(message.clone()), + Err(github_copilot_sdk::Error::with_message( + github_copilot_sdk::ErrorKind::Session( + github_copilot_sdk::SessionErrorKind::AgentError, + ), + message.clone(), )) } else { Ok(()) @@ -3303,8 +3311,9 @@ async fn command_execute_handler_error_propagates_to_ack() { // SessionFsProvider tests -------------------------------------------------- use github_copilot_sdk::session_fs::{ - DirEntry, DirEntryKind, FileInfo, FsError, SessionFsConventions, SessionFsProvider, - SessionFsSqliteProvider, SessionFsSqliteQueryResult, SessionFsSqliteQueryType, + DirEntry, DirEntryKind, FileInfo, FsError, FsErrorKind, SessionFsConventions, + SessionFsProvider, SessionFsSqliteProvider, SessionFsSqliteQueryResult, + SessionFsSqliteQueryType, }; struct RecordingFsProvider { @@ -3333,7 +3342,7 @@ impl SessionFsProvider for RecordingFsProvider { .lock() .get(path) .cloned() - .ok_or_else(|| FsError::NotFound(path.to_string())) + .ok_or_else(|| FsError::from(FsErrorKind::NotFound(path.to_string()))) } async fn write_file( @@ -3352,7 +3361,7 @@ impl SessionFsProvider for RecordingFsProvider { let files = self.files.lock(); let content = files .get(path) - .ok_or_else(|| FsError::NotFound(path.to_string()))?; + .ok_or_else(|| FsError::from(FsErrorKind::NotFound(path.to_string())))?; Ok(FileInfo::new( true, false, @@ -3372,7 +3381,7 @@ impl SessionFsProvider for RecordingFsProvider { async fn rm(&self, path: &str, _recursive: bool, force: bool) -> Result<(), FsError> { let mut files = self.files.lock(); if files.remove(path).is_none() && !force { - return Err(FsError::NotFound(path.to_string())); + return Err(FsError::from(FsErrorKind::NotFound(path.to_string()))); } Ok(()) } @@ -3514,7 +3523,10 @@ async fn session_fs_maps_other_to_unknown() { #[async_trait] impl SessionFsProvider for AlwaysFails { async fn stat(&self, _path: &str) -> Result { - Err(FsError::Other("backing store unavailable".to_string())) + Err(FsError::with_message( + FsErrorKind::Other, + "backing store unavailable", + )) } } @@ -3605,11 +3617,17 @@ async fn session_fs_maps_sqlite_errors_to_results() { _query: &str, _params: Option<&std::collections::HashMap>, ) -> Result, FsError> { - Err(FsError::Other("sqlite unavailable".to_string())) + Err(FsError::with_message( + FsErrorKind::Other, + "sqlite unavailable", + )) } async fn sqlite_exists(&self) -> Result { - Err(FsError::Other("sqlite unavailable".to_string())) + Err(FsError::with_message( + FsErrorKind::Other, + "sqlite unavailable", + )) } } @@ -3737,7 +3755,7 @@ async fn create_session_errors_when_provider_required_but_missing() { // through Client::start; the unit-level behavior is covered by the // SessionError::SessionFsProviderRequired variant being constructible. // This test asserts the error type's display formatting is stable. - let err = github_copilot_sdk::SessionError::SessionFsProviderRequired; + let err = github_copilot_sdk::SessionErrorKind::SessionFsProviderRequired; assert!(format!("{err}").contains("session_fs")); }