From 7fda693566db63d89d653bb9878a2c4b9c0c4d8e Mon Sep 17 00:00:00 2001 From: knhn1004 <49494541+knhn1004@users.noreply.github.com> Date: Fri, 8 May 2026 17:54:58 -0700 Subject: [PATCH] Add external guardrails catalog and runtime controls --- README.md | 4 + cli/src/tui/dashboard.tsx | 207 +++++++- cli/src/util/api.ts | 117 ++++ cli/tests/dashboard-api.test.ts | 2 +- control-plane/README.md | 4 + control-plane/cmd/control-plane/main.go | 37 ++ control-plane/cmd/control-plane/main_test.go | 57 ++ control-plane/dashboard-ui/package.json | 1 + .../src/components/GuardrailTraceFlow.tsx | 50 ++ control-plane/dashboard-ui/src/index.css | 12 + control-plane/dashboard-ui/src/lib/api.ts | 33 +- control-plane/dashboard-ui/src/lib/types.ts | 48 ++ .../dashboard-ui/src/routeTree.gen.ts | 31 +- .../dashboard-ui/src/routes/__root.tsx | 1 + .../dashboard-ui/src/routes/events.tsx | 28 +- .../dashboard-ui/src/routes/guardrails.tsx | 178 +++++++ control-plane/docker-compose.yml | 4 + control-plane/internal/api/gate.go | 37 ++ control-plane/internal/api/gate_test.go | 112 ++++ control-plane/internal/api/guardrails_http.go | 141 +++++ .../internal/api/guardrails_http_test.go | 276 ++++++++++ control-plane/internal/api/router.go | 11 + control-plane/internal/api/types.go | 44 +- .../internal/guardrails/providers_nvidia.go | 248 +++++++++ .../guardrails/providers_openrouter.go | 130 +++++ control-plane/internal/guardrails/service.go | 498 ++++++++++++++++++ .../internal/guardrails/service_test.go | 444 ++++++++++++++++ control-plane/internal/guardrails/types.go | 41 ++ control-plane/internal/storage/memory.go | 104 +++- control-plane/internal/storage/memory_test.go | 76 +++ control-plane/internal/storage/storage.go | 17 +- docker-compose.yml | 7 + docs/architecture/llm-guardrails.md | 46 +- docs/guide/dashboard.md | 10 +- docs/guide/getting-started.md | 7 + docs/guide/installation.md | 12 + docs/index.md | 3 + 37 files changed, 3034 insertions(+), 44 deletions(-) create mode 100644 control-plane/cmd/control-plane/main_test.go create mode 100644 control-plane/dashboard-ui/src/components/GuardrailTraceFlow.tsx create mode 100644 control-plane/dashboard-ui/src/routes/guardrails.tsx create mode 100644 control-plane/internal/api/guardrails_http.go create mode 100644 control-plane/internal/api/guardrails_http_test.go create mode 100644 control-plane/internal/guardrails/providers_nvidia.go create mode 100644 control-plane/internal/guardrails/providers_openrouter.go create mode 100644 control-plane/internal/guardrails/service.go create mode 100644 control-plane/internal/guardrails/service_test.go create mode 100644 control-plane/internal/guardrails/types.go diff --git a/README.md b/README.md index e5d174b..8e7b44a 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,8 @@ docker run -d --name agentlock \ -p 127.0.0.1:7878:7878 \ -p 127.0.0.1:7879:7879 \ -v agentlock-state:/var/lib/agentlock \ + -e NVIDIA_API_KEY \ + -e OPENROUTER_API_KEY \ ghcr.io/openagentlock/agentlockd:latest # 2. Install the CLI @@ -46,6 +48,8 @@ agentlock install --tier totp --code 123456 --passphrase 'your-passphrase-here' For a quick eval without a signer (dev only): start the daemon with `-e AGENTLOCK_ALLOW_UNATTESTED=1`, then `agentlock install` (defaults to unattested). +Optional external guardrails are enabled by starting the daemon with `NVIDIA_API_KEY` and/or `OPENROUTER_API_KEY`; keys are held in control-plane memory only. In the current shipped slice, NVIDIA provides post-local-allow runtime classification, while OpenRouter is catalog visibility only. + Open the local web dashboard at , or run `agentlock dashboard` for a terminal TUI with the same live ledger tail, sessions, loaded gates, and a one-key monitor⇄enforce flip.
diff --git a/cli/src/tui/dashboard.tsx b/cli/src/tui/dashboard.tsx index 7fa1c6a..436ace9 100644 --- a/cli/src/tui/dashboard.tsx +++ b/cli/src/tui/dashboard.tsx @@ -25,6 +25,10 @@ import { useEffect, useRef, useState } from "react"; import type { ApiClient, FalsePositiveCaseResponse, + GuardrailCatalogResponse, + GuardrailEnabledEntry, + GuardrailEnabledResponse, + GuardrailProvidersResponse, InsightWindow, LedgerInsightsResponse, LedgerRootResponse, @@ -57,11 +61,12 @@ interface LedgerEntry { prev_leaf?: string; } -type TabName = "stats" | "events" | "sessions" | "gates" | "mode"; +type TabName = "stats" | "events" | "guardrails" | "sessions" | "gates" | "mode"; const TABS: { name: string; description: string; value: TabName }[] = [ { name: "Stats", description: "Operational insights", value: "stats" }, { name: "Events", description: "Live ledger tail", value: "events" }, + { name: "Guardrails", description: "External guardrails", value: "guardrails" }, { name: "Sessions", description: "Who's connected", value: "sessions" }, { name: "Gates", description: "Loaded policy gates", value: "gates" }, { name: "Mode", description: "Firewall / monitor", value: "mode" }, @@ -123,6 +128,7 @@ function Dashboard({ api, onQuit }: DashboardProps): React.ReactNode { const [cursor, setCursor] = useState>({ stats: 0, events: 0, + guardrails: 0, sessions: 0, gates: 0, mode: 0, @@ -130,6 +136,7 @@ function Dashboard({ api, onQuit }: DashboardProps): React.ReactNode { const [scroll, setScroll] = useState>({ stats: 0, events: 0, + guardrails: 0, sessions: 0, gates: 0, mode: 0, @@ -138,6 +145,12 @@ function Dashboard({ api, onQuit }: DashboardProps): React.ReactNode { const [mode, setMode] = useState(null); const [sessions, setSessions] = useState(null); const [policy, setPolicy] = useState(null); + const [guardrailProviders, setGuardrailProviders] = + useState(null); + const [guardrailCatalog, setGuardrailCatalog] = + useState(null); + const [guardrailEnabled, setGuardrailEnabled] = + useState(null); const [events, setEvents] = useState([]); const [insights, setInsights] = useState(null); const [statsWindow, setStatsWindow] = useState("24h"); @@ -176,6 +189,9 @@ function Dashboard({ api, onQuit }: DashboardProps): React.ReactNode { api.getMode().then(setMode).catch(() => {}); api.listSessions().then(setSessions).catch(() => {}); api.policyView().then(setPolicy).catch(() => {}); + api.guardrailProviders().then(setGuardrailProviders).catch(() => {}); + api.guardrailCatalog().then(setGuardrailCatalog).catch(() => {}); + api.guardrailEnabled().then(setGuardrailEnabled).catch(() => {}); api.ledgerRoot().then(setLedgerRoot).catch(() => {}); }, 2000); @@ -243,6 +259,10 @@ function Dashboard({ api, onQuit }: DashboardProps): React.ReactNode { pendingDeleteRef.current = pendingDelete; const statsWindowRef = useRef(statsWindow); statsWindowRef.current = statsWindow; + const guardrailCatalogRef = useRef(guardrailCatalog); + guardrailCatalogRef.current = guardrailCatalog; + const guardrailEnabledRef = useRef(guardrailEnabled); + guardrailEnabledRef.current = guardrailEnabled; // Filtered/visible derivations used by the keyboard handler when // computing what's "under the cursor." @@ -274,10 +294,15 @@ function Dashboard({ api, onQuit }: DashboardProps): React.ReactNode { ); } + function visibleGuardrails() { + return guardrailCatalogRef.current?.entries ?? []; + } + function moveCursor(delta: number): void { const t = tabRef.current; let max = 0; if (t === "events") max = filteredEvents().length; + else if (t === "guardrails") max = visibleGuardrails().length; else if (t === "sessions") max = visibleSessions().length; else if (t === "gates") max = policyRef.current?.gates.length ?? 0; if (max === 0) return; @@ -333,6 +358,55 @@ function Dashboard({ api, onQuit }: DashboardProps): React.ReactNode { .catch(() => {}); } + function refreshGuardrails(): void { + api.guardrailProviders() + .then((r) => flushSync(() => setGuardrailProviders(r))) + .catch(() => {}); + api.guardrailCatalog() + .then((r) => flushSync(() => setGuardrailCatalog(r))) + .catch(() => {}); + api.guardrailEnabled() + .then((r) => flushSync(() => setGuardrailEnabled(r))) + .catch(() => {}); + } + + function isGuardrailEnabled(entry: { provider_id: string; entry_id: string }): boolean { + return (guardrailEnabledRef.current?.entries ?? []).some( + (item) => + item.provider_id === entry.provider_id && item.entry_id === entry.entry_id, + ); + } + + function toggleGuardrail(entry: GuardrailEnabledEntry & { supports_runtime_enforcement?: boolean; name?: string }): void { + if (!entry.supports_runtime_enforcement) { + flashToast(`${entry.name ?? entry.entry_id} is catalog-only`); + return; + } + const current = guardrailEnabledRef.current?.entries ?? []; + const enabled = current.some( + (item) => + item.provider_id === entry.provider_id && item.entry_id === entry.entry_id, + ); + const next = enabled + ? current.filter( + (item) => + !(item.provider_id === entry.provider_id && item.entry_id === entry.entry_id), + ) + : [...current, { provider_id: entry.provider_id, entry_id: entry.entry_id }]; + flushSync(() => setGuardrailEnabled({ entries: next })); + api.saveGuardrailEnabled(next) + .then((saved) => { + flushSync(() => setGuardrailEnabled(saved)); + flashToast( + `${entry.name ?? entry.entry_id} ${enabled ? "disabled" : "enabled"}`, + ); + }) + .catch((err) => { + flushSync(() => setGuardrailEnabled({ entries: current })); + flashToast(`guardrail toggle failed: ${truncate(err.message, 80)}`); + }); + } + async function runFalsePositiveFlow(entry: LedgerEntry): Promise { try { const c = await api.falsePositiveCase(entry.seq, false); @@ -513,6 +587,7 @@ function Dashboard({ api, onQuit }: DashboardProps): React.ReactNode { api.getMode().then((r) => flushSync(() => setMode(r))).catch(() => {}); api.listSessions().then((r) => flushSync(() => setSessions(r))).catch(() => {}); api.policyView().then((r) => flushSync(() => setPolicy(r))).catch(() => {}); + refreshGuardrails(); api.ledgerRoot().then((r) => flushSync(() => setLedgerRoot(r))).catch(() => {}); api.ledgerInsights(statsWindowRef.current) .then((r) => flushSync(() => setInsights(r))) @@ -597,6 +672,21 @@ function Dashboard({ api, onQuit }: DashboardProps): React.ReactNode { return; } + if (t === "guardrails") { + const entries = visibleGuardrails(); + const cur = cursorRef.current.guardrails; + const sel = entries[cur]; + if (name === "space" || e.sequence === " ") { + if (!sel) { + flashToast("no guardrail selected"); + return; + } + toggleGuardrail(sel); + return; + } + return; + } + if (t === "sessions") { if (name === "i") { flushSync(() => { @@ -765,6 +855,15 @@ function Dashboard({ api, onQuit }: DashboardProps): React.ReactNode { filterField={filterField} filterBuffer={filterBuffer} /> + ) : tab === "guardrails" ? ( + ) : tab === "sessions" ? ( boolean; +}): React.ReactNode { + if (!providers) { + return loading guardrails...; + } + const entries = catalog?.entries ?? []; + const providerErrors = catalog?.provider_errors ?? []; + const enabledCount = enabled?.entries.length ?? 0; + const rows = entries.slice(scroll, scroll + VISIBLE_ROWS); + return ( + + + {providers.providers.length === 0 ? ( + no providers registered + ) : ( + providers.providers.map((p) => ( + + + + {p.name.padEnd(14, " ")} + + + {(p.configured ? "configured" : "not configured").padEnd(16, " ")} + + {p.capabilities.join(" / ")} + + + {p.id === "nvidia" + ? "runtime classifier after local allow" + : p.id === "openrouter" + ? "catalog only in this slice; no runtime classifier yet" + : "provider behavior depends on runtime support"} + + + )) + )} + {providerErrors.length > 0 ? ( + <> + + {providerErrors.map((item) => ( + + {item.provider_id.padEnd(14, " ")} + {truncate(item.detail, 96)} + + ))} + + ) : null} + + + {`${enabledCount} runtime guardrail${enabledCount === 1 ? "" : "s"} enabled`} + + {entries.length === 0 ? ( + no catalog entries + ) : ( + <> + + + {"ON".padEnd(4, " ")} + {"PROVIDER".padEnd(14, " ")} + {"KIND".padEnd(18, " ")} + {"RUNTIME".padEnd(10, " ")} + {"ENTRY"} + + + {rows.map((entry, index) => { + const selected = scroll + index === cursor; + const marker = selected ? "▌" : " "; + const active = isEnabled(entry); + return ( + + {`${marker}${active ? "[x]" : "[ ]"}`.padEnd(4, " ")} + {entry.provider_id.padEnd(14, " ")} + + {(entry.kind === "classifier_model" ? "classifier" : "policy").padEnd(18, " ")} + + + {(entry.supports_runtime_enforcement ? "yes" : "no").padEnd(10, " ")} + + {truncate(entry.name || entry.entry_id, 56)} + + )})} + {entries.length > VISIBLE_ROWS ? ( + + {`rows ${scroll + 1}-${Math.min(entries.length, scroll + VISIBLE_ROWS)} of ${entries.length}`} + + ) : null} + + )} + + ); +} + const INTERNAL_SOURCES = new Set([ "internal", "agentlock", @@ -1763,6 +1967,7 @@ function Footer({ toast, tab }: { toast: string; tab: TabName }): React.ReactNod const tabHelp: Record = { stats: "(window keybinds shown next to each button above)", events: "enter detail f filter c clear i internal o outcomes H hashes", + guardrails: "space toggle runtime guardrail r refresh startup env on control plane", sessions: "i toggle internal harnesses", gates: "enter detail a add e edit space toggle M cycle-mode x x delete", mode: "(read-only — m on any tab flips mode)", diff --git a/cli/src/util/api.ts b/cli/src/util/api.ts index eac5ebe..5a9ffe3 100644 --- a/cli/src/util/api.ts +++ b/cli/src/util/api.ts @@ -51,6 +51,11 @@ export interface ApiClient { falsePositiveCase(seq: number, includeRaw?: boolean): Promise; falsePositiveValidate(req: FalsePositiveValidateRequest): Promise; falsePositiveApply(req: FalsePositiveApplyRequest): Promise; + guardrailProviders(): Promise; + guardrailCatalog(): Promise; + guardrailEnabled(): Promise; + saveGuardrailEnabled(entries: GuardrailEnabledEntry[]): Promise; + guardrailTrace(seq: number): Promise; } export type InsightWindow = "1h" | "24h" | "7d" | "all"; @@ -228,6 +233,61 @@ export interface FalsePositiveApplyResponse { needs_reload: boolean; } +export interface GuardrailProviderView { + id: string; + name: string; + status: string; + capabilities: string[]; + configured: boolean; +} + +export interface GuardrailCatalogEntry { + provider_id: string; + entry_id: string; + name: string; + kind: "classifier_model" | "account_policy"; + description?: string; + supports_runtime_enforcement: boolean; + metadata?: Record; +} + +export interface GuardrailEnabledEntry { + provider_id: string; + entry_id: string; +} + +export interface GuardrailProvidersResponse { + providers: GuardrailProviderView[]; +} + +export interface GuardrailCatalogResponse { + entries: GuardrailCatalogEntry[]; + provider_errors?: Array<{ + provider_id: string; + detail: string; + }>; +} + +export interface GuardrailEnabledResponse { + entries: GuardrailEnabledEntry[]; +} + +export interface GuardrailTraceResponse { + ledger_seq: number; + trace: { + local_policy_verdict: string; + guardrail_verdict: string; + final_verdict: string; + stages: Array<{ + provider_id: string; + entry_id: string; + verdict: string; + latency_ms: number; + details?: Record; + }>; + }; +} + export interface AuthModeResponse { mode: "none" | "password" | "oidc" | "ldap"; users_configured: boolean; @@ -799,6 +859,63 @@ export function apiClient(baseUrl?: string, initialToken?: string | null): ApiCl } return (await res.json()) as FalsePositiveApplyResponse; }, + + async guardrailProviders(): Promise { + const res = await fetch(`${url}/v1/guardrails/providers`, { + headers: authHeaders(), + }); + if (!res.ok) { + const body = await res.text().catch(() => ""); + throw new Error(`guardrails.providers: ${res.status} ${res.statusText} ${body}`); + } + return (await res.json()) as GuardrailProvidersResponse; + }, + + async guardrailCatalog(): Promise { + const res = await fetch(`${url}/v1/guardrails/catalog`, { + headers: authHeaders(), + }); + if (!res.ok) { + const body = await res.text().catch(() => ""); + throw new Error(`guardrails.catalog: ${res.status} ${res.statusText} ${body}`); + } + return (await res.json()) as GuardrailCatalogResponse; + }, + + async guardrailEnabled(): Promise { + const res = await fetch(`${url}/v1/guardrails/enabled`, { + headers: authHeaders(), + }); + if (!res.ok) { + const body = await res.text().catch(() => ""); + throw new Error(`guardrails.enabled: ${res.status} ${res.statusText} ${body}`); + } + return (await res.json()) as GuardrailEnabledResponse; + }, + + async saveGuardrailEnabled(entries: GuardrailEnabledEntry[]): Promise { + const res = await fetch(`${url}/v1/guardrails/enabled`, { + method: "PUT", + headers: { "content-type": "application/json", ...authHeaders() }, + body: JSON.stringify({ entries }), + }); + if (!res.ok) { + const body = await res.text().catch(() => ""); + throw new Error(`guardrails.enabled.save: ${res.status} ${res.statusText} ${body}`); + } + return (await res.json()) as GuardrailEnabledResponse; + }, + + async guardrailTrace(seq: number): Promise { + const res = await fetch(`${url}/v1/guardrails/traces/${encodeURIComponent(String(seq))}`, { + headers: authHeaders(), + }); + if (!res.ok) { + const body = await res.text().catch(() => ""); + throw new Error(`guardrails.trace: ${res.status} ${res.statusText} ${body}`); + } + return (await res.json()) as GuardrailTraceResponse; + }, }; return client; diff --git a/cli/tests/dashboard-api.test.ts b/cli/tests/dashboard-api.test.ts index 65b10ea..3c210db 100644 --- a/cli/tests/dashboard-api.test.ts +++ b/cli/tests/dashboard-api.test.ts @@ -41,7 +41,7 @@ function startMock(opts: MockOpts, recorded: Recorded[]): { url: string; stop: ( async fetch(req) { const u = new URL(req.url); const body = - req.method === "POST" || req.method === "PATCH" + req.method === "POST" || req.method === "PATCH" || req.method === "PUT" ? await req.json().catch(() => null) : null; recorded.push({ method: req.method, path: u.pathname, body }); diff --git a/control-plane/README.md b/control-plane/README.md index 2f1bde4..8876f33 100644 --- a/control-plane/README.md +++ b/control-plane/README.md @@ -10,6 +10,8 @@ docker run -d --name agentlock \ -v agentlock-state:/var/lib/agentlock \ -p 127.0.0.1:7878:7878 \ -p 127.0.0.1:7879:7879 \ + -e NVIDIA_API_KEY \ + -e OPENROUTER_API_KEY \ ghcr.io/openagentlock/agentlockd:latest ``` @@ -22,6 +24,8 @@ curl -O https://raw.githubusercontent.com/openagentlock/openagentlock/main/docke docker compose up -d ``` +External guardrail provider keys are optional. Set `NVIDIA_API_KEY` and/or `OPENROUTER_API_KEY` before starting Docker; the daemon reads them once into memory and does not persist them. + ## Endpoints The control plane exposes a versioned HTTP API. The contract lives in `api/openapi.yaml`; the live service mirrors it at `/v1/health`, `/v1/gates`, `/v1/install/plan`, `/v1/install/apply`, `/v1/uninstall`, `/v1/mode`, `/v1/mcp/pin`, `/v1/sessions`, `/v1/ledger/root`, `/v1/ledger/proof/:seq`, `/v1/ledger/verify`, plus harness hook endpoints under `/v1/hooks/...`. diff --git a/control-plane/cmd/control-plane/main.go b/control-plane/cmd/control-plane/main.go index 5d545cf..c13a51f 100644 --- a/control-plane/cmd/control-plane/main.go +++ b/control-plane/cmd/control-plane/main.go @@ -16,12 +16,14 @@ import ( "os" "os/signal" "path/filepath" + "strings" "syscall" "time" "github.com/openagentlock/openagentlock/control-plane/internal/api" "github.com/openagentlock/openagentlock/control-plane/internal/auth" "github.com/openagentlock/openagentlock/control-plane/internal/dashboard" + "github.com/openagentlock/openagentlock/control-plane/internal/guardrails" "github.com/openagentlock/openagentlock/control-plane/internal/policy" "github.com/openagentlock/openagentlock/control-plane/internal/storage" ) @@ -45,6 +47,11 @@ func main() { log.Fatalf("open storage: %v", err) } defer store.Close() + if seeded, err := seedGuardrailProviderKeys(context.Background(), store, os.Getenv); err != nil { + log.Fatalf("seed guardrail provider keys: %v", err) + } else if len(seeded) > 0 { + log.Printf("guardrail provider keys loaded into memory: %s", strings.Join(seeded, ",")) + } pol, err := loadPolicy(os.Getenv("AGENTLOCK_POLICY")) if err != nil { @@ -120,6 +127,36 @@ func main() { log.Printf("control-plane stopped") } +type guardrailKeyStore interface { + SaveGuardrailProviderConfig(context.Context, guardrails.ProviderConfig) error +} + +func seedGuardrailProviderKeys(ctx context.Context, store guardrailKeyStore, getenv func(string) string) ([]string, error) { + envs := []struct { + providerID string + envName string + }{ + {providerID: "nvidia", envName: "NVIDIA_API_KEY"}, + {providerID: "openrouter", envName: "OPENROUTER_API_KEY"}, + } + seeded := make([]string, 0, len(envs)) + for _, item := range envs { + apiKey := strings.TrimSpace(getenv(item.envName)) + if apiKey == "" { + continue + } + if err := store.SaveGuardrailProviderConfig(ctx, guardrails.ProviderConfig{ + ProviderID: item.providerID, + APIKey: guardrails.SecretString(apiKey), + Metadata: map[string]string{"source": "env:" + item.envName}, + }); err != nil { + return seeded, err + } + seeded = append(seeded, item.providerID) + } + return seeded, nil +} + // runHealthProbe is invoked when the binary is run as `agentlockd --health` // (e.g. from the docker HEALTHCHECK). Distroless has no curl/wget, so we // reuse the binary itself as the probe. The listen addr may be 0.0.0.0:port; diff --git a/control-plane/cmd/control-plane/main_test.go b/control-plane/cmd/control-plane/main_test.go new file mode 100644 index 0000000..3192f35 --- /dev/null +++ b/control-plane/cmd/control-plane/main_test.go @@ -0,0 +1,57 @@ +package main + +import ( + "context" + "testing" + + "github.com/openagentlock/openagentlock/control-plane/internal/storage" +) + +func TestSeedGuardrailProviderKeysFromEnv(t *testing.T) { + store, err := storage.NewMemory(t.TempDir()) + if err != nil { + t.Fatalf("NewMemory: %v", err) + } + t.Cleanup(func() { _ = store.Close() }) + + env := map[string]string{ + "NVIDIA_API_KEY": " nvapi-test ", + "OPENROUTER_API_KEY": "sk-or-test", + } + seeded, err := seedGuardrailProviderKeys(context.Background(), store, func(k string) string { + return env[k] + }) + if err != nil { + t.Fatalf("seedGuardrailProviderKeys: %v", err) + } + if len(seeded) != 2 || seeded[0] != "nvidia" || seeded[1] != "openrouter" { + t.Fatalf("seeded = %+v", seeded) + } + + nvidia, ok, err := store.GetGuardrailProviderConfig(context.Background(), "nvidia") + if err != nil || !ok || nvidia.APIKey.Value() != "nvapi-test" { + t.Fatalf("nvidia cfg = %+v ok=%v err=%v", nvidia, ok, err) + } + openrouter, ok, err := store.GetGuardrailProviderConfig(context.Background(), "openrouter") + if err != nil || !ok || openrouter.APIKey.Value() != "sk-or-test" { + t.Fatalf("openrouter cfg = %+v ok=%v err=%v", openrouter, ok, err) + } +} + +func TestSeedGuardrailProviderKeysSkipsEmptyEnv(t *testing.T) { + store, err := storage.NewMemory(t.TempDir()) + if err != nil { + t.Fatalf("NewMemory: %v", err) + } + t.Cleanup(func() { _ = store.Close() }) + + seeded, err := seedGuardrailProviderKeys(context.Background(), store, func(string) string { + return " " + }) + if err != nil { + t.Fatalf("seedGuardrailProviderKeys: %v", err) + } + if len(seeded) != 0 { + t.Fatalf("seeded = %+v", seeded) + } +} diff --git a/control-plane/dashboard-ui/package.json b/control-plane/dashboard-ui/package.json index bb6846a..6a3e940 100644 --- a/control-plane/dashboard-ui/package.json +++ b/control-plane/dashboard-ui/package.json @@ -11,6 +11,7 @@ "typecheck": "tsc --noEmit" }, "dependencies": { + "@xyflow/react": "^12.8.7", "@tanstack/react-router": "^1.130.0", "@tanstack/react-start": "^1.130.0", "react": "^19.2.5", diff --git a/control-plane/dashboard-ui/src/components/GuardrailTraceFlow.tsx b/control-plane/dashboard-ui/src/components/GuardrailTraceFlow.tsx new file mode 100644 index 0000000..11f1057 --- /dev/null +++ b/control-plane/dashboard-ui/src/components/GuardrailTraceFlow.tsx @@ -0,0 +1,50 @@ +import "@xyflow/react/dist/style.css"; +import { Background, Controls, ReactFlow, type Edge, type Node } from "@xyflow/react"; +import type { GuardrailTrace } from "@/lib/types"; + +export function GuardrailTraceFlow({ trace }: { trace: GuardrailTrace }) { + const stages = trace.stages ?? []; + const stageLabel = + stages.length === 0 + ? "not run" + : stages.map((s) => `${s.provider_id}/${s.entry_id}: ${s.verdict}`).join("\n"); + const latency = stages.reduce((sum, stage) => sum + (stage.latency_ms || 0), 0); + const verdictClass = (verdict: string) => + verdict === "deny" + ? "trace-node deny" + : verdict === "abstain" + ? "trace-node abstain" + : "trace-node allow"; + const nodes: Node[] = [ + { + id: "local", + position: { x: 0, y: 20 }, + data: { label: `local\n${trace.local_policy_verdict}` }, + className: verdictClass(trace.local_policy_verdict), + }, + { + id: "guardrail", + position: { x: 220, y: 20 }, + data: { label: `guardrail\n${trace.guardrail_verdict}\n${latency}ms\n${stageLabel}` }, + className: verdictClass(trace.guardrail_verdict), + }, + { + id: "final", + position: { x: 500, y: 20 }, + data: { label: `final\n${trace.final_verdict}` }, + className: verdictClass(trace.final_verdict), + }, + ]; + const edges: Edge[] = [ + { id: "local-guardrail", source: "local", target: "guardrail", animated: true }, + { id: "guardrail-final", source: "guardrail", target: "final", animated: true }, + ]; + return ( +
+ + + + +
+ ); +} diff --git a/control-plane/dashboard-ui/src/index.css b/control-plane/dashboard-ui/src/index.css index 1666b48..bd092b6 100644 --- a/control-plane/dashboard-ui/src/index.css +++ b/control-plane/dashboard-ui/src/index.css @@ -52,4 +52,16 @@ body { .oal-btn-link:focus-visible { @apply outline-none ring-2 ring-accent rounded-sm text-neutral-100; } + .trace-node { + @apply whitespace-pre-line rounded border border-border bg-panel px-3 py-2 text-center text-[11px] text-neutral-100 shadow; + } + .trace-node.allow { + @apply border-allow/50 bg-allow/10; + } + .trace-node.deny { + @apply border-deny/50 bg-deny/10; + } + .trace-node.abstain { + @apply border-monitor/50 bg-monitor/10; + } } diff --git a/control-plane/dashboard-ui/src/lib/api.ts b/control-plane/dashboard-ui/src/lib/api.ts index 985d846..b240b91 100644 --- a/control-plane/dashboard-ui/src/lib/api.ts +++ b/control-plane/dashboard-ui/src/lib/api.ts @@ -51,7 +51,7 @@ export async function apiJSON(path: string, init?: RequestInit): Promise { export async function apiSend( path: string, - method: "POST" | "PATCH" | "DELETE", + method: "POST" | "PATCH" | "PUT" | "DELETE", body?: unknown, ): Promise { const res = await fetch(`${apiBase()}${path}`, { @@ -80,3 +80,34 @@ export async function apiSend( ); } } + +export async function listGuardrailProviders(): Promise<{ + providers: import("./types").GuardrailProviderView[]; +}> { + return apiJSON("/v1/guardrails/providers"); +} + +export async function listGuardrailCatalog(): Promise<{ + entries: import("./types").GuardrailCatalogEntry[]; + provider_errors?: import("./types").GuardrailCatalogProviderError[]; +}> { + return apiJSON("/v1/guardrails/catalog"); +} + +export async function listGuardrailEnabled(): Promise<{ + entries: import("./types").GuardrailEnabledEntry[]; +}> { + return apiJSON("/v1/guardrails/enabled"); +} + +export async function saveGuardrailEnabled( + entries: import("./types").GuardrailEnabledEntry[], +): Promise<{ entries: import("./types").GuardrailEnabledEntry[] }> { + return apiSend("/v1/guardrails/enabled", "PUT", { entries }); +} + +export async function guardrailTrace( + seq: number, +): Promise { + return apiJSON(`/v1/guardrails/traces/${encodeURIComponent(String(seq))}`); +} diff --git a/control-plane/dashboard-ui/src/lib/types.ts b/control-plane/dashboard-ui/src/lib/types.ts index 53766eb..0c27bf1 100644 --- a/control-plane/dashboard-ui/src/lib/types.ts +++ b/control-plane/dashboard-ui/src/lib/types.ts @@ -20,6 +20,54 @@ export interface LedgerEntry { prev_leaf: string; } +export interface GuardrailProviderView { + id: string; + name: string; + status: string; + capabilities: string[]; + configured: boolean; +} + +export interface GuardrailCatalogEntry { + provider_id: string; + entry_id: string; + name: string; + kind: "classifier_model" | "account_policy"; + description?: string; + supports_runtime_enforcement: boolean; + metadata?: Record; +} + +export interface GuardrailCatalogProviderError { + provider_id: string; + detail: string; +} + +export interface GuardrailEnabledEntry { + provider_id: string; + entry_id: string; +} + +export interface GuardrailRuntimeStage { + provider_id: string; + entry_id: string; + verdict: string; + latency_ms: number; + details?: Record; +} + +export interface GuardrailTrace { + local_policy_verdict: string; + guardrail_verdict: string; + final_verdict: string; + stages: GuardrailRuntimeStage[]; +} + +export interface GuardrailTraceResponse { + ledger_seq: number; + trace: GuardrailTrace; +} + export interface PolicyTraceItem { layer?: string; source?: string; diff --git a/control-plane/dashboard-ui/src/routeTree.gen.ts b/control-plane/dashboard-ui/src/routeTree.gen.ts index 2ea17dc..2bca981 100644 --- a/control-plane/dashboard-ui/src/routeTree.gen.ts +++ b/control-plane/dashboard-ui/src/routeTree.gen.ts @@ -12,6 +12,7 @@ import { Route as rootRouteImport } from './routes/__root' import { Route as SessionsRouteImport } from './routes/sessions' import { Route as RulesRouteImport } from './routes/rules' import { Route as McpRouteImport } from './routes/mcp' +import { Route as GuardrailsRouteImport } from './routes/guardrails' import { Route as EventsRouteImport } from './routes/events' import { Route as IndexRouteImport } from './routes/index' @@ -30,6 +31,11 @@ const McpRoute = McpRouteImport.update({ path: '/mcp', getParentRoute: () => rootRouteImport, } as any) +const GuardrailsRoute = GuardrailsRouteImport.update({ + id: '/guardrails', + path: '/guardrails', + getParentRoute: () => rootRouteImport, +} as any) const EventsRoute = EventsRouteImport.update({ id: '/events', path: '/events', @@ -44,6 +50,7 @@ const IndexRoute = IndexRouteImport.update({ export interface FileRoutesByFullPath { '/': typeof IndexRoute '/events': typeof EventsRoute + '/guardrails': typeof GuardrailsRoute '/mcp': typeof McpRoute '/rules': typeof RulesRoute '/sessions': typeof SessionsRoute @@ -51,6 +58,7 @@ export interface FileRoutesByFullPath { export interface FileRoutesByTo { '/': typeof IndexRoute '/events': typeof EventsRoute + '/guardrails': typeof GuardrailsRoute '/mcp': typeof McpRoute '/rules': typeof RulesRoute '/sessions': typeof SessionsRoute @@ -59,21 +67,30 @@ export interface FileRoutesById { __root__: typeof rootRouteImport '/': typeof IndexRoute '/events': typeof EventsRoute + '/guardrails': typeof GuardrailsRoute '/mcp': typeof McpRoute '/rules': typeof RulesRoute '/sessions': typeof SessionsRoute } export interface FileRouteTypes { fileRoutesByFullPath: FileRoutesByFullPath - fullPaths: '/' | '/events' | '/mcp' | '/rules' | '/sessions' + fullPaths: '/' | '/events' | '/guardrails' | '/mcp' | '/rules' | '/sessions' fileRoutesByTo: FileRoutesByTo - to: '/' | '/events' | '/mcp' | '/rules' | '/sessions' - id: '__root__' | '/' | '/events' | '/mcp' | '/rules' | '/sessions' + to: '/' | '/events' | '/guardrails' | '/mcp' | '/rules' | '/sessions' + id: + | '__root__' + | '/' + | '/events' + | '/guardrails' + | '/mcp' + | '/rules' + | '/sessions' fileRoutesById: FileRoutesById } export interface RootRouteChildren { IndexRoute: typeof IndexRoute EventsRoute: typeof EventsRoute + GuardrailsRoute: typeof GuardrailsRoute McpRoute: typeof McpRoute RulesRoute: typeof RulesRoute SessionsRoute: typeof SessionsRoute @@ -102,6 +119,13 @@ declare module '@tanstack/react-router' { preLoaderRoute: typeof McpRouteImport parentRoute: typeof rootRouteImport } + '/guardrails': { + id: '/guardrails' + path: '/guardrails' + fullPath: '/guardrails' + preLoaderRoute: typeof GuardrailsRouteImport + parentRoute: typeof rootRouteImport + } '/events': { id: '/events' path: '/events' @@ -122,6 +146,7 @@ declare module '@tanstack/react-router' { const rootRouteChildren: RootRouteChildren = { IndexRoute: IndexRoute, EventsRoute: EventsRoute, + GuardrailsRoute: GuardrailsRoute, McpRoute: McpRoute, RulesRoute: RulesRoute, SessionsRoute: SessionsRoute, diff --git a/control-plane/dashboard-ui/src/routes/__root.tsx b/control-plane/dashboard-ui/src/routes/__root.tsx index c169441..cd6a811 100644 --- a/control-plane/dashboard-ui/src/routes/__root.tsx +++ b/control-plane/dashboard-ui/src/routes/__root.tsx @@ -15,6 +15,7 @@ function RootLayout() { { to: "/rules", label: "Rules" }, { to: "/sessions", label: "Sessions" }, { to: "/mcp", label: "MCP" }, + { to: "/guardrails", label: "Guardrails" }, ]} />
diff --git a/control-plane/dashboard-ui/src/routes/events.tsx b/control-plane/dashboard-ui/src/routes/events.tsx index f256cb2..84a29e1 100644 --- a/control-plane/dashboard-ui/src/routes/events.tsx +++ b/control-plane/dashboard-ui/src/routes/events.tsx @@ -2,13 +2,15 @@ import type React from "react"; import { useEffect, useMemo, useRef, useState } from "react"; import { createFileRoute } from "@tanstack/react-router"; import { AddRuleModal } from "@/components/AddRuleModal"; +import { GuardrailTraceFlow } from "@/components/GuardrailTraceFlow"; import { useRootInfo } from "@/hooks/usePoll"; import { useSSELedger } from "@/hooks/useSSE"; -import { apiJSON, apiSend } from "@/lib/api"; +import { apiJSON, apiSend, guardrailTrace } from "@/lib/api"; import type { FalsePositiveApplyResult, FalsePositiveCase, FalsePositiveValidation, + GuardrailTrace, LedgerEntry, } from "@/lib/types"; import { INTERNAL_SOURCES } from "@/lib/filters"; @@ -456,6 +458,7 @@ function EventDetail({ const [fpApply, setFpApply] = useState(null); const [fpError, setFpError] = useState(null); const [fpBusy, setFpBusy] = useState(false); + const [trace, setTrace] = useState(null); // Esc-to-close + initial focus on the close button. Focus trap is left // out for now — the dialog has a single interactive element so Tab // would just cycle there anyway. @@ -468,6 +471,21 @@ function EventDetail({ closeBtnRef.current?.focus(); return () => window.removeEventListener("keydown", onKey); }, [onClose]); + useEffect(() => { + if (!entry) return; + let active = true; + setTrace(null); + guardrailTrace(entry.seq) + .then((res) => { + if (active) setTrace(res.trace); + }) + .catch(() => { + if (active) setTrace(null); + }); + return () => { + active = false; + }; + }, [entry]); if (!entry) { return ( @@ -618,6 +636,14 @@ function EventDetail({ wrap /> )} + {trace && ( +
+
+ Guardrail trace +
+ +
+ )} diff --git a/control-plane/dashboard-ui/src/routes/guardrails.tsx b/control-plane/dashboard-ui/src/routes/guardrails.tsx new file mode 100644 index 0000000..be189e6 --- /dev/null +++ b/control-plane/dashboard-ui/src/routes/guardrails.tsx @@ -0,0 +1,178 @@ +import { createFileRoute } from "@tanstack/react-router"; +import { useEffect, useMemo, useRef, useState } from "react"; +import { + listGuardrailCatalog, + listGuardrailEnabled, + listGuardrailProviders, + saveGuardrailEnabled, +} from "@/lib/api"; +import type { + GuardrailCatalogEntry, + GuardrailCatalogProviderError, + GuardrailEnabledEntry, + GuardrailProviderView, +} from "@/lib/types"; + +function GuardrailsTab() { + const [providers, setProviders] = useState([]); + const [catalog, setCatalog] = useState([]); + const [providerErrors, setProviderErrors] = useState([]); + const [enabled, setEnabled] = useState([]); + const [error, setError] = useState(""); + const saveRequestRef = useRef(0); + + async function refresh() { + setError(""); + setProviderErrors([]); + const [providerRes, catalogRes, enabledRes] = await Promise.all([ + listGuardrailProviders(), + listGuardrailCatalog(), + listGuardrailEnabled(), + ]); + setProviders(providerRes.providers); + setCatalog(catalogRes.entries); + setProviderErrors(catalogRes.provider_errors ?? []); + setEnabled(enabledRes.entries); + } + + useEffect(() => { + refresh().catch((err) => setError((err as Error).message)); + }, []); + + const enabledKeys = useMemo( + () => new Set(enabled.map((entry) => `${entry.provider_id}/${entry.entry_id}`)), + [enabled], + ); + + async function toggleEntry(entry: GuardrailCatalogEntry, checked: boolean) { + const key = `${entry.provider_id}/${entry.entry_id}`; + const prev = enabled; + const requestID = saveRequestRef.current + 1; + saveRequestRef.current = requestID; + const next = checked + ? [...enabled, { provider_id: entry.provider_id, entry_id: entry.entry_id }] + : enabled.filter((item) => `${item.provider_id}/${item.entry_id}` !== key); + setEnabled(next); + try { + const saved = await saveGuardrailEnabled(next); + if (saveRequestRef.current === requestID) { + setEnabled(saved.entries); + } + } catch (err) { + if (saveRequestRef.current === requestID) { + setEnabled(prev); + } + setError((err as Error).message); + } + } + + function providerNote(provider: GuardrailProviderView): string { + if (provider.id === "nvidia") { + return "Runtime classifier entries can block after local rules allow."; + } + if (provider.id === "openrouter") { + return "Catalog visibility only in this slice. OpenRouter entries do not run as runtime classifiers yet."; + } + return "Provider-specific behavior depends on runtime enforcement support."; + } + + return ( +
+
+
+
+
External Guardrails
+
Providers are opt-in and run after local rules allow.
+
+ +
+ {error &&
{error}
} + {providerErrors.length > 0 && ( +
+ {providerErrors.map((item) => ( +
+ {item.provider_id}: {item.detail} +
+ ))} +
+ )} +
+ {providers.map((provider) => ( +
+
+
+
{provider.name}
+
{provider.capabilities.join(" / ")}
+
+ {provider.configured ? "configured" : "not configured"} +
+
+ Configure keys when starting the control plane. Example:{" "} + + {provider.id === "nvidia" ? "NVIDIA_API_KEY" : "OPENROUTER_API_KEY"} + =... + +
+
{providerNote(provider)}
+
+ ))} +
+
+ +
+
Catalog
+
+ + + + + + + + + + + + {catalog.length === 0 ? ( + + + + ) : ( + catalog.map((entry) => { + const key = `${entry.provider_id}/${entry.entry_id}`; + return ( + + + + + + + + ); + }) + )} + +
enabledproviderentrykindruntime
+ no catalog entries +
+ toggleEntry(entry, e.target.checked)} + aria-label={`enable ${entry.name || entry.entry_id}`} + /> + + {entry.provider_id} + {entry.name || entry.entry_id}{entry.kind === "classifier_model" ? "classifier" : "policy"}{entry.supports_runtime_enforcement ? "yes" : "no"}
+
+
+
+ ); +} + +export const Route = createFileRoute("/guardrails")({ + component: GuardrailsTab, +}); diff --git a/control-plane/docker-compose.yml b/control-plane/docker-compose.yml index 14b78e5..e7ba207 100644 --- a/control-plane/docker-compose.yml +++ b/control-plane/docker-compose.yml @@ -6,6 +6,10 @@ services: - "127.0.0.1:7878:7878" environment: AGENTLOCK_LISTEN: "0.0.0.0:7878" + # Optional external guardrail providers. Values are read once at + # control-plane startup, kept in daemon RAM, and cleared on restart. + NVIDIA_API_KEY: "${NVIDIA_API_KEY:-}" + OPENROUTER_API_KEY: "${OPENROUTER_API_KEY:-}" volumes: - ../dev/agentlock:/var/lib/agentlock healthcheck: diff --git a/control-plane/internal/api/gate.go b/control-plane/internal/api/gate.go index f3bfef1..93d5b03 100644 --- a/control-plane/internal/api/gate.go +++ b/control-plane/internal/api/gate.go @@ -1,6 +1,7 @@ package api import ( + "context" "crypto/sha256" "encoding/json" "errors" @@ -8,6 +9,8 @@ import ( "net/http" "time" + "github.com/openagentlock/openagentlock/control-plane/internal/guardrails" + "github.com/openagentlock/openagentlock/control-plane/internal/policy" "github.com/openagentlock/openagentlock/control-plane/internal/storage" ) @@ -73,6 +76,15 @@ func gateCheckHandler(d Deps) http.HandlerFunc { var origVerdict string result, _, origVerdict = applyDaemonModeOverride(result) + guardrailTrace, guardrailRuleID := evaluateGuardrailsAfterLocalAllow(r.Context(), d, req.Source, req.Tool, req.Input, origVerdict, result) + if guardrailTrace.FinalVerdict == "deny" { + result.Verdict = "deny" + result.RuleID = guardrailRuleID + result.Reason = "blocked by external guardrail" + result.MonitorMatch = false + result.Nudge = "" + origVerdict = "deny" + } // Ledger entry for the evaluated call. The payload hash is over the // canonical request body so verifiers can re-derive it from the same @@ -85,6 +97,7 @@ func gateCheckHandler(d Deps) http.HandlerFunc { "input": req.Input, "verdict": origVerdict, "rule_id": result.RuleID, + "guardrails": guardrailTrace, }) if err != nil { log.Printf("gate.check: payload marshal: %v", err) @@ -110,6 +123,9 @@ func gateCheckHandler(d Deps) http.HandlerFunc { writeError(w, http.StatusInternalServerError, "ledger_error", err.Error()) return } + if d.Guardrails != nil && guardrailTrace.LocalPolicyVerdict != "" { + d.Guardrails.RecordTrace(entry.Seq, guardrailTrace) + } writeJSON(w, http.StatusOK, gateCheckResponse{ Verdict: result.Verdict, @@ -121,3 +137,24 @@ func gateCheckHandler(d Deps) http.HandlerFunc { }) } } + +func evaluateGuardrailsAfterLocalAllow(ctx context.Context, d Deps, source, tool string, input map[string]any, localVerdict string, result policy.EvalResult) (guardrails.Trace, string) { + if d.Guardrails == nil || result.Verdict != "allow" || localVerdict != "allow" { + return guardrails.Trace{}, "" + } + trace, final := d.Guardrails.EvaluatePostPolicy(ctx, guardrails.EvaluateRequest{ + LocalPolicyVerdict: "allow", + Source: source, + Tool: tool, + Input: input, + }) + if final != "deny" { + return trace, "" + } + for _, stage := range trace.Stages { + if stage.Verdict == "deny" { + return trace, "guardrail:" + stage.ProviderID + "/" + stage.EntryID + } + } + return trace, "guardrail" +} diff --git a/control-plane/internal/api/gate_test.go b/control-plane/internal/api/gate_test.go index c5f9f52..21cc634 100644 --- a/control-plane/internal/api/gate_test.go +++ b/control-plane/internal/api/gate_test.go @@ -15,6 +15,7 @@ import ( "strings" "testing" + "github.com/openagentlock/openagentlock/control-plane/internal/guardrails" "github.com/openagentlock/openagentlock/control-plane/internal/policy" "github.com/openagentlock/openagentlock/control-plane/internal/signer" "github.com/openagentlock/openagentlock/control-plane/internal/storage" @@ -109,6 +110,61 @@ func newGateFixture(t *testing.T, policyYAML string) gateFixture { return gateFixture{srv: srv, store: store, sessionID: id, home: home} } +func newGateFixtureWithGuardrails(t *testing.T, policyYAML string, provider guardrailsHTTPFakeProvider) gateFixture { + t.Helper() + pol, err := policy.LoadBytes([]byte(policyYAML)) + if err != nil { + t.Fatalf("LoadPolicy: %v", err) + } + home := t.TempDir() + store, err := storage.NewMemory(home) + if err != nil { + t.Fatalf("NewMemory: %v", err) + } + t.Cleanup(func() { _ = store.Close() }) + + if err := store.SaveGuardrailProviderConfig(context.Background(), guardrails.ProviderConfig{ProviderID: provider.id}); err != nil { + t.Fatalf("SaveGuardrailProviderConfig: %v", err) + } + enabled := make([]guardrails.EnabledEntry, 0, len(provider.catalog)) + for _, entry := range provider.catalog { + enabled = append(enabled, guardrails.EnabledEntry{ProviderID: entry.ProviderID, EntryID: entry.EntryID}) + } + if _, err := store.SaveGuardrailEnabled(context.Background(), enabled); err != nil { + t.Fatalf("SaveGuardrailEnabled: %v", err) + } + svc := guardrails.NewService(store, provider) + srv := httptest.NewServer(NewRouter(Deps{Store: store, Policy: pol, Guardrails: svc, AgentlockHome: home})) + t.Cleanup(srv.Close) + + pub, priv, _ := ed25519.GenerateKey(nil) + payload := signer.AttestationPayload{ + PolicyHash: pol.Hash, + SessionPubKey: "ed25519:" + hex.EncodeToString(pub), + Signer: "software", + SignerPubKey: "ed25519:" + hex.EncodeToString(pub), + } + canon := signer.CanonicalAttestation(payload) + sig := ed25519.Sign(priv, canon) + body := fmt.Sprintf(`{"policy_hash":%q,"session_pubkey":%q,"signer":"software","signer_pubkey":%q,"attestation":"ed25519:%s"}`, + payload.PolicyHash, payload.SessionPubKey, payload.SignerPubKey, hex.EncodeToString(sig)) + res, err := http.Post(srv.URL+"/v1/sessions", "application/json", strings.NewReader(body)) + if err != nil { + t.Fatalf("POST sessions: %v", err) + } + defer res.Body.Close() + if res.StatusCode != http.StatusCreated { + buf := new(bytes.Buffer) + _, _ = buf.ReadFrom(res.Body) + t.Fatalf("session create: %d %s", res.StatusCode, buf.String()) + } + var sess map[string]any + _ = json.NewDecoder(res.Body).Decode(&sess) + id, _ := sess["id"].(string) + + return gateFixture{srv: srv, store: store, sessionID: id, home: home} +} + func postGateCheck(t *testing.T, srv *httptest.Server, body string) (*http.Response, map[string]any) { t.Helper() res, err := http.Post(srv.URL+"/v1/gates/check", "application/json", strings.NewReader(body)) @@ -144,6 +200,62 @@ func TestGateCheck_AllowsBenignBash(t *testing.T) { } } +func TestGateCheck_LocalAllowGuardrailDenyReturnsDeny(t *testing.T) { + fx := newGateFixtureWithGuardrails(t, enforcePolicyYAML, guardrailsHTTPFakeProvider{ + id: "nvidia", + catalog: []guardrails.CatalogEntry{{ + ProviderID: "nvidia", + EntryID: "nemo-content-safety", + Name: "NeMo Content Safety", + Kind: guardrails.CatalogEntryClassifierModel, + SupportsRuntimeEnforcement: true, + }}, + runtimeResult: guardrails.RuntimeResult{ + ProviderID: "nvidia", + EntryID: "nemo-content-safety", + Verdict: "deny", + LatencyMS: 180, + }, + }) + + body := fmt.Sprintf(`{ + "session_id": %q, + "source": "codex", + "tool": "Bash", + "input": {"command": "cat secrets.txt"} + }`, fx.sessionID) + res, out := postGateCheck(t, fx.srv, body) + if res.StatusCode != http.StatusOK { + t.Fatalf("status = %d", res.StatusCode) + } + if out["verdict"] != "deny" { + t.Fatalf("verdict = %v, want deny", out["verdict"]) + } + if out["rule_id"] != "guardrail:nvidia/nemo-content-safety" { + t.Fatalf("rule_id = %v", out["rule_id"]) + } + seqFloat, ok := out["ledger_seq"].(float64) + if !ok || seqFloat < 1 { + t.Fatalf("ledger_seq = %v", out["ledger_seq"]) + } + seq := uint64(seqFloat) + traceRes, err := http.Get(fmt.Sprintf("%s/v1/guardrails/traces/%d", fx.srv.URL, seq)) + if err != nil { + t.Fatalf("GET trace: %v", err) + } + defer traceRes.Body.Close() + if traceRes.StatusCode != http.StatusOK { + t.Fatalf("trace status = %d", traceRes.StatusCode) + } + var traceBody GuardrailTraceResponse + if err := json.NewDecoder(traceRes.Body).Decode(&traceBody); err != nil { + t.Fatalf("decode trace: %v", err) + } + if traceBody.Trace.LocalPolicyVerdict != "allow" || traceBody.Trace.FinalVerdict != "deny" { + t.Fatalf("trace = %+v", traceBody.Trace) + } +} + func TestGateCheck_DeniesDestructiveBash(t *testing.T) { fx := newGateFixture(t, enforcePolicyYAML) diff --git a/control-plane/internal/api/guardrails_http.go b/control-plane/internal/api/guardrails_http.go new file mode 100644 index 0000000..6db7219 --- /dev/null +++ b/control-plane/internal/api/guardrails_http.go @@ -0,0 +1,141 @@ +package api + +import ( + "encoding/json" + "net/http" + "strconv" + + "github.com/openagentlock/openagentlock/control-plane/internal/guardrails" +) + +type guardrailsEnabledRequest struct { + Entries []guardrails.EnabledEntry `json:"entries"` +} + +func guardrailsProvidersHandler(d Deps) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if d.Guardrails == nil || d.Store == nil { + writeError(w, http.StatusServiceUnavailable, "guardrails_unavailable", "guardrails service is not configured") + return + } + views := make([]GuardrailProviderView, 0, len(d.Guardrails.ProviderIDs())) + for _, id := range d.Guardrails.ProviderIDs() { + _, configured, err := d.Store.GetGuardrailProviderConfig(r.Context(), id) + if err != nil { + writeError(w, http.StatusInternalServerError, "storage_error", err.Error()) + return + } + status := "available" + if configured { + status = "configured" + } + views = append(views, GuardrailProviderView{ + ID: id, + Name: d.Guardrails.ProviderName(id), + Status: status, + Capabilities: d.Guardrails.ProviderCapabilities(id), + Configured: configured, + }) + } + writeJSON(w, http.StatusOK, map[string]any{"providers": views}) + } +} + +func guardrailsProviderTestHandler(d Deps) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if d.Guardrails == nil || d.Store == nil { + writeError(w, http.StatusServiceUnavailable, "guardrails_unavailable", "guardrails service is not configured") + return + } + providerID := routeParam("/v1/guardrails/providers/{id}/test", r.URL.Path, "id") + cfg, ok, err := d.Store.GetGuardrailProviderConfig(r.Context(), providerID) + if err != nil { + writeError(w, http.StatusInternalServerError, "storage_error", err.Error()) + return + } + if !ok { + writeError(w, http.StatusNotFound, "provider_not_configured", providerID) + return + } + if err := d.Guardrails.TestCredentials(r.Context(), cfg); err != nil { + writeError(w, http.StatusBadGateway, "guardrails_provider_test_failed", err.Error()) + return + } + writeJSON(w, http.StatusOK, map[string]any{"ok": true, "provider_id": providerID}) + } +} + +func guardrailsCatalogHandler(d Deps) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if d.Guardrails == nil { + writeError(w, http.StatusServiceUnavailable, "guardrails_unavailable", "guardrails service is not configured") + return + } + entries, providerErrors, err := d.Guardrails.ListCatalogStatus(r.Context()) + if err != nil { + writeError(w, http.StatusBadGateway, "guardrails_catalog_failed", err.Error()) + return + } + writeJSON(w, http.StatusOK, GuardrailCatalogResponse{ + Entries: entries, + ProviderErrors: providerErrors, + }) + } +} + +func guardrailsEnabledGetHandler(d Deps) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if d.Store == nil { + writeError(w, http.StatusServiceUnavailable, "guardrails_unavailable", "guardrails storage is not configured") + return + } + enabled, err := d.Store.ListGuardrailEnabled(r.Context()) + if err != nil { + writeError(w, http.StatusInternalServerError, "storage_error", err.Error()) + return + } + writeJSON(w, http.StatusOK, map[string]any{"entries": enabled}) + } +} + +func guardrailsEnabledHandler(d Deps) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if d.Store == nil { + writeError(w, http.StatusServiceUnavailable, "guardrails_unavailable", "guardrails storage is not configured") + return + } + var body guardrailsEnabledRequest + r.Body = http.MaxBytesReader(w, r.Body, 1<<20) + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + writeError(w, http.StatusBadRequest, "bad_json", err.Error()) + return + } + enabled, err := d.Store.SaveGuardrailEnabled(r.Context(), body.Entries) + if err != nil { + writeError(w, http.StatusInternalServerError, "storage_error", err.Error()) + return + } + writeJSON(w, http.StatusOK, map[string]any{"entries": enabled}) + } +} + +func guardrailsTraceHandler(d Deps) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if d.Guardrails == nil { + writeError(w, http.StatusServiceUnavailable, "guardrails_unavailable", "guardrails service is not configured") + return + } + rawSeq := routeParam("/v1/guardrails/traces/{seq}", r.URL.Path, "seq") + seq, err := strconv.ParseUint(rawSeq, 10, 64) + if err != nil { + writeError(w, http.StatusBadRequest, "bad_seq", "guardrail trace seq must be an unsigned integer") + return + } + trace, ok := d.Guardrails.Trace(seq) + if !ok { + writeError(w, http.StatusNotFound, "trace_not_found", rawSeq) + return + } + writeJSON(w, http.StatusOK, GuardrailTraceResponse{LedgerSeq: seq, Trace: trace}) + } +} diff --git a/control-plane/internal/api/guardrails_http_test.go b/control-plane/internal/api/guardrails_http_test.go new file mode 100644 index 0000000..cca9f27 --- /dev/null +++ b/control-plane/internal/api/guardrails_http_test.go @@ -0,0 +1,276 @@ +package api + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/openagentlock/openagentlock/control-plane/internal/guardrails" + "github.com/openagentlock/openagentlock/control-plane/internal/storage" +) + +type guardrailsHTTPFixture struct { + srv *httptest.Server + home string + store *storage.Memory + svc *guardrails.Service +} + +func newGuardrailsHTTPFixture(t *testing.T) guardrailsHTTPFixture { + t.Helper() + + home := t.TempDir() + store, err := storage.NewMemory(home) + if err != nil { + t.Fatalf("NewMemory: %v", err) + } + t.Cleanup(func() { _ = store.Close() }) + + if err := store.SaveGuardrailProviderConfig(context.Background(), guardrails.ProviderConfig{ + ProviderID: "nvidia", + }); err != nil { + t.Fatalf("SaveGuardrailProviderConfig: %v", err) + } + + svc := guardrails.NewService(store, guardrailsHTTPFakeProvider{ + id: "nvidia", + catalog: []guardrails.CatalogEntry{ + { + EntryID: "nemo-content-safety", + Name: "NeMo Content Safety", + Kind: guardrails.CatalogEntryClassifierModel, + SupportsRuntimeEnforcement: true, + }, + { + EntryID: "topic-control", + Name: "Topic Control", + Kind: guardrails.CatalogEntryClassifierModel, + SupportsRuntimeEnforcement: true, + }, + }, + }) + + svc.RecordTrace(7, guardrails.Trace{ + LocalPolicyVerdict: "allow", + GuardrailVerdict: "deny", + FinalVerdict: "deny", + Stages: []guardrails.RuntimeStage{{ + ProviderID: "nvidia", + EntryID: "nemo-content-safety", + Verdict: "deny", + LatencyMS: 180, + }}, + }) + + srv := httptest.NewServer(NewRouter(Deps{Store: store, Guardrails: svc})) + t.Cleanup(srv.Close) + + return guardrailsHTTPFixture{ + srv: srv, + home: home, + store: store, + svc: svc, + } +} + +type guardrailsHTTPFakeProvider struct { + id string + catalog []guardrails.CatalogEntry + catalogError error + testError error + runtimeResult guardrails.RuntimeResult + runtimeError error +} + +func (p guardrailsHTTPFakeProvider) ID() string { + return p.id +} + +func (p guardrailsHTTPFakeProvider) Name() string { + if p.id == "nvidia" { + return "NVIDIA" + } + return p.id +} + +func (p guardrailsHTTPFakeProvider) Capabilities() []string { + return []string{"catalog"} +} + +func (p guardrailsHTTPFakeProvider) TestCredentials(context.Context, guardrails.ProviderConfig) error { + return p.testError +} + +func (p guardrailsHTTPFakeProvider) ListCatalog(context.Context, guardrails.ProviderConfig) ([]guardrails.CatalogEntry, error) { + if p.catalogError != nil { + return nil, p.catalogError + } + return append([]guardrails.CatalogEntry(nil), p.catalog...), nil +} + +func (p guardrailsHTTPFakeProvider) RunRuntime(context.Context, guardrails.ProviderConfig, guardrails.CatalogEntry, guardrails.EvaluateRequest) (guardrails.RuntimeResult, error) { + if p.runtimeResult.Verdict == "" && p.runtimeError == nil { + return guardrails.RuntimeResult{Verdict: "allow"}, nil + } + return p.runtimeResult, p.runtimeError +} + +func TestGuardrailsCatalogHandler_ReturnsNormalizedEntries(t *testing.T) { + fx := newGuardrailsHTTPFixture(t) + + res, err := http.Get(fx.srv.URL + "/v1/guardrails/catalog") + if err != nil { + t.Fatalf("GET catalog: %v", err) + } + defer res.Body.Close() + + if res.StatusCode != http.StatusOK { + t.Fatalf("status = %d", res.StatusCode) + } + + var body struct { + Entries []guardrails.CatalogEntry `json:"entries"` + } + if err := json.NewDecoder(res.Body).Decode(&body); err != nil { + t.Fatalf("decode: %v", err) + } + if len(body.Entries) != 2 { + t.Fatalf("entries = %d, want 2", len(body.Entries)) + } +} + +func TestGuardrailsCatalogHandler_ReturnsPartialEntriesAndProviderErrors(t *testing.T) { + home := t.TempDir() + store, err := storage.NewMemory(home) + if err != nil { + t.Fatalf("NewMemory: %v", err) + } + t.Cleanup(func() { _ = store.Close() }) + + for _, providerID := range []string{"nvidia", "openrouter"} { + if err := store.SaveGuardrailProviderConfig(context.Background(), guardrails.ProviderConfig{ + ProviderID: providerID, + }); err != nil { + t.Fatalf("SaveGuardrailProviderConfig(%s): %v", providerID, err) + } + } + + svc := guardrails.NewService( + store, + guardrailsHTTPFakeProvider{ + id: "nvidia", + catalog: []guardrails.CatalogEntry{{ + EntryID: "nemo-content-safety", + Name: "NeMo Content Safety", + Kind: guardrails.CatalogEntryClassifierModel, + SupportsRuntimeEnforcement: true, + }}, + }, + guardrailsHTTPFakeProvider{ + id: "openrouter", + catalogError: context.DeadlineExceeded, + }, + ) + + srv := httptest.NewServer(NewRouter(Deps{Store: store, Guardrails: svc})) + t.Cleanup(srv.Close) + + res, err := http.Get(srv.URL + "/v1/guardrails/catalog") + if err != nil { + t.Fatalf("GET catalog: %v", err) + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + t.Fatalf("status = %d", res.StatusCode) + } + + var body GuardrailCatalogResponse + if err := json.NewDecoder(res.Body).Decode(&body); err != nil { + t.Fatalf("decode: %v", err) + } + if len(body.Entries) != 1 || body.Entries[0].ProviderID != "nvidia" { + t.Fatalf("entries = %+v", body.Entries) + } + if len(body.ProviderErrors) != 1 || body.ProviderErrors[0].ProviderID != "openrouter" { + t.Fatalf("provider errors = %+v", body.ProviderErrors) + } +} + +func TestGuardrailsProvidersHandler_ReturnsProviderStatus(t *testing.T) { + fx := newGuardrailsHTTPFixture(t) + + res, err := http.Get(fx.srv.URL + "/v1/guardrails/providers") + if err != nil { + t.Fatalf("GET providers: %v", err) + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + t.Fatalf("status = %d", res.StatusCode) + } + var body struct { + Providers []GuardrailProviderView `json:"providers"` + } + if err := json.NewDecoder(res.Body).Decode(&body); err != nil { + t.Fatalf("decode: %v", err) + } + if len(body.Providers) != 1 || body.Providers[0].ID != "nvidia" || !body.Providers[0].Configured { + t.Fatalf("providers = %+v", body.Providers) + } +} + +func TestGuardrailsEnabledHandler_RoundTripsEnabledEntries(t *testing.T) { + fx := newGuardrailsHTTPFixture(t) + req, err := http.NewRequest(http.MethodPut, fx.srv.URL+"/v1/guardrails/enabled", strings.NewReader(`{ + "entries": [{"provider_id": "nvidia", "entry_id": "nemo-content-safety"}] + }`)) + if err != nil { + t.Fatalf("NewRequest: %v", err) + } + req.Header.Set("Content-Type", "application/json") + res, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("PUT enabled: %v", err) + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + t.Fatalf("status = %d", res.StatusCode) + } + got, err := fx.store.ListGuardrailEnabled(context.Background()) + if err != nil { + t.Fatalf("ListGuardrailEnabled: %v", err) + } + if len(got) != 1 || got[0].ProviderID != "nvidia" || got[0].EntryID != "nemo-content-safety" { + t.Fatalf("enabled = %+v", got) + } + + res, err = http.Get(fx.srv.URL + "/v1/guardrails/enabled") + if err != nil { + t.Fatalf("GET enabled: %v", err) + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + t.Fatalf("GET status = %d", res.StatusCode) + } +} + +func TestGuardrailsTraceHandler_ReturnsRecordedTrace(t *testing.T) { + fx := newGuardrailsHTTPFixture(t) + res, err := http.Get(fx.srv.URL + "/v1/guardrails/traces/7") + if err != nil { + t.Fatalf("GET trace: %v", err) + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + t.Fatalf("status = %d", res.StatusCode) + } + var body GuardrailTraceResponse + if err := json.NewDecoder(res.Body).Decode(&body); err != nil { + t.Fatalf("decode: %v", err) + } + if body.LedgerSeq != 7 || body.Trace.FinalVerdict != "deny" { + t.Fatalf("trace response = %+v", body) + } +} diff --git a/control-plane/internal/api/router.go b/control-plane/internal/api/router.go index 33e8fa7..e32170e 100644 --- a/control-plane/internal/api/router.go +++ b/control-plane/internal/api/router.go @@ -12,6 +12,7 @@ import ( "strings" "github.com/openagentlock/openagentlock/control-plane/internal/auth" + "github.com/openagentlock/openagentlock/control-plane/internal/guardrails" "github.com/openagentlock/openagentlock/control-plane/internal/policy" "github.com/openagentlock/openagentlock/control-plane/internal/storage" ) @@ -27,6 +28,7 @@ type route struct { type Deps struct { Store storage.Storage Policy *policy.Policy + Guardrails *guardrails.Service PinStorePath string AgentlockHome string // PolicyPath is the on-disk YAML the daemon loaded at startup. When @@ -50,6 +52,9 @@ func NewRouter(deps ...Deps) http.Handler { if d.Policy != nil { bootstrapPolicy(d.Policy) } + if d.Guardrails == nil && d.Store != nil { + d.Guardrails = guardrails.NewDefaultService(d.Store) + } routes := []route{ {"GET", "/v1/health", health}, @@ -114,6 +119,12 @@ func NewRouter(deps ...Deps) http.Handler { {"POST", "/v1/false-positives/validate", falsePositiveValidateHandler(d)}, {"POST", "/v1/false-positives/apply", falsePositiveApplyHandler(d)}, {"GET", "/v1/sessions", sessionsListHandler(d)}, + {"GET", "/v1/guardrails/providers", guardrailsProvidersHandler(d)}, + {"POST", "/v1/guardrails/providers/{id}/test", guardrailsProviderTestHandler(d)}, + {"GET", "/v1/guardrails/catalog", guardrailsCatalogHandler(d)}, + {"GET", "/v1/guardrails/enabled", guardrailsEnabledGetHandler(d)}, + {"PUT", "/v1/guardrails/enabled", guardrailsEnabledHandler(d)}, + {"GET", "/v1/guardrails/traces/{seq}", guardrailsTraceHandler(d)}, // Ledger. {"GET", "/v1/ledger/tail", ledgerTailHandler(d)}, diff --git a/control-plane/internal/api/types.go b/control-plane/internal/api/types.go index e375b47..94a1c04 100644 --- a/control-plane/internal/api/types.go +++ b/control-plane/internal/api/types.go @@ -3,7 +3,11 @@ package api -import "time" +import ( + "time" + + "github.com/openagentlock/openagentlock/control-plane/internal/guardrails" +) type Verdict string @@ -31,12 +35,12 @@ type Session struct { } type GateCheckRequest struct { - SessionID string `json:"session_id"` - Source string `json:"source"` // "claude-code" | "mcp-proxy" | ... - Tool string `json:"tool"` - Input map[string]any `json:"input"` - Cwd string `json:"cwd,omitempty"` - Meta map[string]any `json:"meta,omitempty"` + SessionID string `json:"session_id"` + Source string `json:"source"` // "claude-code" | "mcp-proxy" | ... + Tool string `json:"tool"` + Input map[string]any `json:"input"` + Cwd string `json:"cwd,omitempty"` + Meta map[string]any `json:"meta,omitempty"` } type GateCheckResponse struct { @@ -71,8 +75,26 @@ type ApprovalDecision struct { } type LedgerRoot struct { - Root string `json:"root"` - Seq uint64 `json:"seq"` - GenesisPK string `json:"genesis_pubkey"` - ComputedAt time.Time `json:"computed_at"` + Root string `json:"root"` + Seq uint64 `json:"seq"` + GenesisPK string `json:"genesis_pubkey"` + ComputedAt time.Time `json:"computed_at"` +} + +type GuardrailProviderView struct { + ID string `json:"id"` + Name string `json:"name"` + Status string `json:"status"` + Capabilities []string `json:"capabilities"` + Configured bool `json:"configured"` +} + +type GuardrailCatalogResponse struct { + Entries []guardrails.CatalogEntry `json:"entries"` + ProviderErrors []guardrails.CatalogProviderError `json:"provider_errors,omitempty"` +} + +type GuardrailTraceResponse struct { + LedgerSeq uint64 `json:"ledger_seq"` + Trace guardrails.Trace `json:"trace"` } diff --git a/control-plane/internal/guardrails/providers_nvidia.go b/control-plane/internal/guardrails/providers_nvidia.go new file mode 100644 index 0000000..a0429d6 --- /dev/null +++ b/control-plane/internal/guardrails/providers_nvidia.go @@ -0,0 +1,248 @@ +package guardrails + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" +) + +const defaultNVIDIABaseURL = "https://integrate.api.nvidia.com" + +type NVIDIAProvider struct { + BaseURL string + HTTPClient *http.Client +} + +func (p *NVIDIAProvider) ID() string { + return "nvidia" +} + +func (p *NVIDIAProvider) Name() string { + return "NVIDIA" +} + +func (p *NVIDIAProvider) Capabilities() []string { + return []string{"catalog", "runtime_classifier"} +} + +func (p *NVIDIAProvider) TestCredentials(ctx context.Context, cfg ProviderConfig) error { + if cfg.APIKey == "" { + return fmt.Errorf("nvidia api key is required") + } + _, err := p.ListCatalog(ctx, cfg) + return err +} + +func (p *NVIDIAProvider) ListCatalog(ctx context.Context, cfg ProviderConfig) ([]CatalogEntry, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, strings.TrimSuffix(p.baseURL(), "/")+"/v1/models", nil) + if err != nil { + return nil, err + } + applyBearer(req, cfg.APIKey) + + res, err := p.client().Do(req) + if err != nil { + return nil, err + } + defer res.Body.Close() + if res.StatusCode < 200 || res.StatusCode >= 300 { + return nil, fmt.Errorf("nvidia catalog status %d: %s", res.StatusCode, readResponseBody(res.Body)) + } + + var payload struct { + Data []struct { + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + Type string `json:"type"` + } `json:"data"` + } + if err := json.NewDecoder(res.Body).Decode(&payload); err != nil { + return nil, fmt.Errorf("decode nvidia catalog: %w", err) + } + + out := make([]CatalogEntry, 0, len(payload.Data)) + for _, model := range payload.Data { + if !looksLikeNVIDIAGuardrailModel(model.ID, model.Name) { + continue + } + entry := CatalogEntry{ + ProviderID: p.ID(), + EntryID: model.ID, + Name: firstNonEmpty(model.Name, model.ID), + Kind: CatalogEntryClassifierModel, + Description: model.Description, + SupportsRuntimeEnforcement: true, + } + if model.Type != "" { + entry.Metadata = map[string]string{"type": model.Type} + } + out = append(out, entry) + } + sortCatalogEntries(out) + return out, nil +} + +func (p *NVIDIAProvider) RunRuntime(ctx context.Context, cfg ProviderConfig, entry CatalogEntry, req EvaluateRequest) (RuntimeResult, error) { + prompt, err := buildNVIDIARuntimePrompt(req) + if err != nil { + return RuntimeResult{}, err + } + body := map[string]any{ + "model": entry.EntryID, + "temperature": 0, + "messages": []map[string]string{{ + "role": "user", + "content": prompt, + }}, + } + rawBody, err := json.Marshal(body) + if err != nil { + return RuntimeResult{}, err + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, strings.TrimSuffix(p.baseURL(), "/")+"/v1/chat/completions", bytes.NewReader(rawBody)) + if err != nil { + return RuntimeResult{}, err + } + httpReq.Header.Set("Content-Type", "application/json") + applyBearer(httpReq, cfg.APIKey) + + res, err := p.client().Do(httpReq) + if err != nil { + return RuntimeResult{}, err + } + defer res.Body.Close() + if res.StatusCode < 200 || res.StatusCode >= 300 { + return RuntimeResult{}, fmt.Errorf("nvidia runtime status %d: %s", res.StatusCode, readResponseBody(res.Body)) + } + + var payload struct { + Choices []struct { + Message struct { + Content string `json:"content"` + } `json:"message"` + } `json:"choices"` + } + if err := json.NewDecoder(res.Body).Decode(&payload); err != nil { + return RuntimeResult{}, fmt.Errorf("decode nvidia runtime: %w", err) + } + if len(payload.Choices) == 0 { + return RuntimeResult{}, fmt.Errorf("nvidia runtime returned no choices") + } + + parsed, err := parseNVIDIAVerdict(payload.Choices[0].Message.Content) + if err != nil { + return RuntimeResult{}, err + } + parsed.ProviderID = p.ID() + parsed.EntryID = entry.EntryID + return parsed, nil +} + +func readResponseBody(body io.Reader) string { + raw, err := io.ReadAll(body) + if err != nil { + return "failed to read response body: " + err.Error() + } + return strings.TrimSpace(string(raw)) +} + +func (p *NVIDIAProvider) client() *http.Client { + if p != nil && p.HTTPClient != nil { + return p.HTTPClient + } + return http.DefaultClient +} + +func (p *NVIDIAProvider) baseURL() string { + if p != nil && p.BaseURL != "" { + return p.BaseURL + } + return defaultNVIDIABaseURL +} + +func looksLikeNVIDIAGuardrailModel(id, name string) bool { + text := strings.ToLower(id + " " + name) + return strings.Contains(text, "content-safety") || + strings.Contains(text, "content safety") || + strings.Contains(text, "safety-guard") || + strings.Contains(text, "safety guard") || + strings.Contains(text, "nemoguard") || + strings.Contains(text, "topic-control") || + strings.Contains(text, "topic control") +} + +func buildNVIDIARuntimePrompt(req EvaluateRequest) (string, error) { + inputJSON, err := json.Marshal(req.Input) + if err != nil { + return "", fmt.Errorf("marshal runtime input: %w", err) + } + return fmt.Sprintf(`Task: Check whether the following agent tool invocation is safe or unsafe. Return only JSON. +Conversation JSON: +{"source":%q,"tool":%q,"input":%s} +Output JSON: +{"User Safety":"safe or unsafe","Safety Categories":"comma-separated categories if unsafe"}`, req.Source, req.Tool, string(inputJSON)), nil +} + +func parseNVIDIAVerdict(content string) (RuntimeResult, error) { + clean := strings.TrimSpace(content) + clean = strings.TrimPrefix(clean, "```json") + clean = strings.TrimPrefix(clean, "```") + clean = strings.TrimSuffix(clean, "```") + clean = strings.TrimSpace(clean) + + var payload struct { + UserSafety string `json:"User Safety"` + ResponseSafety string `json:"Response Safety"` + SafetyCategories string `json:"Safety Categories"` + } + if err := json.Unmarshal([]byte(clean), &payload); err != nil { + return RuntimeResult{}, fmt.Errorf("parse nvidia verdict: %w", err) + } + + userSafety, userSet, err := parseSafetyLabel("User Safety", payload.UserSafety) + if err != nil { + return RuntimeResult{}, err + } + responseSafety, responseSet, err := parseSafetyLabel("Response Safety", payload.ResponseSafety) + if err != nil { + return RuntimeResult{}, err + } + if !userSet && !responseSet { + return RuntimeResult{}, fmt.Errorf("parse nvidia verdict: missing safety label") + } + + verdict := "allow" + if userSafety == "unsafe" || responseSafety == "unsafe" { + verdict = "deny" + } + details := map[string]string{} + if userSet { + details["user_safety"] = userSafety + } + if responseSet { + details["response_safety"] = responseSafety + } + if payload.SafetyCategories != "" { + details["safety_categories"] = payload.SafetyCategories + } + return RuntimeResult{Verdict: verdict, Details: details}, nil +} + +func parseSafetyLabel(field, value string) (string, bool, error) { + if strings.TrimSpace(value) == "" { + return "", false, nil + } + label := strings.ToLower(strings.TrimSpace(value)) + switch label { + case "safe", "unsafe": + return label, true, nil + default: + return "", false, fmt.Errorf("parse nvidia verdict: invalid %s %q", field, value) + } +} diff --git a/control-plane/internal/guardrails/providers_openrouter.go b/control-plane/internal/guardrails/providers_openrouter.go new file mode 100644 index 0000000..67580b3 --- /dev/null +++ b/control-plane/internal/guardrails/providers_openrouter.go @@ -0,0 +1,130 @@ +package guardrails + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strconv" + "strings" +) + +const defaultOpenRouterBaseURL = "https://openrouter.ai" + +type OpenRouterProvider struct { + BaseURL string + HTTPClient *http.Client +} + +func (p *OpenRouterProvider) ID() string { + return "openrouter" +} + +func (p *OpenRouterProvider) Name() string { + return "OpenRouter" +} + +func (p *OpenRouterProvider) Capabilities() []string { + return []string{"catalog", "catalog_only"} +} + +func (p *OpenRouterProvider) TestCredentials(ctx context.Context, cfg ProviderConfig) error { + _, err := p.ListCatalog(ctx, cfg) + return err +} + +func (p *OpenRouterProvider) ListCatalog(ctx context.Context, cfg ProviderConfig) ([]CatalogEntry, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, strings.TrimSuffix(p.baseURL(), "/")+"/api/v1/guardrails", nil) + if err != nil { + return nil, err + } + applyBearer(req, cfg.APIKey) + + res, err := p.client().Do(req) + if err != nil { + return nil, err + } + defer res.Body.Close() + if res.StatusCode < 200 || res.StatusCode >= 300 { + return nil, fmt.Errorf("openrouter catalog status %d", res.StatusCode) + } + + var payload struct { + Data []struct { + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + WorkspaceID string `json:"workspace_id"` + AllowedModels []string `json:"allowed_models"` + AllowedProviders []string `json:"allowed_providers"` + IgnoredModels []string `json:"ignored_models"` + IgnoredProviders []string `json:"ignored_providers"` + LimitUSD float64 `json:"limit_usd"` + ResetInterval string `json:"reset_interval"` + EnforceZDR bool `json:"enforce_zdr"` + } `json:"data"` + } + if err := json.NewDecoder(res.Body).Decode(&payload); err != nil { + return nil, fmt.Errorf("decode openrouter catalog: %w", err) + } + + out := make([]CatalogEntry, 0, len(payload.Data)) + for _, guardrail := range payload.Data { + metadata := map[string]string{"enforce_zdr": strconv.FormatBool(guardrail.EnforceZDR)} + if guardrail.WorkspaceID != "" { + metadata["workspace_id"] = guardrail.WorkspaceID + } + if len(guardrail.AllowedModels) > 0 { + metadata["allowed_models"] = strings.Join(guardrail.AllowedModels, ",") + } + if len(guardrail.AllowedProviders) > 0 { + metadata["allowed_providers"] = strings.Join(guardrail.AllowedProviders, ",") + } + if len(guardrail.IgnoredModels) > 0 { + metadata["ignored_models"] = strings.Join(guardrail.IgnoredModels, ",") + } + if len(guardrail.IgnoredProviders) > 0 { + metadata["ignored_providers"] = strings.Join(guardrail.IgnoredProviders, ",") + } + if guardrail.LimitUSD != 0 { + metadata["limit_usd"] = strconv.FormatFloat(guardrail.LimitUSD, 'f', 2, 64) + } + if guardrail.ResetInterval != "" { + metadata["reset_interval"] = guardrail.ResetInterval + } + out = append(out, CatalogEntry{ + ProviderID: p.ID(), + EntryID: guardrail.ID, + Name: firstNonEmpty(guardrail.Name, guardrail.ID), + Kind: CatalogEntryAccountPolicy, + Description: guardrail.Description, + Metadata: metadata, + }) + } + sortCatalogEntries(out) + return out, nil +} + +func (p *OpenRouterProvider) RunRuntime(context.Context, ProviderConfig, CatalogEntry, EvaluateRequest) (RuntimeResult, error) { + return RuntimeResult{Verdict: "abstain"}, ErrRuntimeUnsupported +} + +func (p *OpenRouterProvider) client() *http.Client { + if p != nil && p.HTTPClient != nil { + return p.HTTPClient + } + return http.DefaultClient +} + +func (p *OpenRouterProvider) baseURL() string { + if p != nil && p.BaseURL != "" { + return p.BaseURL + } + return defaultOpenRouterBaseURL +} + +func applyBearer(req *http.Request, apiKey SecretString) { + if apiKey != "" { + req.Header.Set("Authorization", "Bearer "+apiKey.Value()) + } +} diff --git a/control-plane/internal/guardrails/service.go b/control-plane/internal/guardrails/service.go new file mode 100644 index 0000000..80fc716 --- /dev/null +++ b/control-plane/internal/guardrails/service.go @@ -0,0 +1,498 @@ +package guardrails + +import ( + "context" + "errors" + "fmt" + "sort" + "sync" + "time" +) + +var ( + ErrProviderNotRegistered = errors.New("guardrails provider not registered") + ErrRuntimeUnsupported = errors.New("guardrail runtime unsupported") +) + +const ( + defaultCatalogCacheTTL = 5 * time.Minute + maxTraceEntries = 1024 +) + +type CatalogEntry struct { + ProviderID string `json:"provider_id"` + EntryID string `json:"entry_id"` + Name string `json:"name"` + Kind CatalogEntryKind `json:"kind"` + Description string `json:"description,omitempty"` + SupportsRuntimeEnforcement bool `json:"supports_runtime_enforcement"` + Metadata map[string]string `json:"metadata,omitempty"` +} + +type CatalogProviderError struct { + ProviderID string `json:"provider_id"` + Detail string `json:"detail"` +} + +type RuntimeResult struct { + ProviderID string `json:"provider_id"` + EntryID string `json:"entry_id"` + Verdict string `json:"verdict"` + LatencyMS int `json:"latency_ms"` + Details map[string]string `json:"details,omitempty"` +} + +type RuntimeStage struct { + ProviderID string `json:"provider_id"` + EntryID string `json:"entry_id"` + Verdict string `json:"verdict"` + LatencyMS int `json:"latency_ms"` + Details map[string]string `json:"details,omitempty"` +} + +type Trace struct { + LocalPolicyVerdict string `json:"local_policy_verdict"` + GuardrailVerdict string `json:"guardrail_verdict"` + FinalVerdict string `json:"final_verdict"` + Stages []RuntimeStage `json:"stages"` +} + +type EvaluateRequest struct { + LocalPolicyVerdict string `json:"local_policy_verdict"` + Source string `json:"source,omitempty"` + Tool string `json:"tool"` + Input map[string]any `json:"input"` +} + +type Provider interface { + ID() string + Name() string + Capabilities() []string + TestCredentials(ctx context.Context, cfg ProviderConfig) error + ListCatalog(ctx context.Context, cfg ProviderConfig) ([]CatalogEntry, error) + RunRuntime(ctx context.Context, cfg ProviderConfig, entry CatalogEntry, req EvaluateRequest) (RuntimeResult, error) +} + +type Store interface { + ListGuardrailProviderConfigs(ctx context.Context) ([]ProviderConfig, error) + GetGuardrailProviderConfig(ctx context.Context, providerID string) (ProviderConfig, bool, error) + ListGuardrailEnabled(ctx context.Context) ([]EnabledEntry, error) +} + +type Service struct { + store Store + + mu sync.RWMutex + providers map[string]Provider + catalogCache map[string]catalogCacheEntry + traces map[uint64]Trace + traceOrder []uint64 + now func() time.Time +} + +type catalogCacheEntry struct { + entries []CatalogEntry + expiresAt time.Time +} + +func NewService(store Store, providers ...Provider) *Service { + svc := &Service{ + store: store, + providers: make(map[string]Provider, len(providers)), + catalogCache: make(map[string]catalogCacheEntry), + traces: map[uint64]Trace{}, + now: time.Now, + } + for _, provider := range providers { + svc.Register(provider) + } + return svc +} + +func NewDefaultService(store Store) *Service { + return NewService(store, &NVIDIAProvider{}, &OpenRouterProvider{}) +} + +func (s *Service) Register(provider Provider) { + if provider == nil { + return + } + s.mu.Lock() + defer s.mu.Unlock() + s.providers[provider.ID()] = provider +} + +func (s *Service) ProviderIDs() []string { + s.mu.RLock() + defer s.mu.RUnlock() + ids := make([]string, 0, len(s.providers)) + for id := range s.providers { + ids = append(ids, id) + } + sort.Strings(ids) + return ids +} + +func (s *Service) IsSupported(providerID string) bool { + _, ok := s.provider(providerID) + return ok +} + +func (s *Service) ProviderName(providerID string) string { + provider, ok := s.provider(providerID) + if !ok { + return providerID + } + return provider.Name() +} + +func (s *Service) ProviderCapabilities(providerID string) []string { + provider, ok := s.provider(providerID) + if !ok { + return []string{} + } + return append([]string(nil), provider.Capabilities()...) +} + +func (s *Service) TestCredentials(ctx context.Context, cfg ProviderConfig) error { + provider, ok := s.provider(cfg.ProviderID) + if !ok { + return fmt.Errorf("%w: %s", ErrProviderNotRegistered, cfg.ProviderID) + } + return provider.TestCredentials(ctx, cfg) +} + +func (s *Service) ListCatalog(ctx context.Context) ([]CatalogEntry, error) { + entries, providerErrors, err := s.ListCatalogStatus(ctx) + if err != nil { + return nil, err + } + if len(providerErrors) > 0 { + return nil, fmt.Errorf("list catalog for %s: %s", providerErrors[0].ProviderID, providerErrors[0].Detail) + } + return entries, nil +} + +func (s *Service) ListCatalogStatus(ctx context.Context) ([]CatalogEntry, []CatalogProviderError, error) { + if s.store == nil { + return nil, nil, nil + } + cfgs, err := s.store.ListGuardrailProviderConfigs(ctx) + if err != nil { + return nil, nil, fmt.Errorf("list provider configs: %w", err) + } + sort.Slice(cfgs, func(i, j int) bool { + return cfgs[i].ProviderID < cfgs[j].ProviderID + }) + + out := make([]CatalogEntry, 0) + providerErrors := make([]CatalogProviderError, 0) + for _, cfg := range cfgs { + provider, ok := s.provider(cfg.ProviderID) + if !ok { + return nil, nil, fmt.Errorf("%w: %s", ErrProviderNotRegistered, cfg.ProviderID) + } + entries, err := provider.ListCatalog(ctx, cfg) + if err != nil { + providerErrors = append(providerErrors, CatalogProviderError{ + ProviderID: cfg.ProviderID, + Detail: err.Error(), + }) + continue + } + normalized := cloneCatalogEntries(entries) + for i := range normalized { + if normalized[i].ProviderID == "" { + normalized[i].ProviderID = cfg.ProviderID + } + } + s.setCachedCatalog(cfg.ProviderID, normalized) + out = append(out, normalized...) + } + sortCatalogEntries(out) + return out, providerErrors, nil +} + +func (s *Service) EvaluatePostPolicy(ctx context.Context, req EvaluateRequest) (Trace, string) { + trace := Trace{ + LocalPolicyVerdict: req.LocalPolicyVerdict, + GuardrailVerdict: "allow", + FinalVerdict: req.LocalPolicyVerdict, + } + if req.LocalPolicyVerdict != "allow" { + trace.GuardrailVerdict = "skipped" + return trace, trace.FinalVerdict + } + if s.store == nil { + trace.FinalVerdict = "allow" + return trace, trace.FinalVerdict + } + + cfgs, err := s.store.ListGuardrailProviderConfigs(ctx) + if err != nil { + trace.GuardrailVerdict = "abstain" + trace.FinalVerdict = "allow" + trace.Stages = append(trace.Stages, RuntimeStage{ + Verdict: "abstain", + Details: map[string]string{ + "reason": "provider_config_load_failed", + "error": err.Error(), + }, + }) + return trace, trace.FinalVerdict + } + cfgByProvider := make(map[string]ProviderConfig, len(cfgs)) + for _, cfg := range cfgs { + cfgByProvider[cfg.ProviderID] = cfg + } + + enabled, err := s.store.ListGuardrailEnabled(ctx) + if err != nil { + trace.GuardrailVerdict = "abstain" + trace.FinalVerdict = "allow" + trace.Stages = append(trace.Stages, RuntimeStage{ + Verdict: "abstain", + Details: map[string]string{ + "reason": "enabled_guardrails_load_failed", + "error": err.Error(), + }, + }) + return trace, trace.FinalVerdict + } + + finalVerdict := "allow" + guardrailVerdict := "allow" + for _, enabledEntry := range enabled { + cfg, ok := cfgByProvider[enabledEntry.ProviderID] + if !ok { + trace.Stages = append(trace.Stages, RuntimeStage{ + ProviderID: enabledEntry.ProviderID, + EntryID: enabledEntry.EntryID, + Verdict: "abstain", + Details: map[string]string{"reason": "provider_not_configured"}, + }) + guardrailVerdict = "abstain" + continue + } + provider, ok := s.provider(enabledEntry.ProviderID) + if !ok { + trace.Stages = append(trace.Stages, RuntimeStage{ + ProviderID: enabledEntry.ProviderID, + EntryID: enabledEntry.EntryID, + Verdict: "abstain", + Details: map[string]string{"reason": "provider_not_registered"}, + }) + guardrailVerdict = "abstain" + continue + } + entry, ok, err := s.runtimeEntry(ctx, cfg, enabledEntry) + if err != nil { + trace.Stages = append(trace.Stages, RuntimeStage{ + ProviderID: enabledEntry.ProviderID, + EntryID: enabledEntry.EntryID, + Verdict: "abstain", + Details: map[string]string{ + "reason": "catalog_lookup_failed", + "error": err.Error(), + }, + }) + guardrailVerdict = "abstain" + continue + } + if !ok { + trace.Stages = append(trace.Stages, RuntimeStage{ + ProviderID: enabledEntry.ProviderID, + EntryID: enabledEntry.EntryID, + Verdict: "abstain", + Details: map[string]string{"reason": "catalog_entry_not_found"}, + }) + guardrailVerdict = "abstain" + continue + } + + result, err := provider.RunRuntime(ctx, cfg, entry, req) + if err != nil { + trace.Stages = append(trace.Stages, RuntimeStage{ + ProviderID: enabledEntry.ProviderID, + EntryID: enabledEntry.EntryID, + Verdict: "abstain", + LatencyMS: result.LatencyMS, + Details: runtimeErrorDetails(entry, err), + }) + guardrailVerdict = "abstain" + continue + } + + stage := RuntimeStage{ + ProviderID: firstNonEmpty(result.ProviderID, enabledEntry.ProviderID), + EntryID: firstNonEmpty(result.EntryID, enabledEntry.EntryID), + Verdict: normalizeVerdict(result.Verdict), + LatencyMS: result.LatencyMS, + Details: cloneStringMap(result.Details), + } + trace.Stages = append(trace.Stages, stage) + if stage.Verdict == "deny" { + trace.GuardrailVerdict = "deny" + trace.FinalVerdict = "deny" + return trace, "deny" + } + if stage.Verdict == "abstain" { + guardrailVerdict = "abstain" + } + } + trace.GuardrailVerdict = guardrailVerdict + trace.FinalVerdict = finalVerdict + return trace, finalVerdict +} + +func (s *Service) RecordTrace(seq uint64, trace Trace) { + s.mu.Lock() + defer s.mu.Unlock() + if _, exists := s.traces[seq]; !exists { + s.traceOrder = append(s.traceOrder, seq) + } + s.traces[seq] = cloneTrace(trace) + for len(s.traceOrder) > maxTraceEntries { + oldest := s.traceOrder[0] + s.traceOrder = s.traceOrder[1:] + delete(s.traces, oldest) + } +} + +func (s *Service) Trace(seq uint64) (Trace, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + trace, ok := s.traces[seq] + return cloneTrace(trace), ok +} + +func (s *Service) runtimeEntry(ctx context.Context, cfg ProviderConfig, enabled EnabledEntry) (CatalogEntry, bool, error) { + entries, ok := s.cachedCatalog(enabled.ProviderID) + if !ok { + provider, exists := s.provider(enabled.ProviderID) + if !exists { + return CatalogEntry{}, false, fmt.Errorf("%w: %s", ErrProviderNotRegistered, enabled.ProviderID) + } + var err error + rawEntries, err := provider.ListCatalog(ctx, cfg) + if err != nil { + return CatalogEntry{}, false, err + } + entries = cloneCatalogEntries(rawEntries) + for i := range entries { + if entries[i].ProviderID == "" { + entries[i].ProviderID = enabled.ProviderID + } + } + s.setCachedCatalog(enabled.ProviderID, entries) + } + for _, entry := range entries { + if entry.EntryID == enabled.EntryID { + return cloneCatalogEntry(entry), true, nil + } + } + return CatalogEntry{}, false, nil +} + +func (s *Service) provider(providerID string) (Provider, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + provider, ok := s.providers[providerID] + return provider, ok +} + +func (s *Service) cachedCatalog(providerID string) ([]CatalogEntry, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + entries, ok := s.catalogCache[providerID] + if !ok { + return nil, false + } + if !entries.expiresAt.IsZero() && s.now().After(entries.expiresAt) { + return nil, false + } + return cloneCatalogEntries(entries.entries), true +} + +func (s *Service) setCachedCatalog(providerID string, entries []CatalogEntry) { + s.mu.Lock() + defer s.mu.Unlock() + s.catalogCache[providerID] = catalogCacheEntry{ + entries: cloneCatalogEntries(entries), + expiresAt: s.now().Add(defaultCatalogCacheTTL), + } +} + +func sortCatalogEntries(entries []CatalogEntry) { + sort.Slice(entries, func(i, j int) bool { + if entries[i].ProviderID != entries[j].ProviderID { + return entries[i].ProviderID < entries[j].ProviderID + } + return entries[i].EntryID < entries[j].EntryID + }) +} + +func cloneCatalogEntries(entries []CatalogEntry) []CatalogEntry { + out := make([]CatalogEntry, 0, len(entries)) + for _, entry := range entries { + out = append(out, cloneCatalogEntry(entry)) + } + return out +} + +func cloneCatalogEntry(entry CatalogEntry) CatalogEntry { + entry.Metadata = cloneStringMap(entry.Metadata) + return entry +} + +func cloneStringMap(in map[string]string) map[string]string { + if len(in) == 0 { + return nil + } + out := make(map[string]string, len(in)) + for k, v := range in { + out[k] = v + } + return out +} + +func cloneTrace(trace Trace) Trace { + trace.Stages = append([]RuntimeStage(nil), trace.Stages...) + for i := range trace.Stages { + trace.Stages[i].Details = cloneStringMap(trace.Stages[i].Details) + } + return trace +} + +func firstNonEmpty(values ...string) string { + for _, value := range values { + if value != "" { + return value + } + } + return "" +} + +func normalizeVerdict(verdict string) string { + switch verdict { + case "allow", "deny", "abstain": + return verdict + default: + return "abstain" + } +} + +func runtimeErrorDetails(entry CatalogEntry, err error) map[string]string { + details := map[string]string{"error": err.Error()} + switch { + case errors.Is(err, ErrRuntimeUnsupported): + details["reason"] = "runtime_unsupported" + case entry.Kind != CatalogEntryClassifierModel: + details["reason"] = "entry_not_classifier_model" + case !entry.SupportsRuntimeEnforcement: + details["reason"] = "entry_not_runtime_capable" + default: + details["reason"] = "runtime_error" + } + return details +} diff --git a/control-plane/internal/guardrails/service_test.go b/control-plane/internal/guardrails/service_test.go new file mode 100644 index 0000000..c5e8508 --- /dev/null +++ b/control-plane/internal/guardrails/service_test.go @@ -0,0 +1,444 @@ +package guardrails + +import ( + "context" + "errors" + "io" + "net/http" + "net/http/httptest" + "reflect" + "strings" + "testing" +) + +type guardrailStoreForTest struct { + configs []ProviderConfig + configsError error + enableds []EnabledEntry + enabledError error +} + +func (s guardrailStoreForTest) ListGuardrailProviderConfigs(context.Context) ([]ProviderConfig, error) { + if s.configsError != nil { + return nil, s.configsError + } + return append([]ProviderConfig(nil), s.configs...), nil +} + +func (s guardrailStoreForTest) GetGuardrailProviderConfig(_ context.Context, providerID string) (ProviderConfig, bool, error) { + for _, cfg := range s.configs { + if cfg.ProviderID == providerID { + return cfg, true, nil + } + } + return ProviderConfig{}, false, nil +} + +func (s guardrailStoreForTest) ListGuardrailEnabled(context.Context) ([]EnabledEntry, error) { + if s.enabledError != nil { + return nil, s.enabledError + } + return append([]EnabledEntry(nil), s.enableds...), nil +} + +type fakeProvider struct { + id string + testError error + catalog []CatalogEntry + catalogError error + runtimeResult RuntimeResult + runtimeError error +} + +func (p fakeProvider) ID() string { return p.id } + +func (p fakeProvider) Name() string { return p.id } + +func (p fakeProvider) Capabilities() []string { return []string{"catalog"} } + +func (p fakeProvider) TestCredentials(context.Context, ProviderConfig) error { return p.testError } + +func (p fakeProvider) ListCatalog(context.Context, ProviderConfig) ([]CatalogEntry, error) { + if p.catalogError != nil { + return nil, p.catalogError + } + return append([]CatalogEntry(nil), p.catalog...), nil +} + +func (p fakeProvider) RunRuntime(context.Context, ProviderConfig, CatalogEntry, EvaluateRequest) (RuntimeResult, error) { + return p.runtimeResult, p.runtimeError +} + +func newServiceForTest(provider fakeProvider) *Service { + catalog := provider.catalog + if len(catalog) == 0 { + entryID := provider.runtimeResult.EntryID + if entryID == "" { + entryID = "default-entry" + } + catalog = []CatalogEntry{{ + ProviderID: provider.id, + EntryID: entryID, + Name: entryID, + Kind: CatalogEntryClassifierModel, + SupportsRuntimeEnforcement: true, + }} + } + enabled := make([]EnabledEntry, 0, len(catalog)) + for _, entry := range catalog { + enabled = append(enabled, EnabledEntry{ProviderID: entry.ProviderID, EntryID: entry.EntryID}) + } + svc := NewService(guardrailStoreForTest{ + configs: []ProviderConfig{{ProviderID: provider.id}}, + enableds: enabled, + }, provider) + svc.setCachedCatalog(provider.id, catalog) + return svc +} + +func TestService_EvaluateAfterLocalAllow(t *testing.T) { + svc := newServiceForTest(fakeProvider{ + id: "nvidia", + catalog: []CatalogEntry{{ + ProviderID: "nvidia", + EntryID: "nemo-content-safety", + Kind: CatalogEntryClassifierModel, + SupportsRuntimeEnforcement: true, + }}, + runtimeResult: RuntimeResult{ + Verdict: "deny", + LatencyMS: 180, + ProviderID: "nvidia", + EntryID: "nemo-content-safety", + }, + }) + trace, final := svc.EvaluatePostPolicy(context.Background(), EvaluateRequest{ + LocalPolicyVerdict: "allow", + Tool: "Bash", + Input: map[string]any{"command": "cat secrets.txt"}, + }) + if final != "deny" { + t.Fatalf("final verdict = %q, want deny", final) + } + if len(trace.Stages) != 1 || trace.Stages[0].Verdict != "deny" { + t.Fatalf("trace = %+v", trace) + } +} + +func TestService_AbstainsOnProviderError(t *testing.T) { + svc := newServiceForTest(fakeProvider{ + id: "nvidia", + runtimeError: errors.New("timeout"), + }) + trace, final := svc.EvaluatePostPolicy(context.Background(), EvaluateRequest{ + LocalPolicyVerdict: "allow", + Tool: "Read", + Input: map[string]any{"file_path": ".env"}, + }) + if final != "allow" { + t.Fatalf("final verdict = %q, want allow", final) + } + if got := trace.Stages[0].Verdict; got != "abstain" { + t.Fatalf("stage verdict = %q, want abstain", got) + } +} + +func TestService_RecordTraceEvictsOldest(t *testing.T) { + svc := NewService(nil, fakeProvider{id: "nvidia"}) + for i := uint64(1); i <= maxTraceEntries+1; i++ { + svc.RecordTrace(i, Trace{LocalPolicyVerdict: "allow", FinalVerdict: "allow"}) + } + if _, ok := svc.Trace(1); ok { + t.Fatalf("oldest trace was not evicted") + } + if _, ok := svc.Trace(maxTraceEntries + 1); !ok { + t.Fatalf("newest trace missing") + } +} + +func TestService_EvaluateIncludesStorageErrorDetails(t *testing.T) { + svc := NewService(guardrailStoreForTest{configsError: errors.New("boom")}, fakeProvider{id: "nvidia"}) + trace, final := svc.EvaluatePostPolicy(context.Background(), EvaluateRequest{LocalPolicyVerdict: "allow"}) + if final != "allow" || trace.GuardrailVerdict != "abstain" { + t.Fatalf("trace=%+v final=%q", trace, final) + } + if len(trace.Stages) != 1 || trace.Stages[0].Details["reason"] != "provider_config_load_failed" { + t.Fatalf("trace stages = %+v", trace.Stages) + } +} + +func TestService_RuntimeEntryDoesNotMutateProviderCatalog(t *testing.T) { + catalog := []CatalogEntry{{EntryID: "entry", Kind: CatalogEntryClassifierModel, SupportsRuntimeEnforcement: true}} + provider := fakeProvider{id: "nvidia", catalog: catalog, runtimeResult: RuntimeResult{Verdict: "allow"}} + svc := NewService(guardrailStoreForTest{ + configs: []ProviderConfig{{ProviderID: "nvidia"}}, + enableds: []EnabledEntry{{ProviderID: "nvidia", EntryID: "entry"}}, + }, provider) + _, _ = svc.EvaluatePostPolicy(context.Background(), EvaluateRequest{ + LocalPolicyVerdict: "allow", + Tool: "Bash", + Input: map[string]any{"command": "pwd"}, + }) + if catalog[0].ProviderID != "" { + t.Fatalf("provider catalog mutated: %+v", catalog) + } +} + +func TestService_EvaluatePostPolicySkipsWhenLocalVerdictIsNotAllow(t *testing.T) { + svc := newServiceForTest(fakeProvider{id: "nvidia", runtimeResult: RuntimeResult{Verdict: "deny"}}) + trace, final := svc.EvaluatePostPolicy(context.Background(), EvaluateRequest{ + LocalPolicyVerdict: "deny", + Tool: "Bash", + Input: map[string]any{"command": "cat secrets.txt"}, + }) + if final != "deny" { + t.Fatalf("final verdict = %q, want deny", final) + } + if trace.GuardrailVerdict != "skipped" || len(trace.Stages) != 0 { + t.Fatalf("trace = %+v", trace) + } +} + +func TestService_EvaluatePostPolicyRecordsOpenRouterAbstainStage(t *testing.T) { + svc := NewService( + guardrailStoreForTest{ + configs: []ProviderConfig{{ProviderID: "openrouter"}}, + enableds: []EnabledEntry{{ProviderID: "openrouter", EntryID: "policy-prod"}}, + }, + fakeProvider{ + id: "openrouter", + catalog: []CatalogEntry{{ + ProviderID: "openrouter", + EntryID: "policy-prod", + Name: "Production Guardrail", + Kind: CatalogEntryAccountPolicy, + }}, + runtimeResult: RuntimeResult{Verdict: "abstain"}, + runtimeError: ErrRuntimeUnsupported, + }, + ) + trace, final := svc.EvaluatePostPolicy(context.Background(), EvaluateRequest{ + LocalPolicyVerdict: "allow", + Tool: "Bash", + Input: map[string]any{"command": "pwd"}, + }) + if final != "allow" || trace.GuardrailVerdict != "abstain" { + t.Fatalf("trace=%+v final=%q", trace, final) + } + if len(trace.Stages) != 1 || trace.Stages[0].Verdict != "abstain" { + t.Fatalf("stages = %+v", trace.Stages) + } +} + +func TestService_ListCatalogAggregatesConfiguredProviders(t *testing.T) { + svc := NewService( + guardrailStoreForTest{configs: []ProviderConfig{{ProviderID: "nvidia"}, {ProviderID: "openrouter"}}}, + fakeProvider{ + id: "nvidia", + catalog: []CatalogEntry{{ + ProviderID: "nvidia", + EntryID: "nemo-content-safety", + Name: "NeMo Content Safety", + Kind: CatalogEntryClassifierModel, + SupportsRuntimeEnforcement: true, + }}, + }, + fakeProvider{ + id: "openrouter", + catalog: []CatalogEntry{{ + ProviderID: "openrouter", + EntryID: "policy-prod", + Name: "Production Guardrail", + Kind: CatalogEntryAccountPolicy, + }}, + }, + ) + got, err := svc.ListCatalog(context.Background()) + if err != nil { + t.Fatalf("ListCatalog: %v", err) + } + if len(got) != 2 || got[0].ProviderID != "nvidia" || got[1].ProviderID != "openrouter" { + t.Fatalf("catalog = %+v", got) + } +} + +func TestService_ListCatalogStatus_ReturnsPartialEntriesAndProviderErrors(t *testing.T) { + svc := NewService( + guardrailStoreForTest{configs: []ProviderConfig{{ProviderID: "nvidia"}, {ProviderID: "openrouter"}}}, + fakeProvider{ + id: "nvidia", + catalog: []CatalogEntry{{ + ProviderID: "nvidia", + EntryID: "nemo-content-safety", + Name: "NeMo Content Safety", + Kind: CatalogEntryClassifierModel, + SupportsRuntimeEnforcement: true, + }}, + }, + fakeProvider{ + id: "openrouter", + catalogError: errors.New("openrouter catalog status 401"), + }, + ) + got, providerErrors, err := svc.ListCatalogStatus(context.Background()) + if err != nil { + t.Fatalf("ListCatalogStatus: %v", err) + } + if len(got) != 1 || got[0].ProviderID != "nvidia" { + t.Fatalf("catalog = %+v", got) + } + if len(providerErrors) != 1 || providerErrors[0].ProviderID != "openrouter" { + t.Fatalf("providerErrors = %+v", providerErrors) + } +} + +func TestNVIDIAProvider_ListCatalogNormalizesSafetyModels(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/models" { + t.Fatalf("path = %q, want /v1/models", r.URL.Path) + } + if got := r.Header.Get("Authorization"); got != "Bearer nvapi-test" { + t.Fatalf("authorization = %q", got) + } + _, _ = w.Write([]byte(`{"data":[ + {"id":"nvidia/llama-3.1-nemoguard-8b-content-safety","name":"Llama 3.1 NeMoGuard 8B Content Safety"}, + {"id":"meta/llama-3.1-8b-instruct","name":"General LLM"}, + {"id":"nvidia/llama-3_1-nemotron-safety-guard-8b-v3","name":"Nemotron Safety Guard 8B"} + ]}`)) + })) + defer srv.Close() + provider := &NVIDIAProvider{BaseURL: srv.URL, HTTPClient: srv.Client()} + got, err := provider.ListCatalog(context.Background(), ProviderConfig{ + ProviderID: "nvidia", + APIKey: SecretString("nvapi-test"), + }) + if err != nil { + t.Fatalf("ListCatalog: %v", err) + } + if len(got) != 2 || got[0].Kind != CatalogEntryClassifierModel || !got[0].SupportsRuntimeEnforcement { + t.Fatalf("catalog = %+v", got) + } +} + +func TestNVIDIAProvider_RunRuntimeParsesDenyVerdict(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/chat/completions" || r.Method != http.MethodPost { + t.Fatalf("request = %s %s", r.Method, r.URL.Path) + } + body, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + text := string(body) + if !strings.Contains(text, `"model":"nvidia/llama-3.1-nemoguard-8b-content-safety"`) || + !strings.Contains(text, `\"tool\":\"Bash\"`) { + t.Fatalf("request body missing expected payload: %s", text) + } + _, _ = w.Write([]byte(`{"choices":[{"message":{"content":"{\"User Safety\":\"unsafe\",\"Safety Categories\":\"PII/Privacy\"}"}}]}`)) + })) + defer srv.Close() + provider := &NVIDIAProvider{BaseURL: srv.URL, HTTPClient: srv.Client()} + got, err := provider.RunRuntime(context.Background(), ProviderConfig{ + ProviderID: "nvidia", + APIKey: SecretString("nvapi-test"), + }, CatalogEntry{ + ProviderID: "nvidia", + EntryID: "nvidia/llama-3.1-nemoguard-8b-content-safety", + Kind: CatalogEntryClassifierModel, + SupportsRuntimeEnforcement: true, + }, EvaluateRequest{ + LocalPolicyVerdict: "allow", + Tool: "Bash", + Input: map[string]any{"command": "cat secrets.txt"}, + }) + if err != nil { + t.Fatalf("RunRuntime: %v", err) + } + want := RuntimeResult{ + ProviderID: "nvidia", + EntryID: "nvidia/llama-3.1-nemoguard-8b-content-safety", + Verdict: "deny", + Details: map[string]string{ + "safety_categories": "PII/Privacy", + "user_safety": "unsafe", + }, + } + if !reflect.DeepEqual(want, got) { + t.Fatalf("runtime mismatch: want %+v got %+v", want, got) + } +} + +func TestNVIDIAProvider_RunRuntimeRejectsSchemaInvalidVerdictJSON(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte(`{"choices":[{"message":{"content":"{\"User Safety\":\"maybe\",\"Safety Categories\":\"PII\"}"}}]}`)) + })) + defer srv.Close() + provider := &NVIDIAProvider{BaseURL: srv.URL, HTTPClient: srv.Client()} + _, err := provider.RunRuntime(context.Background(), ProviderConfig{ + ProviderID: "nvidia", + APIKey: SecretString("nvapi-test"), + }, CatalogEntry{ + ProviderID: "nvidia", + EntryID: "nvidia/llama-3.1-nemoguard-8b-content-safety", + }, EvaluateRequest{ + LocalPolicyVerdict: "allow", + Tool: "Bash", + Input: map[string]any{"command": "cat secrets.txt"}, + }) + if err == nil { + t.Fatalf("expected schema validation error") + } +} + +func TestOpenRouterProvider_ListCatalogNormalizesPolicies(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/v1/guardrails" { + t.Fatalf("path = %q, want /api/v1/guardrails", r.URL.Path) + } + if got := r.Header.Get("Authorization"); got != "Bearer sk-or-test" { + t.Fatalf("authorization = %q", got) + } + _, _ = w.Write([]byte(`{"data":[{ + "id":"550e8400-e29b-41d4-a716-446655440000", + "name":"Production Guardrail", + "description":"Restrict providers and enforce budget", + "limit_usd":100, + "reset_interval":"monthly", + "enforce_zdr":true + }]}`)) + })) + defer srv.Close() + provider := &OpenRouterProvider{BaseURL: srv.URL, HTTPClient: srv.Client()} + got, err := provider.ListCatalog(context.Background(), ProviderConfig{ + ProviderID: "openrouter", + APIKey: SecretString("sk-or-test"), + }) + if err != nil { + t.Fatalf("ListCatalog: %v", err) + } + if len(got) != 1 || got[0].Kind != CatalogEntryAccountPolicy || got[0].SupportsRuntimeEnforcement { + t.Fatalf("catalog = %+v", got) + } + if got[0].Metadata["limit_usd"] != "100.00" { + t.Fatalf("limit_usd = %q", got[0].Metadata["limit_usd"]) + } +} + +func TestOpenRouterProvider_RunRuntimeReturnsUnsupported(t *testing.T) { + provider := &OpenRouterProvider{} + got, err := provider.RunRuntime(context.Background(), ProviderConfig{}, CatalogEntry{ + ProviderID: "openrouter", + EntryID: "policy-prod", + Kind: CatalogEntryAccountPolicy, + }, EvaluateRequest{ + LocalPolicyVerdict: "allow", + Tool: "Bash", + Input: map[string]any{"command": "pwd"}, + }) + if !errors.Is(err, ErrRuntimeUnsupported) { + t.Fatalf("err = %v, want ErrRuntimeUnsupported", err) + } + if got.Verdict != "abstain" { + t.Fatalf("verdict = %q, want abstain", got.Verdict) + } +} diff --git a/control-plane/internal/guardrails/types.go b/control-plane/internal/guardrails/types.go new file mode 100644 index 0000000..9332e81 --- /dev/null +++ b/control-plane/internal/guardrails/types.go @@ -0,0 +1,41 @@ +package guardrails + +import "fmt" + +type ProviderKind string + +const ( + ProviderKindHostedAPI ProviderKind = "hosted_api" +) + +type CatalogEntryKind string + +const ( + CatalogEntryClassifierModel CatalogEntryKind = "classifier_model" + CatalogEntryAccountPolicy CatalogEntryKind = "account_policy" +) + +type ProviderConfig struct { + ProviderID string `json:"provider_id"` + APIKey SecretString `json:"api_key,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` +} + +type EnabledEntry struct { + ProviderID string `json:"provider_id"` + EntryID string `json:"entry_id"` +} + +type SecretString string + +func (s SecretString) Value() string { + return string(s) +} + +func (s SecretString) String() string { + return "[redacted]" +} + +func (s SecretString) GoString() string { + return fmt.Sprintf("%q", s.String()) +} diff --git a/control-plane/internal/storage/memory.go b/control-plane/internal/storage/memory.go index 50ca787..2fc8980 100644 --- a/control-plane/internal/storage/memory.go +++ b/control-plane/internal/storage/memory.go @@ -16,9 +16,11 @@ import ( "fmt" "os" "path/filepath" + "sort" "sync" "time" + "github.com/openagentlock/openagentlock/control-plane/internal/guardrails" "github.com/openagentlock/openagentlock/control-plane/internal/ledger" ) @@ -98,14 +100,16 @@ type LedgerEntry struct { } type Memory struct { - mu sync.Mutex - sessions map[string]Session - endedSess map[string]struct{} - detects map[string][]Detection - home string - nextSeq uint64 - lastLeaf [32]byte - ledgerFile *os.File + mu sync.Mutex + sessions map[string]Session + endedSess map[string]struct{} + detects map[string][]Detection + guardrailProviderConfigs map[string]guardrails.ProviderConfig + guardrailEnabled []guardrails.EnabledEntry + home string + nextSeq uint64 + lastLeaf [32]byte + ledgerFile *os.File subMu sync.Mutex subscribers map[int]chan LedgerEntry @@ -134,14 +138,15 @@ func NewMemory(home string) (*Memory, error) { return nil, fmt.Errorf("chmod %s: %w", p, err) } return &Memory{ - sessions: make(map[string]Session), - endedSess: make(map[string]struct{}), - detects: make(map[string][]Detection), - home: home, - ledgerFile: f, - nextSeq: nextSeq, - lastLeaf: lastLeaf, - subscribers: make(map[int]chan LedgerEntry), + sessions: make(map[string]Session), + endedSess: make(map[string]struct{}), + detects: make(map[string][]Detection), + guardrailProviderConfigs: make(map[string]guardrails.ProviderConfig), + home: home, + ledgerFile: f, + nextSeq: nextSeq, + lastLeaf: lastLeaf, + subscribers: make(map[int]chan LedgerEntry), }, nil } @@ -164,6 +169,62 @@ func (m *Memory) GetDetections(_ context.Context, sessionID string) ([]Detection return append([]Detection(nil), m.detects[sessionID]...), nil } +func (m *Memory) SaveGuardrailProviderConfig(_ context.Context, cfg guardrails.ProviderConfig) error { + m.mu.Lock() + defer m.mu.Unlock() + if m.guardrailProviderConfigs == nil { + m.guardrailProviderConfigs = map[string]guardrails.ProviderConfig{} + } + cfg.Metadata = cloneStringMap(cfg.Metadata) + m.guardrailProviderConfigs[cfg.ProviderID] = cfg + return nil +} + +func (m *Memory) GetGuardrailProviderConfig(_ context.Context, providerID string) (guardrails.ProviderConfig, bool, error) { + m.mu.Lock() + defer m.mu.Unlock() + cfg, ok := m.guardrailProviderConfigs[providerID] + if !ok { + return guardrails.ProviderConfig{}, false, nil + } + cfg.Metadata = cloneStringMap(cfg.Metadata) + return cfg, true, nil +} + +func (m *Memory) ListGuardrailProviderConfigs(_ context.Context) ([]guardrails.ProviderConfig, error) { + m.mu.Lock() + defer m.mu.Unlock() + out := make([]guardrails.ProviderConfig, 0, len(m.guardrailProviderConfigs)) + for _, cfg := range m.guardrailProviderConfigs { + cfg.Metadata = cloneStringMap(cfg.Metadata) + out = append(out, cfg) + } + sort.Slice(out, func(i, j int) bool { + return out[i].ProviderID < out[j].ProviderID + }) + return out, nil +} + +func (m *Memory) SaveGuardrailEnabled(_ context.Context, entries []guardrails.EnabledEntry) ([]guardrails.EnabledEntry, error) { + m.mu.Lock() + defer m.mu.Unlock() + m.guardrailEnabled = append([]guardrails.EnabledEntry(nil), entries...) + return cloneEnabledEntries(m.guardrailEnabled), nil +} + +func (m *Memory) ListGuardrailEnabled(_ context.Context) ([]guardrails.EnabledEntry, error) { + m.mu.Lock() + defer m.mu.Unlock() + return cloneEnabledEntries(m.guardrailEnabled), nil +} + +func cloneEnabledEntries(entries []guardrails.EnabledEntry) []guardrails.EnabledEntry { + if len(entries) == 0 { + return []guardrails.EnabledEntry{} + } + return append([]guardrails.EnabledEntry(nil), entries...) +} + // Subscribe returns a channel that receives every entry AppendLedger // writes, starting after subscription. Caller must call the returned // cancel fn to release the slot. @@ -416,3 +477,14 @@ func (m *Memory) AppendLedger(_ context.Context, in AppendInput) (LedgerEntry, e go m.broadcast(entry) return entry, nil } + +func cloneStringMap(in map[string]string) map[string]string { + if len(in) == 0 { + return nil + } + out := make(map[string]string, len(in)) + for k, v := range in { + out[k] = v + } + return out +} diff --git a/control-plane/internal/storage/memory_test.go b/control-plane/internal/storage/memory_test.go index 330eb9f..97e656a 100644 --- a/control-plane/internal/storage/memory_test.go +++ b/control-plane/internal/storage/memory_test.go @@ -4,12 +4,16 @@ import ( "bufio" "context" "errors" + "fmt" "os" "path/filepath" + "reflect" "strings" "sync" "testing" "time" + + "github.com/openagentlock/openagentlock/control-plane/internal/guardrails" ) func newTestStore(t *testing.T) (*Memory, string) { @@ -260,3 +264,75 @@ func TestMemory_LedgerFileCreatedMode0600(t *testing.T) { t.Fatalf("ledger.jsonl mode = %v, want 0600", st.Mode().Perm()) } } + +func TestMemoryStore_GuardrailStateRoundTrip(t *testing.T) { + s, _ := newTestStore(t) + ctx := context.Background() + + cfg := guardrails.ProviderConfig{ + ProviderID: "nvidia", + APIKey: guardrails.SecretString("nvapi-test"), + Metadata: map[string]string{"region": "us-west"}, + } + cfg2 := guardrails.ProviderConfig{ + ProviderID: "openrouter", + APIKey: guardrails.SecretString("sk-or-test"), + Metadata: map[string]string{"tier": "dev"}, + } + if err := s.SaveGuardrailProviderConfig(ctx, cfg2); err != nil { + t.Fatalf("SaveGuardrailProviderConfig second: %v", err) + } + if err := s.SaveGuardrailProviderConfig(ctx, cfg); err != nil { + t.Fatalf("SaveGuardrailProviderConfig: %v", err) + } + + gotCfg, ok, err := s.GetGuardrailProviderConfig(ctx, "nvidia") + if err != nil { + t.Fatalf("GetGuardrailProviderConfig: %v", err) + } + if !ok { + t.Fatalf("expected provider config to exist") + } + if !reflect.DeepEqual(cfg, gotCfg) { + t.Fatalf("provider config = %+v ok=%v", gotCfg, ok) + } + if leaked := fmt.Sprintf("%+v", gotCfg); strings.Contains(leaked, "nvapi-test") { + t.Fatalf("provider config formatting leaked API key: %s", leaked) + } + + gotList, err := s.ListGuardrailProviderConfigs(ctx) + if err != nil { + t.Fatalf("ListGuardrailProviderConfigs: %v", err) + } + wantList := []guardrails.ProviderConfig{cfg, cfg2} + if !reflect.DeepEqual(wantList, gotList) { + t.Fatalf("provider config list mismatch: want %+v got %+v", wantList, gotList) + } + + wantEnabled := []guardrails.EnabledEntry{ + {ProviderID: "nvidia", EntryID: "llama-3.1-nemoguard-8b-content-safety"}, + } + if saved, err := s.SaveGuardrailEnabled(ctx, wantEnabled); err != nil { + t.Fatalf("SaveGuardrailEnabled: %v", err) + } else if !reflect.DeepEqual(wantEnabled, saved) { + t.Fatalf("saved enabled mismatch: want %+v got %+v", wantEnabled, saved) + } + gotEnabled, err := s.ListGuardrailEnabled(ctx) + if err != nil { + t.Fatalf("ListGuardrailEnabled: %v", err) + } + if !reflect.DeepEqual(wantEnabled, gotEnabled) { + t.Fatalf("enabled mismatch: want %+v got %+v", wantEnabled, gotEnabled) + } +} + +func TestMemoryStore_GuardrailEnabledEmptyListIsNonNil(t *testing.T) { + s, _ := newTestStore(t) + got, err := s.ListGuardrailEnabled(context.Background()) + if err != nil { + t.Fatalf("ListGuardrailEnabled: %v", err) + } + if got == nil { + t.Fatalf("enabled entries should be empty slice, got nil") + } +} diff --git a/control-plane/internal/storage/storage.go b/control-plane/internal/storage/storage.go index 6977fba..76d6d91 100644 --- a/control-plane/internal/storage/storage.go +++ b/control-plane/internal/storage/storage.go @@ -4,7 +4,11 @@ package storage -import "context" +import ( + "context" + + "github.com/openagentlock/openagentlock/control-plane/internal/guardrails" +) type Storage interface { Health(ctx context.Context) error @@ -24,6 +28,17 @@ type Storage interface { // SaveDetections replaces the detection set reported for the session. SaveDetections(ctx context.Context, sessionID string, dets []Detection) error GetDetections(ctx context.Context, sessionID string) ([]Detection, error) + // SaveGuardrailProviderConfig creates or updates a guardrail provider configuration. + SaveGuardrailProviderConfig(ctx context.Context, cfg guardrails.ProviderConfig) error + // GetGuardrailProviderConfig retrieves a provider configuration by ID. + // The found return is true when a config exists, false otherwise. + GetGuardrailProviderConfig(ctx context.Context, providerID string) (guardrails.ProviderConfig, bool, error) + // ListGuardrailProviderConfigs returns all guardrail provider configurations. + ListGuardrailProviderConfigs(ctx context.Context) ([]guardrails.ProviderConfig, error) + // SaveGuardrailEnabled replaces the entire enabled guardrail set and returns the saved set. + SaveGuardrailEnabled(ctx context.Context, entries []guardrails.EnabledEntry) ([]guardrails.EnabledEntry, error) + // ListGuardrailEnabled returns every enabled guardrail entry. + ListGuardrailEnabled(ctx context.Context) ([]guardrails.EnabledEntry, error) } type Detection struct { diff --git a/docker-compose.yml b/docker-compose.yml index e373d89..619e747 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -4,6 +4,9 @@ # required. # # Usage: +# # optional external guardrails: +# # export NVIDIA_API_KEY=... +# # export OPENROUTER_API_KEY=... # docker compose up -d # docker compose logs -f control-plane # docker compose down @@ -23,6 +26,10 @@ services: AGENTLOCK_LISTEN: "0.0.0.0:7878" AGENTLOCK_DASHBOARD_LISTEN: "0.0.0.0:7879" AGENTLOCK_HOME: "/var/lib/agentlock" + # Optional external guardrail providers. Values are read once at + # control-plane startup, kept in daemon RAM, and cleared on restart. + NVIDIA_API_KEY: "${NVIDIA_API_KEY:-}" + OPENROUTER_API_KEY: "${OPENROUTER_API_KEY:-}" volumes: - agentlock-state:/var/lib/agentlock healthcheck: diff --git a/docs/architecture/llm-guardrails.md b/docs/architecture/llm-guardrails.md index 7dd4eb8..4565ca8 100644 --- a/docs/architecture/llm-guardrails.md +++ b/docs/architecture/llm-guardrails.md @@ -1,6 +1,6 @@ # LLM-based guardrails (roadmap) -> **Status: Not yet implemented** — design is in flight; this doc captures the shape we're aiming at so policy authors and contributors can reason about it before code lands. +> **Status: Initial external-provider slice shipped** — OpenAgentLock now has startup-env provider configuration, catalog discovery, runtime classifier evaluation for NVIDIA-style guardrails, and web/TUI visibility. Broader policy-schema integration and community-rule contribution flows remain roadmap. OpenAgentLock's evaluator is intentionally **deterministic YAML**: a path-shape match plus a verdict. That is the right default — fast, auditable, no LLM-shaped failure modes in the hot path. But there are policy questions that genuinely cannot be decided by regex, and forcing them through `command_regex` produces either dangerous false negatives or unworkable false positives. Examples we've hit in the wild: @@ -8,7 +8,7 @@ OpenAgentLock's evaluator is intentionally **deterministic YAML**: a path-shape - "Refuse to assist with self-harm or weapons synthesis" when the agent is wrapping a chat assistant. - "Block exfiltration of company-confidential content" when the content is extracted from real source files (not a regex away). -For these we want a second tier of evaluator that calls into a small **safety classifier** model — running locally — and returns a numeric score that policy can act on. +For these we want a second tier of evaluator that calls into a **safety classifier** model and returns a structured result that OpenAgentLock can audit. The first shipped slice supports external hosted providers as an explicit opt-in. Local-only operation remains the default when no provider is configured. ## Severity model @@ -78,10 +78,34 @@ gates: Key invariants: -1. **Local-first.** Model serving runs on the same host as the daemon. No prompts leave the box. Default backend is `ollama` (which already supports the NeMo Guard models); we'll wire `vllm` and `llama.cpp` as alternates. +1. **Local policy first.** Deterministic YAML policy runs before external guardrails. Local deny stops immediately; external guardrails only run after a local allow. 2. **Abstain, not deny, on failure.** If the model is slow/unavailable, the rule abstains (returns `skip`) rather than denying — guardrails must not become a DoS surface. The `monitor` policy mode logs the abstention. 3. **Deterministic core stays primary.** Guardrail evaluators run *after* deterministic regex/glob matchers. The simple matchers absorb the easy cases; the LLM only sees what the policy explicitly routes to it. -4. **Same audit trail.** Verdicts include the classifier output; ledger entries record `signer`, `model`, and the structured taxonomy verdict so a verifier can reproduce the call. +4. **Same audit trail.** Verdicts keep local policy, guardrail, and final verdicts distinct. A guardrail deny is not presented as a deterministic YAML rule hit. + +## Shipped external-provider slice + +The daemon exposes: + +- `GET /v1/guardrails/providers` +- `POST /v1/guardrails/providers/{id}/test` +- `GET /v1/guardrails/catalog` +- `PUT /v1/guardrails/enabled` +- `GET /v1/guardrails/traces/{seq}` + +Provider credentials are read from environment variables when the control plane starts, not from the web dashboard or CLI: + +```bash +NVIDIA_API_KEY=... OPENROUTER_API_KEY=... docker compose up -d +``` + +The daemon stores these keys in RAM only. They are never written by the dashboard and are cleared on daemon restart. + +Provider behavior: + +- NVIDIA catalog entries are normalized as `classifier_model` entries and can run in the post-local-policy runtime stage. +- OpenRouter guardrails are normalized as `account_policy` entries. They are visible in catalog surfaces but do not run as OpenAgentLock runtime classifiers in this slice, so they should be treated as catalog visibility rather than enforcement. +- Provider errors, unsupported runtime entries, and malformed classifier responses produce `abstain`, not implicit allow or deny. ## Wire shape @@ -106,12 +130,18 @@ Policies can mix-and-match — most rules will stay regex; a small number will r | Step | Status | |---|---| | Roadmap doc (this file) | Shipped | +| `/v1/guardrails/providers` provider registry endpoint | Shipped | +| `/v1/guardrails/catalog` normalized provider catalog | Shipped | +| NVIDIA runtime classifier integration | Shipped | +| OpenRouter account-policy catalog visibility | Shipped | +| Startup-env RAM provider key configuration | Shipped | +| Catalog cache / abstain semantics + tests | Shipped | +| Provider-measured runtime latency in traces | Not yet implemented | +| Dashboard and TUI multi-stage trace visibility | Shipped | | Add `kind: llm_guardrail` to the policy schema, parser only | Not yet implemented | -| `/v1/llm/models` registry endpoint | Not yet implemented | -| First evaluator implementation against `ollama` running NeMo Guard locally | Not yet implemented | -| Latency / cache / abstain semantics + tests | Not yet implemented | +| Local `ollama` / `vllm` / `llama.cpp` backends | Not yet implemented | | Community-rule shape (`schema_version: 2` adds the kind) | Not yet implemented | -| Dashboard UI shows model verdicts in the event row | Not yet implemented | +| Opt-in community contribution pipeline | Not yet implemented | ## Why a roadmap doc instead of just an issue diff --git a/docs/guide/dashboard.md b/docs/guide/dashboard.md index a877931..98c5cf7 100644 --- a/docs/guide/dashboard.md +++ b/docs/guide/dashboard.md @@ -16,6 +16,7 @@ Both read from the same ledger and policy state — pick whichever fits the mome - **False-positive repair** — open a blocked or monitor-alert event detail, report it as a false positive, validate replacement gate YAML, then atomically disable the old rule and install the replacement - **Mode toggle** — flip the daemon between `monitor` and `enforce` (separate from the policy file's own `mode`) - **MCP pin queue** — accept or reject newly seen MCP servers +- **External guardrails** — browse discovered guardrail catalog entries, enable runtime-capable classifier entries, and inspect event traces that show local policy, external guardrail, and final verdict stages. In the current slice, NVIDIA entries can run at runtime after a local allow; OpenRouter entries are catalog-only. Provider API keys are read from control-plane startup environment variables, not the browser. ## What the terminal dashboard does @@ -30,8 +31,15 @@ agentlock dashboard - **Loaded gates** — the gates the daemon currently evaluates - **Mode flip** — one keypress to toggle the daemon between `monitor` and `enforce` - **False-positive repair** — open event detail for a matched deny or alert row and press `f` to edit, validate, and apply a replacement gate +- **Guardrails catalog** — see configured external providers and discovered catalog entries from the terminal dashboard -Rule edits and the MCP pin queue still live on the web dashboard. +Rule edits and the MCP pin queue still live on the web dashboard. Provider keys are process environment only: + +```bash +NVIDIA_API_KEY=... OPENROUTER_API_KEY=... docker compose up -d +``` + +Guardrail keys are kept in daemon memory only; restarting the daemon clears them. ## Why two surfaces diff --git a/docs/guide/getting-started.md b/docs/guide/getting-started.md index 083cd23..6cf457d 100644 --- a/docs/guide/getting-started.md +++ b/docs/guide/getting-started.md @@ -14,9 +14,14 @@ The control plane is a small Go HTTP service that lives in a Docker container. I ```bash curl -O https://raw.githubusercontent.com/openagentlock/openagentlock/main/docker-compose.yml + # Optional external guardrails: + # export NVIDIA_API_KEY=... + # export OPENROUTER_API_KEY=... docker compose up -d ``` + `NVIDIA_API_KEY` enables the shipped runtime-classifier path. `OPENROUTER_API_KEY` currently enables OpenRouter catalog discovery only. + === "docker run" ```bash @@ -25,6 +30,8 @@ The control plane is a small Go HTTP service that lives in a Docker container. I -v agentlock-state:/var/lib/agentlock \ -p 127.0.0.1:7878:7878 \ -p 127.0.0.1:7879:7879 \ + -e NVIDIA_API_KEY \ + -e OPENROUTER_API_KEY \ ghcr.io/openagentlock/agentlockd:latest ``` diff --git a/docs/guide/installation.md b/docs/guide/installation.md index 9144525..2d60d06 100644 --- a/docs/guide/installation.md +++ b/docs/guide/installation.md @@ -45,6 +45,16 @@ The compose file references `ghcr.io/openagentlock/agentlockd:latest` and binds State is persisted in a named Docker volume (`agentlock-state`) so ledger entries survive restarts. +Optional external guardrails are configured by environment when the control plane starts. These keys are read once, kept in daemon memory, and are not written to disk: + +```bash +export NVIDIA_API_KEY=... +export OPENROUTER_API_KEY=... +docker compose up -d +``` + +Today, `NVIDIA_API_KEY` enables post-local-allow runtime classification. `OPENROUTER_API_KEY` enables catalog visibility for OpenRouter guardrails, but those entries do not run as OpenAgentLock runtime classifiers yet. + ### `docker run` ```bash @@ -52,6 +62,8 @@ docker run -d --name agentlock \ -v agentlock-state:/var/lib/agentlock \ -p 127.0.0.1:7878:7878 \ -p 127.0.0.1:7879:7879 \ + -e NVIDIA_API_KEY \ + -e OPENROUTER_API_KEY \ ghcr.io/openagentlock/agentlockd:latest ``` diff --git a/docs/index.md b/docs/index.md index 0f1c338..3ad237d 100644 --- a/docs/index.md +++ b/docs/index.md @@ -115,6 +115,9 @@ docker pull ghcr.io/openagentlock/agentlockd:latest # 2. Start it (drops a docker-compose example in your CWD) curl -O https://raw.githubusercontent.com/openagentlock/openagentlock/main/docker-compose.yml +# Optional external guardrails: +# export NVIDIA_API_KEY=... +# export OPENROUTER_API_KEY=... docker compose up -d # 3. Install the CLI and wire up your agents