Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion crates/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,12 @@ readme = "README.md"
workspace = true

[features]
default = ["otel", "openinference"]
default = ["otel", "openinference", "guardrails-remote"]
schema = ["dep:schemars"]
guardrails-remote = [
"dep:reqwest",
"dep:rustls",
]
otel = [
"dep:async-trait",
"dep:getrandom",
Expand Down
8 changes: 5 additions & 3 deletions crates/core/src/plugin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -762,9 +762,11 @@ pub fn register_plugin(plugin: Arc<dyn Plugin>) -> Result<()> {
/// Built-in plugins are available to validation and initialization without a
/// binding or application-specific registration call.
pub fn ensure_builtin_plugins_registered() -> Result<()> {
match BUILTIN_PLUGIN_REGISTRATION
.get_or_init(crate::observability::plugin_component::register_observability_component)
{
let register_builtins = || {
crate::observability::plugin_component::register_observability_component()?;
crate::plugins::nemo_guardrails::component::register_nemo_guardrails_component()
};
match BUILTIN_PLUGIN_REGISTRATION.get_or_init(register_builtins) {
Ok(()) => Ok(()),
Err(err) => Err(clone_cached_plugin_error(err)),
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,25 @@ use crate::plugin::{
lookup_plugin, register_plugin,
};

#[cfg(all(feature = "guardrails-remote", not(target_arch = "wasm32")))]
#[path = "remote.rs"]
mod remote;
Comment on lines +20 to +22
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, since you have this, I don't think you need all of the cfg entries in remote.rs

#[cfg(all(feature = "guardrails-remote", not(target_arch = "wasm32")))]
use remote::register_remote_backend;
Comment thread
coderabbitai[bot] marked this conversation as resolved.

/// The plugin kind reserved for the planned first-party component.
pub const NEMO_GUARDRAILS_PLUGIN_KIND: &str = "nemo_guardrails";

#[cfg(any(target_arch = "wasm32", not(feature = "guardrails-remote")))]
fn register_remote_backend(
_config: NeMoGuardrailsConfig,
_ctx: &mut PluginRegistrationContext,
) -> PluginResult<()> {
Err(PluginError::RegistrationFailed(
"built-in NeMo Guardrails remote backend is unavailable in this build".to_string(),
))
}

/// Top-level NeMo Guardrails component wrapper.
#[derive(Debug, Clone)]
pub struct ComponentSpec {
Expand Down Expand Up @@ -182,6 +198,12 @@ pub struct RequestDefaultsConfig {
/// Default context object passed into Guardrails requests.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub context: Option<Json>,
/// Default remote thread identifier for continuation-aware requests.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub thread_id: Option<String>,
/// Default remote Guardrails state payload for continuation-aware requests.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub state: Option<Json>,
/// Default request-time rail selection.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub rails: Option<RequestRailsConfig>,
Expand Down Expand Up @@ -307,6 +329,8 @@ crate::editor_config! {
crate::editor_config! {
impl RequestDefaultsConfig {
context => { label: "context", kind: Json, optional: true },
thread_id => { label: "thread_id", kind: String, optional: true },
state => { label: "state", kind: Json, optional: true },
rails => {
label: "rails",
kind: Section,
Expand Down Expand Up @@ -349,13 +373,13 @@ impl Plugin for NeMoGuardrailsPlugin {

fn register<'a>(
&'a self,
_plugin_config: &Map<String, Json>,
_ctx: &'a mut PluginRegistrationContext,
plugin_config: &Map<String, Json>,
ctx: &'a mut PluginRegistrationContext,
) -> Pin<Box<dyn Future<Output = PluginResult<()>> + Send + 'a>> {
Box::pin(async {
Err(PluginError::RegistrationFailed(
"built-in NeMo Guardrails plugin backend is not implemented yet".to_string(),
))
let parsed = parse_nemo_guardrails_config(plugin_config);
Box::pin(async move {
let config = parsed?;
register_nemo_guardrails_backend(config, ctx)
})
}
}
Expand Down Expand Up @@ -419,6 +443,21 @@ fn string_enum_schema(
schema.into()
}

fn register_nemo_guardrails_backend(
config: NeMoGuardrailsConfig,
ctx: &mut PluginRegistrationContext,
) -> PluginResult<()> {
match config.mode.as_str() {
"remote" => register_remote_backend(config, ctx),
"local" => Err(PluginError::RegistrationFailed(
"built-in NeMo Guardrails local backend is not implemented yet".to_string(),
)),
other => Err(PluginError::InvalidConfig(format!(
"unsupported NeMo Guardrails mode '{other}'"
))),
}
}

fn parse_nemo_guardrails_config(
plugin_config: &Map<String, Json>,
) -> PluginResult<NeMoGuardrailsConfig> {
Expand Down Expand Up @@ -497,6 +536,8 @@ fn validate_nemo_guardrails_plugin_config(
"request_defaults",
&[
"context",
"thread_id",
"state",
"rails",
"llm_params",
"llm_output",
Expand Down Expand Up @@ -526,6 +567,7 @@ fn validate_nemo_guardrails_plugin_config(
validate_config_shape(&mut diagnostics, &config.policy, &config);
validate_codec_requirements(&mut diagnostics, &config.policy, &config);
validate_surface_selection(&mut diagnostics, &config.policy, &config);
validate_remote_backend_support(&mut diagnostics, &config.policy, &config);
validate_request_defaults(&mut diagnostics, &config.policy, &config);

diagnostics
Expand Down Expand Up @@ -869,6 +911,32 @@ fn validate_surface_selection(
);
}

fn validate_remote_backend_support(
diagnostics: &mut Vec<ConfigDiagnostic>,
policy: &ConfigPolicy,
config: &NeMoGuardrailsConfig,
) {
if config.mode != "remote" {
return;
}

if (config.input || config.output)
&& config
.codec
.as_deref()
.is_some_and(|codec| codec != "openai_chat")
{
push_policy_diag(
diagnostics,
policy.unsupported_value,
"nemo_guardrails.unsupported_value",
Some(NEMO_GUARDRAILS_PLUGIN_KIND.to_string()),
Some("codec".to_string()),
"remote mode currently supports only codec = 'openai_chat'".to_string(),
);
}
}

fn validate_request_defaults(
diagnostics: &mut Vec<ConfigDiagnostic>,
policy: &ConfigPolicy,
Expand All @@ -885,6 +953,54 @@ fn validate_request_defaults(
"request_defaults.context",
"request_defaults.context must be a JSON object",
);
if let Some(thread_id) = &request_defaults.thread_id {
let trimmed_thread_id = thread_id.trim();
if trimmed_thread_id.is_empty() {
push_policy_diag(
diagnostics,
policy.unsupported_value,
"nemo_guardrails.unsupported_value",
Some(NEMO_GUARDRAILS_PLUGIN_KIND.to_string()),
Some("request_defaults.thread_id".to_string()),
"request_defaults.thread_id must not be empty".to_string(),
);
} else if trimmed_thread_id.len() < 16 {
push_policy_diag(
diagnostics,
policy.unsupported_value,
"nemo_guardrails.unsupported_value",
Some(NEMO_GUARDRAILS_PLUGIN_KIND.to_string()),
Some("request_defaults.thread_id".to_string()),
"request_defaults.thread_id must be at least 16 characters long".to_string(),
);
}
}
validate_json_object_field(
diagnostics,
policy,
request_defaults.state.as_ref(),
"request_defaults.state",
"request_defaults.state must be a JSON object",
);
if let Some(state) = request_defaults
.state
.as_ref()
.and_then(|value| value.as_object())
{
let contains_supported_key = state.contains_key("events") || state.contains_key("state");
let contains_unsupported_key = state.keys().any(|key| key != "events" && key != "state");
if (!state.is_empty() && !contains_supported_key) || contains_unsupported_key {
push_policy_diag(
diagnostics,
policy.unsupported_value,
"nemo_guardrails.unsupported_value",
Some(NEMO_GUARDRAILS_PLUGIN_KIND.to_string()),
Some("request_defaults.state".to_string()),
"request_defaults.state must be empty or contain only 'events' or 'state'"
.to_string(),
);
}
}
validate_json_object_field(
diagnostics,
policy,
Expand Down Expand Up @@ -1138,5 +1254,5 @@ fn default_timeout_millis() -> u64 {
}

#[cfg(test)]
#[path = "../../../tests/unit/plugins/nemo_guardrails/plugin_component_tests.rs"]
#[path = "../../../tests/unit/plugins/nemo_guardrails/component_tests.rs"]
mod tests;
2 changes: 1 addition & 1 deletion crates/core/src/plugins/nemo_guardrails/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ pub(crate) fn test_mutex() -> &'static Mutex<()> {
crate::shared_runtime::runtime_owner_test_mutex()
}

pub mod plugin_component;
pub mod component;
Loading
Loading