diff --git a/crates/rmcp-macros/src/tool.rs b/crates/rmcp-macros/src/tool.rs index bec3ddd1..68ab27c2 100644 --- a/crates/rmcp-macros/src/tool.rs +++ b/crates/rmcp-macros/src/tool.rs @@ -238,10 +238,7 @@ pub fn tool(attr: TokenStream, input: TokenStream) -> syn::Result { // if not found, use a default empty JSON schema object // TODO: should be updated according to the new specifications syn::parse2::(quote! { - std::sync::Arc::new(serde_json::json!({ - "type": "object", - "properties": {} - }).as_object().unwrap().clone()) + rmcp::handler::server::common::schema_for_empty_input() })? } }; diff --git a/crates/rmcp/README.md b/crates/rmcp/README.md index 54b27b91..60f5e02d 100644 --- a/crates/rmcp/README.md +++ b/crates/rmcp/README.md @@ -130,7 +130,7 @@ async fn calculate(&self, params: Parameters) -> Result() -> Arc { }) } +// TODO: should be updated according to the new specifications +/// Schema used when input is empty. +pub fn schema_for_empty_input() -> Arc { + std::sync::Arc::new( + serde_json::json!({ + "type": "object", + "properties": {} + }) + .as_object() + .unwrap() + .clone(), + ) +} + /// Generate and validate a JSON schema for outputSchema (must have root type "object"). pub fn schema_for_output() -> Result, String> { thread_local! { diff --git a/crates/rmcp/src/handler/server/router/tool.rs b/crates/rmcp/src/handler/server/router/tool.rs index 51d49d97..5c1941bd 100644 --- a/crates/rmcp/src/handler/server/router/tool.rs +++ b/crates/rmcp/src/handler/server/router/tool.rs @@ -1,7 +1,132 @@ +//! Tools for MCP servers. +//! +//! It's straightforward to define tools using [`tool_router`][crate::tool_router] and +//! [`tool`][crate::tool] macro. +//! +//! ```rust +//! # use rmcp::{ +//! # tool_router, tool, +//! # handler::server::{wrapper::{Parameters, Json}, tool::ToolRouter}, +//! # schemars +//! # }; +//! # use serde::{Serialize, Deserialize}; +//! struct Server { +//! tool_router: ToolRouter, +//! } +//! #[derive(Deserialize, schemars::JsonSchema, Default)] +//! struct AddParameter { +//! left: usize, +//! right: usize +//! } +//! #[derive(Serialize, schemars::JsonSchema)] +//! struct AddOutput { +//! sum: usize +//! } +//! #[tool_router] +//! impl Server { +//! #[tool(name = "adder", description = "Modular add two integers")] +//! fn add( +//! &self, +//! Parameters(AddParameter { left, right }): Parameters +//! ) -> Json { +//! Json(AddOutput { sum: left.wrapping_add(right) }) +//! } +//! } +//! ``` +//! +//! Using the macro-based code pattern above is suitable for small MCP servers with simple interfaces. +//! When the business logic become larger, it is recommended that each tool should reside +//! in individual file, combined into MCP server using [`SyncTool`] and [`AsyncTool`] traits. +//! +//! ```rust +//! # use rmcp::{ +//! # handler::server::{ +//! # tool::ToolRouter, +//! # router::tool::{SyncTool, AsyncTool, ToolBase}, +//! # }, +//! # schemars, ErrorData +//! # }; +//! # pub struct MyCustomError; +//! # impl From for ErrorData { +//! # fn from(err: MyCustomError) -> ErrorData { unimplemented!() } +//! # } +//! # use serde::{Serialize, Deserialize}; +//! # use std::borrow::Cow; +//! // In tool1.rs +//! pub struct ComplexTool1; +//! #[derive(Deserialize, schemars::JsonSchema, Default)] +//! pub struct ComplexTool1Input { /* ... */ } +//! #[derive(Serialize, schemars::JsonSchema)] +//! pub struct ComplexTool1Output { /* ... */ } +//! +//! impl ToolBase for ComplexTool1 { +//! type Parameter = ComplexTool1Input; +//! type Output = ComplexTool1Output; +//! type Error = MyCustomError; +//! fn name() -> Cow<'static, str> { +//! "complex-tool1".into() +//! } +//! +//! fn description() -> Option> { +//! Some("...".into()) +//! } +//! } +//! impl SyncTool for ComplexTool1 { +//! fn invoke(service: &MyToolServer, param: Self::Parameter) -> Result { +//! // ... +//! # unimplemented!() +//! } +//! } +//! // In tool2.rs +//! pub struct ComplexTool2; +//! #[derive(Deserialize, schemars::JsonSchema, Default)] +//! pub struct ComplexTool2Input { /* ... */ } +//! #[derive(Serialize, schemars::JsonSchema)] +//! pub struct ComplexTool2Output { /* ... */ } +//! +//! impl ToolBase for ComplexTool2 { +//! type Parameter = ComplexTool2Input; +//! type Output = ComplexTool2Output; +//! type Error = MyCustomError; +//! fn name() -> Cow<'static, str> { +//! "complex-tool2".into() +//! } +//! +//! fn description() -> Option> { +//! Some("...".into()) +//! } +//! } +//! impl AsyncTool for ComplexTool2 { +//! async fn invoke(service: &MyToolServer, param: Self::Parameter) -> Result { +//! // ... +//! # unimplemented!() +//! } +//! } +//! +//! // In tool_router.rs +//! struct MyToolServer { +//! tool_router: ToolRouter, +//! } +//! impl MyToolServer { +//! pub fn tool_router() -> ToolRouter { +//! ToolRouter::new() +//! .with_sync_tool::() +//! .with_async_tool::() +//! } +//! } +//! ``` +//! +//! It's also possible to use macro-based and trait-based tool definition together: Since +//! [`ToolRouter`] implements [`Add`][std::ops::Add], you can add two tool routers into final +//! router as showed in [the documentation of `tool_router`][crate::tool_router]. + +mod tool_traits; + use std::{borrow::Cow, sync::Arc}; use futures::{FutureExt, future::BoxFuture}; use schemars::JsonSchema; +pub use tool_traits::{AsyncTool, SyncTool, ToolBase}; use crate::{ handler::server::{ @@ -219,6 +344,42 @@ where self } + /// Add a tool that implements [`SyncTool`] + pub fn with_sync_tool(self) -> Self + where + T: SyncTool + 'static, + { + if T::input_schema().is_some() { + self.with_route(( + tool_traits::tool_attribute::(), + tool_traits::sync_tool_wrapper::, + )) + } else { + self.with_route(( + tool_traits::tool_attribute::(), + tool_traits::sync_tool_wrapper_with_empty_params::, + )) + } + } + + /// Add a tool that implements [`AsyncTool`] + pub fn with_async_tool(self) -> Self + where + T: AsyncTool + 'static, + { + if T::input_schema().is_some() { + self.with_route(( + tool_traits::tool_attribute::(), + tool_traits::async_tool_wrapper::, + )) + } else { + self.with_route(( + tool_traits::tool_attribute::(), + tool_traits::async_tool_wrapper_with_empty_params::, + )) + } + } + pub fn add_route(&mut self, item: ToolRoute) { let new_name = &item.attr.name; validate_and_warn_tool_name(new_name); diff --git a/crates/rmcp/src/handler/server/router/tool/tool_traits.rs b/crates/rmcp/src/handler/server/router/tool/tool_traits.rs new file mode 100644 index 00000000..60ac9cff --- /dev/null +++ b/crates/rmcp/src/handler/server/router/tool/tool_traits.rs @@ -0,0 +1,341 @@ +use std::{borrow::Cow, pin::Pin, sync::Arc}; + +use serde::{Deserialize, Serialize}; + +use crate::{ + ErrorData, + handler::server::{ + common::schema_for_empty_input, + tool::{schema_for_output, schema_for_type}, + wrapper::{Json, Parameters}, + }, + model::{Icon, JsonObject, Meta, ToolAnnotations, ToolExecution}, + schemars::JsonSchema, +}; + +/// Base trait to define attributes of a tool. +/// +/// Tools implementing [`SyncTool`] or [`AsyncTool`] must implement this trait first. +/// +/// All methods are consistent with fields of [`Tool`][crate::model::Tool]. +pub trait ToolBase { + /// Parameter type, will used in the invoke parameter of [`SyncTool`] or [`AsyncTool`] trait + /// + /// If the tool does not have any parameters, you **MUST** override [`input_schema`][Self::input_schema] + /// method. See its documentation for more details. + type Parameter: for<'de> Deserialize<'de> + JsonSchema + Send + Default + 'static; + /// Output type, will used in the invoke output of [`SyncTool`] or [`AsyncTool`] trait + /// + /// If the tool does not have any output, you **MUST** override [`output_schema`][Self::output_schema] + /// method. See its documentation for more details. + type Output: Serialize + JsonSchema + Send + 'static; + /// Error type, will used in the invoke output of [`SyncTool`] or [`AsyncTool`] trait + type Error: Into + Send + 'static; + + fn name() -> Cow<'static, str>; + + fn title() -> Option { + None + } + fn description() -> Option> { + None + } + + /// Json schema for tool input. + /// + /// The default implementation generates schema based on [`Self::Parameter`] type. + /// + /// If the tool does not have any parameters, you should override this methods to return [`None`], + /// and when invoked, the parameter will get default values. + fn input_schema() -> Option> { + Some(schema_for_type::>()) + } + + /// Json schema for tool output. + /// + /// The default implementation generates schema based on [`Self::Output`] type. + /// + /// If the tool does not have any output, you should override this methods to return [`None`]. + fn output_schema() -> Option> { + Some(schema_for_output::().unwrap_or_else(|e| { + panic!( + "Invalid output schema for ToolBase::Output type `{0}`: {1}", + std::any::type_name::(), + e, + ); + })) + } + + fn annotations() -> Option { + None + } + fn execution() -> Option { + None + } + fn icons() -> Option> { + None + } + fn meta() -> Option { + None + } +} + +/// Synchronous version of a tool. +/// +/// Consider using [`AsyncTool`] if your workflow involves asynchronous operations. +/// Examples are shown in [the module-level documentation][crate::handler::server::router::tool]. +pub trait SyncTool: ToolBase { + fn invoke(service: &S, param: Self::Parameter) -> Result; +} + +/// Asynchronous version of a tool. +/// +/// Consider using [`SyncTool`] if your workflow does not involve asynchronous operations. +/// Examples are shown in [the module-level documentation][crate::handler::server::router::tool]. +pub trait AsyncTool: ToolBase { + fn invoke( + service: &S, + param: Self::Parameter, + ) -> impl Future> + Send; +} + +pub(crate) fn tool_attribute() -> crate::model::Tool { + crate::model::Tool { + name: T::name(), + title: T::title(), + description: T::description(), + input_schema: T::input_schema().unwrap_or_else(schema_for_empty_input), + output_schema: T::output_schema(), + annotations: T::annotations(), + execution: T::execution(), + icons: T::icons(), + meta: T::meta(), + } +} + +pub(crate) fn sync_tool_wrapper>( + service: &S, + Parameters(params): Parameters, +) -> Result, ErrorData> { + T::invoke(service, params).map(Json).map_err(Into::into) +} + +pub(crate) fn sync_tool_wrapper_with_empty_params>( + service: &S, +) -> Result, ErrorData> { + T::invoke(service, T::Parameter::default()) + .map(Json) + .map_err(Into::into) +} + +#[expect(clippy::type_complexity)] +pub(crate) fn async_tool_wrapper>( + service: &S, + Parameters(params): Parameters, +) -> Pin, ErrorData>> + Send + '_>> { + Box::pin(async move { + T::invoke(service, params) + .await + .map(Json) + .map_err(Into::into) + }) +} + +#[expect(clippy::type_complexity)] +pub(crate) fn async_tool_wrapper_with_empty_params>( + service: &S, +) -> Pin, ErrorData>> + Send + '_>> { + Box::pin(async move { + T::invoke(service, T::Parameter::default()) + .await + .map(Json) + .map_err(Into::into) + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate as rmcp; + use crate::tool; // workaround for macros + + #[derive(Deserialize, schemars::JsonSchema, Default)] + struct AddParameter { + left: usize, + right: usize, + } + #[derive(Serialize, schemars::JsonSchema, PartialEq, Debug)] + struct AddOutput { + sum: usize, + } + + struct MacroBasedToolServer; + + impl MacroBasedToolServer { + #[expect(unused)] + #[tool(name = "adder", description = "Modular add two integers")] + fn add( + &self, + Parameters(AddParameter { left, right }): Parameters, + ) -> Json { + Json(AddOutput { + sum: left.wrapping_add(right), + }) + } + + #[expect(unused)] + #[tool(name = "empty", description = "Empty tool")] + fn empty(&self) {} + } + + struct AddTool; + impl ToolBase for AddTool { + type Parameter = AddParameter; + type Output = AddOutput; + type Error = ErrorData; + + fn name() -> Cow<'static, str> { + "adder".into() + } + + fn description() -> Option> { + Some("Modular add two integers".into()) + } + } + impl SyncTool for AddTool { + fn invoke( + _service: &TraitBasedToolServer, + AddParameter { left, right }: Self::Parameter, + ) -> Result { + Ok(AddOutput { + sum: left.wrapping_add(right), + }) + } + } + impl AsyncTool for AddTool { + async fn invoke( + _service: &TraitBasedToolServer, + AddParameter { left, right }: Self::Parameter, + ) -> Result { + Ok(AddOutput { + sum: left.wrapping_add(right), + }) + } + } + + enum EmptyToolCustomError { + Internal, + InvalidParams, + } + impl From for ErrorData { + fn from(value: EmptyToolCustomError) -> Self { + match value { + EmptyToolCustomError::Internal => Self::internal_error("internal error", None), + EmptyToolCustomError::InvalidParams => Self::invalid_params("invalid params", None), + } + } + } + + struct EmptyTool; + impl ToolBase for EmptyTool { + type Parameter = (); + type Output = (); + type Error = EmptyToolCustomError; + + fn name() -> Cow<'static, str> { + "empty".into() + } + + fn description() -> Option> { + Some("Empty tool".into()) + } + + fn input_schema() -> Option> { + None + } + + fn output_schema() -> Option> { + None + } + } + impl SyncTool for EmptyTool { + fn invoke( + _service: &TraitBasedToolServer, + _param: Self::Parameter, + ) -> Result { + Err(EmptyToolCustomError::Internal) + } + } + impl AsyncTool for EmptyTool { + async fn invoke( + _service: &TraitBasedToolServer, + _param: Self::Parameter, + ) -> Result { + Err(EmptyToolCustomError::InvalidParams) + } + } + + struct TraitBasedToolServer; + + #[test] + fn test_macro_and_trait_have_same_attrs() { + let macro_attrs = MacroBasedToolServer::add_tool_attr(); + let trait_attrs = tool_attribute::(); + assert_eq!(macro_attrs, trait_attrs); + } + + #[test] + fn test_macro_and_trait_have_same_attrs_for_empty_tool() { + let macro_attrs = MacroBasedToolServer::empty_tool_attr(); + let trait_attrs = tool_attribute::(); + assert_eq!(macro_attrs, trait_attrs); + } + + #[test] + fn test_sync_tool_wrapper_happy_path() { + let left = 1; + let right = 2; + let result = sync_tool_wrapper::<_, AddTool>( + &TraitBasedToolServer, + Parameters(AddParameter { left, right }), + ); + assert!(result.is_ok()); + if let Ok(result) = result { + assert_eq!(result.0, AddOutput { sum: 3 }); + } + } + + #[tokio::test] + async fn test_async_tool_wrapper_happy_path() { + let left = 1; + let right = 2; + let result = async_tool_wrapper::<_, AddTool>( + &TraitBasedToolServer, + Parameters(AddParameter { left, right }), + ) + .await; + assert!(result.is_ok()); + if let Ok(result) = result { + assert_eq!(result.0, AddOutput { sum: 3 }); + } + } + + #[test] + fn test_sync_tool_wrapper_error_conversion() { + let result = sync_tool_wrapper::<_, EmptyTool>(&TraitBasedToolServer, Parameters(())); + assert!(result.is_err()); + if let Err(result) = result { + assert_eq!(result, ErrorData::internal_error("internal error", None)); + } + } + + #[tokio::test] + async fn test_async_tool_wrapper_error_conversion() { + let result = + async_tool_wrapper::<_, EmptyTool>(&TraitBasedToolServer, Parameters(())).await; + assert!(result.is_err()); + if let Err(result) = result { + assert_eq!(result, ErrorData::invalid_params("invalid params", None)); + } + } +}