diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index 898b1614..6dcf8778 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -14,8 +14,12 @@ readme = "README.md" workspace = true [features] -default = ["otel", "openinference"] +default = ["otel", "openinference", "guardrails-remote"] schema = ["dep:schemars"] +guardrails-remote = [ + "dep:reqwest", + "dep:rustls", +] otel = [ "dep:async-trait", "dep:getrandom", diff --git a/crates/core/src/plugin.rs b/crates/core/src/plugin.rs index 1b48b267..4d6c5a51 100644 --- a/crates/core/src/plugin.rs +++ b/crates/core/src/plugin.rs @@ -762,9 +762,11 @@ pub fn register_plugin(plugin: Arc) -> Result<()> { /// Built-in plugins are available to validation and initialization without a /// binding or application-specific registration call. pub fn ensure_builtin_plugins_registered() -> Result<()> { - match BUILTIN_PLUGIN_REGISTRATION - .get_or_init(crate::observability::plugin_component::register_observability_component) - { + let register_builtins = || { + crate::observability::plugin_component::register_observability_component()?; + crate::plugins::nemo_guardrails::component::register_nemo_guardrails_component() + }; + match BUILTIN_PLUGIN_REGISTRATION.get_or_init(register_builtins) { Ok(()) => Ok(()), Err(err) => Err(clone_cached_plugin_error(err)), } diff --git a/crates/core/src/plugins/nemo_guardrails/plugin_component.rs b/crates/core/src/plugins/nemo_guardrails/component.rs similarity index 88% rename from crates/core/src/plugins/nemo_guardrails/plugin_component.rs rename to crates/core/src/plugins/nemo_guardrails/component.rs index 5617a743..1b16cfa1 100644 --- a/crates/core/src/plugins/nemo_guardrails/plugin_component.rs +++ b/crates/core/src/plugins/nemo_guardrails/component.rs @@ -17,9 +17,25 @@ use crate::plugin::{ lookup_plugin, register_plugin, }; +#[cfg(all(feature = "guardrails-remote", not(target_arch = "wasm32")))] +#[path = "remote.rs"] +mod remote; +#[cfg(all(feature = "guardrails-remote", not(target_arch = "wasm32")))] +use remote::register_remote_backend; + /// The plugin kind reserved for the planned first-party component. pub const NEMO_GUARDRAILS_PLUGIN_KIND: &str = "nemo_guardrails"; +#[cfg(any(target_arch = "wasm32", not(feature = "guardrails-remote")))] +fn register_remote_backend( + _config: NeMoGuardrailsConfig, + _ctx: &mut PluginRegistrationContext, +) -> PluginResult<()> { + Err(PluginError::RegistrationFailed( + "built-in NeMo Guardrails remote backend is unavailable in this build".to_string(), + )) +} + /// Top-level NeMo Guardrails component wrapper. #[derive(Debug, Clone)] pub struct ComponentSpec { @@ -182,6 +198,12 @@ pub struct RequestDefaultsConfig { /// Default context object passed into Guardrails requests. #[serde(default, skip_serializing_if = "Option::is_none")] pub context: Option, + /// Default remote thread identifier for continuation-aware requests. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub thread_id: Option, + /// Default remote Guardrails state payload for continuation-aware requests. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub state: Option, /// Default request-time rail selection. #[serde(default, skip_serializing_if = "Option::is_none")] pub rails: Option, @@ -307,6 +329,8 @@ crate::editor_config! { crate::editor_config! { impl RequestDefaultsConfig { context => { label: "context", kind: Json, optional: true }, + thread_id => { label: "thread_id", kind: String, optional: true }, + state => { label: "state", kind: Json, optional: true }, rails => { label: "rails", kind: Section, @@ -349,13 +373,13 @@ impl Plugin for NeMoGuardrailsPlugin { fn register<'a>( &'a self, - _plugin_config: &Map, - _ctx: &'a mut PluginRegistrationContext, + plugin_config: &Map, + ctx: &'a mut PluginRegistrationContext, ) -> Pin> + Send + 'a>> { - Box::pin(async { - Err(PluginError::RegistrationFailed( - "built-in NeMo Guardrails plugin backend is not implemented yet".to_string(), - )) + let parsed = parse_nemo_guardrails_config(plugin_config); + Box::pin(async move { + let config = parsed?; + register_nemo_guardrails_backend(config, ctx) }) } } @@ -419,6 +443,21 @@ fn string_enum_schema( schema.into() } +fn register_nemo_guardrails_backend( + config: NeMoGuardrailsConfig, + ctx: &mut PluginRegistrationContext, +) -> PluginResult<()> { + match config.mode.as_str() { + "remote" => register_remote_backend(config, ctx), + "local" => Err(PluginError::RegistrationFailed( + "built-in NeMo Guardrails local backend is not implemented yet".to_string(), + )), + other => Err(PluginError::InvalidConfig(format!( + "unsupported NeMo Guardrails mode '{other}'" + ))), + } +} + fn parse_nemo_guardrails_config( plugin_config: &Map, ) -> PluginResult { @@ -497,6 +536,8 @@ fn validate_nemo_guardrails_plugin_config( "request_defaults", &[ "context", + "thread_id", + "state", "rails", "llm_params", "llm_output", @@ -526,6 +567,7 @@ fn validate_nemo_guardrails_plugin_config( validate_config_shape(&mut diagnostics, &config.policy, &config); validate_codec_requirements(&mut diagnostics, &config.policy, &config); validate_surface_selection(&mut diagnostics, &config.policy, &config); + validate_remote_backend_support(&mut diagnostics, &config.policy, &config); validate_request_defaults(&mut diagnostics, &config.policy, &config); diagnostics @@ -869,6 +911,32 @@ fn validate_surface_selection( ); } +fn validate_remote_backend_support( + diagnostics: &mut Vec, + policy: &ConfigPolicy, + config: &NeMoGuardrailsConfig, +) { + if config.mode != "remote" { + return; + } + + if (config.input || config.output) + && config + .codec + .as_deref() + .is_some_and(|codec| codec != "openai_chat") + { + push_policy_diag( + diagnostics, + policy.unsupported_value, + "nemo_guardrails.unsupported_value", + Some(NEMO_GUARDRAILS_PLUGIN_KIND.to_string()), + Some("codec".to_string()), + "remote mode currently supports only codec = 'openai_chat'".to_string(), + ); + } +} + fn validate_request_defaults( diagnostics: &mut Vec, policy: &ConfigPolicy, @@ -885,6 +953,54 @@ fn validate_request_defaults( "request_defaults.context", "request_defaults.context must be a JSON object", ); + if let Some(thread_id) = &request_defaults.thread_id { + let trimmed_thread_id = thread_id.trim(); + if trimmed_thread_id.is_empty() { + push_policy_diag( + diagnostics, + policy.unsupported_value, + "nemo_guardrails.unsupported_value", + Some(NEMO_GUARDRAILS_PLUGIN_KIND.to_string()), + Some("request_defaults.thread_id".to_string()), + "request_defaults.thread_id must not be empty".to_string(), + ); + } else if trimmed_thread_id.len() < 16 { + push_policy_diag( + diagnostics, + policy.unsupported_value, + "nemo_guardrails.unsupported_value", + Some(NEMO_GUARDRAILS_PLUGIN_KIND.to_string()), + Some("request_defaults.thread_id".to_string()), + "request_defaults.thread_id must be at least 16 characters long".to_string(), + ); + } + } + validate_json_object_field( + diagnostics, + policy, + request_defaults.state.as_ref(), + "request_defaults.state", + "request_defaults.state must be a JSON object", + ); + if let Some(state) = request_defaults + .state + .as_ref() + .and_then(|value| value.as_object()) + { + let contains_supported_key = state.contains_key("events") || state.contains_key("state"); + let contains_unsupported_key = state.keys().any(|key| key != "events" && key != "state"); + if (!state.is_empty() && !contains_supported_key) || contains_unsupported_key { + push_policy_diag( + diagnostics, + policy.unsupported_value, + "nemo_guardrails.unsupported_value", + Some(NEMO_GUARDRAILS_PLUGIN_KIND.to_string()), + Some("request_defaults.state".to_string()), + "request_defaults.state must be empty or contain only 'events' or 'state'" + .to_string(), + ); + } + } validate_json_object_field( diagnostics, policy, @@ -1138,5 +1254,5 @@ fn default_timeout_millis() -> u64 { } #[cfg(test)] -#[path = "../../../tests/unit/plugins/nemo_guardrails/plugin_component_tests.rs"] +#[path = "../../../tests/unit/plugins/nemo_guardrails/component_tests.rs"] mod tests; diff --git a/crates/core/src/plugins/nemo_guardrails/mod.rs b/crates/core/src/plugins/nemo_guardrails/mod.rs index 9a7689d8..05136350 100644 --- a/crates/core/src/plugins/nemo_guardrails/mod.rs +++ b/crates/core/src/plugins/nemo_guardrails/mod.rs @@ -11,4 +11,4 @@ pub(crate) fn test_mutex() -> &'static Mutex<()> { crate::shared_runtime::runtime_owner_test_mutex() } -pub mod plugin_component; +pub mod component; diff --git a/crates/core/src/plugins/nemo_guardrails/remote.rs b/crates/core/src/plugins/nemo_guardrails/remote.rs new file mode 100644 index 00000000..ac8d22a8 --- /dev/null +++ b/crates/core/src/plugins/nemo_guardrails/remote.rs @@ -0,0 +1,944 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +use std::sync::Arc; +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +use std::time::Duration; + +use serde_json::{Map, Value as Json, json}; +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +use tokio::sync::mpsc; +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +use tokio_stream::wrappers::ReceiverStream; + +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +use crate::api::llm::LlmRequest; +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +use crate::api::runtime::{LlmExecutionFn, LlmJsonStream, LlmStreamExecutionFn, ToolExecutionFn}; +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +use crate::api::scope::{EmitMarkEventParams, ScopeHandle, event, get_handle}; +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +use crate::codec::openai_chat::OpenAIChatCodec; +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +use crate::codec::streaming::SseEventDecoder; +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +use crate::codec::traits::LlmCodec; +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +use crate::error::FlowError; +use crate::plugin::{PluginError, PluginRegistrationContext, Result as PluginResult}; +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +use rustls::crypto::ring; + +use super::{NeMoGuardrailsConfig, RemoteBackendConfig, RequestDefaultsConfig}; + +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +#[derive(Clone)] +struct RemoteBackendRuntime { + endpoint: String, + client: reqwest::Client, + config_id: Option, + config_ids: Vec, + request_defaults: Option, +} + +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +#[derive(Clone, Copy)] +enum RemoteCheckKind { + Input, + Output, +} + +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +impl RemoteBackendRuntime { + fn new(config: &NeMoGuardrailsConfig, remote: &RemoteBackendConfig) -> PluginResult { + let endpoint = remote.endpoint.clone().ok_or_else(|| { + PluginError::InvalidConfig("remote.endpoint is required in remote mode".to_string()) + })?; + let mut default_headers = HeaderMap::new(); + for (name, value) in &remote.headers { + let header_name = HeaderName::from_bytes(name.as_bytes()).map_err(|err| { + PluginError::InvalidConfig(format!( + "remote.headers contains invalid header name '{name}': {err}" + )) + })?; + let header_value = HeaderValue::from_str(value).map_err(|err| { + PluginError::InvalidConfig(format!( + "remote.headers[{name}] has an invalid value: {err}" + )) + })?; + default_headers.insert(header_name, header_value); + } + + let _ = ring::default_provider().install_default(); + + let client = reqwest::Client::builder() + .default_headers(default_headers) + .timeout(Duration::from_millis(remote.timeout_millis)) + .build() + .map_err(|err| { + PluginError::RegistrationFailed(format!( + "failed to construct NeMo Guardrails remote client: {err}" + )) + })?; + + Ok(Self { + endpoint: endpoint.trim_end_matches('/').to_string(), + client, + config_id: remote.config_id.clone(), + config_ids: remote.config_ids.clone(), + request_defaults: config.request_defaults.clone(), + }) + } + + async fn execute(&self, request: LlmRequest, stream: bool) -> crate::error::Result { + let parent = get_handle().ok(); + self.emit_mark( + "nemo_guardrails.remote.start", + &parent, + remote_mark_data(stream, &self.config_id, &self.config_ids, None, None), + ); + let body = self + .build_request_body(&request, stream) + .inspect_err(|err| { + self.emit_mark( + "nemo_guardrails.remote.error", + &parent, + remote_mark_data( + stream, + &self.config_id, + &self.config_ids, + None, + Some(err.to_string()), + ), + ); + })?; + let serialized = serde_json::to_vec(&body).map_err(|err| { + self.emit_mark( + "nemo_guardrails.remote.error", + &parent, + remote_mark_data( + stream, + &self.config_id, + &self.config_ids, + None, + Some(format!("failed to serialize remote request body: {err}")), + ), + ); + FlowError::Internal(format!( + "nemo_guardrails failed to serialize remote request body: {err}" + )) + })?; + let response = self + .client + .post(self.chat_completions_url()) + .header(reqwest::header::CONTENT_TYPE, "application/json") + .body(serialized) + .send() + .await + .map_err(|err| { + self.emit_mark( + "nemo_guardrails.remote.error", + &parent, + remote_mark_data( + stream, + &self.config_id, + &self.config_ids, + None, + Some(format!("remote request failed: {err}")), + ), + ); + FlowError::Internal(format!("nemo_guardrails remote request failed: {err}")) + })?; + let status = response.status(); + let payload = response.text().await.map_err(|err| { + self.emit_mark( + "nemo_guardrails.remote.error", + &parent, + remote_mark_data( + stream, + &self.config_id, + &self.config_ids, + Some(status.as_u16()), + Some(format!("failed to read remote response body: {err}")), + ), + ); + FlowError::Internal(format!( + "nemo_guardrails failed to read remote response body: {err}" + )) + })?; + if !status.is_success() { + self.emit_mark( + "nemo_guardrails.remote.error", + &parent, + remote_mark_data( + stream, + &self.config_id, + &self.config_ids, + Some(status.as_u16()), + Some(redact_remote_error_payload(status.as_u16(), &payload)), + ), + ); + return Err(FlowError::Internal(format!( + "nemo_guardrails remote request failed with status {status}: {payload}" + ))); + } + let response_json = serde_json::from_str(&payload).map_err(|err| { + self.emit_mark( + "nemo_guardrails.remote.error", + &parent, + remote_mark_data( + stream, + &self.config_id, + &self.config_ids, + Some(status.as_u16()), + Some(format!("failed to parse remote response JSON: {err}")), + ), + ); + FlowError::Internal(format!( + "nemo_guardrails failed to parse remote response JSON: {err}" + )) + })?; + self.emit_mark( + "nemo_guardrails.remote.end", + &parent, + remote_mark_data( + stream, + &self.config_id, + &self.config_ids, + Some(status.as_u16()), + None, + ), + ); + Ok(response_json) + } + + async fn execute_stream(&self, request: LlmRequest) -> crate::error::Result { + let parent = get_handle().ok(); + self.emit_mark( + "nemo_guardrails.remote.start", + &parent, + remote_mark_data(true, &self.config_id, &self.config_ids, None, None), + ); + let body = self.build_request_body(&request, true).inspect_err(|err| { + self.emit_mark( + "nemo_guardrails.remote.error", + &parent, + remote_mark_data( + true, + &self.config_id, + &self.config_ids, + None, + Some(err.to_string()), + ), + ); + })?; + let serialized = serde_json::to_vec(&body).map_err(|err| { + self.emit_mark( + "nemo_guardrails.remote.error", + &parent, + remote_mark_data( + true, + &self.config_id, + &self.config_ids, + None, + Some(format!( + "failed to serialize remote stream request body: {err}" + )), + ), + ); + FlowError::Internal(format!( + "nemo_guardrails failed to serialize remote stream request body: {err}" + )) + })?; + let mut response = self + .client + .post(self.chat_completions_url()) + .header(reqwest::header::CONTENT_TYPE, "application/json") + .body(serialized) + .send() + .await + .map_err(|err| { + self.emit_mark( + "nemo_guardrails.remote.error", + &parent, + remote_mark_data( + true, + &self.config_id, + &self.config_ids, + None, + Some(format!("remote stream request failed: {err}")), + ), + ); + FlowError::Internal(format!( + "nemo_guardrails remote stream request failed: {err}" + )) + })?; + let status = response.status(); + if !status.is_success() { + let payload = response.text().await.map_err(|err| { + self.emit_mark( + "nemo_guardrails.remote.error", + &parent, + remote_mark_data( + true, + &self.config_id, + &self.config_ids, + Some(status.as_u16()), + Some(format!("failed to read remote stream error body: {err}")), + ), + ); + FlowError::Internal(format!( + "nemo_guardrails failed to read remote stream error body: {err}" + )) + })?; + self.emit_mark( + "nemo_guardrails.remote.error", + &parent, + remote_mark_data( + true, + &self.config_id, + &self.config_ids, + Some(status.as_u16()), + Some(redact_remote_error_payload(status.as_u16(), &payload)), + ), + ); + return Err(FlowError::Internal(format!( + "nemo_guardrails remote stream request failed with status {status}: {payload}" + ))); + } + + let (tx, rx) = mpsc::channel(16); + let parent_for_task = parent.clone(); + let config_id = self.config_id.clone(); + let config_ids = self.config_ids.clone(); + tokio::spawn(async move { + let mut decoder = SseEventDecoder::new(); + loop { + let bytes = match response.chunk().await { + Ok(Some(bytes)) => bytes, + Ok(None) => break, + Err(err) => { + emit_remote_mark( + "nemo_guardrails.remote.error", + &parent_for_task, + remote_mark_data( + true, + &config_id, + &config_ids, + Some(status.as_u16()), + Some(format!("failed to read remote stream chunk: {err}")), + ), + ); + let _ = tx + .send(Err(FlowError::Internal(format!( + "nemo_guardrails failed to read remote stream chunk: {err}" + )))) + .await; + return; + } + }; + let events = match decoder.push_bytes(&bytes) { + Ok(events) => events, + Err(err) => { + emit_remote_mark( + "nemo_guardrails.remote.error", + &parent_for_task, + remote_mark_data( + true, + &config_id, + &config_ids, + Some(status.as_u16()), + Some(err.to_string()), + ), + ); + let _ = tx.send(Err(err)).await; + return; + } + }; + for event in events { + if tx.send(Ok(event.data)).await.is_err() { + return; + } + } + } + + match decoder.finish() { + Ok(Some(event)) => { + let _ = tx.send(Ok(event.data)).await; + } + Ok(None) => {} + Err(err) => { + emit_remote_mark( + "nemo_guardrails.remote.error", + &parent_for_task, + remote_mark_data( + true, + &config_id, + &config_ids, + Some(status.as_u16()), + Some(err.to_string()), + ), + ); + let _ = tx.send(Err(err)).await; + return; + } + } + + emit_remote_mark( + "nemo_guardrails.remote.end", + &parent_for_task, + remote_mark_data(true, &config_id, &config_ids, Some(status.as_u16()), None), + ); + }); + + Ok(Box::pin(ReceiverStream::new(rx)) as LlmJsonStream) + } + + async fn check_tool_input(&self, tool_name: &str, args: &Json) -> crate::error::Result { + let original_content = tool_input_content(tool_name, args); + let messages = vec![json!({"role": "user", "content": original_content.clone()})]; + let response = self + .execute_remote_check(messages, RemoteCheckKind::Input, tool_name) + .await?; + if let Some(blocking_rail) = blocking_rail_name(&response) { + return Err(FlowError::GuardrailRejected(format!( + "nemo_guardrails tool_input rail blocked tool call by rail '{blocking_rail}'" + ))); + } + + let result_content = chat_completion_content(&response)?; + if result_content == original_content { + return Ok(args.clone()); + } + + modified_tool_payload(&result_content, tool_name, "arguments") + } + + async fn check_tool_output( + &self, + tool_name: &str, + args: &Json, + result: &Json, + ) -> crate::error::Result { + let input_content = tool_input_content(tool_name, args); + let original_content = tool_output_content(tool_name, args, result); + let messages = vec![ + json!({"role": "user", "content": input_content}), + json!({"role": "assistant", "content": original_content.clone()}), + ]; + let response = self + .execute_remote_check(messages, RemoteCheckKind::Output, tool_name) + .await?; + if let Some(blocking_rail) = blocking_rail_name(&response) { + return Err(FlowError::GuardrailRejected(format!( + "nemo_guardrails tool_output rail blocked tool call by rail '{blocking_rail}'" + ))); + } + + let result_content = chat_completion_content(&response)?; + if result_content == original_content { + return Ok(result.clone()); + } + + modified_tool_payload(&result_content, tool_name, "result") + } + + fn build_request_body(&self, request: &LlmRequest, stream: bool) -> crate::error::Result { + let annotated = OpenAIChatCodec.decode(request)?; + if annotated.tools.is_some() || annotated.tool_choice.is_some() { + return Err(FlowError::Internal( + "nemo_guardrails remote backend does not support OpenAI tool definitions or tool_choice yet" + .to_string(), + )); + } + + let mut body = request.content.as_object().cloned().ok_or_else(|| { + FlowError::Internal("LLM request content is not a JSON object".to_string()) + })?; + body.insert("stream".to_string(), Json::Bool(stream)); + if let Some(guardrails) = self.build_guardrails_config() { + body.insert("guardrails".to_string(), Json::Object(guardrails)); + } + Ok(Json::Object(body)) + } + + fn build_guardrails_config(&self) -> Option> { + let mut guardrails = Map::new(); + if let Some(config_id) = &self.config_id { + guardrails.insert("config_id".to_string(), Json::String(config_id.clone())); + } + if !self.config_ids.is_empty() { + guardrails.insert( + "config_ids".to_string(), + Json::Array(self.config_ids.iter().cloned().map(Json::String).collect()), + ); + } + if let Some(request_defaults) = &self.request_defaults { + if let Some(context) = &request_defaults.context { + guardrails.insert("context".to_string(), context.clone()); + } + if let Some(thread_id) = &request_defaults.thread_id { + guardrails.insert("thread_id".to_string(), Json::String(thread_id.clone())); + } + if let Some(state) = &request_defaults.state { + guardrails.insert("state".to_string(), state.clone()); + } + let mut options = Map::new(); + if let Some(rails) = &request_defaults.rails { + options.insert( + "rails".to_string(), + serde_json::to_value(rails) + .expect("request rails config should serialize to JSON"), + ); + } + if let Some(llm_params) = &request_defaults.llm_params { + options.insert("llm_params".to_string(), llm_params.clone()); + } + if let Some(llm_output) = request_defaults.llm_output { + options.insert("llm_output".to_string(), Json::Bool(llm_output)); + } + if let Some(output_vars) = &request_defaults.output_vars { + options.insert("output_vars".to_string(), output_vars.clone()); + } + if let Some(log) = &request_defaults.log { + options.insert("log".to_string(), log.clone()); + } + if !options.is_empty() { + guardrails.insert("options".to_string(), Json::Object(options)); + } + } + + (!guardrails.is_empty()).then_some(guardrails) + } + + fn chat_completions_url(&self) -> String { + format!("{}/v1/chat/completions", self.endpoint) + } + + fn emit_mark(&self, name: &str, parent: &Option, data: Json) { + emit_remote_mark(name, parent, data); + } + + async fn execute_remote_check( + &self, + messages: Vec, + kind: RemoteCheckKind, + tool_name: &str, + ) -> crate::error::Result { + let parent = get_handle().ok(); + self.emit_mark( + "nemo_guardrails.remote.start", + &parent, + tool_remote_mark_data( + kind, + tool_name, + &self.config_id, + &self.config_ids, + None, + None, + ), + ); + let mut body = Map::new(); + body.insert("model".to_string(), Json::String(String::new())); + body.insert("messages".to_string(), Json::Array(messages)); + body.insert("stream".to_string(), Json::Bool(false)); + body.insert( + "guardrails".to_string(), + Json::Object(self.build_tool_check_guardrails(kind)), + ); + let serialized = serde_json::to_vec(&Json::Object(body)).map_err(|err| { + let message = format!("nemo_guardrails failed to serialize remote request body: {err}"); + self.emit_mark( + "nemo_guardrails.remote.error", + &parent, + tool_remote_mark_data( + kind, + tool_name, + &self.config_id, + &self.config_ids, + None, + Some(message.clone()), + ), + ); + FlowError::Internal(message) + })?; + let response = self + .client + .post(self.chat_completions_url()) + .header(reqwest::header::CONTENT_TYPE, "application/json") + .body(serialized) + .send() + .await + .map_err(|err| { + let message = format!("nemo_guardrails remote request failed: {err}"); + self.emit_mark( + "nemo_guardrails.remote.error", + &parent, + tool_remote_mark_data( + kind, + tool_name, + &self.config_id, + &self.config_ids, + None, + Some(message.clone()), + ), + ); + FlowError::Internal(message) + })?; + let status = response.status(); + let payload = response.text().await.map_err(|err| { + let message = format!("nemo_guardrails failed to read remote response body: {err}"); + self.emit_mark( + "nemo_guardrails.remote.error", + &parent, + tool_remote_mark_data( + kind, + tool_name, + &self.config_id, + &self.config_ids, + Some(status.as_u16()), + Some(message.clone()), + ), + ); + FlowError::Internal(message) + })?; + if !status.is_success() { + self.emit_mark( + "nemo_guardrails.remote.error", + &parent, + tool_remote_mark_data( + kind, + tool_name, + &self.config_id, + &self.config_ids, + Some(status.as_u16()), + Some(redact_remote_error_payload(status.as_u16(), &payload)), + ), + ); + return Err(FlowError::Internal(format!( + "nemo_guardrails remote request failed with status {status}: {payload}" + ))); + } + let response_json = serde_json::from_str(&payload).map_err(|err| { + let message = format!("nemo_guardrails failed to parse remote response JSON: {err}"); + self.emit_mark( + "nemo_guardrails.remote.error", + &parent, + tool_remote_mark_data( + kind, + tool_name, + &self.config_id, + &self.config_ids, + Some(status.as_u16()), + Some(message.clone()), + ), + ); + FlowError::Internal(message) + })?; + self.emit_mark( + "nemo_guardrails.remote.end", + &parent, + tool_remote_mark_data( + kind, + tool_name, + &self.config_id, + &self.config_ids, + Some(status.as_u16()), + None, + ), + ); + Ok(response_json) + } + + fn build_tool_check_guardrails(&self, kind: RemoteCheckKind) -> Map { + let mut guardrails = Map::new(); + if let Some(config_id) = &self.config_id { + guardrails.insert("config_id".to_string(), Json::String(config_id.clone())); + } + if !self.config_ids.is_empty() { + guardrails.insert( + "config_ids".to_string(), + Json::Array(self.config_ids.iter().cloned().map(Json::String).collect()), + ); + } + if let Some(request_defaults) = &self.request_defaults { + if let Some(context) = &request_defaults.context { + guardrails.insert("context".to_string(), context.clone()); + } + if let Some(thread_id) = &request_defaults.thread_id { + guardrails.insert("thread_id".to_string(), Json::String(thread_id.clone())); + } + if let Some(state) = &request_defaults.state { + guardrails.insert("state".to_string(), state.clone()); + } + } + + let mut options = Map::new(); + let rails = match kind { + RemoteCheckKind::Input => json!({ + "input": false, + "output": false, + "dialog": false, + "retrieval": false, + "tool_input": true, + "tool_output": false, + }), + RemoteCheckKind::Output => json!({ + "input": false, + "output": false, + "dialog": false, + "retrieval": false, + "tool_input": false, + "tool_output": true, + }), + }; + options.insert("rails".to_string(), rails); + let mut log = self + .request_defaults + .as_ref() + .and_then(|defaults| defaults.log.as_ref()) + .and_then(Json::as_object) + .cloned() + .unwrap_or_default(); + log.insert("activated_rails".to_string(), Json::Bool(true)); + options.insert("log".to_string(), Json::Object(log)); + guardrails.insert("options".to_string(), Json::Object(options)); + guardrails + } +} + +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +fn emit_remote_mark(name: &str, parent: &Option, data: Json) { + let _ = event( + EmitMarkEventParams::builder() + .name(name) + .parent_opt(parent.as_ref()) + .data(data) + .build(), + ); +} + +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +fn tool_input_content(tool_name: &str, args: &Json) -> String { + serde_json::to_string(&json!({ + "tool_name": tool_name, + "arguments": args, + })) + .expect("tool input payload should serialize to JSON") +} + +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +fn redact_remote_error_payload(status: u16, payload: &str) -> String { + format!( + "remote request failed with status {status}; error body omitted from marks ({} bytes)", + payload.len() + ) +} + +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +fn tool_output_content(tool_name: &str, args: &Json, result: &Json) -> String { + serde_json::to_string(&json!({ + "tool_name": tool_name, + "arguments": args, + "result": result, + })) + .expect("tool output payload should serialize to JSON") +} + +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +fn modified_tool_payload( + content: &str, + expected_tool_name: &str, + field: &str, +) -> crate::error::Result { + let value: Json = serde_json::from_str(content).map_err(|err| { + FlowError::Internal(format!( + "nemo_guardrails returned modified tool {field} content that is not valid JSON: {err}" + )) + })?; + let Json::Object(object) = value else { + return Err(FlowError::Internal(format!( + "nemo_guardrails returned modified tool {field} content without a '{field}' field" + ))); + }; + if let Some(tool_name) = object.get("tool_name").and_then(Json::as_str) + && tool_name != expected_tool_name + { + return Err(FlowError::Internal(format!( + "nemo_guardrails returned modified tool {field} content for unexpected tool '{tool_name}'" + ))); + } + object.get(field).cloned().ok_or_else(|| { + FlowError::Internal(format!( + "nemo_guardrails returned modified tool {field} content without a '{field}' field" + )) + }) +} + +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +fn chat_completion_content(response: &Json) -> crate::error::Result { + response + .get("choices") + .and_then(Json::as_array) + .and_then(|choices| choices.first()) + .and_then(|choice| choice.get("message")) + .and_then(|message| message.get("content")) + .and_then(Json::as_str) + .map(str::to_string) + .ok_or_else(|| { + FlowError::Internal( + "nemo_guardrails remote response did not contain choices[0].message.content" + .to_string(), + ) + }) +} + +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +fn blocking_rail_name(response: &Json) -> Option { + response + .get("guardrails") + .and_then(|guardrails| guardrails.get("log")) + .and_then(|log| log.get("activated_rails")) + .and_then(Json::as_array) + .and_then(|activated| { + activated.iter().find_map(|rail| { + if rail.get("stop").and_then(Json::as_bool) == Some(true) { + rail.get("name").and_then(Json::as_str).map(str::to_string) + } else { + None + } + }) + }) +} + +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +fn remote_mark_data( + stream: bool, + config_id: &Option, + config_ids: &[String], + status: Option, + error: Option, +) -> Json { + let mut data = Map::new(); + data.insert("stream".to_string(), Json::Bool(stream)); + if let Some(config_id) = config_id { + data.insert("config_id".to_string(), Json::String(config_id.clone())); + } + if !config_ids.is_empty() { + data.insert( + "config_ids".to_string(), + Json::Array(config_ids.iter().cloned().map(Json::String).collect()), + ); + } + if let Some(status) = status { + data.insert( + "http_status".to_string(), + Json::Number(serde_json::Number::from(status)), + ); + } + if let Some(error) = error { + data.insert("error".to_string(), Json::String(error)); + } + Json::Object(data) +} + +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +fn tool_remote_mark_data( + kind: RemoteCheckKind, + tool_name: &str, + config_id: &Option, + config_ids: &[String], + status: Option, + error: Option, +) -> Json { + let mut data = match remote_mark_data(false, config_id, config_ids, status, error) { + Json::Object(data) => data, + _ => unreachable!("remote_mark_data always returns an object"), + }; + data.insert( + "surface".to_string(), + Json::String(match kind { + RemoteCheckKind::Input => "tool_input".to_string(), + RemoteCheckKind::Output => "tool_output".to_string(), + }), + ); + data.insert("tool_name".to_string(), Json::String(tool_name.to_string())); + Json::Object(data) +} + +#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))] +pub(super) fn register_remote_backend( + config: NeMoGuardrailsConfig, + ctx: &mut PluginRegistrationContext, +) -> PluginResult<()> { + let remote = config.remote.clone().ok_or_else(|| { + PluginError::InvalidConfig("remote config is required when mode is 'remote'".to_string()) + })?; + let runtime = Arc::new(RemoteBackendRuntime::new(&config, &remote)?); + + if config.input || config.output { + let llm_execution_runtime = Arc::clone(&runtime); + let llm_execution: LlmExecutionFn = Arc::new(move |_name, request, _next| { + let runtime = Arc::clone(&llm_execution_runtime); + Box::pin(async move { runtime.execute(request, false).await }) + }); + ctx.register_llm_execution_intercept("llm_remote_backend", config.priority, llm_execution)?; + + let llm_stream_runtime = Arc::clone(&runtime); + let llm_stream_execution: LlmStreamExecutionFn = Arc::new(move |_name, request, _next| { + let runtime = Arc::clone(&llm_stream_runtime); + Box::pin(async move { runtime.execute_stream(request).await }) + }); + ctx.register_llm_stream_execution_intercept( + "llm_stream_remote_backend", + config.priority, + llm_stream_execution, + )?; + } + + if config.tool_input || config.tool_output { + let tool_runtime = Arc::clone(&runtime); + let enable_tool_input = config.tool_input; + let enable_tool_output = config.tool_output; + let tool_execution: ToolExecutionFn = Arc::new(move |tool_name, args, next| { + let runtime = Arc::clone(&tool_runtime); + let tool_name = tool_name.to_string(); + Box::pin(async move { + let current_args = if enable_tool_input { + runtime.check_tool_input(&tool_name, &args).await? + } else { + args + }; + + let tool_result = next(current_args.clone()).await?; + if !enable_tool_output { + return Ok(tool_result); + } + + runtime + .check_tool_output(&tool_name, ¤t_args, &tool_result) + .await + }) + }); + ctx.register_tool_execution_intercept( + "tool_remote_backend", + config.priority, + tool_execution, + )?; + } + + Ok(()) +} + +#[cfg(any(target_arch = "wasm32", not(feature = "guardrails-remote")))] +pub(super) fn register_remote_backend( + _config: NeMoGuardrailsConfig, + _ctx: &mut PluginRegistrationContext, +) -> PluginResult<()> { + Err(PluginError::RegistrationFailed( + "built-in NeMo Guardrails remote backend is unavailable in this build".to_string(), + )) +} diff --git a/crates/core/tests/unit/plugins/nemo_guardrails/component_tests.rs b/crates/core/tests/unit/plugins/nemo_guardrails/component_tests.rs new file mode 100644 index 00000000..dd526020 --- /dev/null +++ b/crates/core/tests/unit/plugins/nemo_guardrails/component_tests.rs @@ -0,0 +1,2480 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Unit tests for the planned NeMo Guardrails plugin component contract. +#![allow(clippy::await_holding_lock)] + +use super::*; +use crate::api::runtime::NemoRelayContextState; +use std::io::{Read, Write}; +use std::net::TcpListener; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, Mutex, mpsc}; +use std::thread; +use std::time::Duration; + +use crate::api::event::Event; +use crate::api::llm::{ + LlmAttributes, LlmCallExecuteParams, LlmRequest, LlmStreamCallExecuteParams, llm_call_execute, + llm_stream_call_execute, +}; +use crate::api::runtime::global_context; +use crate::api::runtime::{ + LlmExecutionNextFn, LlmJsonStream, LlmStreamExecutionNextFn, create_scope_stack, + set_thread_scope_stack, +}; +use crate::api::subscriber::{deregister_subscriber, register_subscriber}; +use crate::api::tool::{ToolCallExecuteParams, tool_call_execute}; +use crate::codec::openai_chat::{OpenAIChatCodec, OpenAIChatStreamingCodec}; +use crate::codec::streaming::StreamingCodec; +use crate::codec::traits::LlmResponseCodec; +use crate::config_editor::{EditorConfig, EditorFieldKind}; +#[cfg(feature = "schema")] +use crate::plugin::plugin_config_schema; +use crate::plugin::{ + PluginComponentSpec, PluginConfig, clear_plugin_configuration, initialize_plugins, + list_plugin_kinds, lookup_plugin, validate_plugin_config, +}; +use futures::StreamExt; +use serde_json::json; + +const TEST_TIMEOUT: Duration = Duration::from_secs(5); + +fn reset_runtime() { + let _ = clear_plugin_configuration(); + crate::shared_runtime::reset_runtime_owner_for_tests(); + let context = global_context(); + *context.write().unwrap() = NemoRelayContextState::new(); +} + +fn setup_isolated_thread() { + let stack = create_scope_stack(); + set_thread_scope_stack(stack); +} + +fn component(config: Json) -> PluginComponentSpec { + let Json::Object(config) = config else { + panic!("component config must be an object"); + }; + PluginComponentSpec { + kind: NEMO_GUARDRAILS_PLUGIN_KIND.to_string(), + enabled: true, + config, + } +} + +fn disabled_component(config: Json) -> PluginComponentSpec { + let Json::Object(config) = config else { + panic!("component config must be an object"); + }; + PluginComponentSpec { + kind: NEMO_GUARDRAILS_PLUGIN_KIND.to_string(), + enabled: false, + config, + } +} + +fn plugin_config(config: Json) -> PluginConfig { + PluginConfig { + version: 1, + components: vec![component(config)], + policy: Default::default(), + } +} + +fn remote_valid_config() -> Json { + json!({ + "mode": "remote", + "codec": "openai_chat", + "remote": { + "endpoint": "http://localhost:8000", + "config_id": "safety-default" + } + }) +} + +#[derive(Debug)] +struct CapturedHttpRequest { + path: String, + content_type: String, + body: Vec, +} + +fn spawn_http_responder( + listener: TcpListener, + response: Vec, + request_tx: mpsc::Sender, +) { + thread::spawn(move || { + let (mut stream, _) = listener.accept().unwrap(); + let request = read_http_request(&mut stream); + stream.write_all(&response).unwrap(); + request_tx.send(request).unwrap(); + }); +} + +fn spawn_http_responder_sequence( + listener: TcpListener, + responses: Vec>, + request_tx: mpsc::Sender, +) { + thread::spawn(move || { + for response in responses { + let (mut stream, _) = listener.accept().unwrap(); + let request = read_http_request(&mut stream); + stream.write_all(&response).unwrap(); + request_tx.send(request).unwrap(); + } + }); +} + +fn read_http_request(stream: &mut impl Read) -> CapturedHttpRequest { + let mut bytes = Vec::new(); + let mut buf = [0_u8; 4096]; + let (header_end, content_length) = read_http_headers(stream, &mut bytes, &mut buf); + read_http_body(stream, &mut bytes, &mut buf, header_end + content_length); + + let headers_text = String::from_utf8_lossy(&bytes[..header_end]); + let request_line = headers_text.lines().next().unwrap(); + CapturedHttpRequest { + path: request_line.split_whitespace().nth(1).unwrap().to_string(), + content_type: header_value(&headers_text, "content-type") + .unwrap_or_default() + .to_string(), + body: bytes[header_end..header_end + content_length].to_vec(), + } +} + +fn read_http_headers( + stream: &mut impl Read, + bytes: &mut Vec, + buf: &mut [u8; 4096], +) -> (usize, usize) { + loop { + let read = stream.read(buf).unwrap(); + if read == 0 { + panic!("remote responder closed before receiving request"); + } + bytes.extend_from_slice(&buf[..read]); + + if let Some(header_end) = bytes.windows(4).position(|window| window == b"\r\n\r\n") { + let header_end = header_end + 4; + let headers_text = String::from_utf8_lossy(&bytes[..header_end]); + let content_length = header_value(&headers_text, "content-length") + .and_then(|value| value.parse::().ok()) + .unwrap_or(0); + return (header_end, content_length); + } + } +} + +fn read_http_body( + stream: &mut impl Read, + bytes: &mut Vec, + buf: &mut [u8; 4096], + expected_total: usize, +) { + while bytes.len() < expected_total { + let read = stream.read(buf).unwrap(); + if read == 0 { + panic!("remote responder closed before full request body"); + } + bytes.extend_from_slice(&buf[..read]); + } +} + +fn header_value<'a>(headers_text: &'a str, header_name: &str) -> Option<&'a str> { + headers_text.lines().find_map(|line| { + let (name, value) = line.split_once(':')?; + if name.eq_ignore_ascii_case(header_name) { + Some(value.trim()) + } else { + None + } + }) +} + +fn recv_captured_request(request_rx: &mpsc::Receiver) -> CapturedHttpRequest { + request_rx + .recv_timeout(TEST_TIMEOUT) + .expect("timed out waiting for captured HTTP request") +} + +fn make_chat_request(stream: bool) -> LlmRequest { + LlmRequest { + headers: serde_json::Map::new(), + content: json!({ + "model": "gpt-4o-mini", + "messages": [{"role": "user", "content": "hello"}], + "temperature": 0.2, + "stream": stream + }), + } +} + +fn capture_events(name: &str) -> Arc>> { + let events = Arc::new(Mutex::new(Vec::new())); + let sink = Arc::clone(&events); + register_subscriber( + name, + Arc::new(move |event| sink.lock().unwrap().push(event.clone())), + ) + .unwrap(); + events +} + +fn unused_local_endpoint() -> String { + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let address = listener.local_addr().unwrap(); + drop(listener); + format!("http://{address}") +} + +#[test] +fn editor_schema_tracks_nemo_guardrails_config_types() { + let schema = NeMoGuardrailsConfig::editor_schema(); + let mode = schema.field("mode").expect("mode field"); + assert_eq!(mode.kind, EditorFieldKind::Enum); + assert_eq!(mode.enum_values, &["remote", "local"]); + + let remote = schema.field("remote").expect("remote section"); + assert_eq!(remote.kind, EditorFieldKind::Section); + assert!(remote.optional); + + let remote_schema = remote.schema().expect("remote editor schema"); + let headers = remote_schema.field("headers").expect("headers field"); + assert_eq!(headers.kind, EditorFieldKind::StringMap); + + let request_defaults = schema + .field("request_defaults") + .expect("request_defaults section"); + assert_eq!(request_defaults.kind, EditorFieldKind::Section); + assert!(request_defaults.optional); + + let request_defaults_schema = request_defaults + .schema() + .expect("request_defaults editor schema"); + let rails = request_defaults_schema.field("rails").expect("rails field"); + assert_eq!(rails.kind, EditorFieldKind::Section); + + let rails_schema = rails.schema().expect("request rails editor schema"); + let retrieval = rails_schema.field("retrieval").expect("retrieval field"); + assert_eq!(retrieval.kind, EditorFieldKind::Json); +} + +#[test] +fn default_config_and_component_conversion_cover_public_shape() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + reset_runtime(); + + let defaults = NeMoGuardrailsConfig::default(); + assert_eq!(defaults.version, 1); + assert_eq!(defaults.mode, "remote"); + assert!(defaults.input); + assert!(defaults.output); + assert!(!defaults.tool_input); + assert!(!defaults.tool_output); + assert_eq!(defaults.priority, 100); + assert!(defaults.remote.is_none()); + assert!(defaults.local.is_none()); + assert!(defaults.request_defaults.is_none()); + + let remote = RemoteBackendConfig::default(); + assert_eq!(remote.timeout_millis, 3_000); + assert!(remote.headers.is_empty()); + assert!(remote.config_ids.is_empty()); + + let generic: PluginComponentSpec = ComponentSpec::new(NeMoGuardrailsConfig { + remote: Some(RemoteBackendConfig { + endpoint: Some("http://localhost:8000".into()), + config_id: Some("default".into()), + ..RemoteBackendConfig::default() + }), + ..NeMoGuardrailsConfig::default() + }) + .into(); + assert_eq!(generic.kind, NEMO_GUARDRAILS_PLUGIN_KIND); + assert!(generic.enabled); + assert_eq!(generic.config["mode"], json!("remote")); + assert_eq!(generic.config["remote"]["config_id"], json!("default")); +} + +#[cfg(feature = "schema")] +fn schema_has_property(schema: &Json, name: &str) -> bool { + schema_property(schema, name).is_some() +} + +#[cfg(feature = "schema")] +fn schema_property_has_enum(schema: &Json, name: &str, expected: &[&str]) -> bool { + schema_property(schema, name) + .and_then(|property| property.get("enum")) + .and_then(Json::as_array) + .is_some_and(|values| { + expected + .iter() + .all(|expected| values.iter().any(|value| value == *expected)) + }) +} + +#[cfg(feature = "schema")] +fn schema_property_has_default(schema: &Json, name: &str, expected: Json) -> bool { + schema_property(schema, name) + .and_then(|property| property.get("default")) + .is_some_and(|default| default == &expected) +} + +#[cfg(feature = "schema")] +fn schema_property<'a>(schema: &'a Json, name: &str) -> Option<&'a Json> { + match schema { + Json::Object(object) => { + if let Some(property) = object + .get("properties") + .and_then(Json::as_object) + .and_then(|properties| properties.get(name)) + { + return Some(property); + } + object + .values() + .find_map(|value| schema_property(value, name)) + } + Json::Array(values) => values.iter().find_map(|value| schema_property(value, name)), + _ => None, + } +} + +#[cfg(feature = "schema")] +#[test] +fn schema_contains_every_supported_nemo_guardrails_option() { + let schema = nemo_guardrails_config_schema(); + for field in [ + "version", + "mode", + "config_path", + "config_yaml", + "colang_content", + "codec", + "input", + "output", + "tool_input", + "tool_output", + "priority", + "remote", + "local", + "request_defaults", + "policy", + "endpoint", + "config_id", + "config_ids", + "headers", + "timeout_millis", + "python_module", + "context", + "thread_id", + "state", + "rails", + "llm_params", + "llm_output", + "output_vars", + "log", + "retrieval", + "dialog", + "unknown_component", + "unknown_field", + "unsupported_value", + ] { + assert!( + schema_has_property(&schema, field), + "schema missing property `{field}`:\n{}", + serde_json::to_string_pretty(&schema).unwrap() + ); + } + assert!(schema_property_has_enum( + &schema, + "mode", + &["remote", "local"] + )); + assert!(schema_property_has_enum( + &schema, + "codec", + &["openai_chat", "openai_responses", "anthropic_messages"] + )); + assert!(schema_property_has_default( + &schema, + "mode", + json!("remote") + )); +} + +#[cfg(feature = "schema")] +#[test] +fn plugin_schema_contains_generic_plugin_surface() { + let schema = plugin_config_schema(); + for field in [ + "version", + "components", + "policy", + "kind", + "enabled", + "config", + ] { + assert!( + schema_has_property(&schema, field), + "plugin schema missing property `{field}`" + ); + } +} + +#[test] +fn builtin_registration_is_automatic() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + reset_runtime(); + + assert!(list_plugin_kinds().contains(&NEMO_GUARDRAILS_PLUGIN_KIND.to_string())); + assert!(lookup_plugin(NEMO_GUARDRAILS_PLUGIN_KIND).is_some()); +} + +#[test] +fn disabled_component_validates_and_initializes_without_runtime_work() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + reset_runtime(); + + let config = PluginConfig { + version: 1, + components: vec![disabled_component(remote_valid_config())], + policy: Default::default(), + }; + assert!(!validate_plugin_config(&config).has_errors()); + futures::executor::block_on(initialize_plugins(config)).unwrap(); +} + +#[test] +fn duplicate_component_is_rejected_as_singleton() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + reset_runtime(); + + let config = PluginConfig { + version: 1, + components: vec![ + component(remote_valid_config()), + component(remote_valid_config()), + ], + policy: Default::default(), + }; + let report = validate_plugin_config(&config); + assert!(report.has_errors()); + assert!( + report + .diagnostics + .iter() + .any(|diag| diag.code == "plugin.duplicate_component") + ); +} + +#[test] +fn invalid_shapes_and_values_are_reported() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + reset_runtime(); + + let invalid_shape = validate_plugin_config(&plugin_config(json!({ + "version": "one", + }))); + assert!(invalid_shape.has_errors()); + assert!( + invalid_shape + .diagnostics + .iter() + .any(|diag| diag.code == "nemo_guardrails.invalid_plugin_config") + ); + + let local_missing_source = validate_plugin_config(&plugin_config(json!({ + "mode": "local", + "codec": "openai_chat", + }))); + assert!(local_missing_source.has_errors()); + assert!(local_missing_source.diagnostics.iter().any(|diag| { + diag.message + .contains("exactly one of config_path or config_yaml is required in local mode") + })); + + let local_bad_colang = validate_plugin_config(&plugin_config(json!({ + "mode": "local", + "config_path": "./rails", + "colang_content": "define flow x", + "codec": "openai_chat", + }))); + assert!(local_bad_colang.has_errors()); + assert!( + local_bad_colang + .diagnostics + .iter() + .any(|diag| diag.message.contains("colang_content can only be used")) + ); + + let remote_missing_identity = validate_plugin_config(&plugin_config(json!({ + "mode": "remote", + "codec": "openai_chat", + "remote": {"endpoint": "http://localhost:8000"}, + }))); + assert!(remote_missing_identity.has_errors()); + assert!(remote_missing_identity.diagnostics.iter().any(|diag| { + diag.message + .contains("remote mode requires remote.config_id or remote.config_ids") + })); + + let remote_conflicting_ids = validate_plugin_config(&plugin_config(json!({ + "mode": "remote", + "codec": "openai_chat", + "remote": { + "endpoint": "http://localhost:8000", + "config_id": "one", + "config_ids": ["two"] + }, + }))); + assert!(remote_conflicting_ids.has_errors()); + assert!(remote_conflicting_ids.diagnostics.iter().any(|diag| { + diag.message + .contains("remote.config_id and remote.config_ids cannot be used together") + })); + + let missing_codec = validate_plugin_config(&plugin_config(json!({ + "mode": "remote", + "remote": { + "endpoint": "http://localhost:8000", + "config_id": "default" + } + }))); + assert!(missing_codec.has_errors()); + assert!( + missing_codec + .diagnostics + .iter() + .any(|diag| diag.field.as_deref() == Some("codec")) + ); + + let bad_codec = validate_plugin_config(&plugin_config(json!({ + "mode": "remote", + "codec": "openai_agents", + "remote": { + "endpoint": "http://localhost:8000", + "config_id": "default" + } + }))); + assert!(bad_codec.has_errors()); + assert!(bad_codec.diagnostics.iter().any(|diag| { + diag.message + .contains("codec must be 'openai_chat', 'openai_responses', or 'anthropic_messages'") + })); + + let unsupported_remote_codec = validate_plugin_config(&plugin_config(json!({ + "mode": "remote", + "codec": "openai_responses", + "remote": { + "endpoint": "http://localhost:8000", + "config_id": "default" + } + }))); + assert!(unsupported_remote_codec.has_errors()); + assert!(unsupported_remote_codec.diagnostics.iter().any(|diag| { + diag.message + .contains("remote mode currently supports only codec = 'openai_chat'") + })); + + let unsupported_remote_anthropic_codec = validate_plugin_config(&plugin_config(json!({ + "mode": "remote", + "codec": "anthropic_messages", + "remote": { + "endpoint": "http://localhost:8000", + "config_id": "default" + } + }))); + assert!(unsupported_remote_anthropic_codec.has_errors()); + assert!( + unsupported_remote_anthropic_codec + .diagnostics + .iter() + .any(|diag| { + diag.message + .contains("remote mode currently supports only codec = 'openai_chat'") + }) + ); + + let supported_remote_tool_surface = validate_plugin_config(&plugin_config(json!({ + "mode": "remote", + "codec": "openai_chat", + "tool_input": true, + "remote": { + "endpoint": "http://localhost:8000", + "config_id": "default" + } + }))); + assert!(!supported_remote_tool_surface.has_errors()); + + let remote_empty_fields = validate_plugin_config(&plugin_config(json!({ + "mode": "remote", + "codec": "openai_chat", + "remote": { + "endpoint": "", + "config_id": "", + "config_ids": ["default", ""] + } + }))); + assert!(remote_empty_fields.has_errors()); + assert!( + remote_empty_fields + .diagnostics + .iter() + .any(|diag| diag.field.as_deref() == Some("remote.endpoint")) + ); + assert!( + remote_empty_fields + .diagnostics + .iter() + .any(|diag| diag.field.as_deref() == Some("remote.config_id")) + ); + assert!( + remote_empty_fields + .diagnostics + .iter() + .any(|diag| diag.field.as_deref() == Some("remote.config_ids[1]")) + ); + + let remote_local_mix = validate_plugin_config(&plugin_config(json!({ + "mode": "remote", + "config_path": "./rails", + "codec": "openai_chat", + "remote": { + "endpoint": "http://localhost:8000", + "config_id": "default" + }, + "local": {"python_module": "nemoguardrails"} + }))); + assert!(remote_local_mix.has_errors()); + assert!( + remote_local_mix + .diagnostics + .iter() + .any(|diag| diag.field.as_deref() == Some("local")) + ); + assert!(remote_local_mix.diagnostics.iter().any(|diag| { + diag.message + .contains("remote mode uses remote config identity") + })); + + let no_surfaces = validate_plugin_config(&plugin_config(json!({ + "mode": "local", + "config_path": "./rails", + "input": false, + "output": false, + "tool_input": false, + "tool_output": false + }))); + assert!(no_surfaces.has_errors()); + assert!( + no_surfaces + .diagnostics + .iter() + .any(|diag| diag.message.contains("at least one Guardrails surface")) + ); + + let local_empty_fields = validate_plugin_config(&plugin_config(json!({ + "mode": "local", + "config_yaml": "", + "colang_content": "", + "codec": "openai_chat", + "local": {"python_module": ""} + }))); + assert!(local_empty_fields.has_errors()); + assert!( + local_empty_fields + .diagnostics + .iter() + .any(|diag| diag.field.as_deref() == Some("config_yaml")) + ); + assert!( + local_empty_fields + .diagnostics + .iter() + .any(|diag| diag.field.as_deref() == Some("colang_content")) + ); + assert!( + local_empty_fields + .diagnostics + .iter() + .any(|diag| diag.field.as_deref() == Some("local.python_module")) + ); + + let invalid_request_defaults = validate_plugin_config(&plugin_config(json!({ + "mode": "remote", + "codec": "openai_chat", + "remote": { + "endpoint": "http://localhost:8000", + "config_id": "default" + }, + "request_defaults": { + "context": true, + "thread_id": "short", + "state": {"foo": "bar"}, + "llm_params": [], + "log": "verbose", + "output_vars": 7, + "rails": { + "retrieval": [""] + } + } + }))); + assert!(invalid_request_defaults.has_errors()); + assert!( + invalid_request_defaults + .diagnostics + .iter() + .any(|diag| diag.field.as_deref() == Some("request_defaults.context")) + ); + assert!( + invalid_request_defaults + .diagnostics + .iter() + .any(|diag| diag.field.as_deref() == Some("request_defaults.thread_id")) + ); + assert!(invalid_request_defaults.diagnostics.iter().any(|diag| { + diag.message + .contains("request_defaults.thread_id must be at least 16 characters long") + })); + assert!( + invalid_request_defaults + .diagnostics + .iter() + .any(|diag| diag.field.as_deref() == Some("request_defaults.state")) + ); + assert!(invalid_request_defaults.diagnostics.iter().any(|diag| { + diag.message + .contains("request_defaults.state must be empty or contain only 'events' or 'state'") + })); + assert!( + invalid_request_defaults + .diagnostics + .iter() + .any(|diag| diag.field.as_deref() == Some("request_defaults.llm_params")) + ); + assert!( + invalid_request_defaults + .diagnostics + .iter() + .any(|diag| diag.field.as_deref() == Some("request_defaults.log")) + ); + assert!( + invalid_request_defaults + .diagnostics + .iter() + .any(|diag| diag.field.as_deref() == Some("request_defaults.output_vars")) + ); + assert!( + invalid_request_defaults + .diagnostics + .iter() + .any(|diag| diag.field.as_deref() == Some("request_defaults.rails.retrieval[0]")) + ); +} + +#[test] +fn unknown_fields_follow_policy() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + reset_runtime(); + + let warn_report = validate_plugin_config(&plugin_config(json!({ + "mode": "remote", + "codec": "openai_chat", + "remote": {"endpoint": "http://localhost:8000", "config_id": "default"}, + "bogus": true + }))); + assert!( + warn_report + .diagnostics + .iter() + .any(|diag| diag.code == "nemo_guardrails.unknown_field") + ); + + let nested_warn_report = validate_plugin_config(&plugin_config(json!({ + "mode": "remote", + "codec": "openai_chat", + "remote": {"endpoint": "http://localhost:8000", "config_id": "default"}, + "request_defaults": { + "rails": { + "bogus": true + } + } + }))); + assert!( + nested_warn_report + .diagnostics + .iter() + .any(|diag| diag.component.as_deref() == Some("request_defaults.rails")) + ); + + let ignored = validate_plugin_config(&plugin_config(json!({ + "policy": {"unknown_field": "ignore", "unsupported_value": "ignore"}, + "mode": "remote", + "codec": "openai_chat", + "remote": {"endpoint": "http://localhost:8000", "config_id": "default"}, + "bogus": true + }))); + assert!(!ignored.has_errors()); + assert!(ignored.diagnostics.is_empty()); +} + +#[test] +fn enabled_local_initialization_fails_fast_until_backend_exists() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + reset_runtime(); + + let error = futures::executor::block_on(initialize_plugins(plugin_config(json!({ + "mode": "local", + "codec": "openai_chat", + "config_path": "./rails" + })))) + .unwrap_err(); + + match error { + crate::plugin::PluginError::RegistrationFailed(message) => { + assert!(message.contains("local backend")); + } + other => panic!("unexpected error: {other}"), + } +} + +#[tokio::test] +async fn remote_initialization_installs_non_streaming_execution_intercept() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + reset_runtime(); + setup_isolated_thread(); + let events = capture_events("nemo-guardrails-remote-execution-events"); + + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let address = listener.local_addr().unwrap(); + let (request_tx, request_rx) = mpsc::channel(); + let response_body = json!({ + "id": "chatcmpl-remote", + "object": "chat.completion", + "created": 1, + "model": "gpt-4o-mini", + "choices": [{ + "index": 0, + "message": {"role": "assistant", "content": "guarded"}, + "finish_reason": "stop" + }], + "guardrails": { + "config_id": "safety-default", + "state": {"state": {"conversation": "server-state"}}, + "output_data": {"decision": "allow"} + } + }) + .to_string(); + let http_response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + response_body.len(), + response_body + ) + .into_bytes(); + spawn_http_responder(listener, http_response, request_tx); + + initialize_plugins(plugin_config(json!({ + "mode": "remote", + "codec": "openai_chat", + "remote": { + "endpoint": format!("http://{address}"), + "config_id": "safety-default", + "headers": {"x-guardrails-auth": "token"}, + "timeout_millis": 5_000 + }, + "request_defaults": { + "context": {"tenant": "acme"}, + "thread_id": "thread-1234567890", + "state": {"state": {"conversation": "client-state"}}, + "rails": {"input": true, "retrieval": ["kb"]}, + "llm_params": {"temperature": 0.1}, + "llm_output": true, + "output_vars": ["answer"], + "log": {"activated_rails": true} + } + }))) + .await + .unwrap(); + + let original_called = Arc::new(AtomicBool::new(false)); + let called = Arc::clone(&original_called); + let func: LlmExecutionNextFn = Arc::new(move |_req| { + called.store(true, Ordering::SeqCst); + Box::pin(async move { Ok(json!({"response": "original"})) }) + }); + + let response = llm_call_execute( + LlmCallExecuteParams::builder() + .name("openai") + .request(make_chat_request(false)) + .func(func) + .attributes(LlmAttributes::empty()) + .response_codec(Arc::new(OpenAIChatCodec) as Arc) + .build(), + ) + .await + .unwrap(); + + assert!(!original_called.load(Ordering::SeqCst)); + assert_eq!(response["id"], json!("chatcmpl-remote")); + assert_eq!(response["object"], json!("chat.completion")); + assert_eq!(response["model"], json!("gpt-4o-mini")); + assert_eq!( + response["choices"][0]["message"]["content"], + json!("guarded") + ); + assert_eq!( + response["guardrails"]["output_data"]["decision"], + json!("allow") + ); + assert_eq!( + response["guardrails"]["state"]["state"]["conversation"], + json!("server-state") + ); + + let captured = recv_captured_request(&request_rx); + assert_eq!(captured.path, "/v1/chat/completions"); + assert!(captured.content_type.starts_with("application/json")); + + let request_json: Json = serde_json::from_slice(&captured.body).unwrap(); + assert_eq!(request_json["messages"][0]["content"], json!("hello")); + assert_eq!(request_json["stream"], json!(false)); + assert_eq!( + request_json["guardrails"]["config_id"], + json!("safety-default") + ); + assert_eq!( + request_json["guardrails"]["context"]["tenant"], + json!("acme") + ); + assert_eq!( + request_json["guardrails"]["thread_id"], + json!("thread-1234567890") + ); + assert_eq!( + request_json["guardrails"]["state"]["state"]["conversation"], + json!("client-state") + ); + assert_eq!( + request_json["guardrails"]["options"]["rails"]["retrieval"], + json!(["kb"]) + ); + assert_eq!( + request_json["guardrails"]["options"]["llm_output"], + json!(true) + ); + + let captured_events = events.lock().unwrap().clone(); + let mark_names: Vec<_> = captured_events + .iter() + .filter(|event| event.kind() == "mark") + .map(|event| event.name().to_string()) + .collect(); + assert!(mark_names.contains(&"nemo_guardrails.remote.start".to_string())); + assert!(mark_names.contains(&"nemo_guardrails.remote.end".to_string())); + + let start_mark = captured_events + .iter() + .find(|event| event.name() == "nemo_guardrails.remote.start") + .unwrap(); + assert_eq!( + start_mark.data().unwrap()["config_id"], + json!("safety-default") + ); + assert_eq!(start_mark.data().unwrap()["stream"], json!(false)); + + let end_mark = captured_events + .iter() + .find(|event| event.name() == "nemo_guardrails.remote.end") + .unwrap(); + assert_eq!(end_mark.data().unwrap()["http_status"], json!(200)); + assert_eq!(end_mark.data().unwrap()["stream"], json!(false)); + + deregister_subscriber("nemo-guardrails-remote-execution-events").unwrap(); +} + +#[tokio::test] +async fn remote_request_uses_config_ids_when_config_id_is_not_set() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + reset_runtime(); + setup_isolated_thread(); + + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let address = listener.local_addr().unwrap(); + let (request_tx, request_rx) = mpsc::channel(); + let response_body = json!({ + "id": "chatcmpl-remote", + "object": "chat.completion", + "created": 1, + "model": "gpt-4o-mini", + "choices": [{ + "index": 0, + "message": {"role": "assistant", "content": "guarded"}, + "finish_reason": "stop" + }] + }) + .to_string(); + let http_response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + response_body.len(), + response_body + ) + .into_bytes(); + spawn_http_responder(listener, http_response, request_tx); + + initialize_plugins(plugin_config(json!({ + "mode": "remote", + "codec": "openai_chat", + "remote": { + "endpoint": format!("http://{address}"), + "config_ids": ["safety-a", "safety-b"] + } + }))) + .await + .unwrap(); + + let func: LlmExecutionNextFn = + Arc::new(move |_req| Box::pin(async move { Ok(json!({"response": "original"})) })); + + let _ = llm_call_execute( + LlmCallExecuteParams::builder() + .name("openai") + .request(make_chat_request(false)) + .func(func) + .attributes(LlmAttributes::empty()) + .response_codec(Arc::new(OpenAIChatCodec) as Arc) + .build(), + ) + .await + .unwrap(); + + let captured = recv_captured_request(&request_rx); + let request_json: Json = serde_json::from_slice(&captured.body).unwrap(); + assert_eq!( + request_json["guardrails"]["config_ids"], + json!(["safety-a", "safety-b"]) + ); + assert!(request_json["guardrails"].get("config_id").is_none()); +} + +#[tokio::test] +async fn remote_initialization_installs_stream_execution_intercept() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + reset_runtime(); + setup_isolated_thread(); + let events = capture_events("nemo-guardrails-remote-stream-events"); + + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let address = listener.local_addr().unwrap(); + let (request_tx, request_rx) = mpsc::channel(); + let sse_body = concat!( + "data: {\"id\":\"chatcmpl-remote\",\"object\":\"chat.completion.chunk\",\"created\":1,\"model\":\"gpt-4o-mini\",\"choices\":[{\"index\":0,\"delta\":{\"role\":\"assistant\",\"content\":\"guard\"},\"finish_reason\":null}]}\n\n", + "data: {\"id\":\"chatcmpl-remote\",\"object\":\"chat.completion.chunk\",\"created\":1,\"model\":\"gpt-4o-mini\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"ed\"},\"finish_reason\":\"stop\"}]}\n\n", + "data: [DONE]\n\n" + ); + let http_response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: text/event-stream\r\nContent-Length: {}\r\n\r\n{}", + sse_body.len(), + sse_body + ) + .into_bytes(); + spawn_http_responder(listener, http_response, request_tx); + + initialize_plugins(plugin_config(json!({ + "mode": "remote", + "codec": "openai_chat", + "remote": { + "endpoint": format!("http://{address}"), + "config_id": "safety-default" + } + }))) + .await + .unwrap(); + + let original_called = Arc::new(AtomicBool::new(false)); + let called = Arc::clone(&original_called); + let func: LlmStreamExecutionNextFn = Arc::new(move |_req| { + called.store(true, Ordering::SeqCst); + Box::pin(async move { + let stream = tokio_stream::iter(vec![Ok(json!({"chunk": "original"}))]); + Ok(Box::pin(stream) as LlmJsonStream) + }) + }); + + let streaming_codec = OpenAIChatStreamingCodec::new(); + let collector = streaming_codec.collector(); + let finalizer = streaming_codec.finalizer(); + let response_codec: Arc = Arc::new(OpenAIChatCodec); + + let mut stream = llm_stream_call_execute( + LlmStreamCallExecuteParams::builder() + .name("openai") + .request(make_chat_request(true)) + .func(func) + .collector(collector) + .finalizer(finalizer) + .attributes(LlmAttributes::STREAMING) + .response_codec(response_codec) + .build(), + ) + .await + .unwrap(); + + let mut chunks = Vec::new(); + while let Some(chunk) = tokio::time::timeout(TEST_TIMEOUT, stream.next()) + .await + .expect("timed out waiting for remote stream chunk") + { + chunks.push(chunk.unwrap()); + } + + assert!(!original_called.load(Ordering::SeqCst)); + assert_eq!(chunks.len(), 2); + assert_eq!(chunks[0]["choices"][0]["delta"]["content"], json!("guard")); + assert_eq!(chunks[1]["choices"][0]["delta"]["content"], json!("ed")); + + let captured = recv_captured_request(&request_rx); + let request_json: Json = serde_json::from_slice(&captured.body).unwrap(); + assert_eq!(request_json["stream"], json!(true)); + assert_eq!( + request_json["guardrails"]["config_id"], + json!("safety-default") + ); + + let captured_events = events.lock().unwrap().clone(); + let start_mark = captured_events + .iter() + .find(|event| event.name() == "nemo_guardrails.remote.start") + .unwrap(); + assert_eq!(start_mark.data().unwrap()["stream"], json!(true)); + + let end_mark = captured_events + .iter() + .find(|event| event.name() == "nemo_guardrails.remote.end") + .unwrap(); + assert_eq!(end_mark.data().unwrap()["http_status"], json!(200)); + assert_eq!(end_mark.data().unwrap()["stream"], json!(true)); + + deregister_subscriber("nemo-guardrails-remote-stream-events").unwrap(); +} + +#[tokio::test] +async fn remote_non_streaming_http_errors_are_reported_and_marked() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + reset_runtime(); + setup_isolated_thread(); + let events = capture_events("nemo-guardrails-remote-error-events"); + + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let address = listener.local_addr().unwrap(); + let (request_tx, _request_rx) = mpsc::channel(); + let response_body = r#"{"error":"backend unavailable"}"#; + let http_response = format!( + "HTTP/1.1 502 Bad Gateway\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + response_body.len(), + response_body + ) + .into_bytes(); + spawn_http_responder(listener, http_response, request_tx); + + initialize_plugins(plugin_config(json!({ + "mode": "remote", + "codec": "openai_chat", + "remote": { + "endpoint": format!("http://{address}"), + "config_id": "safety-default" + } + }))) + .await + .unwrap(); + + let original_called = Arc::new(AtomicBool::new(false)); + let called = Arc::clone(&original_called); + let func: LlmExecutionNextFn = Arc::new(move |_req| { + called.store(true, Ordering::SeqCst); + Box::pin(async move { Ok(json!({"response": "original"})) }) + }); + + let error = llm_call_execute( + LlmCallExecuteParams::builder() + .name("openai") + .request(make_chat_request(false)) + .func(func) + .attributes(LlmAttributes::empty()) + .response_codec(Arc::new(OpenAIChatCodec) as Arc) + .build(), + ) + .await + .unwrap_err(); + + assert!(!original_called.load(Ordering::SeqCst)); + match error { + crate::error::FlowError::Internal(message) => { + assert!(message.contains("status 502")); + assert!(message.contains("backend unavailable")); + } + other => panic!("unexpected error: {other}"), + } + + let captured_events = events.lock().unwrap().clone(); + assert!( + captured_events + .iter() + .any(|event| event.name() == "nemo_guardrails.remote.start") + ); + let error_mark = captured_events + .iter() + .find(|event| event.name() == "nemo_guardrails.remote.error") + .unwrap(); + assert_eq!(error_mark.data().unwrap()["http_status"], json!(502)); + assert_eq!(error_mark.data().unwrap()["stream"], json!(false)); + assert!( + error_mark.data().unwrap()["error"] + .as_str() + .unwrap() + .contains("error body omitted from marks") + ); + + deregister_subscriber("nemo-guardrails-remote-error-events").unwrap(); +} + +#[tokio::test] +async fn remote_streaming_http_errors_are_reported_and_marked() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + reset_runtime(); + setup_isolated_thread(); + let events = capture_events("nemo-guardrails-remote-stream-error-events"); + + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let address = listener.local_addr().unwrap(); + let (request_tx, _request_rx) = mpsc::channel(); + let response_body = r#"{"error":"stream backend unavailable"}"#; + let http_response = format!( + "HTTP/1.1 503 Service Unavailable\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + response_body.len(), + response_body + ) + .into_bytes(); + spawn_http_responder(listener, http_response, request_tx); + + initialize_plugins(plugin_config(json!({ + "mode": "remote", + "codec": "openai_chat", + "remote": { + "endpoint": format!("http://{address}"), + "config_id": "safety-default" + } + }))) + .await + .unwrap(); + + let original_called = Arc::new(AtomicBool::new(false)); + let called = Arc::clone(&original_called); + let func: LlmStreamExecutionNextFn = Arc::new(move |_req| { + called.store(true, Ordering::SeqCst); + Box::pin(async move { + let stream = tokio_stream::iter(vec![Ok(json!({"chunk": "original"}))]); + Ok(Box::pin(stream) as LlmJsonStream) + }) + }); + + let streaming_codec = OpenAIChatStreamingCodec::new(); + let collector = streaming_codec.collector(); + let finalizer = streaming_codec.finalizer(); + let response_codec: Arc = Arc::new(OpenAIChatCodec); + + let error = match llm_stream_call_execute( + LlmStreamCallExecuteParams::builder() + .name("openai") + .request(make_chat_request(true)) + .func(func) + .collector(collector) + .finalizer(finalizer) + .attributes(LlmAttributes::STREAMING) + .response_codec(response_codec) + .build(), + ) + .await + { + Ok(_) => panic!("expected remote streaming request to fail"), + Err(error) => error, + }; + + assert!(!original_called.load(Ordering::SeqCst)); + match error { + crate::error::FlowError::Internal(message) => { + assert!(message.contains("status 503")); + assert!(message.contains("stream backend unavailable")); + } + other => panic!("unexpected error: {other}"), + } + + let captured_events = events.lock().unwrap().clone(); + assert!( + captured_events + .iter() + .any(|event| event.name() == "nemo_guardrails.remote.start") + ); + let error_mark = captured_events + .iter() + .find(|event| event.name() == "nemo_guardrails.remote.error") + .unwrap(); + assert_eq!(error_mark.data().unwrap()["http_status"], json!(503)); + assert_eq!(error_mark.data().unwrap()["stream"], json!(true)); + assert!( + error_mark.data().unwrap()["error"] + .as_str() + .unwrap() + .contains("error body omitted from marks") + ); + + deregister_subscriber("nemo-guardrails-remote-stream-error-events").unwrap(); +} + +#[tokio::test] +async fn remote_non_streaming_invalid_json_is_reported_and_marked() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + reset_runtime(); + setup_isolated_thread(); + let events = capture_events("nemo-guardrails-remote-invalid-json-events"); + + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let address = listener.local_addr().unwrap(); + let (request_tx, _request_rx) = mpsc::channel(); + let response_body = "{not-json}"; + let http_response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + response_body.len(), + response_body + ) + .into_bytes(); + spawn_http_responder(listener, http_response, request_tx); + + initialize_plugins(plugin_config(json!({ + "mode": "remote", + "codec": "openai_chat", + "remote": { + "endpoint": format!("http://{address}"), + "config_id": "safety-default" + } + }))) + .await + .unwrap(); + + let func: LlmExecutionNextFn = + Arc::new(move |_req| Box::pin(async move { Ok(json!({"response": "original"})) })); + + let error = llm_call_execute( + LlmCallExecuteParams::builder() + .name("openai") + .request(make_chat_request(false)) + .func(func) + .attributes(LlmAttributes::empty()) + .response_codec(Arc::new(OpenAIChatCodec) as Arc) + .build(), + ) + .await + .unwrap_err(); + + match error { + crate::error::FlowError::Internal(message) => { + assert!(message.contains("failed to parse remote response JSON")); + } + other => panic!("unexpected error: {other}"), + } + + let captured_events = events.lock().unwrap().clone(); + let error_mark = captured_events + .iter() + .find(|event| event.name() == "nemo_guardrails.remote.error") + .unwrap(); + assert_eq!(error_mark.data().unwrap()["http_status"], json!(200)); + assert_eq!(error_mark.data().unwrap()["stream"], json!(false)); + + deregister_subscriber("nemo-guardrails-remote-invalid-json-events").unwrap(); +} + +#[tokio::test] +async fn remote_streaming_malformed_chunk_is_reported_and_marked() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + reset_runtime(); + setup_isolated_thread(); + let events = capture_events("nemo-guardrails-remote-malformed-stream-events"); + + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let address = listener.local_addr().unwrap(); + let (request_tx, _request_rx) = mpsc::channel(); + let sse_body = "data: {not-json}\n\n"; + let http_response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: text/event-stream\r\nContent-Length: {}\r\n\r\n{}", + sse_body.len(), + sse_body + ) + .into_bytes(); + spawn_http_responder(listener, http_response, request_tx); + + initialize_plugins(plugin_config(json!({ + "mode": "remote", + "codec": "openai_chat", + "remote": { + "endpoint": format!("http://{address}"), + "config_id": "safety-default" + } + }))) + .await + .unwrap(); + + let func: LlmStreamExecutionNextFn = Arc::new(move |_req| { + Box::pin(async move { + let stream = tokio_stream::iter(vec![Ok(json!({"chunk": "original"}))]); + Ok(Box::pin(stream) as LlmJsonStream) + }) + }); + + let streaming_codec = OpenAIChatStreamingCodec::new(); + let collector = streaming_codec.collector(); + let finalizer = streaming_codec.finalizer(); + let response_codec: Arc = Arc::new(OpenAIChatCodec); + + let mut stream = llm_stream_call_execute( + LlmStreamCallExecuteParams::builder() + .name("openai") + .request(make_chat_request(true)) + .func(func) + .collector(collector) + .finalizer(finalizer) + .attributes(LlmAttributes::STREAMING) + .response_codec(response_codec) + .build(), + ) + .await + .unwrap(); + + let error = tokio::time::timeout(TEST_TIMEOUT, stream.next()) + .await + .expect("timed out waiting for remote stream error") + .unwrap() + .unwrap_err(); + match error { + crate::error::FlowError::Internal(message) => { + assert!(!message.is_empty()); + } + other => panic!("unexpected error: {other}"), + } + + let captured_events = events.lock().unwrap().clone(); + let error_mark = captured_events + .iter() + .find(|event| event.name() == "nemo_guardrails.remote.error") + .unwrap(); + assert_eq!(error_mark.data().unwrap()["http_status"], json!(200)); + assert_eq!(error_mark.data().unwrap()["stream"], json!(true)); + + deregister_subscriber("nemo-guardrails-remote-malformed-stream-events").unwrap(); +} + +#[tokio::test] +async fn remote_preflight_tool_choice_failure_is_reported_and_marked() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + reset_runtime(); + setup_isolated_thread(); + let events = capture_events("nemo-guardrails-remote-preflight-error-events"); + + initialize_plugins(plugin_config(json!({ + "mode": "remote", + "codec": "openai_chat", + "remote": { + "endpoint": unused_local_endpoint(), + "config_id": "safety-default" + } + }))) + .await + .unwrap(); + + let func: LlmExecutionNextFn = + Arc::new(move |_req| Box::pin(async move { Ok(json!({"response": "original"})) })); + let request = LlmRequest { + headers: serde_json::Map::new(), + content: json!({ + "model": "gpt-4o-mini", + "messages": [{"role": "user", "content": "hello"}], + "tools": [{ + "type": "function", + "function": { + "name": "lookup", + "description": "Lookup data", + "parameters": {"type": "object"} + } + }] + }), + }; + + let error = llm_call_execute( + LlmCallExecuteParams::builder() + .name("openai") + .request(request) + .func(func) + .attributes(LlmAttributes::empty()) + .response_codec(Arc::new(OpenAIChatCodec) as Arc) + .build(), + ) + .await + .unwrap_err(); + + match error { + crate::error::FlowError::Internal(message) => { + assert!(message.contains("does not support OpenAI tool definitions or tool_choice")); + } + other => panic!("unexpected error: {other}"), + } + + let captured_events = events.lock().unwrap().clone(); + assert!( + captured_events + .iter() + .any(|event| event.name() == "nemo_guardrails.remote.start") + ); + let error_mark = captured_events + .iter() + .find(|event| event.name() == "nemo_guardrails.remote.error") + .unwrap(); + assert_eq!(error_mark.data().unwrap()["stream"], json!(false)); + assert!( + error_mark.data().unwrap()["error"] + .as_str() + .unwrap() + .contains("does not support OpenAI tool definitions or tool_choice") + ); + + deregister_subscriber("nemo-guardrails-remote-preflight-error-events").unwrap(); +} + +#[tokio::test] +async fn remote_transport_failure_is_reported_and_marked() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + reset_runtime(); + setup_isolated_thread(); + let events = capture_events("nemo-guardrails-remote-transport-error-events"); + + initialize_plugins(plugin_config(json!({ + "mode": "remote", + "codec": "openai_chat", + "remote": { + "endpoint": unused_local_endpoint(), + "config_id": "safety-default", + "timeout_millis": 50 + } + }))) + .await + .unwrap(); + + let func: LlmExecutionNextFn = + Arc::new(move |_req| Box::pin(async move { Ok(json!({"response": "original"})) })); + + let error = llm_call_execute( + LlmCallExecuteParams::builder() + .name("openai") + .request(make_chat_request(false)) + .func(func) + .attributes(LlmAttributes::empty()) + .response_codec(Arc::new(OpenAIChatCodec) as Arc) + .build(), + ) + .await + .unwrap_err(); + + match error { + crate::error::FlowError::Internal(message) => { + assert!(message.contains("remote request failed")); + } + other => panic!("unexpected error: {other}"), + } + + let captured_events = events.lock().unwrap().clone(); + let error_mark = captured_events + .iter() + .find(|event| event.name() == "nemo_guardrails.remote.error") + .unwrap(); + assert_eq!(error_mark.data().unwrap()["stream"], json!(false)); + assert!(error_mark.data().unwrap().get("http_status").is_none()); + + deregister_subscriber("nemo-guardrails-remote-transport-error-events").unwrap(); +} + +#[tokio::test] +async fn remote_success_without_guardrails_payload_is_allowed() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + reset_runtime(); + setup_isolated_thread(); + + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let address = listener.local_addr().unwrap(); + let (request_tx, _request_rx) = mpsc::channel(); + let response_body = json!({ + "id": "chatcmpl-remote", + "object": "chat.completion", + "created": 1, + "model": "gpt-4o-mini", + "choices": [{ + "index": 0, + "message": {"role": "assistant", "content": "guarded"}, + "finish_reason": "stop" + }] + }) + .to_string(); + let http_response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + response_body.len(), + response_body + ) + .into_bytes(); + spawn_http_responder(listener, http_response, request_tx); + + initialize_plugins(plugin_config(json!({ + "mode": "remote", + "codec": "openai_chat", + "remote": { + "endpoint": format!("http://{address}"), + "config_id": "safety-default" + } + }))) + .await + .unwrap(); + + let func: LlmExecutionNextFn = + Arc::new(move |_req| Box::pin(async move { Ok(json!({"response": "original"})) })); + + let response = llm_call_execute( + LlmCallExecuteParams::builder() + .name("openai") + .request(make_chat_request(false)) + .func(func) + .attributes(LlmAttributes::empty()) + .response_codec(Arc::new(OpenAIChatCodec) as Arc) + .build(), + ) + .await + .unwrap(); + + assert_eq!(response["id"], json!("chatcmpl-remote")); + assert!(response.get("guardrails").is_none()); +} + +#[tokio::test] +async fn remote_tool_input_block_rejects_before_tool_execution() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + reset_runtime(); + setup_isolated_thread(); + let events = capture_events("nemo-guardrails-remote-tool-input-events"); + + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let address = listener.local_addr().unwrap(); + let (request_tx, request_rx) = mpsc::channel(); + let response_body = json!({ + "id": "chatcmpl-tool-input-blocked", + "object": "chat.completion", + "created": 1, + "model": "", + "choices": [{ + "index": 0, + "message": {"role": "assistant", "content": "blocked"}, + "finish_reason": "stop" + }], + "guardrails": { + "config_id": "safety-default", + "log": { + "activated_rails": [{ + "name": "tool_input_block", + "stop": true + }] + } + } + }) + .to_string(); + let http_response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + response_body.len(), + response_body + ) + .into_bytes(); + spawn_http_responder(listener, http_response, request_tx); + + initialize_plugins(plugin_config(json!({ + "mode": "remote", + "input": false, + "output": false, + "tool_input": true, + "remote": { + "endpoint": format!("http://{address}"), + "config_id": "safety-default" + } + }))) + .await + .unwrap(); + + let original_called = Arc::new(AtomicBool::new(false)); + let called = Arc::clone(&original_called); + let error = tool_call_execute( + ToolCallExecuteParams::builder() + .name("weather_lookup") + .args(json!({"city": "Phoenix"})) + .func(Arc::new(move |_args| { + called.store(true, Ordering::SeqCst); + Box::pin(async move { Ok(json!({"forecast": "sunny"})) }) + })) + .build(), + ) + .await + .unwrap_err(); + + assert!(!original_called.load(Ordering::SeqCst)); + match error { + crate::error::FlowError::GuardrailRejected(message) => { + assert!(message.contains("tool_input")); + } + other => panic!("unexpected error: {other}"), + } + + let captured = recv_captured_request(&request_rx); + let request_json: Json = serde_json::from_slice(&captured.body).unwrap(); + assert_eq!( + request_json["guardrails"]["options"]["rails"]["tool_input"], + json!(true) + ); + assert_eq!( + request_json["guardrails"]["options"]["rails"]["tool_output"], + json!(false) + ); + + let captured_events = events.lock().unwrap().clone(); + let start_mark = captured_events + .iter() + .find(|event| event.name() == "nemo_guardrails.remote.start") + .unwrap(); + assert_eq!(start_mark.data().unwrap()["surface"], json!("tool_input")); + assert_eq!( + start_mark.data().unwrap()["tool_name"], + json!("weather_lookup") + ); + let end_mark = captured_events + .iter() + .find(|event| event.name() == "nemo_guardrails.remote.end") + .unwrap(); + assert_eq!(end_mark.data().unwrap()["surface"], json!("tool_input")); + + deregister_subscriber("nemo-guardrails-remote-tool-input-events").unwrap(); +} + +#[tokio::test] +async fn remote_tool_input_can_rewrite_tool_arguments() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + reset_runtime(); + setup_isolated_thread(); + + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let address = listener.local_addr().unwrap(); + let (request_tx, request_rx) = mpsc::channel(); + let response_body = json!({ + "id": "chatcmpl-tool-input-modified", + "object": "chat.completion", + "created": 1, + "model": "", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "{\"tool_name\":\"weather_lookup\",\"arguments\":{\"city\":\"Boston\"}}" + }, + "finish_reason": "stop" + }], + "guardrails": { + "config_id": "safety-default", + "log": { + "activated_rails": [] + } + } + }) + .to_string(); + let http_response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + response_body.len(), + response_body + ) + .into_bytes(); + spawn_http_responder(listener, http_response, request_tx); + + initialize_plugins(plugin_config(json!({ + "mode": "remote", + "input": false, + "output": false, + "tool_input": true, + "remote": { + "endpoint": format!("http://{address}"), + "config_id": "safety-default" + } + }))) + .await + .unwrap(); + + let seen_args = Arc::new(Mutex::new(None::)); + let seen_args_for_call = Arc::clone(&seen_args); + let result = tool_call_execute( + ToolCallExecuteParams::builder() + .name("weather_lookup") + .args(json!({"city": "Phoenix"})) + .func(Arc::new(move |args| { + *seen_args_for_call.lock().unwrap() = Some(args.clone()); + Box::pin(async move { Ok(json!({"forecast": "sunny"})) }) + })) + .build(), + ) + .await + .unwrap(); + + assert_eq!(result, json!({"forecast": "sunny"})); + assert_eq!(*seen_args.lock().unwrap(), Some(json!({"city": "Boston"}))); + + let captured = recv_captured_request(&request_rx); + let request_json: Json = serde_json::from_slice(&captured.body).unwrap(); + assert_eq!(request_json["messages"][0]["role"], json!("user")); +} + +#[tokio::test] +async fn remote_tool_output_can_rewrite_tool_result() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + reset_runtime(); + setup_isolated_thread(); + + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let address = listener.local_addr().unwrap(); + let (request_tx, request_rx) = mpsc::channel(); + let response_body = json!({ + "id": "chatcmpl-tool-output-modified", + "object": "chat.completion", + "created": 1, + "model": "", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "{\"tool_name\":\"weather_lookup\",\"arguments\":{\"city\":\"Phoenix\"},\"result\":{\"forecast\":\"cloudy\"}}" + }, + "finish_reason": "stop" + }], + "guardrails": { + "config_id": "safety-default", + "log": { + "activated_rails": [] + } + } + }) + .to_string(); + let http_response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + response_body.len(), + response_body + ) + .into_bytes(); + spawn_http_responder(listener, http_response, request_tx); + + initialize_plugins(plugin_config(json!({ + "mode": "remote", + "input": false, + "output": false, + "tool_output": true, + "remote": { + "endpoint": format!("http://{address}"), + "config_id": "safety-default" + } + }))) + .await + .unwrap(); + + let result = tool_call_execute( + ToolCallExecuteParams::builder() + .name("weather_lookup") + .args(json!({"city": "Phoenix"})) + .func(Arc::new(move |_args| { + Box::pin(async move { Ok(json!({"forecast": "sunny"})) }) + })) + .build(), + ) + .await + .unwrap(); + + assert_eq!(result, json!({"forecast": "cloudy"})); + + let captured = recv_captured_request(&request_rx); + let request_json: Json = serde_json::from_slice(&captured.body).unwrap(); + assert_eq!( + request_json["guardrails"]["options"]["rails"]["tool_input"], + json!(false) + ); + assert_eq!( + request_json["guardrails"]["options"]["rails"]["tool_output"], + json!(true) + ); +} + +#[tokio::test] +async fn remote_tool_input_invalid_modified_arguments_are_reported() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + reset_runtime(); + setup_isolated_thread(); + + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let address = listener.local_addr().unwrap(); + let (request_tx, _request_rx) = mpsc::channel(); + let response_body = json!({ + "id": "chatcmpl-tool-input-invalid", + "object": "chat.completion", + "created": 1, + "model": "", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "{not-json}" + }, + "finish_reason": "stop" + }], + "guardrails": { + "config_id": "safety-default", + "log": { + "activated_rails": [] + } + } + }) + .to_string(); + let http_response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + response_body.len(), + response_body + ) + .into_bytes(); + spawn_http_responder(listener, http_response, request_tx); + + initialize_plugins(plugin_config(json!({ + "mode": "remote", + "input": false, + "output": false, + "tool_input": true, + "remote": { + "endpoint": format!("http://{address}"), + "config_id": "safety-default" + } + }))) + .await + .unwrap(); + + let error = tool_call_execute( + ToolCallExecuteParams::builder() + .name("weather_lookup") + .args(json!({"city": "Phoenix"})) + .func(Arc::new(move |_args| { + Box::pin(async move { Ok(json!({"forecast": "sunny"})) }) + })) + .build(), + ) + .await + .unwrap_err(); + + match error { + crate::error::FlowError::Internal(message) => { + assert!(message.contains("modified tool arguments content that is not valid JSON")); + } + other => panic!("unexpected error: {other}"), + } +} + +#[tokio::test] +async fn remote_tool_output_missing_result_field_is_reported() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + reset_runtime(); + setup_isolated_thread(); + + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let address = listener.local_addr().unwrap(); + let (request_tx, _request_rx) = mpsc::channel(); + let response_body = json!({ + "id": "chatcmpl-tool-output-missing-result", + "object": "chat.completion", + "created": 1, + "model": "", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "{\"tool_name\":\"weather_lookup\",\"arguments\":{\"city\":\"Phoenix\"}}" + }, + "finish_reason": "stop" + }], + "guardrails": { + "config_id": "safety-default", + "log": { + "activated_rails": [] + } + } + }) + .to_string(); + let http_response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + response_body.len(), + response_body + ) + .into_bytes(); + spawn_http_responder(listener, http_response, request_tx); + + initialize_plugins(plugin_config(json!({ + "mode": "remote", + "input": false, + "output": false, + "tool_output": true, + "remote": { + "endpoint": format!("http://{address}"), + "config_id": "safety-default" + } + }))) + .await + .unwrap(); + + let error = tool_call_execute( + ToolCallExecuteParams::builder() + .name("weather_lookup") + .args(json!({"city": "Phoenix"})) + .func(Arc::new(move |_args| { + Box::pin(async move { Ok(json!({"forecast": "sunny"})) }) + })) + .build(), + ) + .await + .unwrap_err(); + + match error { + crate::error::FlowError::Internal(message) => { + assert!(message.contains("without a 'result' field")); + } + other => panic!("unexpected error: {other}"), + } +} + +#[tokio::test] +async fn remote_tool_output_does_not_run_when_tool_callback_errors() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + reset_runtime(); + setup_isolated_thread(); + + initialize_plugins(plugin_config(json!({ + "mode": "remote", + "input": false, + "output": false, + "tool_output": true, + "remote": { + "endpoint": unused_local_endpoint(), + "config_id": "safety-default" + } + }))) + .await + .unwrap(); + + let error = tool_call_execute( + ToolCallExecuteParams::builder() + .name("weather_lookup") + .args(json!({"city": "Phoenix"})) + .func(Arc::new(move |_args| { + Box::pin(async move { + Err(crate::error::FlowError::Internal( + "tool callback failed".to_string(), + )) + }) + })) + .build(), + ) + .await + .unwrap_err(); + + match error { + crate::error::FlowError::Internal(message) => { + assert_eq!(message, "tool callback failed"); + } + other => panic!("unexpected error: {other}"), + } +} + +#[tokio::test] +async fn remote_tool_input_rewrite_with_mismatched_tool_name_is_rejected() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + reset_runtime(); + setup_isolated_thread(); + + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let address = listener.local_addr().unwrap(); + let (request_tx, _request_rx) = mpsc::channel(); + let response_body = json!({ + "id": "chatcmpl-tool-input-mismatch", + "object": "chat.completion", + "created": 1, + "model": "", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "{\"tool_name\":\"different_lookup\",\"arguments\":{\"city\":\"Boston\"}}" + }, + "finish_reason": "stop" + }], + "guardrails": { + "config_id": "safety-default", + "log": { + "activated_rails": [] + } + } + }) + .to_string(); + let http_response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + response_body.len(), + response_body + ) + .into_bytes(); + spawn_http_responder(listener, http_response, request_tx); + + initialize_plugins(plugin_config(json!({ + "mode": "remote", + "input": false, + "output": false, + "tool_input": true, + "remote": { + "endpoint": format!("http://{address}"), + "config_id": "safety-default" + } + }))) + .await + .unwrap(); + + let error = tool_call_execute( + ToolCallExecuteParams::builder() + .name("weather_lookup") + .args(json!({"city": "Phoenix"})) + .func(Arc::new(move |_args| { + Box::pin(async move { Ok(json!({"forecast": "sunny"})) }) + })) + .build(), + ) + .await + .unwrap_err(); + + match error { + crate::error::FlowError::Internal(message) => { + assert!(message.contains("unexpected tool 'different_lookup'")); + } + other => panic!("unexpected error: {other}"), + } +} + +#[tokio::test] +async fn remote_tool_input_and_output_run_in_order() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + reset_runtime(); + setup_isolated_thread(); + + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let address = listener.local_addr().unwrap(); + let (request_tx, request_rx) = mpsc::channel(); + let input_response_body = json!({ + "id": "chatcmpl-tool-input-modified", + "object": "chat.completion", + "created": 1, + "model": "", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "{\"tool_name\":\"weather_lookup\",\"arguments\":{\"city\":\"Boston\"}}" + }, + "finish_reason": "stop" + }], + "guardrails": { + "config_id": "safety-default", + "log": { + "activated_rails": [] + } + } + }) + .to_string(); + let output_response_body = json!({ + "id": "chatcmpl-tool-output-modified", + "object": "chat.completion", + "created": 1, + "model": "", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "{\"tool_name\":\"weather_lookup\",\"arguments\":{\"city\":\"Boston\"},\"result\":{\"forecast\":\"cloudy\"}}" + }, + "finish_reason": "stop" + }], + "guardrails": { + "config_id": "safety-default", + "log": { + "activated_rails": [] + } + } + }) + .to_string(); + let input_response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + input_response_body.len(), + input_response_body + ) + .into_bytes(); + let output_response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + output_response_body.len(), + output_response_body + ) + .into_bytes(); + spawn_http_responder_sequence(listener, vec![input_response, output_response], request_tx); + + initialize_plugins(plugin_config(json!({ + "mode": "remote", + "input": false, + "output": false, + "tool_input": true, + "tool_output": true, + "remote": { + "endpoint": format!("http://{address}"), + "config_id": "safety-default" + } + }))) + .await + .unwrap(); + + let seen_args = Arc::new(Mutex::new(None::)); + let seen_args_for_call = Arc::clone(&seen_args); + let result = tool_call_execute( + ToolCallExecuteParams::builder() + .name("weather_lookup") + .args(json!({"city": "Phoenix"})) + .func(Arc::new(move |args| { + *seen_args_for_call.lock().unwrap() = Some(args.clone()); + Box::pin(async move { Ok(json!({"forecast": "sunny"})) }) + })) + .build(), + ) + .await + .unwrap(); + + assert_eq!(*seen_args.lock().unwrap(), Some(json!({"city": "Boston"}))); + assert_eq!(result, json!({"forecast": "cloudy"})); + + let first_request = recv_captured_request(&request_rx); + let first_request_json: Json = serde_json::from_slice(&first_request.body).unwrap(); + assert_eq!(first_request_json["messages"][0]["role"], json!("user")); + assert_eq!( + first_request_json["guardrails"]["options"]["rails"]["tool_input"], + json!(true) + ); + assert_eq!( + first_request_json["guardrails"]["options"]["rails"]["tool_output"], + json!(false) + ); + + let second_request = recv_captured_request(&request_rx); + let second_request_json: Json = serde_json::from_slice(&second_request.body).unwrap(); + assert_eq!(second_request_json["messages"][0]["role"], json!("user")); + assert_eq!( + second_request_json["messages"][1]["role"], + json!("assistant") + ); + assert_eq!( + second_request_json["guardrails"]["options"]["rails"]["tool_input"], + json!(false) + ); + assert_eq!( + second_request_json["guardrails"]["options"]["rails"]["tool_output"], + json!(true) + ); +} + +#[tokio::test] +async fn remote_tool_checks_forward_context_state_and_thread_id() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + reset_runtime(); + setup_isolated_thread(); + + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let address = listener.local_addr().unwrap(); + let (request_tx, request_rx) = mpsc::channel(); + let response_body = json!({ + "id": "chatcmpl-tool-input-context", + "object": "chat.completion", + "created": 1, + "model": "", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "{\"tool_name\":\"weather_lookup\",\"arguments\":{\"city\":\"Phoenix\"}}" + }, + "finish_reason": "stop" + }], + "guardrails": { + "config_id": "safety-default", + "log": { + "activated_rails": [] + } + } + }) + .to_string(); + let http_response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + response_body.len(), + response_body + ) + .into_bytes(); + spawn_http_responder(listener, http_response, request_tx); + + initialize_plugins(plugin_config(json!({ + "mode": "remote", + "input": false, + "output": false, + "tool_input": true, + "remote": { + "endpoint": format!("http://{address}"), + "config_id": "safety-default" + }, + "request_defaults": { + "context": {"tenant": "smoke"}, + "thread_id": "1234567890abcdef", + "state": {"events": []} + } + }))) + .await + .unwrap(); + + let result = tool_call_execute( + ToolCallExecuteParams::builder() + .name("weather_lookup") + .args(json!({"city": "Phoenix"})) + .func(Arc::new(move |_args| { + Box::pin(async move { Ok(json!({"forecast": "sunny"})) }) + })) + .build(), + ) + .await + .unwrap(); + + assert_eq!(result, json!({"forecast": "sunny"})); + + let captured = recv_captured_request(&request_rx); + let request_json: Json = serde_json::from_slice(&captured.body).unwrap(); + assert_eq!( + request_json["guardrails"]["context"], + json!({"tenant": "smoke"}) + ); + assert_eq!( + request_json["guardrails"]["thread_id"], + json!("1234567890abcdef") + ); + assert_eq!(request_json["guardrails"]["state"], json!({"events": []})); +} + +#[tokio::test] +async fn remote_tool_only_configuration_does_not_intercept_llm_calls() { + let _guard = crate::plugins::nemo_guardrails::test_mutex() + .lock() + .unwrap_or_else(|err| err.into_inner()); + reset_runtime(); + setup_isolated_thread(); + + initialize_plugins(plugin_config(json!({ + "mode": "remote", + "input": false, + "output": false, + "tool_input": true, + "remote": { + "endpoint": unused_local_endpoint(), + "config_id": "safety-default" + } + }))) + .await + .unwrap(); + + let expected = json!({"response": "original"}); + let func: LlmExecutionNextFn = Arc::new(move |_req| { + let expected = expected.clone(); + Box::pin(async move { Ok(expected) }) + }); + + let response = llm_call_execute( + LlmCallExecuteParams::builder() + .name("openai") + .request(make_chat_request(false)) + .func(func) + .attributes(LlmAttributes::empty()) + .response_codec(Arc::new(OpenAIChatCodec) as Arc) + .build(), + ) + .await + .unwrap(); + + assert_eq!(response, json!({"response": "original"})); +} diff --git a/crates/core/tests/unit/plugins/nemo_guardrails/plugin_component_tests.rs b/crates/core/tests/unit/plugins/nemo_guardrails/plugin_component_tests.rs deleted file mode 100644 index 22e721b4..00000000 --- a/crates/core/tests/unit/plugins/nemo_guardrails/plugin_component_tests.rs +++ /dev/null @@ -1,638 +0,0 @@ -// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -// SPDX-License-Identifier: Apache-2.0 - -//! Unit tests for the planned NeMo Guardrails plugin component contract. - -use super::*; -use crate::api::runtime::NemoRelayContextState; -use crate::api::runtime::global_context; -use crate::config_editor::{EditorConfig, EditorFieldKind}; -#[cfg(feature = "schema")] -use crate::plugin::plugin_config_schema; -use crate::plugin::{ - PluginComponentSpec, PluginConfig, clear_plugin_configuration, initialize_plugins, - list_plugin_kinds, lookup_plugin, validate_plugin_config, -}; -use serde_json::json; - -fn reset_runtime() { - let _ = clear_plugin_configuration(); - let _ = deregister_nemo_guardrails_component(); - crate::shared_runtime::reset_runtime_owner_for_tests(); - let context = global_context(); - *context.write().unwrap() = NemoRelayContextState::new(); -} - -fn ensure_registered() { - register_nemo_guardrails_component().unwrap(); -} - -fn component(config: Json) -> PluginComponentSpec { - let Json::Object(config) = config else { - panic!("component config must be an object"); - }; - PluginComponentSpec { - kind: NEMO_GUARDRAILS_PLUGIN_KIND.to_string(), - enabled: true, - config, - } -} - -fn disabled_component(config: Json) -> PluginComponentSpec { - let Json::Object(config) = config else { - panic!("component config must be an object"); - }; - PluginComponentSpec { - kind: NEMO_GUARDRAILS_PLUGIN_KIND.to_string(), - enabled: false, - config, - } -} - -fn plugin_config(config: Json) -> PluginConfig { - PluginConfig { - version: 1, - components: vec![component(config)], - policy: Default::default(), - } -} - -fn remote_valid_config() -> Json { - json!({ - "mode": "remote", - "codec": "openai_chat", - "remote": { - "endpoint": "http://localhost:8000", - "config_id": "safety-default" - } - }) -} - -#[test] -fn editor_schema_tracks_nemo_guardrails_config_types() { - let schema = NeMoGuardrailsConfig::editor_schema(); - let mode = schema.field("mode").expect("mode field"); - assert_eq!(mode.kind, EditorFieldKind::Enum); - assert_eq!(mode.enum_values, &["remote", "local"]); - - let remote = schema.field("remote").expect("remote section"); - assert_eq!(remote.kind, EditorFieldKind::Section); - assert!(remote.optional); - - let remote_schema = remote.schema().expect("remote editor schema"); - let headers = remote_schema.field("headers").expect("headers field"); - assert_eq!(headers.kind, EditorFieldKind::StringMap); - - let request_defaults = schema - .field("request_defaults") - .expect("request_defaults section"); - assert_eq!(request_defaults.kind, EditorFieldKind::Section); - assert!(request_defaults.optional); - - let request_defaults_schema = request_defaults - .schema() - .expect("request_defaults editor schema"); - let rails = request_defaults_schema.field("rails").expect("rails field"); - assert_eq!(rails.kind, EditorFieldKind::Section); - - let rails_schema = rails.schema().expect("request rails editor schema"); - let retrieval = rails_schema.field("retrieval").expect("retrieval field"); - assert_eq!(retrieval.kind, EditorFieldKind::Json); -} - -#[test] -fn default_config_and_component_conversion_cover_public_shape() { - let _guard = crate::plugins::nemo_guardrails::test_mutex() - .lock() - .unwrap_or_else(|err| err.into_inner()); - reset_runtime(); - - let defaults = NeMoGuardrailsConfig::default(); - assert_eq!(defaults.version, 1); - assert_eq!(defaults.mode, "remote"); - assert!(defaults.input); - assert!(defaults.output); - assert!(!defaults.tool_input); - assert!(!defaults.tool_output); - assert_eq!(defaults.priority, 100); - assert!(defaults.remote.is_none()); - assert!(defaults.local.is_none()); - assert!(defaults.request_defaults.is_none()); - - let remote = RemoteBackendConfig::default(); - assert_eq!(remote.timeout_millis, 3_000); - assert!(remote.headers.is_empty()); - assert!(remote.config_ids.is_empty()); - - let generic: PluginComponentSpec = ComponentSpec::new(NeMoGuardrailsConfig { - remote: Some(RemoteBackendConfig { - endpoint: Some("http://localhost:8000".into()), - config_id: Some("default".into()), - ..RemoteBackendConfig::default() - }), - ..NeMoGuardrailsConfig::default() - }) - .into(); - assert_eq!(generic.kind, NEMO_GUARDRAILS_PLUGIN_KIND); - assert!(generic.enabled); - assert_eq!(generic.config["mode"], json!("remote")); - assert_eq!(generic.config["remote"]["config_id"], json!("default")); -} - -#[cfg(feature = "schema")] -fn schema_has_property(schema: &Json, name: &str) -> bool { - schema_property(schema, name).is_some() -} - -#[cfg(feature = "schema")] -fn schema_property_has_enum(schema: &Json, name: &str, expected: &[&str]) -> bool { - schema_property(schema, name) - .and_then(|property| property.get("enum")) - .and_then(Json::as_array) - .is_some_and(|values| { - expected - .iter() - .all(|expected| values.iter().any(|value| value == *expected)) - }) -} - -#[cfg(feature = "schema")] -fn schema_property_has_default(schema: &Json, name: &str, expected: Json) -> bool { - schema_property(schema, name) - .and_then(|property| property.get("default")) - .is_some_and(|default| default == &expected) -} - -#[cfg(feature = "schema")] -fn schema_property<'a>(schema: &'a Json, name: &str) -> Option<&'a Json> { - match schema { - Json::Object(object) => { - if let Some(property) = object - .get("properties") - .and_then(Json::as_object) - .and_then(|properties| properties.get(name)) - { - return Some(property); - } - object - .values() - .find_map(|value| schema_property(value, name)) - } - Json::Array(values) => values.iter().find_map(|value| schema_property(value, name)), - _ => None, - } -} - -#[cfg(feature = "schema")] -#[test] -fn schema_contains_every_supported_nemo_guardrails_option() { - let schema = nemo_guardrails_config_schema(); - for field in [ - "version", - "mode", - "config_path", - "config_yaml", - "colang_content", - "codec", - "input", - "output", - "tool_input", - "tool_output", - "priority", - "remote", - "local", - "request_defaults", - "policy", - "endpoint", - "config_id", - "config_ids", - "headers", - "timeout_millis", - "python_module", - "context", - "rails", - "llm_params", - "llm_output", - "output_vars", - "log", - "retrieval", - "dialog", - "unknown_component", - "unknown_field", - "unsupported_value", - ] { - assert!( - schema_has_property(&schema, field), - "schema missing property `{field}`:\n{}", - serde_json::to_string_pretty(&schema).unwrap() - ); - } - assert!(schema_property_has_enum( - &schema, - "mode", - &["remote", "local"] - )); - assert!(schema_property_has_enum( - &schema, - "codec", - &["openai_chat", "openai_responses", "anthropic_messages"] - )); - assert!(schema_property_has_default( - &schema, - "mode", - json!("remote") - )); -} - -#[cfg(feature = "schema")] -#[test] -fn plugin_schema_contains_generic_plugin_surface() { - let schema = plugin_config_schema(); - for field in [ - "version", - "components", - "policy", - "kind", - "enabled", - "config", - ] { - assert!( - schema_has_property(&schema, field), - "plugin schema missing property `{field}`" - ); - } -} - -#[test] -fn registration_is_explicit_not_automatic() { - let _guard = crate::plugins::nemo_guardrails::test_mutex() - .lock() - .unwrap_or_else(|err| err.into_inner()); - reset_runtime(); - - assert!(!list_plugin_kinds().contains(&NEMO_GUARDRAILS_PLUGIN_KIND.to_string())); - assert!(lookup_plugin(NEMO_GUARDRAILS_PLUGIN_KIND).is_none()); - - ensure_registered(); - assert!(list_plugin_kinds().contains(&NEMO_GUARDRAILS_PLUGIN_KIND.to_string())); - assert!(lookup_plugin(NEMO_GUARDRAILS_PLUGIN_KIND).is_some()); - - ensure_registered(); - assert!(lookup_plugin(NEMO_GUARDRAILS_PLUGIN_KIND).is_some()); - assert!(deregister_nemo_guardrails_component()); - assert!(!deregister_nemo_guardrails_component()); -} - -#[test] -fn disabled_component_validates_and_initializes_without_runtime_work() { - let _guard = crate::plugins::nemo_guardrails::test_mutex() - .lock() - .unwrap_or_else(|err| err.into_inner()); - reset_runtime(); - ensure_registered(); - - let config = PluginConfig { - version: 1, - components: vec![disabled_component(remote_valid_config())], - policy: Default::default(), - }; - assert!(!validate_plugin_config(&config).has_errors()); - futures::executor::block_on(initialize_plugins(config)).unwrap(); -} - -#[test] -fn duplicate_component_is_rejected_as_singleton() { - let _guard = crate::plugins::nemo_guardrails::test_mutex() - .lock() - .unwrap_or_else(|err| err.into_inner()); - reset_runtime(); - ensure_registered(); - - let config = PluginConfig { - version: 1, - components: vec![ - component(remote_valid_config()), - component(remote_valid_config()), - ], - policy: Default::default(), - }; - let report = validate_plugin_config(&config); - assert!(report.has_errors()); - assert!( - report - .diagnostics - .iter() - .any(|diag| diag.code == "plugin.duplicate_component") - ); -} - -#[test] -fn invalid_shapes_and_values_are_reported() { - let _guard = crate::plugins::nemo_guardrails::test_mutex() - .lock() - .unwrap_or_else(|err| err.into_inner()); - reset_runtime(); - ensure_registered(); - - let invalid_shape = validate_plugin_config(&plugin_config(json!({ - "version": "one", - }))); - assert!(invalid_shape.has_errors()); - assert!( - invalid_shape - .diagnostics - .iter() - .any(|diag| diag.code == "nemo_guardrails.invalid_plugin_config") - ); - - let local_missing_source = validate_plugin_config(&plugin_config(json!({ - "mode": "local", - "codec": "openai_chat", - }))); - assert!(local_missing_source.has_errors()); - assert!(local_missing_source.diagnostics.iter().any(|diag| { - diag.message - .contains("exactly one of config_path or config_yaml is required in local mode") - })); - - let local_bad_colang = validate_plugin_config(&plugin_config(json!({ - "mode": "local", - "config_path": "./rails", - "colang_content": "define flow x", - "codec": "openai_chat", - }))); - assert!(local_bad_colang.has_errors()); - assert!( - local_bad_colang - .diagnostics - .iter() - .any(|diag| diag.message.contains("colang_content can only be used")) - ); - - let remote_missing_identity = validate_plugin_config(&plugin_config(json!({ - "mode": "remote", - "codec": "openai_chat", - "remote": {"endpoint": "http://localhost:8000"}, - }))); - assert!(remote_missing_identity.has_errors()); - assert!(remote_missing_identity.diagnostics.iter().any(|diag| { - diag.message - .contains("remote mode requires remote.config_id or remote.config_ids") - })); - - let remote_conflicting_ids = validate_plugin_config(&plugin_config(json!({ - "mode": "remote", - "codec": "openai_chat", - "remote": { - "endpoint": "http://localhost:8000", - "config_id": "one", - "config_ids": ["two"] - }, - }))); - assert!(remote_conflicting_ids.has_errors()); - assert!(remote_conflicting_ids.diagnostics.iter().any(|diag| { - diag.message - .contains("remote.config_id and remote.config_ids cannot be used together") - })); - - let missing_codec = validate_plugin_config(&plugin_config(json!({ - "mode": "remote", - "remote": { - "endpoint": "http://localhost:8000", - "config_id": "default" - } - }))); - assert!(missing_codec.has_errors()); - assert!( - missing_codec - .diagnostics - .iter() - .any(|diag| diag.field.as_deref() == Some("codec")) - ); - - let bad_codec = validate_plugin_config(&plugin_config(json!({ - "mode": "remote", - "codec": "openai_agents", - "remote": { - "endpoint": "http://localhost:8000", - "config_id": "default" - } - }))); - assert!(bad_codec.has_errors()); - assert!(bad_codec.diagnostics.iter().any(|diag| { - diag.message - .contains("codec must be 'openai_chat', 'openai_responses', or 'anthropic_messages'") - })); - - let remote_empty_fields = validate_plugin_config(&plugin_config(json!({ - "mode": "remote", - "codec": "openai_chat", - "remote": { - "endpoint": "", - "config_id": "", - "config_ids": ["default", ""] - } - }))); - assert!(remote_empty_fields.has_errors()); - assert!( - remote_empty_fields - .diagnostics - .iter() - .any(|diag| diag.field.as_deref() == Some("remote.endpoint")) - ); - assert!( - remote_empty_fields - .diagnostics - .iter() - .any(|diag| diag.field.as_deref() == Some("remote.config_id")) - ); - assert!( - remote_empty_fields - .diagnostics - .iter() - .any(|diag| diag.field.as_deref() == Some("remote.config_ids[1]")) - ); - - let remote_local_mix = validate_plugin_config(&plugin_config(json!({ - "mode": "remote", - "config_path": "./rails", - "codec": "openai_chat", - "remote": { - "endpoint": "http://localhost:8000", - "config_id": "default" - }, - "local": {"python_module": "nemoguardrails"} - }))); - assert!(remote_local_mix.has_errors()); - assert!( - remote_local_mix - .diagnostics - .iter() - .any(|diag| diag.field.as_deref() == Some("local")) - ); - assert!(remote_local_mix.diagnostics.iter().any(|diag| { - diag.message - .contains("remote mode uses remote config identity") - })); - - let no_surfaces = validate_plugin_config(&plugin_config(json!({ - "mode": "local", - "config_path": "./rails", - "input": false, - "output": false, - "tool_input": false, - "tool_output": false - }))); - assert!(no_surfaces.has_errors()); - assert!( - no_surfaces - .diagnostics - .iter() - .any(|diag| diag.message.contains("at least one Guardrails surface")) - ); - - let local_empty_fields = validate_plugin_config(&plugin_config(json!({ - "mode": "local", - "config_yaml": "", - "colang_content": "", - "codec": "openai_chat", - "local": {"python_module": ""} - }))); - assert!(local_empty_fields.has_errors()); - assert!( - local_empty_fields - .diagnostics - .iter() - .any(|diag| diag.field.as_deref() == Some("config_yaml")) - ); - assert!( - local_empty_fields - .diagnostics - .iter() - .any(|diag| diag.field.as_deref() == Some("colang_content")) - ); - assert!( - local_empty_fields - .diagnostics - .iter() - .any(|diag| diag.field.as_deref() == Some("local.python_module")) - ); - - let invalid_request_defaults = validate_plugin_config(&plugin_config(json!({ - "mode": "remote", - "codec": "openai_chat", - "remote": { - "endpoint": "http://localhost:8000", - "config_id": "default" - }, - "request_defaults": { - "context": true, - "llm_params": [], - "log": "verbose", - "output_vars": 7, - "rails": { - "retrieval": [""] - } - } - }))); - assert!(invalid_request_defaults.has_errors()); - assert!( - invalid_request_defaults - .diagnostics - .iter() - .any(|diag| diag.field.as_deref() == Some("request_defaults.context")) - ); - assert!( - invalid_request_defaults - .diagnostics - .iter() - .any(|diag| diag.field.as_deref() == Some("request_defaults.llm_params")) - ); - assert!( - invalid_request_defaults - .diagnostics - .iter() - .any(|diag| diag.field.as_deref() == Some("request_defaults.log")) - ); - assert!( - invalid_request_defaults - .diagnostics - .iter() - .any(|diag| diag.field.as_deref() == Some("request_defaults.output_vars")) - ); - assert!( - invalid_request_defaults - .diagnostics - .iter() - .any(|diag| diag.field.as_deref() == Some("request_defaults.rails.retrieval[0]")) - ); -} - -#[test] -fn unknown_fields_follow_policy() { - let _guard = crate::plugins::nemo_guardrails::test_mutex() - .lock() - .unwrap_or_else(|err| err.into_inner()); - reset_runtime(); - ensure_registered(); - - let warn_report = validate_plugin_config(&plugin_config(json!({ - "mode": "remote", - "codec": "openai_chat", - "remote": {"endpoint": "http://localhost:8000", "config_id": "default"}, - "bogus": true - }))); - assert!( - warn_report - .diagnostics - .iter() - .any(|diag| diag.code == "nemo_guardrails.unknown_field") - ); - - let nested_warn_report = validate_plugin_config(&plugin_config(json!({ - "mode": "remote", - "codec": "openai_chat", - "remote": {"endpoint": "http://localhost:8000", "config_id": "default"}, - "request_defaults": { - "rails": { - "bogus": true - } - } - }))); - assert!( - nested_warn_report - .diagnostics - .iter() - .any(|diag| diag.component.as_deref() == Some("request_defaults.rails")) - ); - - let ignored = validate_plugin_config(&plugin_config(json!({ - "policy": {"unknown_field": "ignore", "unsupported_value": "ignore"}, - "mode": "remote", - "codec": "openai_chat", - "remote": {"endpoint": "http://localhost:8000", "config_id": "default"}, - "bogus": true - }))); - assert!(!ignored.has_errors()); - assert!(ignored.diagnostics.is_empty()); -} - -#[test] -fn enabled_initialization_fails_fast_until_backend_exists() { - let _guard = crate::plugins::nemo_guardrails::test_mutex() - .lock() - .unwrap_or_else(|err| err.into_inner()); - reset_runtime(); - ensure_registered(); - - let error = - futures::executor::block_on(initialize_plugins(plugin_config(remote_valid_config()))) - .unwrap_err(); - - match error { - crate::plugin::PluginError::RegistrationFailed(message) => { - assert!(message.contains("not implemented yet")); - } - other => panic!("unexpected error: {other}"), - } -}