From 78c06d2598788dcae0ab5c39040b551192f7905e Mon Sep 17 00:00:00 2001 From: Alex Fournier Date: Thu, 21 May 2026 14:05:11 -0400 Subject: [PATCH 1/6] feat: implement nemo_guardrails remote backend Signed-off-by: Alex Fournier --- crates/core/Cargo.toml | 8 +- crates/core/src/plugin.rs | 8 +- .../nemo_guardrails/plugin_component.rs | 1052 ++++++++- .../nemo_guardrails/plugin_component_tests.rs | 1872 ++++++++++++++++- 4 files changed, 2903 insertions(+), 37 deletions(-) diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index 898b1614..9d23c61c 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -14,8 +14,12 @@ readme = "README.md" workspace = true [features] -default = ["otel", "openinference"] +default = ["otel", "openinference", "guardrails-remote"] schema = ["dep:schemars"] +guardrails-remote = [ + "dep:reqwest", + "dep:rustls", +] otel = [ "dep:async-trait", "dep:getrandom", @@ -24,7 +28,6 @@ otel = [ "dep:opentelemetry-http", "dep:opentelemetry-otlp", "dep:opentelemetry_sdk", - "dep:reqwest", "dep:rustls", "dep:tonic", "dep:web-sys", @@ -40,7 +43,6 @@ openinference = [ "dep:opentelemetry-http", "dep:opentelemetry-otlp", "dep:opentelemetry_sdk", - "dep:reqwest", "dep:rustls", "dep:tonic", "dep:web-sys", diff --git a/crates/core/src/plugin.rs b/crates/core/src/plugin.rs index 1b48b267..d2ff0bdf 100644 --- a/crates/core/src/plugin.rs +++ b/crates/core/src/plugin.rs @@ -762,9 +762,11 @@ pub fn register_plugin(plugin: Arc) -> Result<()> { /// Built-in plugins are available to validation and initialization without a /// binding or application-specific registration call. pub fn ensure_builtin_plugins_registered() -> Result<()> { - match BUILTIN_PLUGIN_REGISTRATION - .get_or_init(crate::observability::plugin_component::register_observability_component) - { + let register_builtins = || { + crate::observability::plugin_component::register_observability_component()?; + crate::plugins::nemo_guardrails::plugin_component::register_nemo_guardrails_component() + }; + match BUILTIN_PLUGIN_REGISTRATION.get_or_init(register_builtins) { Ok(()) => Ok(()), Err(err) => Err(clone_cached_plugin_error(err)), } diff --git a/crates/core/src/plugins/nemo_guardrails/plugin_component.rs b/crates/core/src/plugins/nemo_guardrails/plugin_component.rs index 5617a743..b0de4ed8 100644 --- a/crates/core/src/plugins/nemo_guardrails/plugin_component.rs +++ b/crates/core/src/plugins/nemo_guardrails/plugin_component.rs @@ -7,15 +7,39 @@ use std::collections::HashMap; use std::future::Future; use std::pin::Pin; use std::sync::Arc; +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +use std::time::Duration; use serde::{Deserialize, Serialize}; -use serde_json::{Map, Value as Json}; - +use serde_json::{Map, Value as Json, json}; +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +use tokio::sync::mpsc; +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +use tokio_stream::wrappers::ReceiverStream; + +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +use crate::api::llm::LlmRequest; +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +use crate::api::runtime::{LlmExecutionFn, LlmJsonStream, LlmStreamExecutionFn, ToolExecutionFn}; +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +use crate::api::scope::{EmitMarkEventParams, ScopeHandle, event, get_handle}; +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +use crate::codec::openai_chat::OpenAIChatCodec; +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +use crate::codec::streaming::SseEventDecoder; +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +use crate::codec::traits::LlmCodec; +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +use crate::error::FlowError; use crate::plugin::{ ConfigDiagnostic, ConfigPolicy, DiagnosticLevel, Plugin, PluginComponentSpec, PluginError, PluginRegistrationContext, Result as PluginResult, UnsupportedBehavior, deregister_plugin, lookup_plugin, register_plugin, }; +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +use rustls::crypto::ring; /// The plugin kind reserved for the planned first-party component. pub const NEMO_GUARDRAILS_PLUGIN_KIND: &str = "nemo_guardrails"; @@ -182,6 +206,12 @@ pub struct RequestDefaultsConfig { /// Default context object passed into Guardrails requests. #[serde(default, skip_serializing_if = "Option::is_none")] pub context: Option, + /// Default remote thread identifier for continuation-aware requests. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub thread_id: Option, + /// Default remote Guardrails state payload for continuation-aware requests. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub state: Option, /// Default request-time rail selection. #[serde(default, skip_serializing_if = "Option::is_none")] pub rails: Option, @@ -307,6 +337,8 @@ crate::editor_config! { crate::editor_config! { impl RequestDefaultsConfig { context => { label: "context", kind: Json, optional: true }, + thread_id => { label: "thread_id", kind: String, optional: true }, + state => { label: "state", kind: Json, optional: true }, rails => { label: "rails", kind: Section, @@ -349,13 +381,13 @@ impl Plugin for NeMoGuardrailsPlugin { fn register<'a>( &'a self, - _plugin_config: &Map, - _ctx: &'a mut PluginRegistrationContext, + plugin_config: &Map, + ctx: &'a mut PluginRegistrationContext, ) -> Pin> + Send + 'a>> { - Box::pin(async { - Err(PluginError::RegistrationFailed( - "built-in NeMo Guardrails plugin backend is not implemented yet".to_string(), - )) + let parsed = parse_nemo_guardrails_config(plugin_config); + Box::pin(async move { + let config = parsed?; + register_nemo_guardrails_backend(config, ctx) }) } } @@ -419,6 +451,932 @@ fn string_enum_schema( schema.into() } +fn register_nemo_guardrails_backend( + config: NeMoGuardrailsConfig, + ctx: &mut PluginRegistrationContext, +) -> PluginResult<()> { + match config.mode.as_str() { + "remote" => register_remote_backend(config, ctx), + "local" => Err(PluginError::RegistrationFailed( + "built-in NeMo Guardrails local backend is not implemented yet".to_string(), + )), + other => Err(PluginError::InvalidConfig(format!( + "unsupported NeMo Guardrails mode '{other}'" + ))), + } +} + +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +#[derive(Clone)] +// PR 2 intentionally implements the first honest remote slice: +// OpenAI chat requests, non-streaming + streaming execution, managed tool +// input/output checks, broad request-defaults pass-through, and response +// passthrough from the Guardrails server. The local backend remains out of +// scope. +struct RemoteBackendRuntime { + endpoint: String, + client: reqwest::Client, + config_id: Option, + config_ids: Vec, + request_defaults: Option, +} + +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +#[derive(Clone, Copy)] +enum RemoteCheckKind { + Input, + Output, +} + +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +impl RemoteBackendRuntime { + fn new(config: &NeMoGuardrailsConfig, remote: &RemoteBackendConfig) -> PluginResult { + let endpoint = remote.endpoint.clone().ok_or_else(|| { + PluginError::InvalidConfig("remote.endpoint is required in remote mode".to_string()) + })?; + let mut default_headers = HeaderMap::new(); + for (name, value) in &remote.headers { + let header_name = HeaderName::from_bytes(name.as_bytes()).map_err(|err| { + PluginError::InvalidConfig(format!( + "remote.headers contains invalid header name '{name}': {err}" + )) + })?; + let header_value = HeaderValue::from_str(value).map_err(|err| { + PluginError::InvalidConfig(format!( + "remote.headers[{name}] has an invalid value: {err}" + )) + })?; + default_headers.insert(header_name, header_value); + } + + let _ = ring::default_provider().install_default(); + + let client = reqwest::Client::builder() + .default_headers(default_headers) + .timeout(Duration::from_millis(remote.timeout_millis)) + .build() + .map_err(|err| { + PluginError::RegistrationFailed(format!( + "failed to construct NeMo Guardrails remote client: {err}" + )) + })?; + + Ok(Self { + endpoint: endpoint.trim_end_matches('/').to_string(), + client, + config_id: remote.config_id.clone(), + config_ids: remote.config_ids.clone(), + request_defaults: config.request_defaults.clone(), + }) + } + + async fn execute(&self, request: LlmRequest, stream: bool) -> crate::error::Result { + let parent = get_handle().ok(); + self.emit_mark( + "nemo_guardrails.remote.start", + &parent, + remote_mark_data(stream, &self.config_id, &self.config_ids, None, None), + ); + let body = self + .build_request_body(&request, stream) + .inspect_err(|err| { + self.emit_mark( + "nemo_guardrails.remote.error", + &parent, + remote_mark_data( + stream, + &self.config_id, + &self.config_ids, + None, + Some(err.to_string()), + ), + ); + })?; + let serialized = serde_json::to_vec(&body).map_err(|err| { + self.emit_mark( + "nemo_guardrails.remote.error", + &parent, + remote_mark_data( + stream, + &self.config_id, + &self.config_ids, + None, + Some(format!("failed to serialize remote request body: {err}")), + ), + ); + FlowError::Internal(format!( + "nemo_guardrails failed to serialize remote request body: {err}" + )) + })?; + let response = self + .client + .post(self.chat_completions_url()) + .header(reqwest::header::CONTENT_TYPE, "application/json") + .body(serialized) + .send() + .await + .map_err(|err| { + self.emit_mark( + "nemo_guardrails.remote.error", + &parent, + remote_mark_data( + stream, + &self.config_id, + &self.config_ids, + None, + Some(format!("remote request failed: {err}")), + ), + ); + FlowError::Internal(format!("nemo_guardrails remote request failed: {err}")) + })?; + let status = response.status(); + let payload = response.text().await.map_err(|err| { + self.emit_mark( + "nemo_guardrails.remote.error", + &parent, + remote_mark_data( + stream, + &self.config_id, + &self.config_ids, + Some(status.as_u16()), + Some(format!("failed to read remote response body: {err}")), + ), + ); + FlowError::Internal(format!( + "nemo_guardrails failed to read remote response body: {err}" + )) + })?; + if !status.is_success() { + self.emit_mark( + "nemo_guardrails.remote.error", + &parent, + remote_mark_data( + stream, + &self.config_id, + &self.config_ids, + Some(status.as_u16()), + Some(payload.clone()), + ), + ); + return Err(FlowError::Internal(format!( + "nemo_guardrails remote request failed with status {status}: {payload}" + ))); + } + let response_json = serde_json::from_str(&payload).map_err(|err| { + self.emit_mark( + "nemo_guardrails.remote.error", + &parent, + remote_mark_data( + stream, + &self.config_id, + &self.config_ids, + Some(status.as_u16()), + Some(format!("failed to parse remote response JSON: {err}")), + ), + ); + FlowError::Internal(format!( + "nemo_guardrails failed to parse remote response JSON: {err}" + )) + })?; + self.emit_mark( + "nemo_guardrails.remote.end", + &parent, + remote_mark_data( + stream, + &self.config_id, + &self.config_ids, + Some(status.as_u16()), + None, + ), + ); + // Preserve the server's OpenAI-compatible chat response and nested + // guardrails payload verbatim in the first remote slice. + Ok(response_json) + } + + async fn execute_stream(&self, request: LlmRequest) -> crate::error::Result { + let parent = get_handle().ok(); + self.emit_mark( + "nemo_guardrails.remote.start", + &parent, + remote_mark_data(true, &self.config_id, &self.config_ids, None, None), + ); + let body = self.build_request_body(&request, true).inspect_err(|err| { + self.emit_mark( + "nemo_guardrails.remote.error", + &parent, + remote_mark_data( + true, + &self.config_id, + &self.config_ids, + None, + Some(err.to_string()), + ), + ); + })?; + let serialized = serde_json::to_vec(&body).map_err(|err| { + self.emit_mark( + "nemo_guardrails.remote.error", + &parent, + remote_mark_data( + true, + &self.config_id, + &self.config_ids, + None, + Some(format!( + "failed to serialize remote stream request body: {err}" + )), + ), + ); + FlowError::Internal(format!( + "nemo_guardrails failed to serialize remote stream request body: {err}" + )) + })?; + let mut response = self + .client + .post(self.chat_completions_url()) + .header(reqwest::header::CONTENT_TYPE, "application/json") + .body(serialized) + .send() + .await + .map_err(|err| { + self.emit_mark( + "nemo_guardrails.remote.error", + &parent, + remote_mark_data( + true, + &self.config_id, + &self.config_ids, + None, + Some(format!("remote stream request failed: {err}")), + ), + ); + FlowError::Internal(format!( + "nemo_guardrails remote stream request failed: {err}" + )) + })?; + let status = response.status(); + if !status.is_success() { + let payload = response.text().await.map_err(|err| { + self.emit_mark( + "nemo_guardrails.remote.error", + &parent, + remote_mark_data( + true, + &self.config_id, + &self.config_ids, + Some(status.as_u16()), + Some(format!("failed to read remote stream error body: {err}")), + ), + ); + FlowError::Internal(format!( + "nemo_guardrails failed to read remote stream error body: {err}" + )) + })?; + self.emit_mark( + "nemo_guardrails.remote.error", + &parent, + remote_mark_data( + true, + &self.config_id, + &self.config_ids, + Some(status.as_u16()), + Some(payload.clone()), + ), + ); + return Err(FlowError::Internal(format!( + "nemo_guardrails remote stream request failed with status {status}: {payload}" + ))); + } + + let (tx, rx) = mpsc::channel(16); + let parent_for_task = parent.clone(); + let config_id = self.config_id.clone(); + let config_ids = self.config_ids.clone(); + tokio::spawn(async move { + let mut decoder = SseEventDecoder::new(); + loop { + let bytes = match response.chunk().await { + Ok(Some(bytes)) => bytes, + Ok(None) => break, + Err(err) => { + emit_remote_mark( + "nemo_guardrails.remote.error", + &parent_for_task, + remote_mark_data( + true, + &config_id, + &config_ids, + Some(status.as_u16()), + Some(format!("failed to read remote stream chunk: {err}")), + ), + ); + let _ = tx + .send(Err(FlowError::Internal(format!( + "nemo_guardrails failed to read remote stream chunk: {err}" + )))) + .await; + return; + } + }; + let events = match decoder.push_bytes(&bytes) { + Ok(events) => events, + Err(err) => { + emit_remote_mark( + "nemo_guardrails.remote.error", + &parent_for_task, + remote_mark_data( + true, + &config_id, + &config_ids, + Some(status.as_u16()), + Some(err.to_string()), + ), + ); + let _ = tx.send(Err(err)).await; + return; + } + }; + for event in events { + if tx.send(Ok(event.data)).await.is_err() { + return; + } + } + } + + match decoder.finish() { + Ok(Some(event)) => { + let _ = tx.send(Ok(event.data)).await; + } + Ok(None) => {} + Err(err) => { + emit_remote_mark( + "nemo_guardrails.remote.error", + &parent_for_task, + remote_mark_data( + true, + &config_id, + &config_ids, + Some(status.as_u16()), + Some(err.to_string()), + ), + ); + let _ = tx.send(Err(err)).await; + return; + } + } + + emit_remote_mark( + "nemo_guardrails.remote.end", + &parent_for_task, + remote_mark_data(true, &config_id, &config_ids, Some(status.as_u16()), None), + ); + }); + + Ok(Box::pin(ReceiverStream::new(rx)) as LlmJsonStream) + } + + async fn check_tool_input(&self, tool_name: &str, args: &Json) -> crate::error::Result { + let original_content = tool_input_content(tool_name, args); + let messages = vec![json!({"role": "user", "content": original_content.clone()})]; + let response = self + .execute_remote_check(messages, RemoteCheckKind::Input, tool_name) + .await?; + if let Some(blocking_rail) = blocking_rail_name(&response) { + return Err(FlowError::GuardrailRejected(format!( + "nemo_guardrails tool_input rail blocked tool call by rail '{blocking_rail}'" + ))); + } + + let result_content = chat_completion_content(&response)?; + if result_content == original_content { + return Ok(args.clone()); + } + + modified_tool_payload(&result_content, tool_name, "arguments") + } + + async fn check_tool_output( + &self, + tool_name: &str, + args: &Json, + result: &Json, + ) -> crate::error::Result { + let input_content = tool_input_content(tool_name, args); + let original_content = tool_output_content(tool_name, args, result); + let messages = vec![ + json!({"role": "user", "content": input_content}), + json!({"role": "assistant", "content": original_content.clone()}), + ]; + let response = self + .execute_remote_check(messages, RemoteCheckKind::Output, tool_name) + .await?; + if let Some(blocking_rail) = blocking_rail_name(&response) { + return Err(FlowError::GuardrailRejected(format!( + "nemo_guardrails tool_output rail blocked tool call by rail '{blocking_rail}'" + ))); + } + + let result_content = chat_completion_content(&response)?; + if result_content == original_content { + return Ok(result.clone()); + } + + modified_tool_payload(&result_content, tool_name, "result") + } + + fn build_request_body(&self, request: &LlmRequest, stream: bool) -> crate::error::Result { + // Remote mode currently accepts only NeMo Flow's OpenAI chat request + // shape and forwards it to the Guardrails server with a nested + // `guardrails` envelope. + let annotated = OpenAIChatCodec.decode(request)?; + if annotated.tools.is_some() || annotated.tool_choice.is_some() { + return Err(FlowError::Internal( + "nemo_guardrails remote backend does not support OpenAI tool definitions or tool_choice yet" + .to_string(), + )); + } + + let mut body = request.content.as_object().cloned().ok_or_else(|| { + FlowError::Internal("LLM request content is not a JSON object".to_string()) + })?; + body.insert("stream".to_string(), Json::Bool(stream)); + if let Some(guardrails) = self.build_guardrails_config() { + body.insert("guardrails".to_string(), Json::Object(guardrails)); + } + Ok(Json::Object(body)) + } + + fn build_guardrails_config(&self) -> Option> { + let mut guardrails = Map::new(); + if let Some(config_id) = &self.config_id { + guardrails.insert("config_id".to_string(), Json::String(config_id.clone())); + } + if !self.config_ids.is_empty() { + guardrails.insert( + "config_ids".to_string(), + Json::Array(self.config_ids.iter().cloned().map(Json::String).collect()), + ); + } + if let Some(request_defaults) = &self.request_defaults { + if let Some(context) = &request_defaults.context { + guardrails.insert("context".to_string(), context.clone()); + } + if let Some(thread_id) = &request_defaults.thread_id { + guardrails.insert("thread_id".to_string(), Json::String(thread_id.clone())); + } + if let Some(state) = &request_defaults.state { + guardrails.insert("state".to_string(), state.clone()); + } + let mut options = Map::new(); + if let Some(rails) = &request_defaults.rails { + options.insert( + "rails".to_string(), + serde_json::to_value(rails) + .expect("request rails config should serialize to JSON"), + ); + } + if let Some(llm_params) = &request_defaults.llm_params { + options.insert("llm_params".to_string(), llm_params.clone()); + } + if let Some(llm_output) = request_defaults.llm_output { + options.insert("llm_output".to_string(), Json::Bool(llm_output)); + } + if let Some(output_vars) = &request_defaults.output_vars { + options.insert("output_vars".to_string(), output_vars.clone()); + } + if let Some(log) = &request_defaults.log { + options.insert("log".to_string(), log.clone()); + } + if !options.is_empty() { + guardrails.insert("options".to_string(), Json::Object(options)); + } + } + + (!guardrails.is_empty()).then_some(guardrails) + } + + fn chat_completions_url(&self) -> String { + format!("{}/v1/chat/completions", self.endpoint) + } + + fn emit_mark(&self, name: &str, parent: &Option, data: Json) { + emit_remote_mark(name, parent, data); + } + + async fn execute_remote_check( + &self, + messages: Vec, + kind: RemoteCheckKind, + tool_name: &str, + ) -> crate::error::Result { + let parent = get_handle().ok(); + self.emit_mark( + "nemo_guardrails.remote.start", + &parent, + tool_remote_mark_data( + kind, + tool_name, + &self.config_id, + &self.config_ids, + None, + None, + ), + ); + let mut body = Map::new(); + body.insert("model".to_string(), Json::String(String::new())); + body.insert("messages".to_string(), Json::Array(messages)); + body.insert("stream".to_string(), Json::Bool(false)); + body.insert( + "guardrails".to_string(), + Json::Object(self.build_tool_check_guardrails(kind)), + ); + let serialized = serde_json::to_vec(&Json::Object(body)).map_err(|err| { + let message = format!("nemo_guardrails failed to serialize remote request body: {err}"); + self.emit_mark( + "nemo_guardrails.remote.error", + &parent, + tool_remote_mark_data( + kind, + tool_name, + &self.config_id, + &self.config_ids, + None, + Some(message.clone()), + ), + ); + FlowError::Internal(message) + })?; + let response = self + .client + .post(self.chat_completions_url()) + .header(reqwest::header::CONTENT_TYPE, "application/json") + .body(serialized) + .send() + .await + .map_err(|err| { + let message = format!("nemo_guardrails remote request failed: {err}"); + self.emit_mark( + "nemo_guardrails.remote.error", + &parent, + tool_remote_mark_data( + kind, + tool_name, + &self.config_id, + &self.config_ids, + None, + Some(message.clone()), + ), + ); + FlowError::Internal(message) + })?; + let status = response.status(); + let payload = response.text().await.map_err(|err| { + let message = format!("nemo_guardrails failed to read remote response body: {err}"); + self.emit_mark( + "nemo_guardrails.remote.error", + &parent, + tool_remote_mark_data( + kind, + tool_name, + &self.config_id, + &self.config_ids, + Some(status.as_u16()), + Some(message.clone()), + ), + ); + FlowError::Internal(message) + })?; + if !status.is_success() { + self.emit_mark( + "nemo_guardrails.remote.error", + &parent, + tool_remote_mark_data( + kind, + tool_name, + &self.config_id, + &self.config_ids, + Some(status.as_u16()), + Some(payload.clone()), + ), + ); + return Err(FlowError::Internal(format!( + "nemo_guardrails remote request failed with status {status}: {payload}" + ))); + } + let response_json = serde_json::from_str(&payload).map_err(|err| { + let message = format!("nemo_guardrails failed to parse remote response JSON: {err}"); + self.emit_mark( + "nemo_guardrails.remote.error", + &parent, + tool_remote_mark_data( + kind, + tool_name, + &self.config_id, + &self.config_ids, + Some(status.as_u16()), + Some(message.clone()), + ), + ); + FlowError::Internal(message) + })?; + self.emit_mark( + "nemo_guardrails.remote.end", + &parent, + tool_remote_mark_data( + kind, + tool_name, + &self.config_id, + &self.config_ids, + Some(status.as_u16()), + None, + ), + ); + Ok(response_json) + } + + fn build_tool_check_guardrails(&self, kind: RemoteCheckKind) -> Map { + let mut guardrails = Map::new(); + if let Some(config_id) = &self.config_id { + guardrails.insert("config_id".to_string(), Json::String(config_id.clone())); + } + if !self.config_ids.is_empty() { + guardrails.insert( + "config_ids".to_string(), + Json::Array(self.config_ids.iter().cloned().map(Json::String).collect()), + ); + } + if let Some(request_defaults) = &self.request_defaults { + if let Some(context) = &request_defaults.context { + guardrails.insert("context".to_string(), context.clone()); + } + if let Some(thread_id) = &request_defaults.thread_id { + guardrails.insert("thread_id".to_string(), Json::String(thread_id.clone())); + } + if let Some(state) = &request_defaults.state { + guardrails.insert("state".to_string(), state.clone()); + } + } + + let mut options = Map::new(); + let rails = match kind { + RemoteCheckKind::Input => json!({ + "input": true, + "output": false, + "dialog": false, + "retrieval": false, + "tool_input": false, + "tool_output": false, + }), + RemoteCheckKind::Output => json!({ + "input": false, + "output": true, + "dialog": false, + "retrieval": false, + "tool_input": false, + "tool_output": false, + }), + }; + options.insert("rails".to_string(), rails); + let mut log = self + .request_defaults + .as_ref() + .and_then(|defaults| defaults.log.as_ref()) + .and_then(Json::as_object) + .cloned() + .unwrap_or_default(); + log.insert("activated_rails".to_string(), Json::Bool(true)); + options.insert("log".to_string(), Json::Object(log)); + guardrails.insert("options".to_string(), Json::Object(options)); + guardrails + } +} + +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +fn emit_remote_mark(name: &str, parent: &Option, data: Json) { + let _ = event( + EmitMarkEventParams::builder() + .name(name) + .parent_opt(parent.as_ref()) + .data(data) + .build(), + ); +} + +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +fn tool_input_content(tool_name: &str, args: &Json) -> String { + serde_json::to_string(&json!({ + "tool_name": tool_name, + "arguments": args, + })) + .expect("tool input payload should serialize to JSON") +} + +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +fn tool_output_content(tool_name: &str, args: &Json, result: &Json) -> String { + serde_json::to_string(&json!({ + "tool_name": tool_name, + "arguments": args, + "result": result, + })) + .expect("tool output payload should serialize to JSON") +} + +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +fn modified_tool_payload( + content: &str, + expected_tool_name: &str, + field: &str, +) -> crate::error::Result { + let value: Json = serde_json::from_str(content).map_err(|err| { + FlowError::Internal(format!( + "nemo_guardrails returned modified tool {field} content that is not valid JSON: {err}" + )) + })?; + let Json::Object(object) = value else { + return Err(FlowError::Internal(format!( + "nemo_guardrails returned modified tool {field} content without a '{field}' field" + ))); + }; + if let Some(tool_name) = object.get("tool_name").and_then(Json::as_str) + && tool_name != expected_tool_name + { + return Err(FlowError::Internal(format!( + "nemo_guardrails returned modified tool {field} content for unexpected tool '{tool_name}'" + ))); + } + object.get(field).cloned().ok_or_else(|| { + FlowError::Internal(format!( + "nemo_guardrails returned modified tool {field} content without a '{field}' field" + )) + }) +} + +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +fn chat_completion_content(response: &Json) -> crate::error::Result { + response + .get("choices") + .and_then(Json::as_array) + .and_then(|choices| choices.first()) + .and_then(|choice| choice.get("message")) + .and_then(|message| message.get("content")) + .and_then(Json::as_str) + .map(str::to_string) + .ok_or_else(|| { + FlowError::Internal( + "nemo_guardrails remote response did not contain choices[0].message.content" + .to_string(), + ) + }) +} + +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +fn blocking_rail_name(response: &Json) -> Option { + response + .get("guardrails") + .and_then(|guardrails| guardrails.get("log")) + .and_then(|log| log.get("activated_rails")) + .and_then(Json::as_array) + .and_then(|activated| { + activated.iter().find_map(|rail| { + if rail.get("stop").and_then(Json::as_bool) == Some(true) { + rail.get("name").and_then(Json::as_str).map(str::to_string) + } else { + None + } + }) + }) +} + +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +fn remote_mark_data( + stream: bool, + config_id: &Option, + config_ids: &[String], + status: Option, + error: Option, +) -> Json { + let mut data = Map::new(); + data.insert("stream".to_string(), Json::Bool(stream)); + if let Some(config_id) = config_id { + data.insert("config_id".to_string(), Json::String(config_id.clone())); + } + if !config_ids.is_empty() { + data.insert( + "config_ids".to_string(), + Json::Array(config_ids.iter().cloned().map(Json::String).collect()), + ); + } + if let Some(status) = status { + data.insert( + "http_status".to_string(), + Json::Number(serde_json::Number::from(status)), + ); + } + if let Some(error) = error { + data.insert("error".to_string(), Json::String(error)); + } + Json::Object(data) +} + +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +fn tool_remote_mark_data( + kind: RemoteCheckKind, + tool_name: &str, + config_id: &Option, + config_ids: &[String], + status: Option, + error: Option, +) -> Json { + let mut data = match remote_mark_data(false, config_id, config_ids, status, error) { + Json::Object(data) => data, + _ => unreachable!("remote_mark_data always returns an object"), + }; + data.insert( + "surface".to_string(), + Json::String(match kind { + RemoteCheckKind::Input => "tool_input".to_string(), + RemoteCheckKind::Output => "tool_output".to_string(), + }), + ); + data.insert("tool_name".to_string(), Json::String(tool_name.to_string())); + Json::Object(data) +} + +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +fn register_remote_backend( + config: NeMoGuardrailsConfig, + ctx: &mut PluginRegistrationContext, +) -> PluginResult<()> { + let remote = config.remote.clone().ok_or_else(|| { + PluginError::InvalidConfig("remote config is required when mode is 'remote'".to_string()) + })?; + let runtime = Arc::new(RemoteBackendRuntime::new(&config, &remote)?); + + if config.input || config.output { + let llm_execution_runtime = Arc::clone(&runtime); + let llm_execution: LlmExecutionFn = Arc::new(move |_name, request, _next| { + let runtime = Arc::clone(&llm_execution_runtime); + Box::pin(async move { runtime.execute(request, false).await }) + }); + ctx.register_llm_execution_intercept("llm_remote_backend", config.priority, llm_execution)?; + + let llm_stream_runtime = Arc::clone(&runtime); + let llm_stream_execution: LlmStreamExecutionFn = Arc::new(move |_name, request, _next| { + let runtime = Arc::clone(&llm_stream_runtime); + Box::pin(async move { runtime.execute_stream(request).await }) + }); + ctx.register_llm_stream_execution_intercept( + "llm_stream_remote_backend", + config.priority, + llm_stream_execution, + )?; + } + + if config.tool_input || config.tool_output { + let tool_runtime = Arc::clone(&runtime); + let enable_tool_input = config.tool_input; + let enable_tool_output = config.tool_output; + let tool_execution: ToolExecutionFn = Arc::new(move |tool_name, args, next| { + let runtime = Arc::clone(&tool_runtime); + let tool_name = tool_name.to_string(); + Box::pin(async move { + let current_args = if enable_tool_input { + runtime.check_tool_input(&tool_name, &args).await? + } else { + args + }; + + let tool_result = next(current_args.clone()).await?; + if !enable_tool_output { + return Ok(tool_result); + } + + runtime + .check_tool_output(&tool_name, ¤t_args, &tool_result) + .await + }) + }); + ctx.register_tool_execution_intercept( + "tool_remote_backend", + config.priority, + tool_execution, + )?; + } + + Ok(()) +} + +#[cfg(any(target_arch = "wasm32", not(feature = "guardrails-remote")))] +fn register_remote_backend( + _config: NeMoGuardrailsConfig, + _ctx: &mut PluginRegistrationContext, +) -> PluginResult<()> { + Err(PluginError::RegistrationFailed( + "built-in NeMo Guardrails remote backend is unavailable in this build".to_string(), + )) +} + fn parse_nemo_guardrails_config( plugin_config: &Map, ) -> PluginResult { @@ -497,6 +1455,8 @@ fn validate_nemo_guardrails_plugin_config( "request_defaults", &[ "context", + "thread_id", + "state", "rails", "llm_params", "llm_output", @@ -526,6 +1486,7 @@ fn validate_nemo_guardrails_plugin_config( validate_config_shape(&mut diagnostics, &config.policy, &config); validate_codec_requirements(&mut diagnostics, &config.policy, &config); validate_surface_selection(&mut diagnostics, &config.policy, &config); + validate_remote_backend_support(&mut diagnostics, &config.policy, &config); validate_request_defaults(&mut diagnostics, &config.policy, &config); diagnostics @@ -869,6 +1830,32 @@ fn validate_surface_selection( ); } +fn validate_remote_backend_support( + diagnostics: &mut Vec, + policy: &ConfigPolicy, + config: &NeMoGuardrailsConfig, +) { + if config.mode != "remote" { + return; + } + + if (config.input || config.output) + && config + .codec + .as_deref() + .is_some_and(|codec| codec != "openai_chat") + { + push_policy_diag( + diagnostics, + policy.unsupported_value, + "nemo_guardrails.unsupported_value", + Some(NEMO_GUARDRAILS_PLUGIN_KIND.to_string()), + Some("codec".to_string()), + "remote mode currently supports only codec = 'openai_chat'".to_string(), + ); + } +} + fn validate_request_defaults( diagnostics: &mut Vec, policy: &ConfigPolicy, @@ -885,6 +1872,55 @@ fn validate_request_defaults( "request_defaults.context", "request_defaults.context must be a JSON object", ); + if let Some(thread_id) = &request_defaults.thread_id + && thread_id.trim().is_empty() + { + push_policy_diag( + diagnostics, + policy.unsupported_value, + "nemo_guardrails.unsupported_value", + Some(NEMO_GUARDRAILS_PLUGIN_KIND.to_string()), + Some("request_defaults.thread_id".to_string()), + "request_defaults.thread_id must not be empty".to_string(), + ); + } + if let Some(thread_id) = &request_defaults.thread_id + && !thread_id.trim().is_empty() + && thread_id.len() < 16 + { + push_policy_diag( + diagnostics, + policy.unsupported_value, + "nemo_guardrails.unsupported_value", + Some(NEMO_GUARDRAILS_PLUGIN_KIND.to_string()), + Some("request_defaults.thread_id".to_string()), + "request_defaults.thread_id must be at least 16 characters long".to_string(), + ); + } + validate_json_object_field( + diagnostics, + policy, + request_defaults.state.as_ref(), + "request_defaults.state", + "request_defaults.state must be a JSON object", + ); + if let Some(state) = request_defaults + .state + .as_ref() + .and_then(|value| value.as_object()) + && !state.is_empty() + && !state.contains_key("events") + && !state.contains_key("state") + { + push_policy_diag( + diagnostics, + policy.unsupported_value, + "nemo_guardrails.unsupported_value", + Some(NEMO_GUARDRAILS_PLUGIN_KIND.to_string()), + Some("request_defaults.state".to_string()), + "request_defaults.state must be empty or contain 'events' or 'state'".to_string(), + ); + } validate_json_object_field( diagnostics, policy, diff --git a/crates/core/tests/unit/plugins/nemo_guardrails/plugin_component_tests.rs b/crates/core/tests/unit/plugins/nemo_guardrails/plugin_component_tests.rs index 22e721b4..ba11c6ad 100644 --- a/crates/core/tests/unit/plugins/nemo_guardrails/plugin_component_tests.rs +++ b/crates/core/tests/unit/plugins/nemo_guardrails/plugin_component_tests.rs @@ -2,10 +2,31 @@ // SPDX-License-Identifier: Apache-2.0 //! Unit tests for the planned NeMo Guardrails plugin component contract. +#![allow(clippy::await_holding_lock)] use super::*; use crate::api::runtime::NemoRelayContextState; +use std::io::{Read, Write}; +use std::net::TcpListener; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, Mutex, mpsc}; +use std::thread; + +use crate::api::event::Event; +use crate::api::llm::{ + LlmAttributes, LlmCallExecuteParams, LlmRequest, LlmStreamCallExecuteParams, llm_call_execute, + llm_stream_call_execute, +}; use crate::api::runtime::global_context; +use crate::api::runtime::{ + LlmExecutionNextFn, LlmJsonStream, LlmStreamExecutionNextFn, create_scope_stack, + set_thread_scope_stack, +}; +use crate::api::subscriber::{deregister_subscriber, register_subscriber}; +use crate::api::tool::{ToolCallExecuteParams, tool_call_execute}; +use crate::codec::openai_chat::{OpenAIChatCodec, OpenAIChatStreamingCodec}; +use crate::codec::streaming::StreamingCodec; +use crate::codec::traits::LlmResponseCodec; use crate::config_editor::{EditorConfig, EditorFieldKind}; #[cfg(feature = "schema")] use crate::plugin::plugin_config_schema; @@ -13,18 +34,19 @@ use crate::plugin::{ PluginComponentSpec, PluginConfig, clear_plugin_configuration, initialize_plugins, list_plugin_kinds, lookup_plugin, validate_plugin_config, }; +use futures::StreamExt; use serde_json::json; fn reset_runtime() { let _ = clear_plugin_configuration(); - let _ = deregister_nemo_guardrails_component(); crate::shared_runtime::reset_runtime_owner_for_tests(); let context = global_context(); *context.write().unwrap() = NemoRelayContextState::new(); } -fn ensure_registered() { - register_nemo_guardrails_component().unwrap(); +fn setup_isolated_thread() { + let stack = create_scope_stack(); + set_thread_scope_stack(stack); } fn component(config: Json) -> PluginComponentSpec { @@ -68,6 +90,137 @@ fn remote_valid_config() -> Json { }) } +#[derive(Debug)] +struct CapturedHttpRequest { + path: String, + content_type: String, + body: Vec, +} + +fn spawn_http_responder( + listener: TcpListener, + response: Vec, + request_tx: mpsc::Sender, +) { + thread::spawn(move || { + let (mut stream, _) = listener.accept().unwrap(); + let request = read_http_request(&mut stream); + stream.write_all(&response).unwrap(); + request_tx.send(request).unwrap(); + }); +} + +fn spawn_http_responder_sequence( + listener: TcpListener, + responses: Vec>, + request_tx: mpsc::Sender, +) { + thread::spawn(move || { + for response in responses { + let (mut stream, _) = listener.accept().unwrap(); + let request = read_http_request(&mut stream); + stream.write_all(&response).unwrap(); + request_tx.send(request).unwrap(); + } + }); +} + +fn read_http_request(stream: &mut impl Read) -> CapturedHttpRequest { + let mut bytes = Vec::new(); + let mut buf = [0_u8; 4096]; + let (header_end, content_length) = read_http_headers(stream, &mut bytes, &mut buf); + read_http_body(stream, &mut bytes, &mut buf, header_end + content_length); + + let headers_text = String::from_utf8_lossy(&bytes[..header_end]); + let request_line = headers_text.lines().next().unwrap(); + CapturedHttpRequest { + path: request_line.split_whitespace().nth(1).unwrap().to_string(), + content_type: header_value(&headers_text, "content-type") + .unwrap_or_default() + .to_string(), + body: bytes[header_end..header_end + content_length].to_vec(), + } +} + +fn read_http_headers( + stream: &mut impl Read, + bytes: &mut Vec, + buf: &mut [u8; 4096], +) -> (usize, usize) { + loop { + let read = stream.read(buf).unwrap(); + if read == 0 { + panic!("remote responder closed before receiving request"); + } + bytes.extend_from_slice(&buf[..read]); + + if let Some(header_end) = bytes.windows(4).position(|window| window == b"\r\n\r\n") { + let header_end = header_end + 4; + let headers_text = String::from_utf8_lossy(&bytes[..header_end]); + let content_length = header_value(&headers_text, "content-length") + .and_then(|value| value.parse::().ok()) + .unwrap_or(0); + return (header_end, content_length); + } + } +} + +fn read_http_body( + stream: &mut impl Read, + bytes: &mut Vec, + buf: &mut [u8; 4096], + expected_total: usize, +) { + while bytes.len() < expected_total { + let read = stream.read(buf).unwrap(); + if read == 0 { + panic!("remote responder closed before full request body"); + } + bytes.extend_from_slice(&buf[..read]); + } +} + +fn header_value<'a>(headers_text: &'a str, header_name: &str) -> Option<&'a str> { + headers_text.lines().find_map(|line| { + let (name, value) = line.split_once(':')?; + if name.eq_ignore_ascii_case(header_name) { + Some(value.trim()) + } else { + None + } + }) +} + +fn make_chat_request(stream: bool) -> LlmRequest { + LlmRequest { + headers: serde_json::Map::new(), + content: json!({ + "model": "gpt-4o-mini", + "messages": [{"role": "user", "content": "hello"}], + "temperature": 0.2, + "stream": stream + }), + } +} + +fn capture_events(name: &str) -> Arc>> { + let events = Arc::new(Mutex::new(Vec::new())); + let sink = Arc::clone(&events); + register_subscriber( + name, + Arc::new(move |event| sink.lock().unwrap().push(event.clone())), + ) + .unwrap(); + events +} + +fn unused_local_endpoint() -> String { + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let address = listener.local_addr().unwrap(); + drop(listener); + format!("http://{address}") +} + #[test] fn editor_schema_tracks_nemo_guardrails_config_types() { let schema = NeMoGuardrailsConfig::editor_schema(); @@ -210,6 +363,8 @@ fn schema_contains_every_supported_nemo_guardrails_option() { "timeout_millis", "python_module", "context", + "thread_id", + "state", "rails", "llm_params", "llm_output", @@ -264,23 +419,14 @@ fn plugin_schema_contains_generic_plugin_surface() { } #[test] -fn registration_is_explicit_not_automatic() { +fn builtin_registration_is_automatic() { let _guard = crate::plugins::nemo_guardrails::test_mutex() .lock() .unwrap_or_else(|err| err.into_inner()); reset_runtime(); - assert!(!list_plugin_kinds().contains(&NEMO_GUARDRAILS_PLUGIN_KIND.to_string())); - assert!(lookup_plugin(NEMO_GUARDRAILS_PLUGIN_KIND).is_none()); - - ensure_registered(); assert!(list_plugin_kinds().contains(&NEMO_GUARDRAILS_PLUGIN_KIND.to_string())); assert!(lookup_plugin(NEMO_GUARDRAILS_PLUGIN_KIND).is_some()); - - ensure_registered(); - assert!(lookup_plugin(NEMO_GUARDRAILS_PLUGIN_KIND).is_some()); - assert!(deregister_nemo_guardrails_component()); - assert!(!deregister_nemo_guardrails_component()); } #[test] @@ -289,7 +435,6 @@ fn disabled_component_validates_and_initializes_without_runtime_work() { .lock() .unwrap_or_else(|err| err.into_inner()); reset_runtime(); - ensure_registered(); let config = PluginConfig { version: 1, @@ -306,7 +451,6 @@ fn duplicate_component_is_rejected_as_singleton() { .lock() .unwrap_or_else(|err| err.into_inner()); reset_runtime(); - ensure_registered(); let config = PluginConfig { version: 1, @@ -332,7 +476,6 @@ fn invalid_shapes_and_values_are_reported() { .lock() .unwrap_or_else(|err| err.into_inner()); reset_runtime(); - ensure_registered(); let invalid_shape = validate_plugin_config(&plugin_config(json!({ "version": "one", @@ -424,6 +567,50 @@ fn invalid_shapes_and_values_are_reported() { .contains("codec must be 'openai_chat', 'openai_responses', or 'anthropic_messages'") })); + let unsupported_remote_codec = validate_plugin_config(&plugin_config(json!({ + "mode": "remote", + "codec": "openai_responses", + "remote": { + "endpoint": "http://localhost:8000", + "config_id": "default" + } + }))); + assert!(unsupported_remote_codec.has_errors()); + assert!(unsupported_remote_codec.diagnostics.iter().any(|diag| { + diag.message + .contains("remote mode currently supports only codec = 'openai_chat'") + })); + + let unsupported_remote_anthropic_codec = validate_plugin_config(&plugin_config(json!({ + "mode": "remote", + "codec": "anthropic_messages", + "remote": { + "endpoint": "http://localhost:8000", + "config_id": "default" + } + }))); + assert!(unsupported_remote_anthropic_codec.has_errors()); + assert!( + unsupported_remote_anthropic_codec + .diagnostics + .iter() + .any(|diag| { + diag.message + .contains("remote mode currently supports only codec = 'openai_chat'") + }) + ); + + let supported_remote_tool_surface = validate_plugin_config(&plugin_config(json!({ + "mode": "remote", + "codec": "openai_chat", + "tool_input": true, + "remote": { + "endpoint": "http://localhost:8000", + "config_id": "default" + } + }))); + assert!(!supported_remote_tool_surface.has_errors()); + let remote_empty_fields = validate_plugin_config(&plugin_config(json!({ "mode": "remote", "codec": "openai_chat", @@ -527,6 +714,8 @@ fn invalid_shapes_and_values_are_reported() { }, "request_defaults": { "context": true, + "thread_id": "short", + "state": {"foo": "bar"}, "llm_params": [], "log": "verbose", "output_vars": 7, @@ -542,6 +731,26 @@ fn invalid_shapes_and_values_are_reported() { .iter() .any(|diag| diag.field.as_deref() == Some("request_defaults.context")) ); + assert!( + invalid_request_defaults + .diagnostics + .iter() + .any(|diag| diag.field.as_deref() == Some("request_defaults.thread_id")) + ); + assert!(invalid_request_defaults.diagnostics.iter().any(|diag| { + diag.message + .contains("request_defaults.thread_id must be at least 16 characters long") + })); + assert!( + invalid_request_defaults + .diagnostics + .iter() + .any(|diag| diag.field.as_deref() == Some("request_defaults.state")) + ); + assert!(invalid_request_defaults.diagnostics.iter().any(|diag| { + diag.message + .contains("request_defaults.state must be empty or contain 'events' or 'state'") + })); assert!( invalid_request_defaults .diagnostics @@ -574,7 +783,6 @@ fn unknown_fields_follow_policy() { .lock() .unwrap_or_else(|err| err.into_inner()); reset_runtime(); - ensure_registered(); let warn_report = validate_plugin_config(&plugin_config(json!({ "mode": "remote", @@ -618,21 +826,1639 @@ fn unknown_fields_follow_policy() { } #[test] -fn enabled_initialization_fails_fast_until_backend_exists() { +fn enabled_local_initialization_fails_fast_until_backend_exists() { let _guard = crate::plugins::nemo_guardrails::test_mutex() .lock() .unwrap_or_else(|err| err.into_inner()); reset_runtime(); - ensure_registered(); - let error = - futures::executor::block_on(initialize_plugins(plugin_config(remote_valid_config()))) - .unwrap_err(); + let error = futures::executor::block_on(initialize_plugins(plugin_config(json!({ + "mode": "local", + "codec": "openai_chat", + "config_path": "./rails" + })))) + .unwrap_err(); match error { crate::plugin::PluginError::RegistrationFailed(message) => { - assert!(message.contains("not implemented yet")); + assert!(message.contains("local backend")); + } + other => panic!("unexpected error: {other}"), + } +} + +#[tokio::test] +async fn remote_initialization_installs_non_streaming_execution_intercept() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + reset_runtime(); + setup_isolated_thread(); + let events = capture_events("nemo-guardrails-remote-execution-events"); + + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let address = listener.local_addr().unwrap(); + let (request_tx, request_rx) = mpsc::channel(); + let response_body = json!({ + "id": "chatcmpl-remote", + "object": "chat.completion", + "created": 1, + "model": "gpt-4o-mini", + "choices": [{ + "index": 0, + "message": {"role": "assistant", "content": "guarded"}, + "finish_reason": "stop" + }], + "guardrails": { + "config_id": "safety-default", + "state": {"state": {"conversation": "server-state"}}, + "output_data": {"decision": "allow"} + } + }) + .to_string(); + let http_response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + response_body.len(), + response_body + ) + .into_bytes(); + spawn_http_responder(listener, http_response, request_tx); + + initialize_plugins(plugin_config(json!({ + "mode": "remote", + "codec": "openai_chat", + "remote": { + "endpoint": format!("http://{address}"), + "config_id": "safety-default", + "headers": {"x-guardrails-auth": "token"}, + "timeout_millis": 5_000 + }, + "request_defaults": { + "context": {"tenant": "acme"}, + "thread_id": "thread-1234567890", + "state": {"state": {"conversation": "client-state"}}, + "rails": {"input": true, "retrieval": ["kb"]}, + "llm_params": {"temperature": 0.1}, + "llm_output": true, + "output_vars": ["answer"], + "log": {"activated_rails": true} + } + }))) + .await + .unwrap(); + + let original_called = Arc::new(AtomicBool::new(false)); + let called = Arc::clone(&original_called); + let func: LlmExecutionNextFn = Arc::new(move |_req| { + called.store(true, Ordering::SeqCst); + Box::pin(async move { Ok(json!({"response": "original"})) }) + }); + + let response = llm_call_execute( + LlmCallExecuteParams::builder() + .name("openai") + .request(make_chat_request(false)) + .func(func) + .attributes(LlmAttributes::empty()) + .response_codec(Arc::new(OpenAIChatCodec) as Arc) + .build(), + ) + .await + .unwrap(); + + assert!(!original_called.load(Ordering::SeqCst)); + assert_eq!(response["id"], json!("chatcmpl-remote")); + assert_eq!(response["object"], json!("chat.completion")); + assert_eq!(response["model"], json!("gpt-4o-mini")); + assert_eq!( + response["choices"][0]["message"]["content"], + json!("guarded") + ); + assert_eq!( + response["guardrails"]["output_data"]["decision"], + json!("allow") + ); + assert_eq!( + response["guardrails"]["state"]["state"]["conversation"], + json!("server-state") + ); + + let captured = request_rx.recv().unwrap(); + assert_eq!(captured.path, "/v1/chat/completions"); + assert!(captured.content_type.starts_with("application/json")); + + let request_json: Json = serde_json::from_slice(&captured.body).unwrap(); + assert_eq!(request_json["messages"][0]["content"], json!("hello")); + assert_eq!(request_json["stream"], json!(false)); + assert_eq!( + request_json["guardrails"]["config_id"], + json!("safety-default") + ); + assert_eq!( + request_json["guardrails"]["context"]["tenant"], + json!("acme") + ); + assert_eq!( + request_json["guardrails"]["thread_id"], + json!("thread-1234567890") + ); + assert_eq!( + request_json["guardrails"]["state"]["state"]["conversation"], + json!("client-state") + ); + assert_eq!( + request_json["guardrails"]["options"]["rails"]["retrieval"], + json!(["kb"]) + ); + assert_eq!( + request_json["guardrails"]["options"]["llm_output"], + json!(true) + ); + + let captured_events = events.lock().unwrap().clone(); + let mark_names: Vec<_> = captured_events + .iter() + .filter(|event| event.kind() == "mark") + .map(|event| event.name().to_string()) + .collect(); + assert!(mark_names.contains(&"nemo_guardrails.remote.start".to_string())); + assert!(mark_names.contains(&"nemo_guardrails.remote.end".to_string())); + + let start_mark = captured_events + .iter() + .find(|event| event.name() == "nemo_guardrails.remote.start") + .unwrap(); + assert_eq!( + start_mark.data().unwrap()["config_id"], + json!("safety-default") + ); + assert_eq!(start_mark.data().unwrap()["stream"], json!(false)); + + let end_mark = captured_events + .iter() + .find(|event| event.name() == "nemo_guardrails.remote.end") + .unwrap(); + assert_eq!(end_mark.data().unwrap()["http_status"], json!(200)); + assert_eq!(end_mark.data().unwrap()["stream"], json!(false)); + + deregister_subscriber("nemo-guardrails-remote-execution-events").unwrap(); +} + +#[tokio::test] +async fn remote_request_uses_config_ids_when_config_id_is_not_set() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + reset_runtime(); + setup_isolated_thread(); + + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let address = listener.local_addr().unwrap(); + let (request_tx, request_rx) = mpsc::channel(); + let response_body = json!({ + "id": "chatcmpl-remote", + "object": "chat.completion", + "created": 1, + "model": "gpt-4o-mini", + "choices": [{ + "index": 0, + "message": {"role": "assistant", "content": "guarded"}, + "finish_reason": "stop" + }] + }) + .to_string(); + let http_response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + response_body.len(), + response_body + ) + .into_bytes(); + spawn_http_responder(listener, http_response, request_tx); + + initialize_plugins(plugin_config(json!({ + "mode": "remote", + "codec": "openai_chat", + "remote": { + "endpoint": format!("http://{address}"), + "config_ids": ["safety-a", "safety-b"] + } + }))) + .await + .unwrap(); + + let func: LlmExecutionNextFn = + Arc::new(move |_req| Box::pin(async move { Ok(json!({"response": "original"})) })); + + let _ = llm_call_execute( + LlmCallExecuteParams::builder() + .name("openai") + .request(make_chat_request(false)) + .func(func) + .attributes(LlmAttributes::empty()) + .response_codec(Arc::new(OpenAIChatCodec) as Arc) + .build(), + ) + .await + .unwrap(); + + let captured = request_rx.recv().unwrap(); + let request_json: Json = serde_json::from_slice(&captured.body).unwrap(); + assert_eq!( + request_json["guardrails"]["config_ids"], + json!(["safety-a", "safety-b"]) + ); + assert!(request_json["guardrails"].get("config_id").is_none()); +} + +#[tokio::test] +async fn remote_initialization_installs_stream_execution_intercept() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + reset_runtime(); + setup_isolated_thread(); + let events = capture_events("nemo-guardrails-remote-stream-events"); + + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let address = listener.local_addr().unwrap(); + let (request_tx, request_rx) = mpsc::channel(); + let sse_body = concat!( + "data: {\"id\":\"chatcmpl-remote\",\"object\":\"chat.completion.chunk\",\"created\":1,\"model\":\"gpt-4o-mini\",\"choices\":[{\"index\":0,\"delta\":{\"role\":\"assistant\",\"content\":\"guard\"},\"finish_reason\":null}]}\n\n", + "data: {\"id\":\"chatcmpl-remote\",\"object\":\"chat.completion.chunk\",\"created\":1,\"model\":\"gpt-4o-mini\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"ed\"},\"finish_reason\":\"stop\"}]}\n\n", + "data: [DONE]\n\n" + ); + let http_response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: text/event-stream\r\nContent-Length: {}\r\n\r\n{}", + sse_body.len(), + sse_body + ) + .into_bytes(); + spawn_http_responder(listener, http_response, request_tx); + + initialize_plugins(plugin_config(json!({ + "mode": "remote", + "codec": "openai_chat", + "remote": { + "endpoint": format!("http://{address}"), + "config_id": "safety-default" + } + }))) + .await + .unwrap(); + + let original_called = Arc::new(AtomicBool::new(false)); + let called = Arc::clone(&original_called); + let func: LlmStreamExecutionNextFn = Arc::new(move |_req| { + called.store(true, Ordering::SeqCst); + Box::pin(async move { + let stream = tokio_stream::iter(vec![Ok(json!({"chunk": "original"}))]); + Ok(Box::pin(stream) as LlmJsonStream) + }) + }); + + let streaming_codec = OpenAIChatStreamingCodec::new(); + let collector = streaming_codec.collector(); + let finalizer = streaming_codec.finalizer(); + let response_codec: Arc = Arc::new(OpenAIChatCodec); + + let mut stream = llm_stream_call_execute( + LlmStreamCallExecuteParams::builder() + .name("openai") + .request(make_chat_request(true)) + .func(func) + .collector(collector) + .finalizer(finalizer) + .attributes(LlmAttributes::STREAMING) + .response_codec(response_codec) + .build(), + ) + .await + .unwrap(); + + let mut chunks = Vec::new(); + while let Some(chunk) = stream.next().await { + chunks.push(chunk.unwrap()); + } + + assert!(!original_called.load(Ordering::SeqCst)); + assert_eq!(chunks.len(), 2); + assert_eq!(chunks[0]["choices"][0]["delta"]["content"], json!("guard")); + assert_eq!(chunks[1]["choices"][0]["delta"]["content"], json!("ed")); + + let captured = request_rx.recv().unwrap(); + let request_json: Json = serde_json::from_slice(&captured.body).unwrap(); + assert_eq!(request_json["stream"], json!(true)); + assert_eq!( + request_json["guardrails"]["config_id"], + json!("safety-default") + ); + + let captured_events = events.lock().unwrap().clone(); + let start_mark = captured_events + .iter() + .find(|event| event.name() == "nemo_guardrails.remote.start") + .unwrap(); + assert_eq!(start_mark.data().unwrap()["stream"], json!(true)); + + let end_mark = captured_events + .iter() + .find(|event| event.name() == "nemo_guardrails.remote.end") + .unwrap(); + assert_eq!(end_mark.data().unwrap()["http_status"], json!(200)); + assert_eq!(end_mark.data().unwrap()["stream"], json!(true)); + + deregister_subscriber("nemo-guardrails-remote-stream-events").unwrap(); +} + +#[tokio::test] +async fn remote_non_streaming_http_errors_are_reported_and_marked() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + reset_runtime(); + setup_isolated_thread(); + let events = capture_events("nemo-guardrails-remote-error-events"); + + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let address = listener.local_addr().unwrap(); + let (request_tx, _request_rx) = mpsc::channel(); + let response_body = r#"{"error":"backend unavailable"}"#; + let http_response = format!( + "HTTP/1.1 502 Bad Gateway\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + response_body.len(), + response_body + ) + .into_bytes(); + spawn_http_responder(listener, http_response, request_tx); + + initialize_plugins(plugin_config(json!({ + "mode": "remote", + "codec": "openai_chat", + "remote": { + "endpoint": format!("http://{address}"), + "config_id": "safety-default" + } + }))) + .await + .unwrap(); + + let original_called = Arc::new(AtomicBool::new(false)); + let called = Arc::clone(&original_called); + let func: LlmExecutionNextFn = Arc::new(move |_req| { + called.store(true, Ordering::SeqCst); + Box::pin(async move { Ok(json!({"response": "original"})) }) + }); + + let error = llm_call_execute( + LlmCallExecuteParams::builder() + .name("openai") + .request(make_chat_request(false)) + .func(func) + .attributes(LlmAttributes::empty()) + .response_codec(Arc::new(OpenAIChatCodec) as Arc) + .build(), + ) + .await + .unwrap_err(); + + assert!(!original_called.load(Ordering::SeqCst)); + match error { + crate::error::FlowError::Internal(message) => { + assert!(message.contains("status 502")); + assert!(message.contains("backend unavailable")); + } + other => panic!("unexpected error: {other}"), + } + + let captured_events = events.lock().unwrap().clone(); + assert!( + captured_events + .iter() + .any(|event| event.name() == "nemo_guardrails.remote.start") + ); + let error_mark = captured_events + .iter() + .find(|event| event.name() == "nemo_guardrails.remote.error") + .unwrap(); + assert_eq!(error_mark.data().unwrap()["http_status"], json!(502)); + assert_eq!(error_mark.data().unwrap()["stream"], json!(false)); + assert!( + error_mark.data().unwrap()["error"] + .as_str() + .unwrap() + .contains("backend unavailable") + ); + + deregister_subscriber("nemo-guardrails-remote-error-events").unwrap(); +} + +#[tokio::test] +async fn remote_streaming_http_errors_are_reported_and_marked() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + reset_runtime(); + setup_isolated_thread(); + let events = capture_events("nemo-guardrails-remote-stream-error-events"); + + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let address = listener.local_addr().unwrap(); + let (request_tx, _request_rx) = mpsc::channel(); + let response_body = r#"{"error":"stream backend unavailable"}"#; + let http_response = format!( + "HTTP/1.1 503 Service Unavailable\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + response_body.len(), + response_body + ) + .into_bytes(); + spawn_http_responder(listener, http_response, request_tx); + + initialize_plugins(plugin_config(json!({ + "mode": "remote", + "codec": "openai_chat", + "remote": { + "endpoint": format!("http://{address}"), + "config_id": "safety-default" + } + }))) + .await + .unwrap(); + + let original_called = Arc::new(AtomicBool::new(false)); + let called = Arc::clone(&original_called); + let func: LlmStreamExecutionNextFn = Arc::new(move |_req| { + called.store(true, Ordering::SeqCst); + Box::pin(async move { + let stream = tokio_stream::iter(vec![Ok(json!({"chunk": "original"}))]); + Ok(Box::pin(stream) as LlmJsonStream) + }) + }); + + let streaming_codec = OpenAIChatStreamingCodec::new(); + let collector = streaming_codec.collector(); + let finalizer = streaming_codec.finalizer(); + let response_codec: Arc = Arc::new(OpenAIChatCodec); + + let error = match llm_stream_call_execute( + LlmStreamCallExecuteParams::builder() + .name("openai") + .request(make_chat_request(true)) + .func(func) + .collector(collector) + .finalizer(finalizer) + .attributes(LlmAttributes::STREAMING) + .response_codec(response_codec) + .build(), + ) + .await + { + Ok(_) => panic!("expected remote streaming request to fail"), + Err(error) => error, + }; + + assert!(!original_called.load(Ordering::SeqCst)); + match error { + crate::error::FlowError::Internal(message) => { + assert!(message.contains("status 503")); + assert!(message.contains("stream backend unavailable")); + } + other => panic!("unexpected error: {other}"), + } + + let captured_events = events.lock().unwrap().clone(); + assert!( + captured_events + .iter() + .any(|event| event.name() == "nemo_guardrails.remote.start") + ); + let error_mark = captured_events + .iter() + .find(|event| event.name() == "nemo_guardrails.remote.error") + .unwrap(); + assert_eq!(error_mark.data().unwrap()["http_status"], json!(503)); + assert_eq!(error_mark.data().unwrap()["stream"], json!(true)); + assert!( + error_mark.data().unwrap()["error"] + .as_str() + .unwrap() + .contains("stream backend unavailable") + ); + + deregister_subscriber("nemo-guardrails-remote-stream-error-events").unwrap(); +} + +#[tokio::test] +async fn remote_non_streaming_invalid_json_is_reported_and_marked() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + reset_runtime(); + setup_isolated_thread(); + let events = capture_events("nemo-guardrails-remote-invalid-json-events"); + + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let address = listener.local_addr().unwrap(); + let (request_tx, _request_rx) = mpsc::channel(); + let response_body = "{not-json}"; + let http_response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + response_body.len(), + response_body + ) + .into_bytes(); + spawn_http_responder(listener, http_response, request_tx); + + initialize_plugins(plugin_config(json!({ + "mode": "remote", + "codec": "openai_chat", + "remote": { + "endpoint": format!("http://{address}"), + "config_id": "safety-default" + } + }))) + .await + .unwrap(); + + let func: LlmExecutionNextFn = + Arc::new(move |_req| Box::pin(async move { Ok(json!({"response": "original"})) })); + + let error = llm_call_execute( + LlmCallExecuteParams::builder() + .name("openai") + .request(make_chat_request(false)) + .func(func) + .attributes(LlmAttributes::empty()) + .response_codec(Arc::new(OpenAIChatCodec) as Arc) + .build(), + ) + .await + .unwrap_err(); + + match error { + crate::error::FlowError::Internal(message) => { + assert!(message.contains("failed to parse remote response JSON")); } other => panic!("unexpected error: {other}"), } + + let captured_events = events.lock().unwrap().clone(); + let error_mark = captured_events + .iter() + .find(|event| event.name() == "nemo_guardrails.remote.error") + .unwrap(); + assert_eq!(error_mark.data().unwrap()["http_status"], json!(200)); + assert_eq!(error_mark.data().unwrap()["stream"], json!(false)); + + deregister_subscriber("nemo-guardrails-remote-invalid-json-events").unwrap(); +} + +#[tokio::test] +async fn remote_streaming_malformed_chunk_is_reported_and_marked() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + reset_runtime(); + setup_isolated_thread(); + let events = capture_events("nemo-guardrails-remote-malformed-stream-events"); + + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let address = listener.local_addr().unwrap(); + let (request_tx, _request_rx) = mpsc::channel(); + let sse_body = "data: {not-json}\n\n"; + let http_response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: text/event-stream\r\nContent-Length: {}\r\n\r\n{}", + sse_body.len(), + sse_body + ) + .into_bytes(); + spawn_http_responder(listener, http_response, request_tx); + + initialize_plugins(plugin_config(json!({ + "mode": "remote", + "codec": "openai_chat", + "remote": { + "endpoint": format!("http://{address}"), + "config_id": "safety-default" + } + }))) + .await + .unwrap(); + + let func: LlmStreamExecutionNextFn = Arc::new(move |_req| { + Box::pin(async move { + let stream = tokio_stream::iter(vec![Ok(json!({"chunk": "original"}))]); + Ok(Box::pin(stream) as LlmJsonStream) + }) + }); + + let streaming_codec = OpenAIChatStreamingCodec::new(); + let collector = streaming_codec.collector(); + let finalizer = streaming_codec.finalizer(); + let response_codec: Arc = Arc::new(OpenAIChatCodec); + + let mut stream = llm_stream_call_execute( + LlmStreamCallExecuteParams::builder() + .name("openai") + .request(make_chat_request(true)) + .func(func) + .collector(collector) + .finalizer(finalizer) + .attributes(LlmAttributes::STREAMING) + .response_codec(response_codec) + .build(), + ) + .await + .unwrap(); + + let error = stream.next().await.unwrap().unwrap_err(); + match error { + crate::error::FlowError::Internal(message) => { + assert!(!message.is_empty()); + } + other => panic!("unexpected error: {other}"), + } + + let captured_events = events.lock().unwrap().clone(); + let error_mark = captured_events + .iter() + .find(|event| event.name() == "nemo_guardrails.remote.error") + .unwrap(); + assert_eq!(error_mark.data().unwrap()["http_status"], json!(200)); + assert_eq!(error_mark.data().unwrap()["stream"], json!(true)); + + deregister_subscriber("nemo-guardrails-remote-malformed-stream-events").unwrap(); +} + +#[tokio::test] +async fn remote_preflight_tool_choice_failure_is_reported_and_marked() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + reset_runtime(); + setup_isolated_thread(); + let events = capture_events("nemo-guardrails-remote-preflight-error-events"); + + initialize_plugins(plugin_config(json!({ + "mode": "remote", + "codec": "openai_chat", + "remote": { + "endpoint": unused_local_endpoint(), + "config_id": "safety-default" + } + }))) + .await + .unwrap(); + + let func: LlmExecutionNextFn = + Arc::new(move |_req| Box::pin(async move { Ok(json!({"response": "original"})) })); + let request = LlmRequest { + headers: serde_json::Map::new(), + content: json!({ + "model": "gpt-4o-mini", + "messages": [{"role": "user", "content": "hello"}], + "tools": [{ + "type": "function", + "function": { + "name": "lookup", + "description": "Lookup data", + "parameters": {"type": "object"} + } + }] + }), + }; + + let error = llm_call_execute( + LlmCallExecuteParams::builder() + .name("openai") + .request(request) + .func(func) + .attributes(LlmAttributes::empty()) + .response_codec(Arc::new(OpenAIChatCodec) as Arc) + .build(), + ) + .await + .unwrap_err(); + + match error { + crate::error::FlowError::Internal(message) => { + assert!(message.contains("does not support OpenAI tool definitions or tool_choice")); + } + other => panic!("unexpected error: {other}"), + } + + let captured_events = events.lock().unwrap().clone(); + assert!( + captured_events + .iter() + .any(|event| event.name() == "nemo_guardrails.remote.start") + ); + let error_mark = captured_events + .iter() + .find(|event| event.name() == "nemo_guardrails.remote.error") + .unwrap(); + assert_eq!(error_mark.data().unwrap()["stream"], json!(false)); + assert!( + error_mark.data().unwrap()["error"] + .as_str() + .unwrap() + .contains("does not support OpenAI tool definitions or tool_choice") + ); + + deregister_subscriber("nemo-guardrails-remote-preflight-error-events").unwrap(); +} + +#[tokio::test] +async fn remote_transport_failure_is_reported_and_marked() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + reset_runtime(); + setup_isolated_thread(); + let events = capture_events("nemo-guardrails-remote-transport-error-events"); + + initialize_plugins(plugin_config(json!({ + "mode": "remote", + "codec": "openai_chat", + "remote": { + "endpoint": unused_local_endpoint(), + "config_id": "safety-default", + "timeout_millis": 50 + } + }))) + .await + .unwrap(); + + let func: LlmExecutionNextFn = + Arc::new(move |_req| Box::pin(async move { Ok(json!({"response": "original"})) })); + + let error = llm_call_execute( + LlmCallExecuteParams::builder() + .name("openai") + .request(make_chat_request(false)) + .func(func) + .attributes(LlmAttributes::empty()) + .response_codec(Arc::new(OpenAIChatCodec) as Arc) + .build(), + ) + .await + .unwrap_err(); + + match error { + crate::error::FlowError::Internal(message) => { + assert!(message.contains("remote request failed")); + } + other => panic!("unexpected error: {other}"), + } + + let captured_events = events.lock().unwrap().clone(); + let error_mark = captured_events + .iter() + .find(|event| event.name() == "nemo_guardrails.remote.error") + .unwrap(); + assert_eq!(error_mark.data().unwrap()["stream"], json!(false)); + assert!(error_mark.data().unwrap().get("http_status").is_none()); + + deregister_subscriber("nemo-guardrails-remote-transport-error-events").unwrap(); +} + +#[tokio::test] +async fn remote_success_without_guardrails_payload_is_allowed() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + reset_runtime(); + setup_isolated_thread(); + + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let address = listener.local_addr().unwrap(); + let (request_tx, _request_rx) = mpsc::channel(); + let response_body = json!({ + "id": "chatcmpl-remote", + "object": "chat.completion", + "created": 1, + "model": "gpt-4o-mini", + "choices": [{ + "index": 0, + "message": {"role": "assistant", "content": "guarded"}, + "finish_reason": "stop" + }] + }) + .to_string(); + let http_response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + response_body.len(), + response_body + ) + .into_bytes(); + spawn_http_responder(listener, http_response, request_tx); + + initialize_plugins(plugin_config(json!({ + "mode": "remote", + "codec": "openai_chat", + "remote": { + "endpoint": format!("http://{address}"), + "config_id": "safety-default" + } + }))) + .await + .unwrap(); + + let func: LlmExecutionNextFn = + Arc::new(move |_req| Box::pin(async move { Ok(json!({"response": "original"})) })); + + let response = llm_call_execute( + LlmCallExecuteParams::builder() + .name("openai") + .request(make_chat_request(false)) + .func(func) + .attributes(LlmAttributes::empty()) + .response_codec(Arc::new(OpenAIChatCodec) as Arc) + .build(), + ) + .await + .unwrap(); + + assert_eq!(response["id"], json!("chatcmpl-remote")); + assert!(response.get("guardrails").is_none()); +} + +#[tokio::test] +async fn remote_tool_input_block_rejects_before_tool_execution() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + reset_runtime(); + setup_isolated_thread(); + let events = capture_events("nemo-guardrails-remote-tool-input-events"); + + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let address = listener.local_addr().unwrap(); + let (request_tx, request_rx) = mpsc::channel(); + let response_body = json!({ + "id": "chatcmpl-tool-input-blocked", + "object": "chat.completion", + "created": 1, + "model": "", + "choices": [{ + "index": 0, + "message": {"role": "assistant", "content": "blocked"}, + "finish_reason": "stop" + }], + "guardrails": { + "config_id": "safety-default", + "log": { + "activated_rails": [{ + "name": "tool_input_block", + "stop": true + }] + } + } + }) + .to_string(); + let http_response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + response_body.len(), + response_body + ) + .into_bytes(); + spawn_http_responder(listener, http_response, request_tx); + + initialize_plugins(plugin_config(json!({ + "mode": "remote", + "input": false, + "output": false, + "tool_input": true, + "remote": { + "endpoint": format!("http://{address}"), + "config_id": "safety-default" + } + }))) + .await + .unwrap(); + + let original_called = Arc::new(AtomicBool::new(false)); + let called = Arc::clone(&original_called); + let error = tool_call_execute( + ToolCallExecuteParams::builder() + .name("weather_lookup") + .args(json!({"city": "Phoenix"})) + .func(Arc::new(move |_args| { + called.store(true, Ordering::SeqCst); + Box::pin(async move { Ok(json!({"forecast": "sunny"})) }) + })) + .build(), + ) + .await + .unwrap_err(); + + assert!(!original_called.load(Ordering::SeqCst)); + match error { + crate::error::FlowError::GuardrailRejected(message) => { + assert!(message.contains("tool_input")); + } + other => panic!("unexpected error: {other}"), + } + + let captured = request_rx.recv().unwrap(); + let request_json: Json = serde_json::from_slice(&captured.body).unwrap(); + assert_eq!( + request_json["guardrails"]["options"]["rails"]["input"], + json!(true) + ); + assert_eq!( + request_json["guardrails"]["options"]["rails"]["output"], + json!(false) + ); + + let captured_events = events.lock().unwrap().clone(); + let start_mark = captured_events + .iter() + .find(|event| event.name() == "nemo_guardrails.remote.start") + .unwrap(); + assert_eq!(start_mark.data().unwrap()["surface"], json!("tool_input")); + assert_eq!( + start_mark.data().unwrap()["tool_name"], + json!("weather_lookup") + ); + let end_mark = captured_events + .iter() + .find(|event| event.name() == "nemo_guardrails.remote.end") + .unwrap(); + assert_eq!(end_mark.data().unwrap()["surface"], json!("tool_input")); + + deregister_subscriber("nemo-guardrails-remote-tool-input-events").unwrap(); +} + +#[tokio::test] +async fn remote_tool_input_can_rewrite_tool_arguments() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + reset_runtime(); + setup_isolated_thread(); + + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let address = listener.local_addr().unwrap(); + let (request_tx, request_rx) = mpsc::channel(); + let response_body = json!({ + "id": "chatcmpl-tool-input-modified", + "object": "chat.completion", + "created": 1, + "model": "", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "{\"tool_name\":\"weather_lookup\",\"arguments\":{\"city\":\"Boston\"}}" + }, + "finish_reason": "stop" + }], + "guardrails": { + "config_id": "safety-default", + "log": { + "activated_rails": [] + } + } + }) + .to_string(); + let http_response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + response_body.len(), + response_body + ) + .into_bytes(); + spawn_http_responder(listener, http_response, request_tx); + + initialize_plugins(plugin_config(json!({ + "mode": "remote", + "input": false, + "output": false, + "tool_input": true, + "remote": { + "endpoint": format!("http://{address}"), + "config_id": "safety-default" + } + }))) + .await + .unwrap(); + + let seen_args = Arc::new(Mutex::new(None::)); + let seen_args_for_call = Arc::clone(&seen_args); + let result = tool_call_execute( + ToolCallExecuteParams::builder() + .name("weather_lookup") + .args(json!({"city": "Phoenix"})) + .func(Arc::new(move |args| { + *seen_args_for_call.lock().unwrap() = Some(args.clone()); + Box::pin(async move { Ok(json!({"forecast": "sunny"})) }) + })) + .build(), + ) + .await + .unwrap(); + + assert_eq!(result, json!({"forecast": "sunny"})); + assert_eq!(*seen_args.lock().unwrap(), Some(json!({"city": "Boston"}))); + + let captured = request_rx.recv().unwrap(); + let request_json: Json = serde_json::from_slice(&captured.body).unwrap(); + assert_eq!(request_json["messages"][0]["role"], json!("user")); +} + +#[tokio::test] +async fn remote_tool_output_can_rewrite_tool_result() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + reset_runtime(); + setup_isolated_thread(); + + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let address = listener.local_addr().unwrap(); + let (request_tx, request_rx) = mpsc::channel(); + let response_body = json!({ + "id": "chatcmpl-tool-output-modified", + "object": "chat.completion", + "created": 1, + "model": "", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "{\"tool_name\":\"weather_lookup\",\"arguments\":{\"city\":\"Phoenix\"},\"result\":{\"forecast\":\"cloudy\"}}" + }, + "finish_reason": "stop" + }], + "guardrails": { + "config_id": "safety-default", + "log": { + "activated_rails": [] + } + } + }) + .to_string(); + let http_response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + response_body.len(), + response_body + ) + .into_bytes(); + spawn_http_responder(listener, http_response, request_tx); + + initialize_plugins(plugin_config(json!({ + "mode": "remote", + "input": false, + "output": false, + "tool_output": true, + "remote": { + "endpoint": format!("http://{address}"), + "config_id": "safety-default" + } + }))) + .await + .unwrap(); + + let result = tool_call_execute( + ToolCallExecuteParams::builder() + .name("weather_lookup") + .args(json!({"city": "Phoenix"})) + .func(Arc::new(move |_args| { + Box::pin(async move { Ok(json!({"forecast": "sunny"})) }) + })) + .build(), + ) + .await + .unwrap(); + + assert_eq!(result, json!({"forecast": "cloudy"})); + + let captured = request_rx.recv().unwrap(); + let request_json: Json = serde_json::from_slice(&captured.body).unwrap(); + assert_eq!( + request_json["guardrails"]["options"]["rails"]["input"], + json!(false) + ); + assert_eq!( + request_json["guardrails"]["options"]["rails"]["output"], + json!(true) + ); +} + +#[tokio::test] +async fn remote_tool_input_invalid_modified_arguments_are_reported() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + reset_runtime(); + setup_isolated_thread(); + + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let address = listener.local_addr().unwrap(); + let (request_tx, _request_rx) = mpsc::channel(); + let response_body = json!({ + "id": "chatcmpl-tool-input-invalid", + "object": "chat.completion", + "created": 1, + "model": "", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "{not-json}" + }, + "finish_reason": "stop" + }], + "guardrails": { + "config_id": "safety-default", + "log": { + "activated_rails": [] + } + } + }) + .to_string(); + let http_response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + response_body.len(), + response_body + ) + .into_bytes(); + spawn_http_responder(listener, http_response, request_tx); + + initialize_plugins(plugin_config(json!({ + "mode": "remote", + "input": false, + "output": false, + "tool_input": true, + "remote": { + "endpoint": format!("http://{address}"), + "config_id": "safety-default" + } + }))) + .await + .unwrap(); + + let error = tool_call_execute( + ToolCallExecuteParams::builder() + .name("weather_lookup") + .args(json!({"city": "Phoenix"})) + .func(Arc::new(move |_args| { + Box::pin(async move { Ok(json!({"forecast": "sunny"})) }) + })) + .build(), + ) + .await + .unwrap_err(); + + match error { + crate::error::FlowError::Internal(message) => { + assert!(message.contains("modified tool arguments content that is not valid JSON")); + } + other => panic!("unexpected error: {other}"), + } +} + +#[tokio::test] +async fn remote_tool_output_missing_result_field_is_reported() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + reset_runtime(); + setup_isolated_thread(); + + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let address = listener.local_addr().unwrap(); + let (request_tx, _request_rx) = mpsc::channel(); + let response_body = json!({ + "id": "chatcmpl-tool-output-missing-result", + "object": "chat.completion", + "created": 1, + "model": "", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "{\"tool_name\":\"weather_lookup\",\"arguments\":{\"city\":\"Phoenix\"}}" + }, + "finish_reason": "stop" + }], + "guardrails": { + "config_id": "safety-default", + "log": { + "activated_rails": [] + } + } + }) + .to_string(); + let http_response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + response_body.len(), + response_body + ) + .into_bytes(); + spawn_http_responder(listener, http_response, request_tx); + + initialize_plugins(plugin_config(json!({ + "mode": "remote", + "input": false, + "output": false, + "tool_output": true, + "remote": { + "endpoint": format!("http://{address}"), + "config_id": "safety-default" + } + }))) + .await + .unwrap(); + + let error = tool_call_execute( + ToolCallExecuteParams::builder() + .name("weather_lookup") + .args(json!({"city": "Phoenix"})) + .func(Arc::new(move |_args| { + Box::pin(async move { Ok(json!({"forecast": "sunny"})) }) + })) + .build(), + ) + .await + .unwrap_err(); + + match error { + crate::error::FlowError::Internal(message) => { + assert!(message.contains("without a 'result' field")); + } + other => panic!("unexpected error: {other}"), + } +} + +#[tokio::test] +async fn remote_tool_output_does_not_run_when_tool_callback_errors() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + reset_runtime(); + setup_isolated_thread(); + + initialize_plugins(plugin_config(json!({ + "mode": "remote", + "input": false, + "output": false, + "tool_output": true, + "remote": { + "endpoint": unused_local_endpoint(), + "config_id": "safety-default" + } + }))) + .await + .unwrap(); + + let error = tool_call_execute( + ToolCallExecuteParams::builder() + .name("weather_lookup") + .args(json!({"city": "Phoenix"})) + .func(Arc::new(move |_args| { + Box::pin(async move { + Err(crate::error::FlowError::Internal( + "tool callback failed".to_string(), + )) + }) + })) + .build(), + ) + .await + .unwrap_err(); + + match error { + crate::error::FlowError::Internal(message) => { + assert_eq!(message, "tool callback failed"); + } + other => panic!("unexpected error: {other}"), + } +} + +#[tokio::test] +async fn remote_tool_input_rewrite_with_mismatched_tool_name_is_rejected() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + reset_runtime(); + setup_isolated_thread(); + + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let address = listener.local_addr().unwrap(); + let (request_tx, _request_rx) = mpsc::channel(); + let response_body = json!({ + "id": "chatcmpl-tool-input-mismatch", + "object": "chat.completion", + "created": 1, + "model": "", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "{\"tool_name\":\"different_lookup\",\"arguments\":{\"city\":\"Boston\"}}" + }, + "finish_reason": "stop" + }], + "guardrails": { + "config_id": "safety-default", + "log": { + "activated_rails": [] + } + } + }) + .to_string(); + let http_response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + response_body.len(), + response_body + ) + .into_bytes(); + spawn_http_responder(listener, http_response, request_tx); + + initialize_plugins(plugin_config(json!({ + "mode": "remote", + "input": false, + "output": false, + "tool_input": true, + "remote": { + "endpoint": format!("http://{address}"), + "config_id": "safety-default" + } + }))) + .await + .unwrap(); + + let error = tool_call_execute( + ToolCallExecuteParams::builder() + .name("weather_lookup") + .args(json!({"city": "Phoenix"})) + .func(Arc::new(move |_args| { + Box::pin(async move { Ok(json!({"forecast": "sunny"})) }) + })) + .build(), + ) + .await + .unwrap_err(); + + match error { + crate::error::FlowError::Internal(message) => { + assert!(message.contains("unexpected tool 'different_lookup'")); + } + other => panic!("unexpected error: {other}"), + } +} + +#[tokio::test] +async fn remote_tool_input_and_output_run_in_order() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + reset_runtime(); + setup_isolated_thread(); + + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let address = listener.local_addr().unwrap(); + let (request_tx, request_rx) = mpsc::channel(); + let input_response_body = json!({ + "id": "chatcmpl-tool-input-modified", + "object": "chat.completion", + "created": 1, + "model": "", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "{\"tool_name\":\"weather_lookup\",\"arguments\":{\"city\":\"Boston\"}}" + }, + "finish_reason": "stop" + }], + "guardrails": { + "config_id": "safety-default", + "log": { + "activated_rails": [] + } + } + }) + .to_string(); + let output_response_body = json!({ + "id": "chatcmpl-tool-output-modified", + "object": "chat.completion", + "created": 1, + "model": "", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "{\"tool_name\":\"weather_lookup\",\"arguments\":{\"city\":\"Boston\"},\"result\":{\"forecast\":\"cloudy\"}}" + }, + "finish_reason": "stop" + }], + "guardrails": { + "config_id": "safety-default", + "log": { + "activated_rails": [] + } + } + }) + .to_string(); + let input_response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + input_response_body.len(), + input_response_body + ) + .into_bytes(); + let output_response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + output_response_body.len(), + output_response_body + ) + .into_bytes(); + spawn_http_responder_sequence(listener, vec![input_response, output_response], request_tx); + + initialize_plugins(plugin_config(json!({ + "mode": "remote", + "input": false, + "output": false, + "tool_input": true, + "tool_output": true, + "remote": { + "endpoint": format!("http://{address}"), + "config_id": "safety-default" + } + }))) + .await + .unwrap(); + + let seen_args = Arc::new(Mutex::new(None::)); + let seen_args_for_call = Arc::clone(&seen_args); + let result = tool_call_execute( + ToolCallExecuteParams::builder() + .name("weather_lookup") + .args(json!({"city": "Phoenix"})) + .func(Arc::new(move |args| { + *seen_args_for_call.lock().unwrap() = Some(args.clone()); + Box::pin(async move { Ok(json!({"forecast": "sunny"})) }) + })) + .build(), + ) + .await + .unwrap(); + + assert_eq!(*seen_args.lock().unwrap(), Some(json!({"city": "Boston"}))); + assert_eq!(result, json!({"forecast": "cloudy"})); + + let first_request = request_rx.recv().unwrap(); + let first_request_json: Json = serde_json::from_slice(&first_request.body).unwrap(); + assert_eq!(first_request_json["messages"][0]["role"], json!("user")); + assert_eq!( + first_request_json["guardrails"]["options"]["rails"]["input"], + json!(true) + ); + assert_eq!( + first_request_json["guardrails"]["options"]["rails"]["output"], + json!(false) + ); + + let second_request = request_rx.recv().unwrap(); + let second_request_json: Json = serde_json::from_slice(&second_request.body).unwrap(); + assert_eq!(second_request_json["messages"][0]["role"], json!("user")); + assert_eq!( + second_request_json["messages"][1]["role"], + json!("assistant") + ); + assert_eq!( + second_request_json["guardrails"]["options"]["rails"]["input"], + json!(false) + ); + assert_eq!( + second_request_json["guardrails"]["options"]["rails"]["output"], + json!(true) + ); +} + +#[tokio::test] +async fn remote_tool_checks_forward_context_state_and_thread_id() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + reset_runtime(); + setup_isolated_thread(); + + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let address = listener.local_addr().unwrap(); + let (request_tx, request_rx) = mpsc::channel(); + let response_body = json!({ + "id": "chatcmpl-tool-input-context", + "object": "chat.completion", + "created": 1, + "model": "", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "{\"tool_name\":\"weather_lookup\",\"arguments\":{\"city\":\"Phoenix\"}}" + }, + "finish_reason": "stop" + }], + "guardrails": { + "config_id": "safety-default", + "log": { + "activated_rails": [] + } + } + }) + .to_string(); + let http_response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + response_body.len(), + response_body + ) + .into_bytes(); + spawn_http_responder(listener, http_response, request_tx); + + initialize_plugins(plugin_config(json!({ + "mode": "remote", + "input": false, + "output": false, + "tool_input": true, + "remote": { + "endpoint": format!("http://{address}"), + "config_id": "safety-default" + }, + "request_defaults": { + "context": {"tenant": "smoke"}, + "thread_id": "1234567890abcdef", + "state": {"events": []} + } + }))) + .await + .unwrap(); + + let result = tool_call_execute( + ToolCallExecuteParams::builder() + .name("weather_lookup") + .args(json!({"city": "Phoenix"})) + .func(Arc::new(move |_args| { + Box::pin(async move { Ok(json!({"forecast": "sunny"})) }) + })) + .build(), + ) + .await + .unwrap(); + + assert_eq!(result, json!({"forecast": "sunny"})); + + let captured = request_rx.recv().unwrap(); + let request_json: Json = serde_json::from_slice(&captured.body).unwrap(); + assert_eq!( + request_json["guardrails"]["context"], + json!({"tenant": "smoke"}) + ); + assert_eq!( + request_json["guardrails"]["thread_id"], + json!("1234567890abcdef") + ); + assert_eq!(request_json["guardrails"]["state"], json!({"events": []})); +} + +#[tokio::test] +async fn remote_tool_only_configuration_does_not_intercept_llm_calls() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + reset_runtime(); + setup_isolated_thread(); + + initialize_plugins(plugin_config(json!({ + "mode": "remote", + "input": false, + "output": false, + "tool_input": true, + "remote": { + "endpoint": unused_local_endpoint(), + "config_id": "safety-default" + } + }))) + .await + .unwrap(); + + let expected = json!({"response": "original"}); + let func: LlmExecutionNextFn = Arc::new(move |_req| { + let expected = expected.clone(); + Box::pin(async move { Ok(expected) }) + }); + + let response = llm_call_execute( + LlmCallExecuteParams::builder() + .name("openai") + .request(make_chat_request(false)) + .func(func) + .attributes(LlmAttributes::empty()) + .response_codec(Arc::new(OpenAIChatCodec) as Arc) + .build(), + ) + .await + .unwrap(); + + assert_eq!(response, json!({"response": "original"})); } From 09f3d21d1248875c8753b0255c43faf533ea45dd Mon Sep 17 00:00:00 2001 From: Alex Fournier Date: Fri, 22 May 2026 08:59:14 -0400 Subject: [PATCH 2/6] refactor: rename guardrails component module Signed-off-by: Alex Fournier --- crates/core/src/plugin.rs | 2 +- .../nemo_guardrails/{plugin_component.rs => component.rs} | 2 +- crates/core/src/plugins/nemo_guardrails/mod.rs | 2 +- .../{plugin_component_tests.rs => component_tests.rs} | 0 4 files changed, 3 insertions(+), 3 deletions(-) rename crates/core/src/plugins/nemo_guardrails/{plugin_component.rs => component.rs} (99%) rename crates/core/tests/unit/plugins/nemo_guardrails/{plugin_component_tests.rs => component_tests.rs} (100%) diff --git a/crates/core/src/plugin.rs b/crates/core/src/plugin.rs index d2ff0bdf..4d6c5a51 100644 --- a/crates/core/src/plugin.rs +++ b/crates/core/src/plugin.rs @@ -764,7 +764,7 @@ pub fn register_plugin(plugin: Arc) -> Result<()> { pub fn ensure_builtin_plugins_registered() -> Result<()> { let register_builtins = || { crate::observability::plugin_component::register_observability_component()?; - crate::plugins::nemo_guardrails::plugin_component::register_nemo_guardrails_component() + crate::plugins::nemo_guardrails::component::register_nemo_guardrails_component() }; match BUILTIN_PLUGIN_REGISTRATION.get_or_init(register_builtins) { Ok(()) => Ok(()), diff --git a/crates/core/src/plugins/nemo_guardrails/plugin_component.rs b/crates/core/src/plugins/nemo_guardrails/component.rs similarity index 99% rename from crates/core/src/plugins/nemo_guardrails/plugin_component.rs rename to crates/core/src/plugins/nemo_guardrails/component.rs index b0de4ed8..94a68e30 100644 --- a/crates/core/src/plugins/nemo_guardrails/plugin_component.rs +++ b/crates/core/src/plugins/nemo_guardrails/component.rs @@ -2174,5 +2174,5 @@ fn default_timeout_millis() -> u64 { } #[cfg(test)] -#[path = "../../../tests/unit/plugins/nemo_guardrails/plugin_component_tests.rs"] +#[path = "../../../tests/unit/plugins/nemo_guardrails/component_tests.rs"] mod tests; diff --git a/crates/core/src/plugins/nemo_guardrails/mod.rs b/crates/core/src/plugins/nemo_guardrails/mod.rs index 9a7689d8..05136350 100644 --- a/crates/core/src/plugins/nemo_guardrails/mod.rs +++ b/crates/core/src/plugins/nemo_guardrails/mod.rs @@ -11,4 +11,4 @@ pub(crate) fn test_mutex() -> &'static Mutex<()> { crate::shared_runtime::runtime_owner_test_mutex() } -pub mod plugin_component; +pub mod component; diff --git a/crates/core/tests/unit/plugins/nemo_guardrails/plugin_component_tests.rs b/crates/core/tests/unit/plugins/nemo_guardrails/component_tests.rs similarity index 100% rename from crates/core/tests/unit/plugins/nemo_guardrails/plugin_component_tests.rs rename to crates/core/tests/unit/plugins/nemo_guardrails/component_tests.rs From 9903e39dd2b0c30685332305df1c771a63576da7 Mon Sep 17 00:00:00 2001 From: Alex Fournier Date: Fri, 22 May 2026 09:12:39 -0400 Subject: [PATCH 3/6] refactor: split guardrails remote backend module Signed-off-by: Alex Fournier --- .../src/plugins/nemo_guardrails/component.rs | 943 +----------------- .../src/plugins/nemo_guardrails/remote.rs | 941 +++++++++++++++++ 2 files changed, 947 insertions(+), 937 deletions(-) create mode 100644 crates/core/src/plugins/nemo_guardrails/remote.rs diff --git a/crates/core/src/plugins/nemo_guardrails/component.rs b/crates/core/src/plugins/nemo_guardrails/component.rs index 94a68e30..5823a2da 100644 --- a/crates/core/src/plugins/nemo_guardrails/component.rs +++ b/crates/core/src/plugins/nemo_guardrails/component.rs @@ -7,39 +7,19 @@ use std::collections::HashMap; use std::future::Future; use std::pin::Pin; use std::sync::Arc; -#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] -use std::time::Duration; use serde::{Deserialize, Serialize}; -use serde_json::{Map, Value as Json, json}; -#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] -use tokio::sync::mpsc; -#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] -use tokio_stream::wrappers::ReceiverStream; - -#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] -use crate::api::llm::LlmRequest; -#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] -use crate::api::runtime::{LlmExecutionFn, LlmJsonStream, LlmStreamExecutionFn, ToolExecutionFn}; -#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] -use crate::api::scope::{EmitMarkEventParams, ScopeHandle, event, get_handle}; -#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] -use crate::codec::openai_chat::OpenAIChatCodec; -#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] -use crate::codec::streaming::SseEventDecoder; -#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] -use crate::codec::traits::LlmCodec; -#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] -use crate::error::FlowError; +use serde_json::{Map, Value as Json}; + use crate::plugin::{ ConfigDiagnostic, ConfigPolicy, DiagnosticLevel, Plugin, PluginComponentSpec, PluginError, PluginRegistrationContext, Result as PluginResult, UnsupportedBehavior, deregister_plugin, lookup_plugin, register_plugin, }; -#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] -use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; -#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] -use rustls::crypto::ring; + +#[path = "remote.rs"] +mod remote; +use remote::register_remote_backend; /// The plugin kind reserved for the planned first-party component. pub const NEMO_GUARDRAILS_PLUGIN_KIND: &str = "nemo_guardrails"; @@ -466,917 +446,6 @@ fn register_nemo_guardrails_backend( } } -#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] -#[derive(Clone)] -// PR 2 intentionally implements the first honest remote slice: -// OpenAI chat requests, non-streaming + streaming execution, managed tool -// input/output checks, broad request-defaults pass-through, and response -// passthrough from the Guardrails server. The local backend remains out of -// scope. -struct RemoteBackendRuntime { - endpoint: String, - client: reqwest::Client, - config_id: Option, - config_ids: Vec, - request_defaults: Option, -} - -#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] -#[derive(Clone, Copy)] -enum RemoteCheckKind { - Input, - Output, -} - -#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] -impl RemoteBackendRuntime { - fn new(config: &NeMoGuardrailsConfig, remote: &RemoteBackendConfig) -> PluginResult { - let endpoint = remote.endpoint.clone().ok_or_else(|| { - PluginError::InvalidConfig("remote.endpoint is required in remote mode".to_string()) - })?; - let mut default_headers = HeaderMap::new(); - for (name, value) in &remote.headers { - let header_name = HeaderName::from_bytes(name.as_bytes()).map_err(|err| { - PluginError::InvalidConfig(format!( - "remote.headers contains invalid header name '{name}': {err}" - )) - })?; - let header_value = HeaderValue::from_str(value).map_err(|err| { - PluginError::InvalidConfig(format!( - "remote.headers[{name}] has an invalid value: {err}" - )) - })?; - default_headers.insert(header_name, header_value); - } - - let _ = ring::default_provider().install_default(); - - let client = reqwest::Client::builder() - .default_headers(default_headers) - .timeout(Duration::from_millis(remote.timeout_millis)) - .build() - .map_err(|err| { - PluginError::RegistrationFailed(format!( - "failed to construct NeMo Guardrails remote client: {err}" - )) - })?; - - Ok(Self { - endpoint: endpoint.trim_end_matches('/').to_string(), - client, - config_id: remote.config_id.clone(), - config_ids: remote.config_ids.clone(), - request_defaults: config.request_defaults.clone(), - }) - } - - async fn execute(&self, request: LlmRequest, stream: bool) -> crate::error::Result { - let parent = get_handle().ok(); - self.emit_mark( - "nemo_guardrails.remote.start", - &parent, - remote_mark_data(stream, &self.config_id, &self.config_ids, None, None), - ); - let body = self - .build_request_body(&request, stream) - .inspect_err(|err| { - self.emit_mark( - "nemo_guardrails.remote.error", - &parent, - remote_mark_data( - stream, - &self.config_id, - &self.config_ids, - None, - Some(err.to_string()), - ), - ); - })?; - let serialized = serde_json::to_vec(&body).map_err(|err| { - self.emit_mark( - "nemo_guardrails.remote.error", - &parent, - remote_mark_data( - stream, - &self.config_id, - &self.config_ids, - None, - Some(format!("failed to serialize remote request body: {err}")), - ), - ); - FlowError::Internal(format!( - "nemo_guardrails failed to serialize remote request body: {err}" - )) - })?; - let response = self - .client - .post(self.chat_completions_url()) - .header(reqwest::header::CONTENT_TYPE, "application/json") - .body(serialized) - .send() - .await - .map_err(|err| { - self.emit_mark( - "nemo_guardrails.remote.error", - &parent, - remote_mark_data( - stream, - &self.config_id, - &self.config_ids, - None, - Some(format!("remote request failed: {err}")), - ), - ); - FlowError::Internal(format!("nemo_guardrails remote request failed: {err}")) - })?; - let status = response.status(); - let payload = response.text().await.map_err(|err| { - self.emit_mark( - "nemo_guardrails.remote.error", - &parent, - remote_mark_data( - stream, - &self.config_id, - &self.config_ids, - Some(status.as_u16()), - Some(format!("failed to read remote response body: {err}")), - ), - ); - FlowError::Internal(format!( - "nemo_guardrails failed to read remote response body: {err}" - )) - })?; - if !status.is_success() { - self.emit_mark( - "nemo_guardrails.remote.error", - &parent, - remote_mark_data( - stream, - &self.config_id, - &self.config_ids, - Some(status.as_u16()), - Some(payload.clone()), - ), - ); - return Err(FlowError::Internal(format!( - "nemo_guardrails remote request failed with status {status}: {payload}" - ))); - } - let response_json = serde_json::from_str(&payload).map_err(|err| { - self.emit_mark( - "nemo_guardrails.remote.error", - &parent, - remote_mark_data( - stream, - &self.config_id, - &self.config_ids, - Some(status.as_u16()), - Some(format!("failed to parse remote response JSON: {err}")), - ), - ); - FlowError::Internal(format!( - "nemo_guardrails failed to parse remote response JSON: {err}" - )) - })?; - self.emit_mark( - "nemo_guardrails.remote.end", - &parent, - remote_mark_data( - stream, - &self.config_id, - &self.config_ids, - Some(status.as_u16()), - None, - ), - ); - // Preserve the server's OpenAI-compatible chat response and nested - // guardrails payload verbatim in the first remote slice. - Ok(response_json) - } - - async fn execute_stream(&self, request: LlmRequest) -> crate::error::Result { - let parent = get_handle().ok(); - self.emit_mark( - "nemo_guardrails.remote.start", - &parent, - remote_mark_data(true, &self.config_id, &self.config_ids, None, None), - ); - let body = self.build_request_body(&request, true).inspect_err(|err| { - self.emit_mark( - "nemo_guardrails.remote.error", - &parent, - remote_mark_data( - true, - &self.config_id, - &self.config_ids, - None, - Some(err.to_string()), - ), - ); - })?; - let serialized = serde_json::to_vec(&body).map_err(|err| { - self.emit_mark( - "nemo_guardrails.remote.error", - &parent, - remote_mark_data( - true, - &self.config_id, - &self.config_ids, - None, - Some(format!( - "failed to serialize remote stream request body: {err}" - )), - ), - ); - FlowError::Internal(format!( - "nemo_guardrails failed to serialize remote stream request body: {err}" - )) - })?; - let mut response = self - .client - .post(self.chat_completions_url()) - .header(reqwest::header::CONTENT_TYPE, "application/json") - .body(serialized) - .send() - .await - .map_err(|err| { - self.emit_mark( - "nemo_guardrails.remote.error", - &parent, - remote_mark_data( - true, - &self.config_id, - &self.config_ids, - None, - Some(format!("remote stream request failed: {err}")), - ), - ); - FlowError::Internal(format!( - "nemo_guardrails remote stream request failed: {err}" - )) - })?; - let status = response.status(); - if !status.is_success() { - let payload = response.text().await.map_err(|err| { - self.emit_mark( - "nemo_guardrails.remote.error", - &parent, - remote_mark_data( - true, - &self.config_id, - &self.config_ids, - Some(status.as_u16()), - Some(format!("failed to read remote stream error body: {err}")), - ), - ); - FlowError::Internal(format!( - "nemo_guardrails failed to read remote stream error body: {err}" - )) - })?; - self.emit_mark( - "nemo_guardrails.remote.error", - &parent, - remote_mark_data( - true, - &self.config_id, - &self.config_ids, - Some(status.as_u16()), - Some(payload.clone()), - ), - ); - return Err(FlowError::Internal(format!( - "nemo_guardrails remote stream request failed with status {status}: {payload}" - ))); - } - - let (tx, rx) = mpsc::channel(16); - let parent_for_task = parent.clone(); - let config_id = self.config_id.clone(); - let config_ids = self.config_ids.clone(); - tokio::spawn(async move { - let mut decoder = SseEventDecoder::new(); - loop { - let bytes = match response.chunk().await { - Ok(Some(bytes)) => bytes, - Ok(None) => break, - Err(err) => { - emit_remote_mark( - "nemo_guardrails.remote.error", - &parent_for_task, - remote_mark_data( - true, - &config_id, - &config_ids, - Some(status.as_u16()), - Some(format!("failed to read remote stream chunk: {err}")), - ), - ); - let _ = tx - .send(Err(FlowError::Internal(format!( - "nemo_guardrails failed to read remote stream chunk: {err}" - )))) - .await; - return; - } - }; - let events = match decoder.push_bytes(&bytes) { - Ok(events) => events, - Err(err) => { - emit_remote_mark( - "nemo_guardrails.remote.error", - &parent_for_task, - remote_mark_data( - true, - &config_id, - &config_ids, - Some(status.as_u16()), - Some(err.to_string()), - ), - ); - let _ = tx.send(Err(err)).await; - return; - } - }; - for event in events { - if tx.send(Ok(event.data)).await.is_err() { - return; - } - } - } - - match decoder.finish() { - Ok(Some(event)) => { - let _ = tx.send(Ok(event.data)).await; - } - Ok(None) => {} - Err(err) => { - emit_remote_mark( - "nemo_guardrails.remote.error", - &parent_for_task, - remote_mark_data( - true, - &config_id, - &config_ids, - Some(status.as_u16()), - Some(err.to_string()), - ), - ); - let _ = tx.send(Err(err)).await; - return; - } - } - - emit_remote_mark( - "nemo_guardrails.remote.end", - &parent_for_task, - remote_mark_data(true, &config_id, &config_ids, Some(status.as_u16()), None), - ); - }); - - Ok(Box::pin(ReceiverStream::new(rx)) as LlmJsonStream) - } - - async fn check_tool_input(&self, tool_name: &str, args: &Json) -> crate::error::Result { - let original_content = tool_input_content(tool_name, args); - let messages = vec![json!({"role": "user", "content": original_content.clone()})]; - let response = self - .execute_remote_check(messages, RemoteCheckKind::Input, tool_name) - .await?; - if let Some(blocking_rail) = blocking_rail_name(&response) { - return Err(FlowError::GuardrailRejected(format!( - "nemo_guardrails tool_input rail blocked tool call by rail '{blocking_rail}'" - ))); - } - - let result_content = chat_completion_content(&response)?; - if result_content == original_content { - return Ok(args.clone()); - } - - modified_tool_payload(&result_content, tool_name, "arguments") - } - - async fn check_tool_output( - &self, - tool_name: &str, - args: &Json, - result: &Json, - ) -> crate::error::Result { - let input_content = tool_input_content(tool_name, args); - let original_content = tool_output_content(tool_name, args, result); - let messages = vec![ - json!({"role": "user", "content": input_content}), - json!({"role": "assistant", "content": original_content.clone()}), - ]; - let response = self - .execute_remote_check(messages, RemoteCheckKind::Output, tool_name) - .await?; - if let Some(blocking_rail) = blocking_rail_name(&response) { - return Err(FlowError::GuardrailRejected(format!( - "nemo_guardrails tool_output rail blocked tool call by rail '{blocking_rail}'" - ))); - } - - let result_content = chat_completion_content(&response)?; - if result_content == original_content { - return Ok(result.clone()); - } - - modified_tool_payload(&result_content, tool_name, "result") - } - - fn build_request_body(&self, request: &LlmRequest, stream: bool) -> crate::error::Result { - // Remote mode currently accepts only NeMo Flow's OpenAI chat request - // shape and forwards it to the Guardrails server with a nested - // `guardrails` envelope. - let annotated = OpenAIChatCodec.decode(request)?; - if annotated.tools.is_some() || annotated.tool_choice.is_some() { - return Err(FlowError::Internal( - "nemo_guardrails remote backend does not support OpenAI tool definitions or tool_choice yet" - .to_string(), - )); - } - - let mut body = request.content.as_object().cloned().ok_or_else(|| { - FlowError::Internal("LLM request content is not a JSON object".to_string()) - })?; - body.insert("stream".to_string(), Json::Bool(stream)); - if let Some(guardrails) = self.build_guardrails_config() { - body.insert("guardrails".to_string(), Json::Object(guardrails)); - } - Ok(Json::Object(body)) - } - - fn build_guardrails_config(&self) -> Option> { - let mut guardrails = Map::new(); - if let Some(config_id) = &self.config_id { - guardrails.insert("config_id".to_string(), Json::String(config_id.clone())); - } - if !self.config_ids.is_empty() { - guardrails.insert( - "config_ids".to_string(), - Json::Array(self.config_ids.iter().cloned().map(Json::String).collect()), - ); - } - if let Some(request_defaults) = &self.request_defaults { - if let Some(context) = &request_defaults.context { - guardrails.insert("context".to_string(), context.clone()); - } - if let Some(thread_id) = &request_defaults.thread_id { - guardrails.insert("thread_id".to_string(), Json::String(thread_id.clone())); - } - if let Some(state) = &request_defaults.state { - guardrails.insert("state".to_string(), state.clone()); - } - let mut options = Map::new(); - if let Some(rails) = &request_defaults.rails { - options.insert( - "rails".to_string(), - serde_json::to_value(rails) - .expect("request rails config should serialize to JSON"), - ); - } - if let Some(llm_params) = &request_defaults.llm_params { - options.insert("llm_params".to_string(), llm_params.clone()); - } - if let Some(llm_output) = request_defaults.llm_output { - options.insert("llm_output".to_string(), Json::Bool(llm_output)); - } - if let Some(output_vars) = &request_defaults.output_vars { - options.insert("output_vars".to_string(), output_vars.clone()); - } - if let Some(log) = &request_defaults.log { - options.insert("log".to_string(), log.clone()); - } - if !options.is_empty() { - guardrails.insert("options".to_string(), Json::Object(options)); - } - } - - (!guardrails.is_empty()).then_some(guardrails) - } - - fn chat_completions_url(&self) -> String { - format!("{}/v1/chat/completions", self.endpoint) - } - - fn emit_mark(&self, name: &str, parent: &Option, data: Json) { - emit_remote_mark(name, parent, data); - } - - async fn execute_remote_check( - &self, - messages: Vec, - kind: RemoteCheckKind, - tool_name: &str, - ) -> crate::error::Result { - let parent = get_handle().ok(); - self.emit_mark( - "nemo_guardrails.remote.start", - &parent, - tool_remote_mark_data( - kind, - tool_name, - &self.config_id, - &self.config_ids, - None, - None, - ), - ); - let mut body = Map::new(); - body.insert("model".to_string(), Json::String(String::new())); - body.insert("messages".to_string(), Json::Array(messages)); - body.insert("stream".to_string(), Json::Bool(false)); - body.insert( - "guardrails".to_string(), - Json::Object(self.build_tool_check_guardrails(kind)), - ); - let serialized = serde_json::to_vec(&Json::Object(body)).map_err(|err| { - let message = format!("nemo_guardrails failed to serialize remote request body: {err}"); - self.emit_mark( - "nemo_guardrails.remote.error", - &parent, - tool_remote_mark_data( - kind, - tool_name, - &self.config_id, - &self.config_ids, - None, - Some(message.clone()), - ), - ); - FlowError::Internal(message) - })?; - let response = self - .client - .post(self.chat_completions_url()) - .header(reqwest::header::CONTENT_TYPE, "application/json") - .body(serialized) - .send() - .await - .map_err(|err| { - let message = format!("nemo_guardrails remote request failed: {err}"); - self.emit_mark( - "nemo_guardrails.remote.error", - &parent, - tool_remote_mark_data( - kind, - tool_name, - &self.config_id, - &self.config_ids, - None, - Some(message.clone()), - ), - ); - FlowError::Internal(message) - })?; - let status = response.status(); - let payload = response.text().await.map_err(|err| { - let message = format!("nemo_guardrails failed to read remote response body: {err}"); - self.emit_mark( - "nemo_guardrails.remote.error", - &parent, - tool_remote_mark_data( - kind, - tool_name, - &self.config_id, - &self.config_ids, - Some(status.as_u16()), - Some(message.clone()), - ), - ); - FlowError::Internal(message) - })?; - if !status.is_success() { - self.emit_mark( - "nemo_guardrails.remote.error", - &parent, - tool_remote_mark_data( - kind, - tool_name, - &self.config_id, - &self.config_ids, - Some(status.as_u16()), - Some(payload.clone()), - ), - ); - return Err(FlowError::Internal(format!( - "nemo_guardrails remote request failed with status {status}: {payload}" - ))); - } - let response_json = serde_json::from_str(&payload).map_err(|err| { - let message = format!("nemo_guardrails failed to parse remote response JSON: {err}"); - self.emit_mark( - "nemo_guardrails.remote.error", - &parent, - tool_remote_mark_data( - kind, - tool_name, - &self.config_id, - &self.config_ids, - Some(status.as_u16()), - Some(message.clone()), - ), - ); - FlowError::Internal(message) - })?; - self.emit_mark( - "nemo_guardrails.remote.end", - &parent, - tool_remote_mark_data( - kind, - tool_name, - &self.config_id, - &self.config_ids, - Some(status.as_u16()), - None, - ), - ); - Ok(response_json) - } - - fn build_tool_check_guardrails(&self, kind: RemoteCheckKind) -> Map { - let mut guardrails = Map::new(); - if let Some(config_id) = &self.config_id { - guardrails.insert("config_id".to_string(), Json::String(config_id.clone())); - } - if !self.config_ids.is_empty() { - guardrails.insert( - "config_ids".to_string(), - Json::Array(self.config_ids.iter().cloned().map(Json::String).collect()), - ); - } - if let Some(request_defaults) = &self.request_defaults { - if let Some(context) = &request_defaults.context { - guardrails.insert("context".to_string(), context.clone()); - } - if let Some(thread_id) = &request_defaults.thread_id { - guardrails.insert("thread_id".to_string(), Json::String(thread_id.clone())); - } - if let Some(state) = &request_defaults.state { - guardrails.insert("state".to_string(), state.clone()); - } - } - - let mut options = Map::new(); - let rails = match kind { - RemoteCheckKind::Input => json!({ - "input": true, - "output": false, - "dialog": false, - "retrieval": false, - "tool_input": false, - "tool_output": false, - }), - RemoteCheckKind::Output => json!({ - "input": false, - "output": true, - "dialog": false, - "retrieval": false, - "tool_input": false, - "tool_output": false, - }), - }; - options.insert("rails".to_string(), rails); - let mut log = self - .request_defaults - .as_ref() - .and_then(|defaults| defaults.log.as_ref()) - .and_then(Json::as_object) - .cloned() - .unwrap_or_default(); - log.insert("activated_rails".to_string(), Json::Bool(true)); - options.insert("log".to_string(), Json::Object(log)); - guardrails.insert("options".to_string(), Json::Object(options)); - guardrails - } -} - -#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] -fn emit_remote_mark(name: &str, parent: &Option, data: Json) { - let _ = event( - EmitMarkEventParams::builder() - .name(name) - .parent_opt(parent.as_ref()) - .data(data) - .build(), - ); -} - -#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] -fn tool_input_content(tool_name: &str, args: &Json) -> String { - serde_json::to_string(&json!({ - "tool_name": tool_name, - "arguments": args, - })) - .expect("tool input payload should serialize to JSON") -} - -#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] -fn tool_output_content(tool_name: &str, args: &Json, result: &Json) -> String { - serde_json::to_string(&json!({ - "tool_name": tool_name, - "arguments": args, - "result": result, - })) - .expect("tool output payload should serialize to JSON") -} - -#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] -fn modified_tool_payload( - content: &str, - expected_tool_name: &str, - field: &str, -) -> crate::error::Result { - let value: Json = serde_json::from_str(content).map_err(|err| { - FlowError::Internal(format!( - "nemo_guardrails returned modified tool {field} content that is not valid JSON: {err}" - )) - })?; - let Json::Object(object) = value else { - return Err(FlowError::Internal(format!( - "nemo_guardrails returned modified tool {field} content without a '{field}' field" - ))); - }; - if let Some(tool_name) = object.get("tool_name").and_then(Json::as_str) - && tool_name != expected_tool_name - { - return Err(FlowError::Internal(format!( - "nemo_guardrails returned modified tool {field} content for unexpected tool '{tool_name}'" - ))); - } - object.get(field).cloned().ok_or_else(|| { - FlowError::Internal(format!( - "nemo_guardrails returned modified tool {field} content without a '{field}' field" - )) - }) -} - -#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] -fn chat_completion_content(response: &Json) -> crate::error::Result { - response - .get("choices") - .and_then(Json::as_array) - .and_then(|choices| choices.first()) - .and_then(|choice| choice.get("message")) - .and_then(|message| message.get("content")) - .and_then(Json::as_str) - .map(str::to_string) - .ok_or_else(|| { - FlowError::Internal( - "nemo_guardrails remote response did not contain choices[0].message.content" - .to_string(), - ) - }) -} - -#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] -fn blocking_rail_name(response: &Json) -> Option { - response - .get("guardrails") - .and_then(|guardrails| guardrails.get("log")) - .and_then(|log| log.get("activated_rails")) - .and_then(Json::as_array) - .and_then(|activated| { - activated.iter().find_map(|rail| { - if rail.get("stop").and_then(Json::as_bool) == Some(true) { - rail.get("name").and_then(Json::as_str).map(str::to_string) - } else { - None - } - }) - }) -} - -#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] -fn remote_mark_data( - stream: bool, - config_id: &Option, - config_ids: &[String], - status: Option, - error: Option, -) -> Json { - let mut data = Map::new(); - data.insert("stream".to_string(), Json::Bool(stream)); - if let Some(config_id) = config_id { - data.insert("config_id".to_string(), Json::String(config_id.clone())); - } - if !config_ids.is_empty() { - data.insert( - "config_ids".to_string(), - Json::Array(config_ids.iter().cloned().map(Json::String).collect()), - ); - } - if let Some(status) = status { - data.insert( - "http_status".to_string(), - Json::Number(serde_json::Number::from(status)), - ); - } - if let Some(error) = error { - data.insert("error".to_string(), Json::String(error)); - } - Json::Object(data) -} - -#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] -fn tool_remote_mark_data( - kind: RemoteCheckKind, - tool_name: &str, - config_id: &Option, - config_ids: &[String], - status: Option, - error: Option, -) -> Json { - let mut data = match remote_mark_data(false, config_id, config_ids, status, error) { - Json::Object(data) => data, - _ => unreachable!("remote_mark_data always returns an object"), - }; - data.insert( - "surface".to_string(), - Json::String(match kind { - RemoteCheckKind::Input => "tool_input".to_string(), - RemoteCheckKind::Output => "tool_output".to_string(), - }), - ); - data.insert("tool_name".to_string(), Json::String(tool_name.to_string())); - Json::Object(data) -} - -#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] -fn register_remote_backend( - config: NeMoGuardrailsConfig, - ctx: &mut PluginRegistrationContext, -) -> PluginResult<()> { - let remote = config.remote.clone().ok_or_else(|| { - PluginError::InvalidConfig("remote config is required when mode is 'remote'".to_string()) - })?; - let runtime = Arc::new(RemoteBackendRuntime::new(&config, &remote)?); - - if config.input || config.output { - let llm_execution_runtime = Arc::clone(&runtime); - let llm_execution: LlmExecutionFn = Arc::new(move |_name, request, _next| { - let runtime = Arc::clone(&llm_execution_runtime); - Box::pin(async move { runtime.execute(request, false).await }) - }); - ctx.register_llm_execution_intercept("llm_remote_backend", config.priority, llm_execution)?; - - let llm_stream_runtime = Arc::clone(&runtime); - let llm_stream_execution: LlmStreamExecutionFn = Arc::new(move |_name, request, _next| { - let runtime = Arc::clone(&llm_stream_runtime); - Box::pin(async move { runtime.execute_stream(request).await }) - }); - ctx.register_llm_stream_execution_intercept( - "llm_stream_remote_backend", - config.priority, - llm_stream_execution, - )?; - } - - if config.tool_input || config.tool_output { - let tool_runtime = Arc::clone(&runtime); - let enable_tool_input = config.tool_input; - let enable_tool_output = config.tool_output; - let tool_execution: ToolExecutionFn = Arc::new(move |tool_name, args, next| { - let runtime = Arc::clone(&tool_runtime); - let tool_name = tool_name.to_string(); - Box::pin(async move { - let current_args = if enable_tool_input { - runtime.check_tool_input(&tool_name, &args).await? - } else { - args - }; - - let tool_result = next(current_args.clone()).await?; - if !enable_tool_output { - return Ok(tool_result); - } - - runtime - .check_tool_output(&tool_name, ¤t_args, &tool_result) - .await - }) - }); - ctx.register_tool_execution_intercept( - "tool_remote_backend", - config.priority, - tool_execution, - )?; - } - - Ok(()) -} - -#[cfg(any(target_arch = "wasm32", not(feature = "guardrails-remote")))] -fn register_remote_backend( - _config: NeMoGuardrailsConfig, - _ctx: &mut PluginRegistrationContext, -) -> PluginResult<()> { - Err(PluginError::RegistrationFailed( - "built-in NeMo Guardrails remote backend is unavailable in this build".to_string(), - )) -} - fn parse_nemo_guardrails_config( plugin_config: &Map, ) -> PluginResult { diff --git a/crates/core/src/plugins/nemo_guardrails/remote.rs b/crates/core/src/plugins/nemo_guardrails/remote.rs new file mode 100644 index 00000000..5f0a5f42 --- /dev/null +++ b/crates/core/src/plugins/nemo_guardrails/remote.rs @@ -0,0 +1,941 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +use std::sync::Arc; +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +use std::time::Duration; + +use serde_json::{Map, Value as Json, json}; +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +use tokio::sync::mpsc; +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +use tokio_stream::wrappers::ReceiverStream; + +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +use crate::api::llm::LlmRequest; +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +use crate::api::runtime::{LlmExecutionFn, LlmJsonStream, LlmStreamExecutionFn, ToolExecutionFn}; +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +use crate::api::scope::{EmitMarkEventParams, ScopeHandle, event, get_handle}; +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +use crate::codec::openai_chat::OpenAIChatCodec; +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +use crate::codec::streaming::SseEventDecoder; +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +use crate::codec::traits::LlmCodec; +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +use crate::error::FlowError; +use crate::plugin::{PluginError, PluginRegistrationContext, Result as PluginResult}; +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +use rustls::crypto::ring; + +use super::{NeMoGuardrailsConfig, RemoteBackendConfig, RequestDefaultsConfig}; + +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +#[derive(Clone)] +// PR 2 intentionally implements the first honest remote slice: +// OpenAI chat requests, non-streaming + streaming execution, managed tool +// input/output checks, broad request-defaults pass-through, and response +// passthrough from the Guardrails server. The local backend remains out of +// scope. +struct RemoteBackendRuntime { + endpoint: String, + client: reqwest::Client, + config_id: Option, + config_ids: Vec, + request_defaults: Option, +} + +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +#[derive(Clone, Copy)] +enum RemoteCheckKind { + Input, + Output, +} + +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +impl RemoteBackendRuntime { + fn new(config: &NeMoGuardrailsConfig, remote: &RemoteBackendConfig) -> PluginResult { + let endpoint = remote.endpoint.clone().ok_or_else(|| { + PluginError::InvalidConfig("remote.endpoint is required in remote mode".to_string()) + })?; + let mut default_headers = HeaderMap::new(); + for (name, value) in &remote.headers { + let header_name = HeaderName::from_bytes(name.as_bytes()).map_err(|err| { + PluginError::InvalidConfig(format!( + "remote.headers contains invalid header name '{name}': {err}" + )) + })?; + let header_value = HeaderValue::from_str(value).map_err(|err| { + PluginError::InvalidConfig(format!( + "remote.headers[{name}] has an invalid value: {err}" + )) + })?; + default_headers.insert(header_name, header_value); + } + + let _ = ring::default_provider().install_default(); + + let client = reqwest::Client::builder() + .default_headers(default_headers) + .timeout(Duration::from_millis(remote.timeout_millis)) + .build() + .map_err(|err| { + PluginError::RegistrationFailed(format!( + "failed to construct NeMo Guardrails remote client: {err}" + )) + })?; + + Ok(Self { + endpoint: endpoint.trim_end_matches('/').to_string(), + client, + config_id: remote.config_id.clone(), + config_ids: remote.config_ids.clone(), + request_defaults: config.request_defaults.clone(), + }) + } + + async fn execute(&self, request: LlmRequest, stream: bool) -> crate::error::Result { + let parent = get_handle().ok(); + self.emit_mark( + "nemo_guardrails.remote.start", + &parent, + remote_mark_data(stream, &self.config_id, &self.config_ids, None, None), + ); + let body = self + .build_request_body(&request, stream) + .inspect_err(|err| { + self.emit_mark( + "nemo_guardrails.remote.error", + &parent, + remote_mark_data( + stream, + &self.config_id, + &self.config_ids, + None, + Some(err.to_string()), + ), + ); + })?; + let serialized = serde_json::to_vec(&body).map_err(|err| { + self.emit_mark( + "nemo_guardrails.remote.error", + &parent, + remote_mark_data( + stream, + &self.config_id, + &self.config_ids, + None, + Some(format!("failed to serialize remote request body: {err}")), + ), + ); + FlowError::Internal(format!( + "nemo_guardrails failed to serialize remote request body: {err}" + )) + })?; + let response = self + .client + .post(self.chat_completions_url()) + .header(reqwest::header::CONTENT_TYPE, "application/json") + .body(serialized) + .send() + .await + .map_err(|err| { + self.emit_mark( + "nemo_guardrails.remote.error", + &parent, + remote_mark_data( + stream, + &self.config_id, + &self.config_ids, + None, + Some(format!("remote request failed: {err}")), + ), + ); + FlowError::Internal(format!("nemo_guardrails remote request failed: {err}")) + })?; + let status = response.status(); + let payload = response.text().await.map_err(|err| { + self.emit_mark( + "nemo_guardrails.remote.error", + &parent, + remote_mark_data( + stream, + &self.config_id, + &self.config_ids, + Some(status.as_u16()), + Some(format!("failed to read remote response body: {err}")), + ), + ); + FlowError::Internal(format!( + "nemo_guardrails failed to read remote response body: {err}" + )) + })?; + if !status.is_success() { + self.emit_mark( + "nemo_guardrails.remote.error", + &parent, + remote_mark_data( + stream, + &self.config_id, + &self.config_ids, + Some(status.as_u16()), + Some(payload.clone()), + ), + ); + return Err(FlowError::Internal(format!( + "nemo_guardrails remote request failed with status {status}: {payload}" + ))); + } + let response_json = serde_json::from_str(&payload).map_err(|err| { + self.emit_mark( + "nemo_guardrails.remote.error", + &parent, + remote_mark_data( + stream, + &self.config_id, + &self.config_ids, + Some(status.as_u16()), + Some(format!("failed to parse remote response JSON: {err}")), + ), + ); + FlowError::Internal(format!( + "nemo_guardrails failed to parse remote response JSON: {err}" + )) + })?; + self.emit_mark( + "nemo_guardrails.remote.end", + &parent, + remote_mark_data( + stream, + &self.config_id, + &self.config_ids, + Some(status.as_u16()), + None, + ), + ); + Ok(response_json) + } + + async fn execute_stream(&self, request: LlmRequest) -> crate::error::Result { + let parent = get_handle().ok(); + self.emit_mark( + "nemo_guardrails.remote.start", + &parent, + remote_mark_data(true, &self.config_id, &self.config_ids, None, None), + ); + let body = self.build_request_body(&request, true).inspect_err(|err| { + self.emit_mark( + "nemo_guardrails.remote.error", + &parent, + remote_mark_data( + true, + &self.config_id, + &self.config_ids, + None, + Some(err.to_string()), + ), + ); + })?; + let serialized = serde_json::to_vec(&body).map_err(|err| { + self.emit_mark( + "nemo_guardrails.remote.error", + &parent, + remote_mark_data( + true, + &self.config_id, + &self.config_ids, + None, + Some(format!( + "failed to serialize remote stream request body: {err}" + )), + ), + ); + FlowError::Internal(format!( + "nemo_guardrails failed to serialize remote stream request body: {err}" + )) + })?; + let mut response = self + .client + .post(self.chat_completions_url()) + .header(reqwest::header::CONTENT_TYPE, "application/json") + .body(serialized) + .send() + .await + .map_err(|err| { + self.emit_mark( + "nemo_guardrails.remote.error", + &parent, + remote_mark_data( + true, + &self.config_id, + &self.config_ids, + None, + Some(format!("remote stream request failed: {err}")), + ), + ); + FlowError::Internal(format!( + "nemo_guardrails remote stream request failed: {err}" + )) + })?; + let status = response.status(); + if !status.is_success() { + let payload = response.text().await.map_err(|err| { + self.emit_mark( + "nemo_guardrails.remote.error", + &parent, + remote_mark_data( + true, + &self.config_id, + &self.config_ids, + Some(status.as_u16()), + Some(format!("failed to read remote stream error body: {err}")), + ), + ); + FlowError::Internal(format!( + "nemo_guardrails failed to read remote stream error body: {err}" + )) + })?; + self.emit_mark( + "nemo_guardrails.remote.error", + &parent, + remote_mark_data( + true, + &self.config_id, + &self.config_ids, + Some(status.as_u16()), + Some(payload.clone()), + ), + ); + return Err(FlowError::Internal(format!( + "nemo_guardrails remote stream request failed with status {status}: {payload}" + ))); + } + + let (tx, rx) = mpsc::channel(16); + let parent_for_task = parent.clone(); + let config_id = self.config_id.clone(); + let config_ids = self.config_ids.clone(); + tokio::spawn(async move { + let mut decoder = SseEventDecoder::new(); + loop { + let bytes = match response.chunk().await { + Ok(Some(bytes)) => bytes, + Ok(None) => break, + Err(err) => { + emit_remote_mark( + "nemo_guardrails.remote.error", + &parent_for_task, + remote_mark_data( + true, + &config_id, + &config_ids, + Some(status.as_u16()), + Some(format!("failed to read remote stream chunk: {err}")), + ), + ); + let _ = tx + .send(Err(FlowError::Internal(format!( + "nemo_guardrails failed to read remote stream chunk: {err}" + )))) + .await; + return; + } + }; + let events = match decoder.push_bytes(&bytes) { + Ok(events) => events, + Err(err) => { + emit_remote_mark( + "nemo_guardrails.remote.error", + &parent_for_task, + remote_mark_data( + true, + &config_id, + &config_ids, + Some(status.as_u16()), + Some(err.to_string()), + ), + ); + let _ = tx.send(Err(err)).await; + return; + } + }; + for event in events { + if tx.send(Ok(event.data)).await.is_err() { + return; + } + } + } + + match decoder.finish() { + Ok(Some(event)) => { + let _ = tx.send(Ok(event.data)).await; + } + Ok(None) => {} + Err(err) => { + emit_remote_mark( + "nemo_guardrails.remote.error", + &parent_for_task, + remote_mark_data( + true, + &config_id, + &config_ids, + Some(status.as_u16()), + Some(err.to_string()), + ), + ); + let _ = tx.send(Err(err)).await; + return; + } + } + + emit_remote_mark( + "nemo_guardrails.remote.end", + &parent_for_task, + remote_mark_data(true, &config_id, &config_ids, Some(status.as_u16()), None), + ); + }); + + Ok(Box::pin(ReceiverStream::new(rx)) as LlmJsonStream) + } + + async fn check_tool_input(&self, tool_name: &str, args: &Json) -> crate::error::Result { + let original_content = tool_input_content(tool_name, args); + let messages = vec![json!({"role": "user", "content": original_content.clone()})]; + let response = self + .execute_remote_check(messages, RemoteCheckKind::Input, tool_name) + .await?; + if let Some(blocking_rail) = blocking_rail_name(&response) { + return Err(FlowError::GuardrailRejected(format!( + "nemo_guardrails tool_input rail blocked tool call by rail '{blocking_rail}'" + ))); + } + + let result_content = chat_completion_content(&response)?; + if result_content == original_content { + return Ok(args.clone()); + } + + modified_tool_payload(&result_content, tool_name, "arguments") + } + + async fn check_tool_output( + &self, + tool_name: &str, + args: &Json, + result: &Json, + ) -> crate::error::Result { + let input_content = tool_input_content(tool_name, args); + let original_content = tool_output_content(tool_name, args, result); + let messages = vec![ + json!({"role": "user", "content": input_content}), + json!({"role": "assistant", "content": original_content.clone()}), + ]; + let response = self + .execute_remote_check(messages, RemoteCheckKind::Output, tool_name) + .await?; + if let Some(blocking_rail) = blocking_rail_name(&response) { + return Err(FlowError::GuardrailRejected(format!( + "nemo_guardrails tool_output rail blocked tool call by rail '{blocking_rail}'" + ))); + } + + let result_content = chat_completion_content(&response)?; + if result_content == original_content { + return Ok(result.clone()); + } + + modified_tool_payload(&result_content, tool_name, "result") + } + + fn build_request_body(&self, request: &LlmRequest, stream: bool) -> crate::error::Result { + let annotated = OpenAIChatCodec.decode(request)?; + if annotated.tools.is_some() || annotated.tool_choice.is_some() { + return Err(FlowError::Internal( + "nemo_guardrails remote backend does not support OpenAI tool definitions or tool_choice yet" + .to_string(), + )); + } + + let mut body = request.content.as_object().cloned().ok_or_else(|| { + FlowError::Internal("LLM request content is not a JSON object".to_string()) + })?; + body.insert("stream".to_string(), Json::Bool(stream)); + if let Some(guardrails) = self.build_guardrails_config() { + body.insert("guardrails".to_string(), Json::Object(guardrails)); + } + Ok(Json::Object(body)) + } + + fn build_guardrails_config(&self) -> Option> { + let mut guardrails = Map::new(); + if let Some(config_id) = &self.config_id { + guardrails.insert("config_id".to_string(), Json::String(config_id.clone())); + } + if !self.config_ids.is_empty() { + guardrails.insert( + "config_ids".to_string(), + Json::Array(self.config_ids.iter().cloned().map(Json::String).collect()), + ); + } + if let Some(request_defaults) = &self.request_defaults { + if let Some(context) = &request_defaults.context { + guardrails.insert("context".to_string(), context.clone()); + } + if let Some(thread_id) = &request_defaults.thread_id { + guardrails.insert("thread_id".to_string(), Json::String(thread_id.clone())); + } + if let Some(state) = &request_defaults.state { + guardrails.insert("state".to_string(), state.clone()); + } + let mut options = Map::new(); + if let Some(rails) = &request_defaults.rails { + options.insert( + "rails".to_string(), + serde_json::to_value(rails) + .expect("request rails config should serialize to JSON"), + ); + } + if let Some(llm_params) = &request_defaults.llm_params { + options.insert("llm_params".to_string(), llm_params.clone()); + } + if let Some(llm_output) = request_defaults.llm_output { + options.insert("llm_output".to_string(), Json::Bool(llm_output)); + } + if let Some(output_vars) = &request_defaults.output_vars { + options.insert("output_vars".to_string(), output_vars.clone()); + } + if let Some(log) = &request_defaults.log { + options.insert("log".to_string(), log.clone()); + } + if !options.is_empty() { + guardrails.insert("options".to_string(), Json::Object(options)); + } + } + + (!guardrails.is_empty()).then_some(guardrails) + } + + fn chat_completions_url(&self) -> String { + format!("{}/v1/chat/completions", self.endpoint) + } + + fn emit_mark(&self, name: &str, parent: &Option, data: Json) { + emit_remote_mark(name, parent, data); + } + + async fn execute_remote_check( + &self, + messages: Vec, + kind: RemoteCheckKind, + tool_name: &str, + ) -> crate::error::Result { + let parent = get_handle().ok(); + self.emit_mark( + "nemo_guardrails.remote.start", + &parent, + tool_remote_mark_data( + kind, + tool_name, + &self.config_id, + &self.config_ids, + None, + None, + ), + ); + let mut body = Map::new(); + body.insert("model".to_string(), Json::String(String::new())); + body.insert("messages".to_string(), Json::Array(messages)); + body.insert("stream".to_string(), Json::Bool(false)); + body.insert( + "guardrails".to_string(), + Json::Object(self.build_tool_check_guardrails(kind)), + ); + let serialized = serde_json::to_vec(&Json::Object(body)).map_err(|err| { + let message = format!("nemo_guardrails failed to serialize remote request body: {err}"); + self.emit_mark( + "nemo_guardrails.remote.error", + &parent, + tool_remote_mark_data( + kind, + tool_name, + &self.config_id, + &self.config_ids, + None, + Some(message.clone()), + ), + ); + FlowError::Internal(message) + })?; + let response = self + .client + .post(self.chat_completions_url()) + .header(reqwest::header::CONTENT_TYPE, "application/json") + .body(serialized) + .send() + .await + .map_err(|err| { + let message = format!("nemo_guardrails remote request failed: {err}"); + self.emit_mark( + "nemo_guardrails.remote.error", + &parent, + tool_remote_mark_data( + kind, + tool_name, + &self.config_id, + &self.config_ids, + None, + Some(message.clone()), + ), + ); + FlowError::Internal(message) + })?; + let status = response.status(); + let payload = response.text().await.map_err(|err| { + let message = format!("nemo_guardrails failed to read remote response body: {err}"); + self.emit_mark( + "nemo_guardrails.remote.error", + &parent, + tool_remote_mark_data( + kind, + tool_name, + &self.config_id, + &self.config_ids, + Some(status.as_u16()), + Some(message.clone()), + ), + ); + FlowError::Internal(message) + })?; + if !status.is_success() { + self.emit_mark( + "nemo_guardrails.remote.error", + &parent, + tool_remote_mark_data( + kind, + tool_name, + &self.config_id, + &self.config_ids, + Some(status.as_u16()), + Some(payload.clone()), + ), + ); + return Err(FlowError::Internal(format!( + "nemo_guardrails remote request failed with status {status}: {payload}" + ))); + } + let response_json = serde_json::from_str(&payload).map_err(|err| { + let message = format!("nemo_guardrails failed to parse remote response JSON: {err}"); + self.emit_mark( + "nemo_guardrails.remote.error", + &parent, + tool_remote_mark_data( + kind, + tool_name, + &self.config_id, + &self.config_ids, + Some(status.as_u16()), + Some(message.clone()), + ), + ); + FlowError::Internal(message) + })?; + self.emit_mark( + "nemo_guardrails.remote.end", + &parent, + tool_remote_mark_data( + kind, + tool_name, + &self.config_id, + &self.config_ids, + Some(status.as_u16()), + None, + ), + ); + Ok(response_json) + } + + fn build_tool_check_guardrails(&self, kind: RemoteCheckKind) -> Map { + let mut guardrails = Map::new(); + if let Some(config_id) = &self.config_id { + guardrails.insert("config_id".to_string(), Json::String(config_id.clone())); + } + if !self.config_ids.is_empty() { + guardrails.insert( + "config_ids".to_string(), + Json::Array(self.config_ids.iter().cloned().map(Json::String).collect()), + ); + } + if let Some(request_defaults) = &self.request_defaults { + if let Some(context) = &request_defaults.context { + guardrails.insert("context".to_string(), context.clone()); + } + if let Some(thread_id) = &request_defaults.thread_id { + guardrails.insert("thread_id".to_string(), Json::String(thread_id.clone())); + } + if let Some(state) = &request_defaults.state { + guardrails.insert("state".to_string(), state.clone()); + } + } + + let mut options = Map::new(); + let rails = match kind { + RemoteCheckKind::Input => json!({ + "input": true, + "output": false, + "dialog": false, + "retrieval": false, + "tool_input": false, + "tool_output": false, + }), + RemoteCheckKind::Output => json!({ + "input": false, + "output": true, + "dialog": false, + "retrieval": false, + "tool_input": false, + "tool_output": false, + }), + }; + options.insert("rails".to_string(), rails); + let mut log = self + .request_defaults + .as_ref() + .and_then(|defaults| defaults.log.as_ref()) + .and_then(Json::as_object) + .cloned() + .unwrap_or_default(); + log.insert("activated_rails".to_string(), Json::Bool(true)); + options.insert("log".to_string(), Json::Object(log)); + guardrails.insert("options".to_string(), Json::Object(options)); + guardrails + } +} + +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +fn emit_remote_mark(name: &str, parent: &Option, data: Json) { + let _ = event( + EmitMarkEventParams::builder() + .name(name) + .parent_opt(parent.as_ref()) + .data(data) + .build(), + ); +} + +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +fn tool_input_content(tool_name: &str, args: &Json) -> String { + serde_json::to_string(&json!({ + "tool_name": tool_name, + "arguments": args, + })) + .expect("tool input payload should serialize to JSON") +} + +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +fn tool_output_content(tool_name: &str, args: &Json, result: &Json) -> String { + serde_json::to_string(&json!({ + "tool_name": tool_name, + "arguments": args, + "result": result, + })) + .expect("tool output payload should serialize to JSON") +} + +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +fn modified_tool_payload( + content: &str, + expected_tool_name: &str, + field: &str, +) -> crate::error::Result { + let value: Json = serde_json::from_str(content).map_err(|err| { + FlowError::Internal(format!( + "nemo_guardrails returned modified tool {field} content that is not valid JSON: {err}" + )) + })?; + let Json::Object(object) = value else { + return Err(FlowError::Internal(format!( + "nemo_guardrails returned modified tool {field} content without a '{field}' field" + ))); + }; + if let Some(tool_name) = object.get("tool_name").and_then(Json::as_str) + && tool_name != expected_tool_name + { + return Err(FlowError::Internal(format!( + "nemo_guardrails returned modified tool {field} content for unexpected tool '{tool_name}'" + ))); + } + object.get(field).cloned().ok_or_else(|| { + FlowError::Internal(format!( + "nemo_guardrails returned modified tool {field} content without a '{field}' field" + )) + }) +} + +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +fn chat_completion_content(response: &Json) -> crate::error::Result { + response + .get("choices") + .and_then(Json::as_array) + .and_then(|choices| choices.first()) + .and_then(|choice| choice.get("message")) + .and_then(|message| message.get("content")) + .and_then(Json::as_str) + .map(str::to_string) + .ok_or_else(|| { + FlowError::Internal( + "nemo_guardrails remote response did not contain choices[0].message.content" + .to_string(), + ) + }) +} + +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +fn blocking_rail_name(response: &Json) -> Option { + response + .get("guardrails") + .and_then(|guardrails| guardrails.get("log")) + .and_then(|log| log.get("activated_rails")) + .and_then(Json::as_array) + .and_then(|activated| { + activated.iter().find_map(|rail| { + if rail.get("stop").and_then(Json::as_bool) == Some(true) { + rail.get("name").and_then(Json::as_str).map(str::to_string) + } else { + None + } + }) + }) +} + +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +fn remote_mark_data( + stream: bool, + config_id: &Option, + config_ids: &[String], + status: Option, + error: Option, +) -> Json { + let mut data = Map::new(); + data.insert("stream".to_string(), Json::Bool(stream)); + if let Some(config_id) = config_id { + data.insert("config_id".to_string(), Json::String(config_id.clone())); + } + if !config_ids.is_empty() { + data.insert( + "config_ids".to_string(), + Json::Array(config_ids.iter().cloned().map(Json::String).collect()), + ); + } + if let Some(status) = status { + data.insert( + "http_status".to_string(), + Json::Number(serde_json::Number::from(status)), + ); + } + if let Some(error) = error { + data.insert("error".to_string(), Json::String(error)); + } + Json::Object(data) +} + +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +fn tool_remote_mark_data( + kind: RemoteCheckKind, + tool_name: &str, + config_id: &Option, + config_ids: &[String], + status: Option, + error: Option, +) -> Json { + let mut data = match remote_mark_data(false, config_id, config_ids, status, error) { + Json::Object(data) => data, + _ => unreachable!("remote_mark_data always returns an object"), + }; + data.insert( + "surface".to_string(), + Json::String(match kind { + RemoteCheckKind::Input => "tool_input".to_string(), + RemoteCheckKind::Output => "tool_output".to_string(), + }), + ); + data.insert("tool_name".to_string(), Json::String(tool_name.to_string())); + Json::Object(data) +} + +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +pub(super) fn register_remote_backend( + config: NeMoGuardrailsConfig, + ctx: &mut PluginRegistrationContext, +) -> PluginResult<()> { + let remote = config.remote.clone().ok_or_else(|| { + PluginError::InvalidConfig("remote config is required when mode is 'remote'".to_string()) + })?; + let runtime = Arc::new(RemoteBackendRuntime::new(&config, &remote)?); + + if config.input || config.output { + let llm_execution_runtime = Arc::clone(&runtime); + let llm_execution: LlmExecutionFn = Arc::new(move |_name, request, _next| { + let runtime = Arc::clone(&llm_execution_runtime); + Box::pin(async move { runtime.execute(request, false).await }) + }); + ctx.register_llm_execution_intercept("llm_remote_backend", config.priority, llm_execution)?; + + let llm_stream_runtime = Arc::clone(&runtime); + let llm_stream_execution: LlmStreamExecutionFn = Arc::new(move |_name, request, _next| { + let runtime = Arc::clone(&llm_stream_runtime); + Box::pin(async move { runtime.execute_stream(request).await }) + }); + ctx.register_llm_stream_execution_intercept( + "llm_stream_remote_backend", + config.priority, + llm_stream_execution, + )?; + } + + if config.tool_input || config.tool_output { + let tool_runtime = Arc::clone(&runtime); + let enable_tool_input = config.tool_input; + let enable_tool_output = config.tool_output; + let tool_execution: ToolExecutionFn = Arc::new(move |tool_name, args, next| { + let runtime = Arc::clone(&tool_runtime); + let tool_name = tool_name.to_string(); + Box::pin(async move { + let current_args = if enable_tool_input { + runtime.check_tool_input(&tool_name, &args).await? + } else { + args + }; + + let tool_result = next(current_args.clone()).await?; + if !enable_tool_output { + return Ok(tool_result); + } + + runtime + .check_tool_output(&tool_name, ¤t_args, &tool_result) + .await + }) + }); + ctx.register_tool_execution_intercept( + "tool_remote_backend", + config.priority, + tool_execution, + )?; + } + + Ok(()) +} + +#[cfg(any(target_arch = "wasm32", not(feature = "guardrails-remote")))] +pub(super) fn register_remote_backend( + _config: NeMoGuardrailsConfig, + _ctx: &mut PluginRegistrationContext, +) -> PluginResult<()> { + Err(PluginError::RegistrationFailed( + "built-in NeMo Guardrails remote backend is unavailable in this build".to_string(), + )) +} From 41c7fe22fcd1dc1b54394c2b88944095cc1c2027 Mon Sep 17 00:00:00 2001 From: Alex Fournier Date: Fri, 22 May 2026 09:26:05 -0400 Subject: [PATCH 4/6] Tighten remote backend review diff Signed-off-by: Alex Fournier --- crates/core/Cargo.toml | 2 ++ crates/core/src/plugins/nemo_guardrails/remote.rs | 5 ----- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index 9d23c61c..6dcf8778 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -28,6 +28,7 @@ otel = [ "dep:opentelemetry-http", "dep:opentelemetry-otlp", "dep:opentelemetry_sdk", + "dep:reqwest", "dep:rustls", "dep:tonic", "dep:web-sys", @@ -43,6 +44,7 @@ openinference = [ "dep:opentelemetry-http", "dep:opentelemetry-otlp", "dep:opentelemetry_sdk", + "dep:reqwest", "dep:rustls", "dep:tonic", "dep:web-sys", diff --git a/crates/core/src/plugins/nemo_guardrails/remote.rs b/crates/core/src/plugins/nemo_guardrails/remote.rs index 5f0a5f42..df726782 100644 --- a/crates/core/src/plugins/nemo_guardrails/remote.rs +++ b/crates/core/src/plugins/nemo_guardrails/remote.rs @@ -36,11 +36,6 @@ use super::{NeMoGuardrailsConfig, RemoteBackendConfig, RequestDefaultsConfig}; #[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] #[derive(Clone)] -// PR 2 intentionally implements the first honest remote slice: -// OpenAI chat requests, non-streaming + streaming execution, managed tool -// input/output checks, broad request-defaults pass-through, and response -// passthrough from the Guardrails server. The local backend remains out of -// scope. struct RemoteBackendRuntime { endpoint: String, client: reqwest::Client, From 3c2276c15c878d35667647457ab9c50d94dfa124 Mon Sep 17 00:00:00 2001 From: Alex Fournier Date: Fri, 22 May 2026 09:32:40 -0400 Subject: [PATCH 5/6] Fix guardrails remote review nits Signed-off-by: Alex Fournier --- .../src/plugins/nemo_guardrails/component.rs | 57 +++++++++++-------- 1 file changed, 33 insertions(+), 24 deletions(-) diff --git a/crates/core/src/plugins/nemo_guardrails/component.rs b/crates/core/src/plugins/nemo_guardrails/component.rs index 5823a2da..4d25d77e 100644 --- a/crates/core/src/plugins/nemo_guardrails/component.rs +++ b/crates/core/src/plugins/nemo_guardrails/component.rs @@ -17,13 +17,25 @@ use crate::plugin::{ lookup_plugin, register_plugin, }; +#[cfg(all(feature = "guardrails-remote", not(target_arch = "wasm32")))] #[path = "remote.rs"] mod remote; +#[cfg(all(feature = "guardrails-remote", not(target_arch = "wasm32")))] use remote::register_remote_backend; /// The plugin kind reserved for the planned first-party component. pub const NEMO_GUARDRAILS_PLUGIN_KIND: &str = "nemo_guardrails"; +#[cfg(any(target_arch = "wasm32", not(feature = "guardrails-remote")))] +fn register_remote_backend( + _config: NeMoGuardrailsConfig, + _ctx: &mut PluginRegistrationContext, +) -> PluginResult<()> { + Err(PluginError::RegistrationFailed( + "built-in NeMo Guardrails remote backend is unavailable in this build".to_string(), + )) +} + /// Top-level NeMo Guardrails component wrapper. #[derive(Debug, Clone)] pub struct ComponentSpec { @@ -941,30 +953,27 @@ fn validate_request_defaults( "request_defaults.context", "request_defaults.context must be a JSON object", ); - if let Some(thread_id) = &request_defaults.thread_id - && thread_id.trim().is_empty() - { - push_policy_diag( - diagnostics, - policy.unsupported_value, - "nemo_guardrails.unsupported_value", - Some(NEMO_GUARDRAILS_PLUGIN_KIND.to_string()), - Some("request_defaults.thread_id".to_string()), - "request_defaults.thread_id must not be empty".to_string(), - ); - } - if let Some(thread_id) = &request_defaults.thread_id - && !thread_id.trim().is_empty() - && thread_id.len() < 16 - { - push_policy_diag( - diagnostics, - policy.unsupported_value, - "nemo_guardrails.unsupported_value", - Some(NEMO_GUARDRAILS_PLUGIN_KIND.to_string()), - Some("request_defaults.thread_id".to_string()), - "request_defaults.thread_id must be at least 16 characters long".to_string(), - ); + if let Some(thread_id) = &request_defaults.thread_id { + let trimmed_thread_id = thread_id.trim(); + if trimmed_thread_id.is_empty() { + push_policy_diag( + diagnostics, + policy.unsupported_value, + "nemo_guardrails.unsupported_value", + Some(NEMO_GUARDRAILS_PLUGIN_KIND.to_string()), + Some("request_defaults.thread_id".to_string()), + "request_defaults.thread_id must not be empty".to_string(), + ); + } else if trimmed_thread_id.len() < 16 { + push_policy_diag( + diagnostics, + policy.unsupported_value, + "nemo_guardrails.unsupported_value", + Some(NEMO_GUARDRAILS_PLUGIN_KIND.to_string()), + Some("request_defaults.thread_id".to_string()), + "request_defaults.thread_id must be at least 16 characters long".to_string(), + ); + } } validate_json_object_field( diagnostics, From 60999aca0f886f72d2bad9aafb6ec6fc7d113fd6 Mon Sep 17 00:00:00 2001 From: Alex Fournier Date: Fri, 22 May 2026 10:19:15 -0400 Subject: [PATCH 6/6] Address guardrails remote review feedback Signed-off-by: Alex Fournier --- .../src/plugins/nemo_guardrails/component.rs | 24 ++++---- .../src/plugins/nemo_guardrails/remote.rs | 22 ++++--- .../nemo_guardrails/component_tests.rs | 60 ++++++++++++------- 3 files changed, 66 insertions(+), 40 deletions(-) diff --git a/crates/core/src/plugins/nemo_guardrails/component.rs b/crates/core/src/plugins/nemo_guardrails/component.rs index 4d25d77e..1b16cfa1 100644 --- a/crates/core/src/plugins/nemo_guardrails/component.rs +++ b/crates/core/src/plugins/nemo_guardrails/component.rs @@ -986,18 +986,20 @@ fn validate_request_defaults( .state .as_ref() .and_then(|value| value.as_object()) - && !state.is_empty() - && !state.contains_key("events") - && !state.contains_key("state") { - push_policy_diag( - diagnostics, - policy.unsupported_value, - "nemo_guardrails.unsupported_value", - Some(NEMO_GUARDRAILS_PLUGIN_KIND.to_string()), - Some("request_defaults.state".to_string()), - "request_defaults.state must be empty or contain 'events' or 'state'".to_string(), - ); + let contains_supported_key = state.contains_key("events") || state.contains_key("state"); + let contains_unsupported_key = state.keys().any(|key| key != "events" && key != "state"); + if (!state.is_empty() && !contains_supported_key) || contains_unsupported_key { + push_policy_diag( + diagnostics, + policy.unsupported_value, + "nemo_guardrails.unsupported_value", + Some(NEMO_GUARDRAILS_PLUGIN_KIND.to_string()), + Some("request_defaults.state".to_string()), + "request_defaults.state must be empty or contain only 'events' or 'state'" + .to_string(), + ); + } } validate_json_object_field( diagnostics, diff --git a/crates/core/src/plugins/nemo_guardrails/remote.rs b/crates/core/src/plugins/nemo_guardrails/remote.rs index df726782..ac8d22a8 100644 --- a/crates/core/src/plugins/nemo_guardrails/remote.rs +++ b/crates/core/src/plugins/nemo_guardrails/remote.rs @@ -178,7 +178,7 @@ impl RemoteBackendRuntime { &self.config_id, &self.config_ids, Some(status.as_u16()), - Some(payload.clone()), + Some(redact_remote_error_payload(status.as_u16(), &payload)), ), ); return Err(FlowError::Internal(format!( @@ -302,7 +302,7 @@ impl RemoteBackendRuntime { &self.config_id, &self.config_ids, Some(status.as_u16()), - Some(payload.clone()), + Some(redact_remote_error_payload(status.as_u16(), &payload)), ), ); return Err(FlowError::Internal(format!( @@ -615,7 +615,7 @@ impl RemoteBackendRuntime { &self.config_id, &self.config_ids, Some(status.as_u16()), - Some(payload.clone()), + Some(redact_remote_error_payload(status.as_u16(), &payload)), ), ); return Err(FlowError::Internal(format!( @@ -679,20 +679,20 @@ impl RemoteBackendRuntime { let mut options = Map::new(); let rails = match kind { RemoteCheckKind::Input => json!({ - "input": true, + "input": false, "output": false, "dialog": false, "retrieval": false, - "tool_input": false, + "tool_input": true, "tool_output": false, }), RemoteCheckKind::Output => json!({ "input": false, - "output": true, + "output": false, "dialog": false, "retrieval": false, "tool_input": false, - "tool_output": false, + "tool_output": true, }), }; options.insert("rails".to_string(), rails); @@ -730,6 +730,14 @@ fn tool_input_content(tool_name: &str, args: &Json) -> String { .expect("tool input payload should serialize to JSON") } +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +fn redact_remote_error_payload(status: u16, payload: &str) -> String { + format!( + "remote request failed with status {status}; error body omitted from marks ({} bytes)", + payload.len() + ) +} + #[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] fn tool_output_content(tool_name: &str, args: &Json, result: &Json) -> String { serde_json::to_string(&json!({ diff --git a/crates/core/tests/unit/plugins/nemo_guardrails/component_tests.rs b/crates/core/tests/unit/plugins/nemo_guardrails/component_tests.rs index ba11c6ad..dd526020 100644 --- a/crates/core/tests/unit/plugins/nemo_guardrails/component_tests.rs +++ b/crates/core/tests/unit/plugins/nemo_guardrails/component_tests.rs @@ -11,6 +11,7 @@ use std::net::TcpListener; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, Mutex, mpsc}; use std::thread; +use std::time::Duration; use crate::api::event::Event; use crate::api::llm::{ @@ -37,6 +38,8 @@ use crate::plugin::{ use futures::StreamExt; use serde_json::json; +const TEST_TIMEOUT: Duration = Duration::from_secs(5); + fn reset_runtime() { let _ = clear_plugin_configuration(); crate::shared_runtime::reset_runtime_owner_for_tests(); @@ -191,6 +194,12 @@ fn header_value<'a>(headers_text: &'a str, header_name: &str) -> Option<&'a str> }) } +fn recv_captured_request(request_rx: &mpsc::Receiver) -> CapturedHttpRequest { + request_rx + .recv_timeout(TEST_TIMEOUT) + .expect("timed out waiting for captured HTTP request") +} + fn make_chat_request(stream: bool) -> LlmRequest { LlmRequest { headers: serde_json::Map::new(), @@ -749,7 +758,7 @@ fn invalid_shapes_and_values_are_reported() { ); assert!(invalid_request_defaults.diagnostics.iter().any(|diag| { diag.message - .contains("request_defaults.state must be empty or contain 'events' or 'state'") + .contains("request_defaults.state must be empty or contain only 'events' or 'state'") })); assert!( invalid_request_defaults @@ -943,7 +952,7 @@ async fn remote_initialization_installs_non_streaming_execution_intercept() { json!("server-state") ); - let captured = request_rx.recv().unwrap(); + let captured = recv_captured_request(&request_rx); assert_eq!(captured.path, "/v1/chat/completions"); assert!(captured.content_type.starts_with("application/json")); @@ -1061,7 +1070,7 @@ async fn remote_request_uses_config_ids_when_config_id_is_not_set() { .await .unwrap(); - let captured = request_rx.recv().unwrap(); + let captured = recv_captured_request(&request_rx); let request_json: Json = serde_json::from_slice(&captured.body).unwrap(); assert_eq!( request_json["guardrails"]["config_ids"], @@ -1136,7 +1145,10 @@ async fn remote_initialization_installs_stream_execution_intercept() { .unwrap(); let mut chunks = Vec::new(); - while let Some(chunk) = stream.next().await { + while let Some(chunk) = tokio::time::timeout(TEST_TIMEOUT, stream.next()) + .await + .expect("timed out waiting for remote stream chunk") + { chunks.push(chunk.unwrap()); } @@ -1145,7 +1157,7 @@ async fn remote_initialization_installs_stream_execution_intercept() { assert_eq!(chunks[0]["choices"][0]["delta"]["content"], json!("guard")); assert_eq!(chunks[1]["choices"][0]["delta"]["content"], json!("ed")); - let captured = request_rx.recv().unwrap(); + let captured = recv_captured_request(&request_rx); let request_json: Json = serde_json::from_slice(&captured.body).unwrap(); assert_eq!(request_json["stream"], json!(true)); assert_eq!( @@ -1246,7 +1258,7 @@ async fn remote_non_streaming_http_errors_are_reported_and_marked() { error_mark.data().unwrap()["error"] .as_str() .unwrap() - .contains("backend unavailable") + .contains("error body omitted from marks") ); deregister_subscriber("nemo-guardrails-remote-error-events").unwrap(); @@ -1341,7 +1353,7 @@ async fn remote_streaming_http_errors_are_reported_and_marked() { error_mark.data().unwrap()["error"] .as_str() .unwrap() - .contains("stream backend unavailable") + .contains("error body omitted from marks") ); deregister_subscriber("nemo-guardrails-remote-stream-error-events").unwrap(); @@ -1470,7 +1482,11 @@ async fn remote_streaming_malformed_chunk_is_reported_and_marked() { .await .unwrap(); - let error = stream.next().await.unwrap().unwrap_err(); + let error = tokio::time::timeout(TEST_TIMEOUT, stream.next()) + .await + .expect("timed out waiting for remote stream error") + .unwrap() + .unwrap_err(); match error { crate::error::FlowError::Internal(message) => { assert!(!message.is_empty()); @@ -1759,14 +1775,14 @@ async fn remote_tool_input_block_rejects_before_tool_execution() { other => panic!("unexpected error: {other}"), } - let captured = request_rx.recv().unwrap(); + let captured = recv_captured_request(&request_rx); let request_json: Json = serde_json::from_slice(&captured.body).unwrap(); assert_eq!( - request_json["guardrails"]["options"]["rails"]["input"], + request_json["guardrails"]["options"]["rails"]["tool_input"], json!(true) ); assert_eq!( - request_json["guardrails"]["options"]["rails"]["output"], + request_json["guardrails"]["options"]["rails"]["tool_output"], json!(false) ); @@ -1860,7 +1876,7 @@ async fn remote_tool_input_can_rewrite_tool_arguments() { assert_eq!(result, json!({"forecast": "sunny"})); assert_eq!(*seen_args.lock().unwrap(), Some(json!({"city": "Boston"}))); - let captured = request_rx.recv().unwrap(); + let captured = recv_captured_request(&request_rx); let request_json: Json = serde_json::from_slice(&captured.body).unwrap(); assert_eq!(request_json["messages"][0]["role"], json!("user")); } @@ -1932,14 +1948,14 @@ async fn remote_tool_output_can_rewrite_tool_result() { assert_eq!(result, json!({"forecast": "cloudy"})); - let captured = request_rx.recv().unwrap(); + let captured = recv_captured_request(&request_rx); let request_json: Json = serde_json::from_slice(&captured.body).unwrap(); assert_eq!( - request_json["guardrails"]["options"]["rails"]["input"], + request_json["guardrails"]["options"]["rails"]["tool_input"], json!(false) ); assert_eq!( - request_json["guardrails"]["options"]["rails"]["output"], + request_json["guardrails"]["options"]["rails"]["tool_output"], json!(true) ); } @@ -2307,19 +2323,19 @@ async fn remote_tool_input_and_output_run_in_order() { assert_eq!(*seen_args.lock().unwrap(), Some(json!({"city": "Boston"}))); assert_eq!(result, json!({"forecast": "cloudy"})); - let first_request = request_rx.recv().unwrap(); + let first_request = recv_captured_request(&request_rx); let first_request_json: Json = serde_json::from_slice(&first_request.body).unwrap(); assert_eq!(first_request_json["messages"][0]["role"], json!("user")); assert_eq!( - first_request_json["guardrails"]["options"]["rails"]["input"], + first_request_json["guardrails"]["options"]["rails"]["tool_input"], json!(true) ); assert_eq!( - first_request_json["guardrails"]["options"]["rails"]["output"], + first_request_json["guardrails"]["options"]["rails"]["tool_output"], json!(false) ); - let second_request = request_rx.recv().unwrap(); + let second_request = recv_captured_request(&request_rx); let second_request_json: Json = serde_json::from_slice(&second_request.body).unwrap(); assert_eq!(second_request_json["messages"][0]["role"], json!("user")); assert_eq!( @@ -2327,11 +2343,11 @@ async fn remote_tool_input_and_output_run_in_order() { json!("assistant") ); assert_eq!( - second_request_json["guardrails"]["options"]["rails"]["input"], + second_request_json["guardrails"]["options"]["rails"]["tool_input"], json!(false) ); assert_eq!( - second_request_json["guardrails"]["options"]["rails"]["output"], + second_request_json["guardrails"]["options"]["rails"]["tool_output"], json!(true) ); } @@ -2408,7 +2424,7 @@ async fn remote_tool_checks_forward_context_state_and_thread_id() { assert_eq!(result, json!({"forecast": "sunny"})); - let captured = request_rx.recv().unwrap(); + let captured = recv_captured_request(&request_rx); let request_json: Json = serde_json::from_slice(&captured.body).unwrap(); assert_eq!( request_json["guardrails"]["context"],