Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package vip.mate.tool.guard.guardian;

import io.modelcontextprotocol.spec.McpSchema;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import vip.mate.tool.guard.engine.ToolGuardRuleRegistry;
import vip.mate.tool.guard.model.*;
import vip.mate.tool.mcp.runtime.McpClientManager;
import vip.mate.tool.mcp.runtime.McpToolNameResolver;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

/**
* MCP 工具名审批守卫。
*
* <p>仅匹配明确配置了 {@code tool_name} 的规则,允许配置运行时 prefixed
* 名称,或配置 MCP server 暴露的 raw tool name。
*/
@Slf4j
@Component
public class McpToolApprovalGuardian implements ToolGuardGuardian {

private final ToolGuardRuleRegistry ruleRegistry;
private final McpClientManager mcpClientManager;

public McpToolApprovalGuardian(ToolGuardRuleRegistry ruleRegistry,
McpClientManager mcpClientManager) {
this.ruleRegistry = ruleRegistry;
this.mcpClientManager = mcpClientManager;
}

@Override
public boolean supports(ToolInvocationContext context) {
return McpToolNameResolver.isMcpPrefixedName(context.toolName());
}

@Override
public int priority() {
return 180;
}

@Override
public List<GuardFinding> evaluate(ToolInvocationContext context) {
String prefixedName = context.toolName();
String rawName = resolveRawToolName(prefixedName);
List<GuardFinding> findings = new ArrayList<>();

for (ToolGuardRuleEntity rule : ruleRegistry.getAllEnabled()) {
String configuredToolName = normalize(rule.getToolName());
if (configuredToolName == null) {
continue;
}
if (!matchesRule(rule, configuredToolName, prefixedName, rawName)) {
continue;
}

findings.add(new GuardFinding(
fallback(rule.getRuleId(), "MCP_TOOL_APPROVAL"),
approvalSeverity(rule.getSeverity()),
category(rule.getCategory()),
fallback(rule.getName(), "MCP 工具调用审批"),
fallback(rule.getDescription(), "MCP 工具调用匹配审批规则,需要用户确认后执行"),
fallback(rule.getRemediation(), "请确认是否允许执行该 MCP 工具"),
prefixedName,
"tool_name",
configuredToolName,
rawName != null ? rawName : prefixedName,
Map.of(
"configuredToolName", configuredToolName,
"prefixedToolName", prefixedName,
"rawToolName", rawName != null ? rawName : ""
)
));
}

return findings;
}

private static boolean matchesRule(ToolGuardRuleEntity rule, String configuredToolName,
String prefixedName, String rawName) {
if (configuredToolName.equals(prefixedName)) {
return true;
}
if (rawName == null || !configuredToolName.equals(rawName)) {
return false;
}
return !Boolean.TRUE.equals(rule.getBuiltin());
}

private String resolveRawToolName(String prefixedName) {
McpToolNameResolver.ParsedRef parsed = McpToolNameResolver.parse(prefixedName);
if (parsed == null) {
return null;
}
try {
for (McpSchema.Tool tool : mcpClientManager.getServerTools(parsed.serverId())) {
String raw = tool != null ? tool.name() : null;
if (raw != null && McpToolNameResolver.hash6(raw).equals(parsed.hash6())) {
return raw;
}
}
} catch (Exception e) {
log.debug("[McpToolApprovalGuardian] Failed to resolve raw MCP tool name for {}: {}",
prefixedName, e.getMessage());
}
return null;
}

private static GuardSeverity approvalSeverity(String configuredSeverity) {
if (configuredSeverity == null || configuredSeverity.isBlank()) {
return GuardSeverity.MEDIUM;
}
try {
GuardSeverity severity = GuardSeverity.valueOf(configuredSeverity.trim());
if (severity == GuardSeverity.HIGH) {
return GuardSeverity.HIGH;
}
} catch (IllegalArgumentException ignored) {
// Fall through to MEDIUM so a malformed rule still asks for approval.
}
return GuardSeverity.MEDIUM;
}

private static GuardCategory category(String configuredCategory) {
if (configuredCategory == null || configuredCategory.isBlank()) {
return GuardCategory.CODE_EXECUTION;
}
try {
return GuardCategory.valueOf(configuredCategory.trim());
} catch (IllegalArgumentException e) {
return GuardCategory.CODE_EXECUTION;
}
}

private static String normalize(String value) {
if (value == null || value.isBlank()) {
return null;
}
return value.trim();
}

private static String fallback(String value, String fallback) {
return value == null || value.isBlank() ? fallback : value;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
package vip.mate.tool.guard.guardian;

import io.modelcontextprotocol.spec.McpSchema;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import vip.mate.tool.guard.engine.ToolGuardRuleRegistry;
import vip.mate.tool.guard.model.*;
import vip.mate.tool.mcp.runtime.McpClientManager;
import vip.mate.tool.mcp.runtime.McpToolNameResolver;

import java.util.List;

import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

class McpToolApprovalGuardianTest {

private ToolGuardRuleRegistry ruleRegistry;
private McpClientManager mcpClientManager;
private McpToolApprovalGuardian guardian;

@BeforeEach
void setUp() {
ruleRegistry = mock(ToolGuardRuleRegistry.class);
mcpClientManager = mock(McpClientManager.class);
guardian = new McpToolApprovalGuardian(ruleRegistry, mcpClientManager);
}

@Test
@DisplayName("prefixed MCP tool name rule creates approval finding")
void prefixedRuleMatches() {
String toolName = McpToolNameResolver.prefixedName(42L, "search_entity");
when(ruleRegistry.getAllEnabled()).thenReturn(List.of(rule("MCP_SEARCH", toolName, "HIGH")));

List<GuardFinding> findings = guardian.evaluate(context(toolName));

assertEquals(1, findings.size());
assertEquals("MCP_SEARCH", findings.get(0).ruleId());
assertEquals(GuardSeverity.HIGH, findings.get(0).severity());
assertEquals(toolName, findings.get(0).toolName());
assertEquals(toolName, findings.get(0).matchedPattern());
}

@Test
@DisplayName("raw MCP tool name rule matches through server tool cache")
void rawRuleMatches() {
String rawName = "search_entity";
String toolName = McpToolNameResolver.prefixedName(42L, rawName);
when(ruleRegistry.getAllEnabled()).thenReturn(List.of(rule("MCP_SEARCH_RAW", rawName, "MEDIUM")));
when(mcpClientManager.getServerTools(42L)).thenReturn(List.of(tool(rawName)));

List<GuardFinding> findings = guardian.evaluate(context(toolName));

assertEquals(1, findings.size());
assertEquals("MCP_SEARCH_RAW", findings.get(0).ruleId());
assertEquals(rawName, findings.get(0).matchedPattern());
assertEquals(rawName, findings.get(0).snippet());
}

@Test
@DisplayName("blank tool_name rules are ignored for MCP approval")
void blankToolRuleIgnored() {
String toolName = McpToolNameResolver.prefixedName(42L, "search_entity");
when(ruleRegistry.getAllEnabled()).thenReturn(List.of(rule("GLOBAL", "", "HIGH")));

List<GuardFinding> findings = guardian.evaluate(context(toolName));

assertTrue(findings.isEmpty());
}

@Test
@DisplayName("raw name matching does not reuse builtin non-MCP rules")
void rawRuleDoesNotMatchBuiltinRule() {
String rawName = "execute_shell_command";
String toolName = McpToolNameResolver.prefixedName(42L, rawName);
ToolGuardRuleEntity builtinShellRule = rule("SHELL_RM", rawName, "HIGH");
builtinShellRule.setBuiltin(true);
when(ruleRegistry.getAllEnabled()).thenReturn(List.of(builtinShellRule));
when(mcpClientManager.getServerTools(42L)).thenReturn(List.of(tool(rawName)));

List<GuardFinding> findings = guardian.evaluate(context(toolName));

assertTrue(findings.isEmpty());
}

@Test
@DisplayName("critical severity is coerced to approval severity")
void criticalSeverityDoesNotBlock() {
String toolName = McpToolNameResolver.prefixedName(42L, "search_entity");
when(ruleRegistry.getAllEnabled()).thenReturn(List.of(rule("MCP_SEARCH", toolName, "CRITICAL")));

List<GuardFinding> findings = guardian.evaluate(context(toolName));
GuardDecision decision = new vip.mate.tool.guard.engine.ToolPolicyResolver()
.resolve(findings, context(toolName));

assertEquals(1, findings.size());
assertEquals(GuardSeverity.MEDIUM, findings.get(0).severity());
assertEquals(GuardDecision.NEEDS_APPROVAL, decision);
}

@Test
@DisplayName("non-MCP tools are not supported")
void nonMcpUnsupported() {
assertFalse(guardian.supports(context("execute_shell_command")));
}

private static ToolInvocationContext context(String toolName) {
return ToolInvocationContext.of(toolName, "{}", "conv-1", "agent-1");
}

private static ToolGuardRuleEntity rule(String ruleId, String toolName, String severity) {
ToolGuardRuleEntity rule = new ToolGuardRuleEntity();
rule.setRuleId(ruleId);
rule.setName(ruleId);
rule.setDescription(ruleId);
rule.setToolName(toolName);
rule.setSeverity(severity);
rule.setCategory(GuardCategory.CODE_EXECUTION.name());
rule.setDecision(GuardDecision.NEEDS_APPROVAL.name());
rule.setPattern(".*");
rule.setEnabled(true);
return rule;
}

private static McpSchema.Tool tool(String name) {
return new McpSchema.Tool(name, null, "Test tool", null, null, null, null);
}
}
Loading