diff --git a/src/core/config/ProviderSettingsManager.ts b/src/core/config/ProviderSettingsManager.ts index 6088bd68fe2..dd2caf7fbe6 100644 --- a/src/core/config/ProviderSettingsManager.ts +++ b/src/core/config/ProviderSettingsManager.ts @@ -368,6 +368,19 @@ export class ProviderSettingsManager { typeof config.apiProvider === "string" && isRetiredProvider(config.apiProvider) ? providerSettingsWithIdSchema.passthrough().parse(config) : discriminatedProviderSettingsWithIdSchema.parse(config) + + // Preserve existing secret values when the webview sends back the + // "__ROO_REDACTED__" sentinel (secrets are redacted in the webview state + // to prevent API key exposure). + const existingConfig = providerProfiles.apiConfigs[name] + if (existingConfig) { + for (const key of Object.keys(filteredConfig)) { + if (isSecretStateKey(key) && (filteredConfig as any)[key] === "__ROO_REDACTED__") { + ;(filteredConfig as any)[key] = (existingConfig as any)[key] + } + } + } + providerProfiles.apiConfigs[name] = { ...filteredConfig, id } await this.store(providerProfiles) return id diff --git a/src/core/tools/ExecuteCommandTool.ts b/src/core/tools/ExecuteCommandTool.ts index 8fcb917b134..b356ebab1bd 100644 --- a/src/core/tools/ExecuteCommandTool.ts +++ b/src/core/tools/ExecuteCommandTool.ts @@ -11,7 +11,7 @@ import { Task } from "../task/Task" import { ToolUse, ToolResponse } from "../../shared/tools" import { formatResponse } from "../prompts/responses" -import { unescapeHtmlEntities } from "../../utils/text-normalization" +import { unescapeHtmlEntities, sanitizeForPromptInjection } from "../../utils/text-normalization" import { ExitCodeDetails, RooTerminalCallbacks, RooTerminalProcess } from "../../integrations/terminal/types" import { TerminalRegistry } from "../../integrations/terminal/TerminalRegistry" import { Terminal } from "../../integrations/terminal/Terminal" @@ -459,6 +459,8 @@ export async function executeCommandInTerminal( await onCompletedPromise } + const safeResult = sanitizeForPromptInjection(result) + if (message) { const { text, images } = message await task.say("user_feedback", text, images) @@ -468,7 +470,7 @@ export async function executeCommandInTerminal( formatResponse.toolResult( [ `Command is still running in terminal from '${terminal.getCurrentWorkingDirectory().toPosix()}'.`, - result.length > 0 ? `Here's the output so far:\n${result}\n` : "\n", + safeResult.length > 0 ? `Here's the output so far:\n${safeResult}\n` : "\n", `\n${text}\n`, ].join("\n"), images, @@ -509,14 +511,14 @@ export async function executeCommandInTerminal( return [ false, - `Command executed in terminal within working directory '${currentWorkingDir}'. ${exitStatus}\nOutput:\n${result}`, + `Command executed in terminal within working directory '${currentWorkingDir}'. ${exitStatus}\nOutput:\n${safeResult}`, ] } else { return [ false, [ `Command is still running in terminal ${workingDir ? ` from '${workingDir.toPosix()}'` : ""}.`, - result.length > 0 ? `Here's the output so far:\n${result}\n` : "\n", + safeResult.length > 0 ? `Here's the output so far:\n${safeResult}\n` : "\n", "You will be updated on the terminal status and new output in the future.", ].join("\n"), ] @@ -569,7 +571,7 @@ function formatPersistedOutput( `Output (${sizeStr}) persisted. Artifact ID: ${artifactId}`, "", "Preview:", - result.preview, + sanitizeForPromptInjection(result.preview), "", "Use read_command_output tool to view full output if needed.", ].join("\n") diff --git a/src/core/tools/ReadFileTool.ts b/src/core/tools/ReadFileTool.ts index 8ad6a3b33d1..200419b04d7 100644 --- a/src/core/tools/ReadFileTool.ts +++ b/src/core/tools/ReadFileTool.ts @@ -21,6 +21,7 @@ import { RecordSource } from "../context-tracking/FileContextTrackerTypes" import { isPathOutsideWorkspace } from "../../utils/pathUtils" import { getReadablePath } from "../../utils/path" import { extractTextFromFile, addLineNumbers, getSupportedBinaryFormats } from "../../integrations/misc/extract-text" +import { sanitizeForPromptInjection } from "../../utils/text-normalization" import { readWithIndentation, readWithSlice } from "../../integrations/misc/indentation-reader" import { DEFAULT_LINE_LIMIT } from "../prompts/tools/native-tools/read_file" import type { ToolUse, PushToolResult } from "../../shared/tools" @@ -221,7 +222,7 @@ export class ReadFileTool extends BaseTool<"read_file"> { await task.fileContextTracker.trackFileContext(relPath, "read_tool" as RecordSource) updateFileResult(relPath, { - nativeContent: `File: ${relPath}\n${result}`, + nativeContent: `File: ${relPath}\n${sanitizeForPromptInjection(result)}`, }) } catch (error) { const errorMsg = error instanceof Error ? error.message : String(error) @@ -397,7 +398,7 @@ export class ReadFileTool extends BaseTool<"read_file"> { updateFileResult(relPath, { nativeContent: lineCount > 0 - ? `File: ${relPath}\nLines 1-${lineCount}:\n${numberedContent}` + ? `File: ${relPath}\nLines 1-${lineCount}:\n${sanitizeForPromptInjection(numberedContent)}` : `File: ${relPath}\nNote: File is empty`, }) return @@ -794,7 +795,7 @@ export class ReadFileTool extends BaseTool<"read_file"> { } } - results.push(`File: ${relPath}\n${content}`) + results.push(`File: ${relPath}\n${sanitizeForPromptInjection(content)}`) // Track file in context await task.fileContextTracker.trackFileContext(relPath, "read_tool") diff --git a/src/core/webview/ClineProvider.ts b/src/core/webview/ClineProvider.ts index 1106d340050..78e2f4433e9 100644 --- a/src/core/webview/ClineProvider.ts +++ b/src/core/webview/ClineProvider.ts @@ -97,6 +97,7 @@ import { Task } from "../task/Task" import { webviewMessageHandler } from "./webviewMessageHandler" import type { ClineMessage, TodoItem } from "@roo-code/types" +import { isSecretStateKey } from "@roo-code/types" import { readApiMessages, saveApiMessages, saveTaskMessages, TaskHistoryStore } from "../task-persistence" import { readTaskMessages } from "../task-persistence/taskMessages" import { getNonce } from "./getNonce" @@ -2468,9 +2469,17 @@ export class ClineProvider ) } + // Redact secrets before sending to webview to prevent API key exposure. + const redactedApiConfiguration = { ...providerSettings } + for (const key of Object.keys(redactedApiConfiguration)) { + if (isSecretStateKey(key) && typeof (redactedApiConfiguration as any)[key] === "string") { + ;(redactedApiConfiguration as any)[key] = "__ROO_REDACTED__" + } + } + // Return the same structure as before. return { - apiConfiguration: providerSettings, + apiConfiguration: redactedApiConfiguration, lastShownAnnouncementId: stateValues.lastShownAnnouncementId, customInstructions: stateValues.customInstructions, apiModelId: stateValues.apiModelId, @@ -2571,7 +2580,7 @@ export class ClineProvider maxGitStatusFiles: stateValues.maxGitStatusFiles ?? 0, taskSyncEnabled, imageGenerationProvider: stateValues.imageGenerationProvider, - openRouterImageApiKey: stateValues.openRouterImageApiKey, + openRouterImageApiKey: stateValues.openRouterImageApiKey ? "__ROO_REDACTED__" : undefined, openRouterImageGenerationSelectedModel: stateValues.openRouterImageGenerationSelectedModel, } } diff --git a/src/integrations/misc/extract-text.ts b/src/integrations/misc/extract-text.ts index f29fa915d13..fc4eca812aa 100644 --- a/src/integrations/misc/extract-text.ts +++ b/src/integrations/misc/extract-text.ts @@ -7,6 +7,7 @@ import { isBinaryFile } from "isbinaryfile" import { extractTextFromXLSX } from "./extract-text-from-xlsx" import { readWithSlice } from "./indentation-reader" import { DEFAULT_LINE_LIMIT } from "../../core/prompts/tools/native-tools/read_file" +import { sanitizeForPromptInjection } from "../../utils/text-normalization" async function extractTextFromPDF(filePath: string): Promise { const dataBuffer = await fs.readFile(filePath) @@ -91,7 +92,7 @@ export async function extractTextFromFileWithMetadata( const extractor = SUPPORTED_BINARY_FORMATS[fileExtension as keyof typeof SUPPORTED_BINARY_FORMATS] if (extractor) { // For binary formats, extract and count lines - const content = await extractor(filePath) + const content = sanitizeForPromptInjection(await extractor(filePath)) const lines = content.split("\n") return { content, @@ -130,7 +131,7 @@ export async function extractTextFromFileWithMetadata( */ export async function extractTextFromFile(filePath: string): Promise { const result = await extractTextFromFileWithMetadata(filePath) - return result.content + return sanitizeForPromptInjection(result.content) } export function addLineNumbers(content: string, startLine: number = 1): string { diff --git a/src/integrations/terminal/ExecaTerminalProcess.ts b/src/integrations/terminal/ExecaTerminalProcess.ts index cc2af938027..82d39f0bfda 100644 --- a/src/integrations/terminal/ExecaTerminalProcess.ts +++ b/src/integrations/terminal/ExecaTerminalProcess.ts @@ -40,7 +40,7 @@ export class ExecaTerminalProcess extends BaseTerminalProcess { this.isHot = true this.subprocess = execa({ - shell: BaseTerminal.getExecaShellPath() || true, + shell: BaseTerminal.getExecaShellPath() || false, cwd: this.terminal.getCurrentWorkingDirectory(), all: true, // Ignore stdin to ensure non-interactive mode and prevent hanging diff --git a/src/integrations/terminal/__tests__/ExecaTerminalProcess.spec.ts b/src/integrations/terminal/__tests__/ExecaTerminalProcess.spec.ts index 5f0a21869ec..500c67260b8 100644 --- a/src/integrations/terminal/__tests__/ExecaTerminalProcess.spec.ts +++ b/src/integrations/terminal/__tests__/ExecaTerminalProcess.spec.ts @@ -63,7 +63,7 @@ describe("ExecaTerminalProcess", () => { const execaMock = vitest.mocked(execa) expect(execaMock).toHaveBeenCalledWith( expect.objectContaining({ - shell: true, + shell: false, cwd: "/test/cwd", all: true, env: expect.objectContaining({ @@ -105,13 +105,13 @@ describe("ExecaTerminalProcess", () => { ) }) - it("should fall back to shell=true when execaShellPath is undefined", async () => { + it("should fall back to shell=false when execaShellPath is undefined", async () => { BaseTerminal.setExecaShellPath(undefined) await terminalProcess.run("echo test") const execaMock = vitest.mocked(execa) expect(execaMock).toHaveBeenCalledWith( expect.objectContaining({ - shell: true, + shell: false, }), ) }) diff --git a/src/utils/__tests__/text-normalization.spec.ts b/src/utils/__tests__/text-normalization.spec.ts index e672617d18b..b5f34709976 100644 --- a/src/utils/__tests__/text-normalization.spec.ts +++ b/src/utils/__tests__/text-normalization.spec.ts @@ -1,4 +1,4 @@ -import { normalizeString, unescapeHtmlEntities } from "../text-normalization" +import { normalizeString, unescapeHtmlEntities, sanitizeForPromptInjection } from "../text-normalization" describe("Text normalization utilities", () => { describe("normalizeString", () => { @@ -100,5 +100,26 @@ describe("Text normalization utilities", () => { const expected = "array[0] and [1]" expect(unescapeHtmlEntities(input)).toBe(expected) }) + + describe("sanitizeForPromptInjection", () => { + it("escapes XML-like tags", () => { + expect(sanitizeForPromptInjection("inject")).toBe( + "\\inject\\", + ) + }) + + it("escapes HTML comment-like sequences", () => { + expect(sanitizeForPromptInjection("")).toBe("\\") + }) + + it("does not escape standalone less-than signs", () => { + expect(sanitizeForPromptInjection("a < b")).toBe("a < b") + }) + + it("returns original string when no tags are present", () => { + const original = "Plain text without any markup" + expect(sanitizeForPromptInjection(original)).toBe(original) + }) + }) }) }) diff --git a/src/utils/text-normalization.ts b/src/utils/text-normalization.ts index 9e25d140c4e..88df034a350 100644 --- a/src/utils/text-normalization.ts +++ b/src/utils/text-normalization.ts @@ -76,6 +76,17 @@ export function normalizeString(str: string, options: NormalizeOptions = DEFAULT return normalized } +/** + * Escapes potential XML/HTML-like tags to prevent indirect prompt injection + * via tool outputs (command output, file contents, etc.). + * + * @param content The untrusted content to sanitize + * @returns The sanitized content with tag-like sequences escaped + */ +export function sanitizeForPromptInjection(content: string): string { + return content.replace(/<(\/?[a-zA-Z!?])/g, "\\<$1") +} + /** * Unescapes common HTML entities in a string *