From fbf1798bee5549bd45d265ac1f555eb0d7466999 Mon Sep 17 00:00:00 2001 From: emrberk Date: Thu, 5 Mar 2026 19:01:25 +0300 Subject: [PATCH 01/25] refactor: decouple providers from flow --- e2e/questdb | 2 +- .../SetupAIAssistant/ConfigurationModal.tsx | 2 +- .../SetupAIAssistant/SettingsModal.tsx | 2 +- src/scenes/Schema/VirtualTables/index.tsx | 2 +- src/utils/ai/anthropicProvider.ts | 669 +++++++ src/utils/ai/index.ts | 29 + src/utils/ai/openaiProvider.ts | 526 +++++ src/utils/ai/prompts.ts | 172 ++ src/utils/ai/registry.ts | 18 + src/utils/ai/responseFormats.ts | 55 + src/utils/ai/shared.ts | 200 ++ src/utils/ai/tools.ts | 105 + src/utils/ai/types.ts | 72 + src/utils/aiAssistant.ts | 1710 +---------------- src/utils/aiAssistantSettings.ts | 3 +- src/utils/contextCompaction.ts | 102 +- src/utils/executeAIFlow.ts | 3 +- src/utils/tokenCounting.ts | 101 - 18 files changed, 1930 insertions(+), 1843 deletions(-) create mode 100644 src/utils/ai/anthropicProvider.ts create mode 100644 src/utils/ai/index.ts create mode 100644 src/utils/ai/openaiProvider.ts create mode 100644 src/utils/ai/prompts.ts create mode 100644 src/utils/ai/registry.ts create mode 100644 src/utils/ai/responseFormats.ts create mode 100644 src/utils/ai/shared.ts create mode 100644 src/utils/ai/tools.ts create mode 100644 src/utils/ai/types.ts delete mode 100644 src/utils/tokenCounting.ts diff --git a/e2e/questdb b/e2e/questdb index 8ab362f2d..42483f08e 160000 --- a/e2e/questdb +++ b/e2e/questdb @@ -1 +1 @@ -Subproject commit 8ab362f2d269f699bdd855870144468aa6e7e5d2 +Subproject commit 42483f08e40ce6ab69e6d85d61f0265ddc8d1d41 diff --git a/src/components/SetupAIAssistant/ConfigurationModal.tsx b/src/components/SetupAIAssistant/ConfigurationModal.tsx index d48f22ab5..d3435524a 100644 --- a/src/components/SetupAIAssistant/ConfigurationModal.tsx +++ b/src/components/SetupAIAssistant/ConfigurationModal.tsx @@ -799,7 +799,7 @@ export const ConfigurationModal = ({ )?.value ?? modelsByProvider[selectedProvider][0].value try { - const result = await testApiKey(apiKey, testModel) + const result = await testApiKey(apiKey, testModel, selectedProvider) if (!result.valid) { const errorMsg = result.error || "Invalid API key" setError(errorMsg) diff --git a/src/components/SetupAIAssistant/SettingsModal.tsx b/src/components/SetupAIAssistant/SettingsModal.tsx index be389a38f..3fbc3a46b 100644 --- a/src/components/SetupAIAssistant/SettingsModal.tsx +++ b/src/components/SetupAIAssistant/SettingsModal.tsx @@ -628,7 +628,7 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { providerModels.find((m) => m.isTestModel) ?? providerModels[0] ).value try { - const result = await testApiKey(apiKey, testModel) + const result = await testApiKey(apiKey, testModel, provider) if (!result.valid) { setValidationState((prev) => ({ ...prev, [provider]: "error" })) setValidationErrors((prev) => ({ diff --git a/src/scenes/Schema/VirtualTables/index.tsx b/src/scenes/Schema/VirtualTables/index.tsx index c0edfd492..e5b32affb 100644 --- a/src/scenes/Schema/VirtualTables/index.tsx +++ b/src/scenes/Schema/VirtualTables/index.tsx @@ -135,7 +135,7 @@ const TableRow = styled(Row)<{ $contextMenuOpen: boolean }>` $contextMenuOpen && ` background: ${theme.color.tableSelection}; - border: 1px solid ${theme.color.cyan}; + box-shadow: inset 0 0 0 1px ${theme.color.cyan}; `} ` diff --git a/src/utils/ai/anthropicProvider.ts b/src/utils/ai/anthropicProvider.ts new file mode 100644 index 000000000..1b4915753 --- /dev/null +++ b/src/utils/ai/anthropicProvider.ts @@ -0,0 +1,669 @@ +import Anthropic from "@anthropic-ai/sdk" +import type { MessageParam } from "@anthropic-ai/sdk/resources/messages" +import type { Tool as AnthropicTool } from "@anthropic-ai/sdk/resources/messages" +import type { + AiAssistantAPIError, + ModelToolsClient, + StatusCallback, + StreamingCallback, + TokenUsage, +} from "../aiAssistant" +import { AIOperationStatus } from "../../providers/AIStatusProvider" +import { getModelProps } from "../aiAssistantSettings" +import type { + AIProvider, + FlowConfig, + ResponseFormatSchema, + ToolDefinition, +} from "./types" +import { + StreamingError, + RefusalError, + MaxTokensError, + extractPartialExplanation, + executeTool, +} from "./shared" + +function toAnthropicTools(tools: ToolDefinition[]): AnthropicTool[] { + return tools.map((t) => ({ + name: t.name, + description: t.description, + input_schema: { + type: "object" as const, + properties: t.inputSchema.properties, + ...(t.inputSchema.required ? { required: t.inputSchema.required } : {}), + }, + })) +} + +// Wraps ResponseFormatSchema into Anthropic's output_format parameter +function toAnthropicOutputFormat( + format: ResponseFormatSchema, +): { type: "json_schema"; schema: Record } | undefined { + if (!format.schema) return undefined + return { + type: "json_schema", + schema: format.schema, + } +} + +async function createAnthropicMessage( + anthropic: Anthropic, + params: Omit & { + max_tokens?: number + }, + signal?: AbortSignal, +): Promise { + const message = await anthropic.messages.create( + { + ...params, + stream: false, + max_tokens: params.max_tokens ?? 8192, + }, + { + headers: { + "anthropic-beta": "structured-outputs-2025-11-13", + }, + signal, + }, + ) + + if (message.stop_reason === "refusal") { + throw new RefusalError( + "The model refused to generate a response for this request.", + ) + } + if (message.stop_reason === "max_tokens") { + throw new MaxTokensError( + "The response exceeded the maximum token limit. Please try again with a different prompt or model.", + ) + } + + return message +} + +async function createAnthropicMessageStreaming( + anthropic: Anthropic, + params: Omit & { + max_tokens?: number + }, + streamCallback: StreamingCallback, + abortSignal?: AbortSignal, +): Promise { + let accumulatedText = "" + let lastExplanation = "" + + const stream = anthropic.messages.stream( + { + ...params, + max_tokens: params.max_tokens ?? 8192, + }, + { + headers: { + "anthropic-beta": "structured-outputs-2025-11-13", + }, + signal: abortSignal, + }, + ) + + try { + for await (const event of stream) { + if (abortSignal?.aborted) { + throw new StreamingError("Operation aborted", "interrupted") + } + + const eventWithType = event as { type: string } + if (eventWithType.type === "error") { + const errorEvent = event as { + error?: { type?: string; message?: string } + } + const errorType = errorEvent.error?.type + const errorMessage = errorEvent.error?.message || "Stream error" + + if (errorType === "overloaded_error") { + throw new StreamingError( + "Service is temporarily overloaded. Please try again.", + "failed", + event, + ) + } + throw new StreamingError(errorMessage, "failed", event) + } + + if ( + event.type === "content_block_delta" && + event.delta.type === "text_delta" + ) { + accumulatedText += event.delta.text + const explanation = extractPartialExplanation(accumulatedText) + if (explanation !== lastExplanation) { + const chunk = explanation.slice(lastExplanation.length) + lastExplanation = explanation + streamCallback.onTextChunk(chunk, explanation) + } + } + } + } catch (error) { + if (error instanceof StreamingError) { + throw error + } + if (abortSignal?.aborted) { + throw new StreamingError("Operation aborted", "interrupted") + } + throw new StreamingError( + error instanceof Error ? error.message : "Stream interrupted", + "network", + error, + ) + } + + let finalMessage: Anthropic.Messages.Message + try { + finalMessage = await stream.finalMessage() + } catch (error) { + if (abortSignal?.aborted || error instanceof Anthropic.APIUserAbortError) { + throw new StreamingError("Operation aborted", "interrupted") + } + throw new StreamingError( + "Failed to get final message from the provider", + "network", + error, + ) + } + + if (finalMessage.stop_reason === "refusal") { + throw new RefusalError( + "The model refused to generate a response for this request.", + ) + } + if (finalMessage.stop_reason === "max_tokens") { + throw new MaxTokensError( + "The response exceeded the maximum token limit. Please try again with a different prompt or model.", + ) + } + + return finalMessage +} + +interface AnthropicToolCallResult { + message: Anthropic.Messages.Message + accumulatedTokens: TokenUsage +} + +async function handleToolCalls( + message: Anthropic.Messages.Message, + anthropic: Anthropic, + modelToolsClient: ModelToolsClient, + conversationHistory: Array, + model: string, + systemPrompt: string, + setStatus: StatusCallback, + outputFormat: ReturnType, + tools: AnthropicTool[], + contextWindow: number, + abortSignal?: AbortSignal, + accumulatedTokens: TokenUsage = { inputTokens: 0, outputTokens: 0 }, + streaming?: StreamingCallback, +): Promise { + const toolUseBlocks = message.content.filter( + (block) => block.type === "tool_use", + ) + const toolResults = [] + + if (abortSignal?.aborted) { + return { + type: "aborted", + message: "Operation was cancelled", + } as AiAssistantAPIError + } + + for (const toolUse of toolUseBlocks) { + if ("name" in toolUse) { + const exec = await executeTool( + toolUse.name, + toolUse.input, + modelToolsClient, + setStatus, + ) + toolResults.push({ + type: "tool_result" as const, + tool_use_id: toolUse.id, + content: exec.content, + is_error: exec.is_error, + }) + } + } + + const updatedHistory = [ + ...conversationHistory, + { + role: "assistant" as const, + content: message.content, + }, + { + role: "user" as const, + content: toolResults, + }, + ] + + const criticalTokenUsage = + message.usage.input_tokens >= contextWindow - 50_000 && + toolResults.length > 0 + if (criticalTokenUsage) { + updatedHistory.push({ + role: "user" as const, + content: + "**CRITICAL TOKEN USAGE: The conversation is getting too long to fit the context window. If you are planning to use more tools, summarize your findings to the user first, and wait for user confirmation to continue working on the task.**", + }) + } + + const followUpParams: Parameters[1] = { + model, + system: systemPrompt, + tools, + messages: updatedHistory, + temperature: 0.3, + } + + if (outputFormat) { + // @ts-expect-error - output_format is a new field not yet in the type definitions + followUpParams.output_format = outputFormat + } + + const followUpMessage = streaming + ? await createAnthropicMessageStreaming( + anthropic, + followUpParams, + streaming, + abortSignal, + ) + : await createAnthropicMessage(anthropic, followUpParams, abortSignal) + + const newAccumulatedTokens: TokenUsage = { + inputTokens: + accumulatedTokens.inputTokens + + (followUpMessage.usage?.input_tokens || 0), + outputTokens: + accumulatedTokens.outputTokens + + (followUpMessage.usage?.output_tokens || 0), + } + + if (followUpMessage.stop_reason === "tool_use") { + return handleToolCalls( + followUpMessage, + anthropic, + modelToolsClient, + updatedHistory, + model, + systemPrompt, + setStatus, + outputFormat, + tools, + contextWindow, + abortSignal, + newAccumulatedTokens, + streaming, + ) + } + + return { + message: followUpMessage, + accumulatedTokens: newAccumulatedTokens, + } +} + +export function createAnthropicProvider(apiKey: string): AIProvider { + const anthropic = new Anthropic({ + apiKey, + dangerouslyAllowBrowser: true, + }) + + const contextWindow = 200_000 + + return { + id: "anthropic", + contextWindow, + + async executeFlow({ + model, + config, + modelToolsClient, + tools, + setStatus, + abortSignal, + streaming, + }: { + model: string + config: FlowConfig + modelToolsClient: ModelToolsClient + tools: ToolDefinition[] + setStatus: StatusCallback + abortSignal?: AbortSignal + streaming?: StreamingCallback + }): Promise { + const initialMessages: MessageParam[] = [] + if (config.conversationHistory && config.conversationHistory.length > 0) { + const validMessages = config.conversationHistory.filter( + (msg) => msg.content && msg.content.trim() !== "", + ) + for (const msg of validMessages) { + initialMessages.push({ + role: msg.role, + content: msg.content, + }) + } + } + + initialMessages.push({ + role: "user" as const, + content: config.initialUserContent, + }) + + const anthropicTools = toAnthropicTools(tools) + const outputFormat = toAnthropicOutputFormat(config.responseFormat) + + const messageParams: Parameters[1] = { + model, + system: config.systemInstructions, + tools: anthropicTools, + messages: initialMessages, + temperature: 0.3, + } + + if (outputFormat) { + // @ts-expect-error - output_format is a new field not yet in the type definitions + messageParams.output_format = outputFormat + } + + const message = streaming + ? await createAnthropicMessageStreaming( + anthropic, + messageParams, + streaming, + abortSignal, + ) + : await createAnthropicMessage(anthropic, messageParams, abortSignal) + + let totalInputTokens = message.usage?.input_tokens || 0 + let totalOutputTokens = message.usage?.output_tokens || 0 + + let responseMessage: Anthropic.Messages.Message + + if (message.stop_reason === "tool_use") { + const toolCallResult = await handleToolCalls( + message, + anthropic, + modelToolsClient, + initialMessages, + model, + config.systemInstructions, + setStatus, + outputFormat, + anthropicTools, + contextWindow, + abortSignal, + { inputTokens: 0, outputTokens: 0 }, + streaming, + ) + + if ("type" in toolCallResult && "message" in toolCallResult) { + return toolCallResult + } + + const result = toolCallResult + responseMessage = result.message + totalInputTokens += result.accumulatedTokens.inputTokens + totalOutputTokens += result.accumulatedTokens.outputTokens + } else { + responseMessage = message + } + + if (abortSignal?.aborted) { + return { + type: "aborted", + message: "Operation was cancelled", + } as AiAssistantAPIError + } + + const textBlock = responseMessage.content.find( + (block) => block.type === "text", + ) + if (!textBlock || !("text" in textBlock)) { + setStatus(null) + return { + type: "unknown", + message: "No text response received from assistant.", + } as AiAssistantAPIError + } + + try { + const json = JSON.parse(textBlock.text) as T + setStatus(null) + + const resultWithTokens = { + ...json, + tokenUsage: { + inputTokens: totalInputTokens, + outputTokens: totalOutputTokens, + }, + } as T & { tokenUsage: TokenUsage } + + if (config.postProcess) { + const processed = config.postProcess(json) + return { + ...processed, + tokenUsage: { + inputTokens: totalInputTokens, + outputTokens: totalOutputTokens, + }, + } as T & { tokenUsage: TokenUsage } + } + return resultWithTokens + } catch { + setStatus(null) + return { + type: "unknown", + message: "Failed to parse assistant response.", + } as AiAssistantAPIError + } + }, + + async generateTitle({ model, prompt, responseFormat }) { + try { + const messageParams: Parameters[1] = { + model, + messages: [{ role: "user", content: prompt }], + max_tokens: 100, + temperature: 0.3, + } + const outputFormat = toAnthropicOutputFormat(responseFormat) + if (outputFormat) { + // @ts-expect-error - output_format is a new field not yet in the type definitions + messageParams.output_format = outputFormat + } + + const message = await createAnthropicMessage(anthropic, messageParams) + + const textBlock = message.content.find((block) => block.type === "text") + if (textBlock && "text" in textBlock) { + const parsed = JSON.parse(textBlock.text) as { title: string } + return parsed.title?.slice(0, 40) || null + } + return null + } catch { + return null + } + }, + + async generateSummary({ model, systemPrompt, userMessage }) { + const response = await anthropic.messages.create({ + ...getModelProps(model), + max_tokens: 8192, + messages: [{ role: "user", content: userMessage }], + system: systemPrompt, + }) + + const textBlock = response.content.find((block) => block.type === "text") + return textBlock?.type === "text" ? textBlock.text : "" + }, + + async testConnection({ apiKey: testApiKey, model }) { + try { + const testClient = new Anthropic({ + apiKey: testApiKey, + dangerouslyAllowBrowser: true, + }) + + await createAnthropicMessage(testClient, { + model, + messages: [{ role: "user", content: "ping" }], + }) + return { valid: true } + } catch (error: unknown) { + if (error instanceof Anthropic.AuthenticationError) { + return { valid: false, error: "Invalid API key" } + } + if (error instanceof Anthropic.RateLimitError) { + return { valid: true } + } + const status = + (error as { status?: number })?.status || + (error as { error?: { status?: number } })?.error?.status + if (status === 401) { + return { valid: false, error: "Invalid API key" } + } + if (status === 429) { + return { valid: true } + } + return { + valid: false, + error: + error instanceof Error + ? error.message + : "Failed to validate API key", + } + } + }, + + async countTokens({ messages, systemPrompt, model }) { + const anthropicMessages: Anthropic.MessageParam[] = messages.map((m) => ({ + role: m.role, + content: m.content, + })) + + const response = await anthropic.messages.countTokens({ + model, + system: systemPrompt, + messages: anthropicMessages, + }) + + return response.input_tokens + }, + + classifyError( + error: unknown, + setStatus: StatusCallback, + ): AiAssistantAPIError { + if ( + error instanceof Anthropic.APIUserAbortError || + (error instanceof StreamingError && error.errorType === "interrupted") + ) { + setStatus(AIOperationStatus.Aborted) + return { type: "aborted", message: "Operation was cancelled" } + } + setStatus(null) + + if (error instanceof RefusalError) { + return { + type: "unknown", + message: "The model refused to generate a response for this request.", + details: error.message, + } + } + + if (error instanceof MaxTokensError) { + return { + type: "unknown", + message: + "The response exceeded the maximum token limit for the selected model. Please try again with a different prompt or model.", + details: error.message, + } + } + + if (error instanceof StreamingError) { + switch (error.errorType) { + case "network": + return { + type: "network", + message: + "Network error during streaming. Please check your connection.", + details: error.message, + } + case "failed": + default: + return { + type: "unknown", + message: error.message || "Stream failed unexpectedly.", + details: + error.originalError instanceof Error + ? error.originalError.message + : undefined, + } + } + } + + if (error instanceof Anthropic.AuthenticationError) { + return { + type: "invalid_key", + message: "Invalid API key. Please check your Anthropic API key.", + details: error.message, + } + } + + if (error instanceof Anthropic.RateLimitError) { + return { + type: "rate_limit", + message: "Rate limit exceeded. Please try again later.", + details: error.message, + } + } + + if (error instanceof Anthropic.APIConnectionError) { + return { + type: "network", + message: "Network error. Please check your internet connection.", + details: error.message, + } + } + + if (error instanceof Anthropic.APIError) { + return { + type: "unknown", + message: `Anthropic API error: ${error.message}`, + } + } + + return { + type: "unknown", + message: "An unexpected error occurred. Please try again.", + details: error as string, + } + }, + + isNonRetryableError(error: unknown): boolean { + if (error instanceof StreamingError) { + return error.errorType === "interrupted" || error.errorType === "failed" + } + return ( + error instanceof RefusalError || + error instanceof MaxTokensError || + error instanceof Anthropic.AuthenticationError || + (error != null && + typeof error === "object" && + "status" in error && + error.status === 429) || + error instanceof Anthropic.APIUserAbortError + ) + }, + } +} diff --git a/src/utils/ai/index.ts b/src/utils/ai/index.ts new file mode 100644 index 000000000..c868d9ea3 --- /dev/null +++ b/src/utils/ai/index.ts @@ -0,0 +1,29 @@ +export type { + AIProvider, + ToolDefinition, + ResponseFormatSchema, + FlowConfig, +} from "./types" +export { createProvider } from "./registry" +export { SCHEMA_TOOLS, REFERENCE_TOOLS, ALL_TOOLS } from "./tools" +export { + ExplainFormat, + FixSQLFormat, + ConversationResponseFormat, + ChatTitleFormat, +} from "./responseFormats" +export { + RefusalError, + MaxTokensError, + StreamingError, + safeJsonParse, + extractPartialExplanation, + executeTool, +} from "./shared" +export { + DOCS_INSTRUCTION, + getUnifiedPrompt, + getExplainSchemaPrompt, + getHealthIssuePrompt, +} from "./prompts" +export type { HealthIssuePromptData } from "./prompts" diff --git a/src/utils/ai/openaiProvider.ts b/src/utils/ai/openaiProvider.ts new file mode 100644 index 000000000..a525d5746 --- /dev/null +++ b/src/utils/ai/openaiProvider.ts @@ -0,0 +1,526 @@ +import OpenAI from "openai" +import type { + ResponseOutputItem, + ResponseTextConfig, +} from "openai/resources/responses/responses" +import type { + AiAssistantAPIError, + ModelToolsClient, + StatusCallback, + StreamingCallback, + TokenUsage, +} from "../aiAssistant" +import { AIOperationStatus } from "../../providers/AIStatusProvider" +import { getModelProps } from "../aiAssistantSettings" +import type { + AIProvider, + FlowConfig, + ResponseFormatSchema, + ToolDefinition, +} from "./types" +import { + StreamingError, + RefusalError, + MaxTokensError, + safeJsonParse, + extractPartialExplanation, + executeTool, +} from "./shared" +import type { Tiktoken, TiktokenBPE } from "js-tiktoken/lite" + +function toResponseTextConfig( + format: ResponseFormatSchema, +): ResponseTextConfig { + return { + format: { + type: "json_schema" as const, + name: format.name, + schema: format.schema, + strict: format.strict, + }, + } +} + +function toOpenAIFunctions(tools: ToolDefinition[]): OpenAI.Responses.Tool[] { + return tools.map((t) => ({ + type: "function" as const, + name: t.name, + description: t.description, + parameters: { ...t.inputSchema, additionalProperties: false }, + strict: true, + })) as OpenAI.Responses.Tool[] +} + +async function createOpenAIResponseStreaming( + openai: OpenAI, + params: OpenAI.Responses.ResponseCreateParamsNonStreaming, + streamCallback: StreamingCallback, + abortSignal?: AbortSignal, +): Promise { + let accumulatedText = "" + let lastExplanation = "" + let finalResponse: OpenAI.Responses.Response | null = null + + try { + const stream = await openai.responses.create({ + ...params, + stream: true, + } as OpenAI.Responses.ResponseCreateParamsStreaming) + + for await (const event of stream) { + if (abortSignal?.aborted) { + throw new StreamingError("Operation aborted", "interrupted") + } + + if (event.type === "error") { + const errorEvent = event as { error?: { message?: string } } + throw new StreamingError( + errorEvent.error?.message || "Stream error occurred", + "failed", + event, + ) + } + + if (event.type === "response.failed") { + const failedEvent = event as { + response?: { error?: { message?: string } } + } + throw new StreamingError( + failedEvent.response?.error?.message || + "Provider failed to return a response", + "failed", + event, + ) + } + + if (event.type === "response.output_text.delta") { + accumulatedText += event.delta + const explanation = extractPartialExplanation(accumulatedText) + if (explanation !== lastExplanation) { + const chunk = explanation.slice(lastExplanation.length) + lastExplanation = explanation + streamCallback.onTextChunk(chunk, explanation) + } + } + + if (event.type === "response.completed") { + finalResponse = event.response + } + } + } catch (error) { + if (error instanceof StreamingError) { + throw error + } + if (abortSignal?.aborted || error instanceof OpenAI.APIUserAbortError) { + throw new StreamingError("Operation aborted", "interrupted") + } + throw new StreamingError( + error instanceof Error ? error.message : "Stream interrupted", + "network", + error, + ) + } + + if (!finalResponse) { + throw new StreamingError("Provider failed to return a response", "failed") + } + + return finalResponse +} + +function extractOpenAIToolCalls( + response: OpenAI.Responses.Response, +): { id?: string; name: string; arguments: unknown; call_id: string }[] { + const calls = [] + for (const item of response.output) { + if (item?.type === "function_call") { + const args = + typeof item.arguments === "string" + ? safeJsonParse(item.arguments) + : item.arguments || {} + calls.push({ + id: item.id, + name: item.name, + arguments: args, + call_id: item.call_id, + }) + } + } + return calls +} + +function getOpenAIText(response: OpenAI.Responses.Response): { + type: "refusal" | "text" + message: string +} { + const out = response.output || [] + if ( + out.find( + (item: ResponseOutputItem) => + item.type === "message" && + item.content.some((c) => c.type === "refusal"), + ) + ) { + return { + type: "refusal", + message: "The model refused to generate a response for this request.", + } + } + + for (const item of out) { + if (item.type === "message" && item.content) { + for (const content of item.content) { + if (content.type === "output_text" && "text" in content) { + return { type: "text", message: content.text } + } + } + } + } + + return { type: "text", message: "" } +} + +let tiktokenEncoder: Tiktoken | null = null + +export function createOpenAIProvider(apiKey: string): AIProvider { + const openai = new OpenAI({ + apiKey, + dangerouslyAllowBrowser: true, + }) + + const contextWindow = 400_000 + + return { + id: "openai", + contextWindow, + + async executeFlow({ + model, + config, + modelToolsClient, + tools, + setStatus, + abortSignal, + streaming, + }: { + model: string + config: FlowConfig + modelToolsClient: ModelToolsClient + tools: ToolDefinition[] + setStatus: StatusCallback + abortSignal?: AbortSignal + streaming?: StreamingCallback + }): Promise { + let input: OpenAI.Responses.ResponseInput = [] + if (config.conversationHistory && config.conversationHistory.length > 0) { + const validMessages = config.conversationHistory.filter( + (msg) => msg.content && msg.content.trim() !== "", + ) + for (const msg of validMessages) { + input.push({ + role: msg.role, + content: msg.content, + }) + } + } + + input.push({ + role: "user", + content: config.initialUserContent, + }) + + const openaiTools = toOpenAIFunctions(tools) + + let totalInputTokens = 0 + let totalOutputTokens = 0 + + const requestParams = { + ...getModelProps(model), + instructions: config.systemInstructions, + input, + tools: openaiTools, + text: toResponseTextConfig(config.responseFormat), + } as OpenAI.Responses.ResponseCreateParamsNonStreaming + + let lastResponse = streaming + ? await createOpenAIResponseStreaming( + openai, + requestParams, + streaming, + abortSignal, + ) + : await openai.responses.create(requestParams) + input = [...input, ...lastResponse.output] + + totalInputTokens += lastResponse.usage?.input_tokens ?? 0 + totalOutputTokens += lastResponse.usage?.output_tokens ?? 0 + + while (true) { + if (abortSignal?.aborted) { + return { + type: "aborted", + message: "Operation was cancelled", + } as AiAssistantAPIError + } + + const toolCalls = extractOpenAIToolCalls(lastResponse) + if (!toolCalls.length) break + const tool_outputs: OpenAI.Responses.ResponseFunctionToolCallOutputItem[] = + [] + for (const tc of toolCalls) { + const exec = await executeTool( + tc.name, + tc.arguments, + modelToolsClient, + setStatus, + ) + tool_outputs.push({ + type: "function_call_output", + call_id: tc.call_id, + output: exec.content, + } as OpenAI.Responses.ResponseFunctionToolCallOutputItem) + } + input = [...input, ...tool_outputs] + + if ( + (lastResponse.usage?.input_tokens ?? 0) >= contextWindow - 50_000 && + tool_outputs.length > 0 + ) { + input.push({ + role: "user" as const, + content: + "**CRITICAL TOKEN USAGE: The conversation is getting too long to fit the context window. If you are planning to use more tools, summarize your findings to the user first, and wait for user confirmation to continue working on the task.**", + }) + } + const loopRequestParams = { + ...getModelProps(model), + instructions: config.systemInstructions, + input, + tools: openaiTools, + text: toResponseTextConfig(config.responseFormat), + } as OpenAI.Responses.ResponseCreateParamsNonStreaming + + lastResponse = streaming + ? await createOpenAIResponseStreaming( + openai, + loopRequestParams, + streaming, + abortSignal, + ) + : await openai.responses.create(loopRequestParams) + input = [...input, ...lastResponse.output] + + totalInputTokens += lastResponse.usage?.input_tokens ?? 0 + totalOutputTokens += lastResponse.usage?.output_tokens ?? 0 + } + + if (abortSignal?.aborted) { + return { + type: "aborted", + message: "Operation was cancelled", + } as AiAssistantAPIError + } + + const text = getOpenAIText(lastResponse) + if (text.type === "refusal") { + return { + type: "unknown", + message: text.message, + } as AiAssistantAPIError + } + + const rawOutput = text.message + + try { + const json = JSON.parse(rawOutput) as T + setStatus(null) + + const resultWithTokens = { + ...json, + tokenUsage: { + inputTokens: totalInputTokens, + outputTokens: totalOutputTokens, + }, + } as T & { tokenUsage: TokenUsage } + + if (config.postProcess) { + const processed = config.postProcess(json) + return { + ...processed, + tokenUsage: { + inputTokens: totalInputTokens, + outputTokens: totalOutputTokens, + }, + } as T & { tokenUsage: TokenUsage } + } + return resultWithTokens + } catch { + setStatus(null) + return { + type: "unknown", + message: "Failed to parse assistant response.", + } as AiAssistantAPIError + } + }, + + async generateTitle({ model, prompt, responseFormat }) { + try { + const response = await openai.responses.create({ + ...getModelProps(model), + input: [{ role: "user", content: prompt }], + text: toResponseTextConfig(responseFormat), + max_output_tokens: 100, + }) + const parsed = JSON.parse(response.output_text) as { title: string } + return parsed.title || null + } catch { + return null + } + }, + + async generateSummary({ model, systemPrompt, userMessage }) { + const response = await openai.responses.create({ + ...getModelProps(model), + instructions: systemPrompt, + input: userMessage, + }) + return response.output_text || "" + }, + + async testConnection({ apiKey: testApiKey, model }) { + try { + const testClient = new OpenAI({ + apiKey: testApiKey, + dangerouslyAllowBrowser: true, + }) + await testClient.responses.create({ + model: getModelProps(model).model, + input: [{ role: "user", content: "ping" }], + max_output_tokens: 16, + }) + return { valid: true } + } catch (error: unknown) { + const status = + (error as { status?: number })?.status || + (error as { error?: { status?: number } })?.error?.status + if (status === 401) { + return { valid: false, error: "Invalid API key" } + } + if (status === 429) { + return { valid: true } + } + return { + valid: false, + error: + error instanceof Error + ? error.message + : "Failed to validate API key", + } + } + }, + + async countTokens({ messages, systemPrompt }) { + if (!tiktokenEncoder) { + const { Tiktoken } = await import("js-tiktoken/lite") + const o200k_base = await import("js-tiktoken/ranks/o200k_base").then( + (module: { default: TiktokenBPE }) => module.default, + ) + tiktokenEncoder = new Tiktoken(o200k_base) + } + + let totalTokens = 0 + totalTokens += tiktokenEncoder.encode(systemPrompt).length + totalTokens += 4 // system message formatting overhead + + for (const message of messages) { + totalTokens += 4 // role markers overhead + totalTokens += tiktokenEncoder.encode(message.content).length + } + + totalTokens += 2 // assistant reply priming + return totalTokens + }, + + classifyError( + error: unknown, + setStatus: StatusCallback, + ): AiAssistantAPIError { + if ( + error instanceof OpenAI.APIUserAbortError || + (error instanceof StreamingError && error.errorType === "interrupted") + ) { + setStatus(AIOperationStatus.Aborted) + return { type: "aborted", message: "Operation was cancelled" } + } + setStatus(null) + + if (error instanceof RefusalError) { + return { + type: "unknown", + message: "The model refused to generate a response for this request.", + details: error.message, + } + } + + if (error instanceof MaxTokensError) { + return { + type: "unknown", + message: + "The response exceeded the maximum token limit for the selected model. Please try again with a different prompt or model.", + details: error.message, + } + } + + if (error instanceof StreamingError) { + switch (error.errorType) { + case "network": + return { + type: "network", + message: + "Network error during streaming. Please check your connection.", + details: error.message, + } + case "failed": + default: + return { + type: "unknown", + message: error.message || "Stream failed unexpectedly.", + details: + error.originalError instanceof Error + ? error.originalError.message + : undefined, + } + } + } + + if (error instanceof OpenAI.APIError) { + return { + type: "unknown", + message: `OpenAI API error: ${error.message}`, + } + } + + return { + type: "unknown", + message: "An unexpected error occurred. Please try again.", + details: error as string, + } + }, + + isNonRetryableError(error: unknown): boolean { + if (error instanceof StreamingError) { + return error.errorType === "interrupted" || error.errorType === "failed" + } + return ( + error instanceof RefusalError || + error instanceof MaxTokensError || + error instanceof OpenAI.AuthenticationError || + (error != null && + typeof error === "object" && + "status" in error && + error.status === 429) || + error instanceof OpenAI.APIUserAbortError + ) + }, + } +} diff --git a/src/utils/ai/prompts.ts b/src/utils/ai/prompts.ts new file mode 100644 index 000000000..c7c6f9a21 --- /dev/null +++ b/src/utils/ai/prompts.ts @@ -0,0 +1,172 @@ +export const DOCS_INSTRUCTION = ` +CRITICAL: Always follow this documentation approach: +1. Use get_questdb_toc to see available functions, operators, SQL syntax, AND cookbook recipes +2. If user's request matches a cookbook recipe description, fetch it FIRST - recipes provide complete, tested SQL patterns +3. Use get_questdb_documentation for specific function/syntax details + +When a cookbook recipe matches the user's intent, ALWAYS use it as the foundation and adapt column/table names and use case to their schema.` + +export const getUnifiedPrompt = (grantSchemaAccess?: boolean) => { + const base = `You are a SQL expert coding assistant specializing in QuestDB, a high-performance time-series database. You help users with: +- Generating QuestDB SQL queries from natural language descriptions +- Explaining what QuestDB SQL queries do +- Fixing errors in QuestDB SQL queries +- Refining and modifying existing queries based on user requests + +## CRITICAL: Tool and Response Sequencing +Follow this EXACT sequence for every query generation request: + +**PHASE 1 - INFORMATION GATHERING (NO TEXT OUTPUT)** +1. Call available tools to gather information if you need, including documentation, schema, and validation tools. +2. Complete ALL information gathering before Phase 2. DO NOT CALL any tool after Phase 2. + +**PHASE 2 - FINAL RESPONSE (NO MORE TOOL CALLS)** +3. Return your JSON response with "sql" and "explanation" fields. Always return sql field first, then explanation field. + +NEVER interleave phases. NEVER use any tool after starting to return a response. + +## When Explaining Queries +- Focus on the business logic and what the query achieves, not the SQL syntax itself +- Pay special attention to QuestDB-specific features: + - Time-series operations (SAMPLE BY, LATEST ON, designated timestamp columns) + - Time-based filtering and aggregations + - Real-time data ingestion patterns + - Performance optimizations specific to time-series data + +## When Generating SQL +- DO NOT return any content before completing your tool calls including documentation and validation tools. You should NOT CALL any tool after starting to return a response. +- Always validate the query in "sql" field using the validate_query tool before returning an explanation or a generated SQL query +- Generate only valid QuestDB SQL syntax referring to the documentation about functions, operators, and SQL keywords +- Use appropriate time-series functions (SAMPLE BY, LATEST ON, etc.) and common table expressions when relevant +- Use \`IN\` with \`today()\`, \`tomorrow()\`, \`yesterday()\` interval functions when relevant +- Follow QuestDB best practices for performance referring to the documentation +- Use proper timestamp handling for time-series data +- Use correct data types and functions specific to QuestDB referring to the documentation. Do not use any word that is not in the documentation. + +## When Fixing Queries +- DO NOT return any content before completing your tool calls including documentation and validation tools. You should NOT CALL any tool after starting to return a response. +- Always validate the query in "sql" field using the validate_query tool before returning an explanation or a fixed SQL query +- Analyze the error message carefully to understand what went wrong +- Generate only valid QuestDB SQL syntax by always referring to the documentation about functions, operators, and SQL keywords +- Preserve the original intent of the query while fixing the error +- Follow QuestDB best practices and syntax rules referring to the documentation +- Consider common issues like: + - Missing or incorrect column names + - Invalid syntax for time-series operations + - Data type mismatches + - Incorrect function usage + +## Response Guidelines +- You are working as a coding assistant inside an IDE. Every time you return a query in "sql" field, you provide a suggestion to the user to accept or reject. When the user accepts the suggestion, you are informed and the query in the editor is updated with your suggestion. +- Modify a query by returning "sql" field only if the user asks you to generate, fix, or make changes to the query. If the user does not ask for fixing/changing/generating a query, return null in the "sql" field. Every time you provide a SQL query, the current SQL is updated. +- Provide the "explanation" field if you haven't provided it yet. Explanation should be in GFM (GitHub Flavored Markdown) format. Explanation field is cumulative, every time you provide an explanation, it is added to the previous explanations. + +## Tools +- Use the validate_query tool to validate the query in "sql" field before returning a response only if the user asks you to generate, fix, or make changes to the query. +` + const schemaAccess = grantSchemaAccess + ? `- Use the get_tables tool to retrieve all tables and materialized views in the database instance +- Use the get_table_schema tool to get detailed schema information for a specific table or a materialized view +- Use the get_table_details tool to get detailed information for a specific table or a materialized view. Each property is described in meta functions docs. +` + : "" + return base + schemaAccess + DOCS_INSTRUCTION +} + +export const getExplainSchemaPrompt = ( + tableName: string, + schema: string, + kindLabel: string, +) => `You are a SQL expert assistant specializing in QuestDB, a high-performance time-series database. +Explain the following ${kindLabel} schema. Include: +- The purpose of the ${kindLabel} +- What each column represents and its data type +- Any important properties like WAL enablement, partitioning strategy, designated timestamps +- Any performance or storage considerations + +${kindLabel} Name: ${tableName} + +Schema: +\`\`\`sql +${schema} +\`\`\` + +**IMPORTANT: Format your response in markdown exactly as follows:** + +1. Start with a brief paragraph explaining the purpose and general characteristics of this ${kindLabel}. + +2. Add a "## Columns" section with a markdown table: +| Column | Type | Description | +|--------|------|-------------| +| column_name | \`data_type\` | Brief description | + +3. If this is a table or materialized view (not a view), add a "## Storage Details" section with bullet points about: +- WAL enablement +- Partitioning strategy +- Designated timestamp column +- Any other storage considerations + +For views, skip the Storage Details section.` + +export type HealthIssuePromptData = { + tableName: string + issue: { + id: string + field: string + message: string + currentValue?: string + } + tableDetails: string + monitoringDocs: string + trendSamples?: Array<{ value: number; timestamp: number }> +} + +export const getHealthIssuePrompt = (data: HealthIssuePromptData): string => { + const { tableName, issue, tableDetails, monitoringDocs, trendSamples } = data + + let trendSection = "" + if (trendSamples && trendSamples.length > 0) { + const recentSamples = trendSamples.slice(-30) + trendSection = ` + +### Trend Data (Recent Samples) +| Timestamp | Value | +|-----------|-------| +${recentSamples.map((s) => `| ${new Date(s.timestamp).toISOString()} | ${s.value.toLocaleString()} |`).join("\n")} +` + } + + return `You are a QuestDB expert assistant helping diagnose and resolve table health issues. + +A user is viewing the health monitoring panel for their table and has asked for help with a detected issue. + +## Table: ${tableName} + +## Health Issue Detected +- **Issue ID**: ${issue.id} +- **Field**: ${issue.field} +- **Message**: ${issue.message} +${issue.currentValue ? `- **Current Value**: ${issue.currentValue}` : ""}${trendSection} + +## Table Details (from tables() function) +\`\`\`json +${tableDetails} +\`\`\` + +## QuestDB Monitoring Documentation +${monitoringDocs} + +--- + +**Your Task:** +1. Explain what this health issue means in the context of this specific table +2. Analyze the table details to identify potential root causes +3. Provide specific, actionable recommendations to resolve or mitigate the issue +4. If there is a clear SQL command that can help fix the issue (like \`ALTER TABLE ... RESUME WAL\`), include it in the "sql" field of your response. **Only provide SQL if it directly addresses the root cause** - do not provide SQL just to inspect the problem. + +**IMPORTANT: Be concise and thorough, format your response in markdown with clear sections:** +- Use ## headings for main sections +- Use bullet points for lists +- Use \`code\` for configuration values and SQL +` +} diff --git a/src/utils/ai/registry.ts b/src/utils/ai/registry.ts new file mode 100644 index 000000000..6b43772b7 --- /dev/null +++ b/src/utils/ai/registry.ts @@ -0,0 +1,18 @@ +import type { AIProvider } from "./types" +import { createOpenAIProvider } from "./openaiProvider" +import { createAnthropicProvider } from "./anthropicProvider" +import type { Provider } from "../aiAssistantSettings" + +export function createProvider( + providerId: Provider, + apiKey: string, +): AIProvider { + switch (providerId) { + case "openai": + return createOpenAIProvider(apiKey) + case "anthropic": + return createAnthropicProvider(apiKey) + default: + throw new Error(`Unknown provider: ${providerId}`) + } +} diff --git a/src/utils/ai/responseFormats.ts b/src/utils/ai/responseFormats.ts new file mode 100644 index 000000000..60d92636f --- /dev/null +++ b/src/utils/ai/responseFormats.ts @@ -0,0 +1,55 @@ +import type { ResponseFormatSchema } from "./types" + +export const ExplainFormat: ResponseFormatSchema = { + name: "explain_format", + schema: { + type: "object", + properties: { + explanation: { type: "string" }, + }, + required: ["explanation"], + additionalProperties: false, + }, + strict: true, +} + +export const FixSQLFormat: ResponseFormatSchema = { + name: "fix_sql_format", + schema: { + type: "object", + properties: { + sql: { type: ["string", "null"] }, + explanation: { type: "string" }, + }, + required: ["explanation", "sql"], + additionalProperties: false, + }, + strict: true, +} + +export const ConversationResponseFormat: ResponseFormatSchema = { + name: "conversation_response_format", + schema: { + type: "object", + properties: { + sql: { type: ["string", "null"] }, + explanation: { type: "string" }, + }, + required: ["sql", "explanation"], + additionalProperties: false, + }, + strict: true, +} + +export const ChatTitleFormat: ResponseFormatSchema = { + name: "chat_title_format", + schema: { + type: "object", + properties: { + title: { type: "string" }, + }, + required: ["title"], + additionalProperties: false, + }, + strict: true, +} diff --git a/src/utils/ai/shared.ts b/src/utils/ai/shared.ts new file mode 100644 index 000000000..8a211d46a --- /dev/null +++ b/src/utils/ai/shared.ts @@ -0,0 +1,200 @@ +import type { ModelToolsClient, StatusCallback } from "../aiAssistant" +import { AIOperationStatus } from "../../providers/AIStatusProvider" +import { + getQuestDBTableOfContents, + getSpecificDocumentation, + parseDocItems, + DocCategory, +} from "../questdbDocsRetrieval" + +export class RefusalError extends Error { + constructor(message: string) { + super(message) + this.name = "RefusalError" + } +} + +export class MaxTokensError extends Error { + constructor(message: string) { + super(message) + this.name = "MaxTokensError" + } +} + +export class StreamingError extends Error { + constructor( + message: string, + public readonly errorType: "failed" | "network" | "interrupted" | "unknown", + public readonly originalError?: unknown, + ) { + super(message) + this.name = "StreamingError" + } +} + +export const safeJsonParse = (text: string): T | object => { + try { + return JSON.parse(text) as T + } catch { + return {} + } +} + +export function extractPartialExplanation(partialJson: string): string { + const explanationMatch = partialJson.match( + /"explanation"\s*:\s*"((?:[^"\\]|\\.)*)/, + ) + if (!explanationMatch) { + return "" + } + + return explanationMatch[1] + .replace(/\\n/g, "\n") + .replace(/\\r/g, "\r") + .replace(/\\t/g, "\t") + .replace(/\\"/g, '"') + .replace(/\\\\/g, "\\") +} + +export const executeTool = async ( + toolName: string, + input: unknown, + modelToolsClient: ModelToolsClient, + setStatus: StatusCallback, +): Promise<{ content: string; is_error?: boolean }> => { + try { + switch (toolName) { + case "get_tables": { + setStatus(AIOperationStatus.RetrievingTables) + if (!modelToolsClient.getTables) { + return { + content: + "Error: Schema access is not granted. This tool is not available.", + is_error: true, + } + } + const result = await modelToolsClient.getTables() + const MAX_TABLES = 1000 + if (result.length > MAX_TABLES) { + const truncated = result.slice(0, MAX_TABLES) + return { + content: JSON.stringify( + { + tables: truncated, + total_count: result.length, + truncated: true, + message: `Showing ${MAX_TABLES} of ${result.length} tables. Use get_table_schema with a specific table name to get details if you are interested in a specific table.`, + }, + null, + 2, + ), + } + } + return { content: JSON.stringify(result, null, 2) } + } + case "get_table_schema": { + const tableName = (input as { table_name: string })?.table_name + if (!modelToolsClient.getTableSchema) { + return { + content: + "Error: Schema access is not granted. This tool is not available.", + is_error: true, + } + } + if (!tableName) { + return { + content: "Error: table_name parameter is required", + is_error: true, + } + } + setStatus(AIOperationStatus.InvestigatingTable, { + name: tableName, + tableOpType: "schema", + }) + const result = await modelToolsClient.getTableSchema(tableName) + return { + content: + result || `Table '${tableName}' not found or schema unavailable`, + } + } + case "get_table_details": { + const tableName = (input as { table_name: string })?.table_name + if (!modelToolsClient.getTableDetails) { + return { + content: + "Error: Schema access is not granted. This tool is not available.", + is_error: true, + } + } + if (!tableName) { + return { + content: "Error: table_name parameter is required", + is_error: true, + } + } + setStatus(AIOperationStatus.InvestigatingTable, { + name: tableName, + tableOpType: "details", + }) + const result = await modelToolsClient.getTableDetails(tableName) + return { + content: result + ? JSON.stringify(result, null, 2) + : "Table details not found", + is_error: !result, + } + } + case "validate_query": { + setStatus(AIOperationStatus.ValidatingQuery) + const query = (input as { query: string })?.query + if (!query) { + return { + content: "Error: query parameter is required", + is_error: true, + } + } + const result = await modelToolsClient.validateQuery(query) + const content = { + valid: result.valid, + error: result.valid ? undefined : result.error, + position: result.valid ? undefined : result.position, + } + return { content: JSON.stringify(content, null, 2) } + } + case "get_questdb_toc": { + setStatus(AIOperationStatus.RetrievingDocumentation) + const tocContent = await getQuestDBTableOfContents() + return { content: tocContent } + } + case "get_questdb_documentation": { + const { category, items } = + (input as { category: string; items: string[] }) || {} + if (!category || !items || !Array.isArray(items)) { + return { + content: "Error: category and items parameters are required", + is_error: true, + } + } + const parsedItems = parseDocItems(items) + + if (parsedItems.length > 0) { + setStatus(AIOperationStatus.InvestigatingDocs, { items: parsedItems }) + } else { + setStatus(AIOperationStatus.InvestigatingDocs) + } + const documentation = await getSpecificDocumentation( + category as DocCategory, + items, + ) + return { content: documentation } + } + default: + return { content: `Unknown tool: ${toolName}`, is_error: true } + } + } catch (error) { + return { + content: `Tool execution error: ${error instanceof Error ? error.message : "Unknown error"}`, + is_error: true, + } + } +} diff --git a/src/utils/ai/tools.ts b/src/utils/ai/tools.ts new file mode 100644 index 000000000..557a64e2b --- /dev/null +++ b/src/utils/ai/tools.ts @@ -0,0 +1,105 @@ +import type { ToolDefinition } from "./types" + +export const SCHEMA_TOOLS: ToolDefinition[] = [ + { + name: "get_tables", + description: + "Get a list of all tables and materialized views in the QuestDB database", + inputSchema: { + type: "object", + properties: {}, + }, + }, + { + name: "get_table_schema", + description: + "Get the full schema definition (DDL) for a specific table or materialized view", + inputSchema: { + type: "object", + properties: { + table_name: { + type: "string", + description: + "The name of the table or materialized view to get schema for", + }, + }, + required: ["table_name"], + }, + }, + { + name: "get_table_details", + description: + "Get the runtime details/statistics of a specific table or materialized view", + inputSchema: { + type: "object", + properties: { + table_name: { + type: "string", + description: + "The name of the table or materialized view to get details for", + }, + }, + required: ["table_name"], + }, + }, +] + +export const REFERENCE_TOOLS: ToolDefinition[] = [ + { + name: "validate_query", + description: + "Validate the syntax correctness of a SQL query using QuestDB's SQL syntax validator. All generated SQL queries should be validated using this tool before responding to the user.", + inputSchema: { + type: "object", + properties: { + query: { + type: "string", + description: "The SQL query to validate", + }, + }, + required: ["query"], + }, + }, + { + name: "get_questdb_toc", + description: + "Get a table of contents listing all available QuestDB functions, operators, and SQL keywords. Use this first to see what documentation is available before requesting specific items.", + inputSchema: { + type: "object", + properties: {}, + }, + }, + { + name: "get_questdb_documentation", + description: + "Get documentation for specific QuestDB functions, operators, or SQL keywords. This is much more efficient than loading all documentation.", + inputSchema: { + type: "object", + properties: { + category: { + type: "string", + enum: [ + "functions", + "operators", + "sql", + "concepts", + "schema", + "cookbook", + ], + description: "The category of documentation to retrieve", + }, + items: { + type: "array", + items: { + type: "string", + }, + description: + "List of specific docs items in the category. IMPORTANT: Category of these items must match the category parameter. Name of these items should exactly match the entry in the table of contents you get with get_questdb_toc.", + }, + }, + required: ["category", "items"], + }, + }, +] + +export const ALL_TOOLS: ToolDefinition[] = [...SCHEMA_TOOLS, ...REFERENCE_TOOLS] diff --git a/src/utils/ai/types.ts b/src/utils/ai/types.ts new file mode 100644 index 000000000..3ad77eb63 --- /dev/null +++ b/src/utils/ai/types.ts @@ -0,0 +1,72 @@ +import type { + AiAssistantAPIError, + ModelToolsClient, + StatusCallback, + StreamingCallback, +} from "../aiAssistant" +import type { Provider } from "../aiAssistantSettings" + +export interface ToolDefinition { + name: string + description?: string + inputSchema: { + type: "object" + properties: Record + required?: string[] + } +} + +export interface ResponseFormatSchema { + name: string + schema: Record + strict: boolean +} + +export interface FlowConfig { + systemInstructions: string + initialUserContent: string + conversationHistory?: Array<{ role: "user" | "assistant"; content: string }> + responseFormat: ResponseFormatSchema + postProcess?: (formatted: T) => T +} + +export interface AIProvider { + readonly id: Provider + readonly contextWindow: number + + executeFlow(params: { + model: string + config: FlowConfig + modelToolsClient: ModelToolsClient + tools: ToolDefinition[] + setStatus: StatusCallback + abortSignal?: AbortSignal + streaming?: StreamingCallback + }): Promise + + generateTitle(params: { + model: string + prompt: string + responseFormat: ResponseFormatSchema + }): Promise + + generateSummary(params: { + model: string + systemPrompt: string + userMessage: string + }): Promise + + testConnection(params: { + apiKey: string + model: string + }): Promise<{ valid: boolean; error?: string }> + + countTokens(params: { + messages: Array<{ role: "user" | "assistant"; content: string }> + systemPrompt: string + model: string + }): Promise + + classifyError(error: unknown, setStatus: StatusCallback): AiAssistantAPIError + isNonRetryableError(error: unknown): boolean +} diff --git a/src/utils/aiAssistant.ts b/src/utils/aiAssistant.ts index 0ed9e6567..8fee84215 100644 --- a/src/utils/aiAssistant.ts +++ b/src/utils/aiAssistant.ts @@ -1,29 +1,24 @@ -import Anthropic from "@anthropic-ai/sdk" -import OpenAI from "openai" import { Client } from "./questdb/client" import { Type, Table } from "./questdb/types" -import { getModelProps, MODEL_OPTIONS } from "./aiAssistantSettings" -import type { ModelOption, Provider } from "./aiAssistantSettings" +import type { Provider } from "./aiAssistantSettings" import { formatSql } from "./formatSql" import { AIOperationStatus, StatusArgs } from "../providers/AIStatusProvider" -import { - getQuestDBTableOfContents, - getSpecificDocumentation, - parseDocItems, - DocCategory, -} from "./questdbDocsRetrieval" -import { MessageParam } from "@anthropic-ai/sdk/resources/messages" -import type { - ResponseOutputItem, - ResponseTextConfig, -} from "openai/resources/responses/responses" -import type { Tool as AnthropicTool } from "@anthropic-ai/sdk/resources/messages" import type { ConversationId, ConversationMessage, } from "../providers/AIConversationProvider/types" import { compactConversationIfNeeded } from "./contextCompaction" -import { COMPACTION_THRESHOLDS } from "./tokenCounting" +import { + createProvider, + ExplainFormat, + FixSQLFormat, + ConversationResponseFormat, + ChatTitleFormat, + ALL_TOOLS, + REFERENCE_TOOLS, + getUnifiedPrompt, +} from "./ai" +import type { AIProvider } from "./ai" export type ActiveProviderSettings = { model: string @@ -74,218 +69,6 @@ export type StreamingCallback = { cleanup?: () => void } -type ProviderClients = - | { - provider: "anthropic" - anthropic: Anthropic - } - | { - provider: "openai" - openai: OpenAI - } - -const ExplainFormat: ResponseTextConfig = { - format: { - type: "json_schema" as const, - name: "explain_format", - schema: { - type: "object", - properties: { - explanation: { type: "string" }, - }, - required: ["explanation"], - additionalProperties: false, - }, - strict: true, - }, -} - -const FixSQLFormat: ResponseTextConfig = { - format: { - type: "json_schema" as const, - name: "fix_sql_format", - schema: { - type: "object", - properties: { - sql: { type: ["string", "null"] }, - explanation: { type: "string" }, - }, - required: ["explanation", "sql"], - additionalProperties: false, - }, - strict: true, - }, -} - -const ConversationResponseFormat: ResponseTextConfig = { - format: { - type: "json_schema" as const, - name: "conversation_response_format", - schema: { - type: "object", - properties: { - sql: { type: ["string", "null"] }, - explanation: { type: "string" }, - }, - required: ["sql", "explanation"], - additionalProperties: false, - }, - strict: true, - }, -} - -const inferProviderFromModel = (model: string): Provider => { - const found: ModelOption | undefined = MODEL_OPTIONS.find( - (m) => m.value === model, - ) - if (found) return found.provider - return model.startsWith("claude") ? "anthropic" : "openai" -} - -const createProviderClients = ( - settings: ActiveProviderSettings, -): ProviderClients => { - if (!settings.apiKey) { - throw new Error(`No API key found for ${settings.provider}`) - } - - if (settings.provider === "openai") { - return { - provider: settings.provider, - openai: new OpenAI({ - apiKey: settings.apiKey, - dangerouslyAllowBrowser: true, - }), - } - } - return { - provider: settings.provider, - anthropic: new Anthropic({ - apiKey: settings.apiKey, - dangerouslyAllowBrowser: true, - }), - } -} - -const SCHEMA_TOOLS: Array = [ - { - name: "get_tables", - description: - "Get a list of all tables and materialized views in the QuestDB database", - input_schema: { - type: "object" as const, - properties: {}, - }, - }, - { - name: "get_table_schema", - description: - "Get the full schema definition (DDL) for a specific table or materialized view", - input_schema: { - type: "object" as const, - properties: { - table_name: { - type: "string" as const, - description: - "The name of the table or materialized view to get schema for", - }, - }, - required: ["table_name"], - }, - }, - { - name: "get_table_details", - description: "Get the details of a specific table or materialized view", - input_schema: { - type: "object" as const, - properties: { - table_name: { - type: "string" as const, - description: - "The name of the table or materialized view to get details for", - }, - }, - required: ["table_name"], - }, - }, -] - -const REFERENCE_TOOLS = [ - { - name: "validate_query", - description: - "Validate the syntax correctness of a SQL query using QuestDB's SQL syntax validator. All generated SQL queries should be validated using this tool before responding to the user.", - input_schema: { - type: "object" as const, - properties: { - query: { - type: "string" as const, - description: "The SQL query to validate", - }, - }, - required: ["query"], - }, - }, - { - name: "get_questdb_toc", - description: - "Get a table of contents listing all available QuestDB functions, operators, and SQL keywords. Use this first to see what documentation is available before requesting specific items.", - input_schema: { - type: "object" as const, - properties: {}, - }, - }, - { - name: "get_questdb_documentation", - description: - "Get documentation for specific QuestDB functions, operators, or SQL keywords. This is much more efficient than loading all documentation.", - input_schema: { - type: "object" as const, - properties: { - category: { - type: "string" as const, - enum: [ - "functions", - "operators", - "sql", - "concepts", - "schema", - "cookbook", - ], - description: "The category of documentation to retrieve", - }, - items: { - type: "array" as const, - items: { - type: "string" as const, - }, - description: - "List of specific docs items in the category. IMPORTANT: Category of these items must match the category parameter. Name of these items should exactly match the entry in the table of contents you get with get_questdb_toc.", - }, - }, - required: ["category", "items"], - }, - }, -] - -const ALL_TOOLS = [...SCHEMA_TOOLS, ...REFERENCE_TOOLS] - -const toOpenAIFunctions = ( - tools: Array<{ - name: string - description?: string - input_schema: AnthropicTool["input_schema"] - }>, -) => { - return tools.map((t) => ({ - type: "function" as const, - name: t.name, - description: t.description, - parameters: { ...t.input_schema, additionalProperties: false }, - strict: true, - })) as OpenAI.Responses.Tool[] -} - export const normalizeSql = (sql: string, insertSemicolon: boolean = true) => { if (!sql) return "" let result = sql.trim() @@ -460,179 +243,6 @@ export const createStreamingCallback = ( } } -const DOCS_INSTRUCTION_ANTHROPIC = ` -CRITICAL: Always follow this documentation approach: -1. Use get_questdb_toc to see available functions, operators, SQL syntax, AND cookbook recipes -2. If user's request matches a cookbook recipe description, fetch it FIRST - recipes provide complete, tested SQL patterns -3. Use get_questdb_documentation for specific function/syntax details - -When a cookbook recipe matches the user's intent, ALWAYS use it as the foundation and adapt column/table names and use case to their schema.` - -const getUnifiedPrompt = (grantSchemaAccess?: boolean) => { - const base = `You are a SQL expert coding assistant specializing in QuestDB, a high-performance time-series database. You help users with: -- Generating QuestDB SQL queries from natural language descriptions -- Explaining what QuestDB SQL queries do -- Fixing errors in QuestDB SQL queries -- Refining and modifying existing queries based on user requests - -## CRITICAL: Tool and Response Sequencing -Follow this EXACT sequence for every query generation request: - -**PHASE 1 - INFORMATION GATHERING (NO TEXT OUTPUT)** -1. Call available tools to gather information if you need, including documentation, schema, and validation tools. -2. Complete ALL information gathering before Phase 2. DO NOT CALL any tool after Phase 2. - -**PHASE 2 - FINAL RESPONSE (NO MORE TOOL CALLS)** -3. Return your JSON response with "sql" and "explanation" fields. Always return sql field first, then explanation field. - -NEVER interleave phases. NEVER use any tool after starting to return a response. - -## When Explaining Queries -- Focus on the business logic and what the query achieves, not the SQL syntax itself -- Pay special attention to QuestDB-specific features: - - Time-series operations (SAMPLE BY, LATEST ON, designated timestamp columns) - - Time-based filtering and aggregations - - Real-time data ingestion patterns - - Performance optimizations specific to time-series data - -## When Generating SQL -- DO NOT return any content before completing your tool calls including documentation and validation tools. You should NOT CALL any tool after starting to return a response. -- Always validate the query in "sql" field using the validate_query tool before returning an explanation or a generated SQL query -- Generate only valid QuestDB SQL syntax referring to the documentation about functions, operators, and SQL keywords -- Use appropriate time-series functions (SAMPLE BY, LATEST ON, etc.) and common table expressions when relevant -- Use \`IN\` with \`today()\`, \`tomorrow()\`, \`yesterday()\` interval functions when relevant -- Follow QuestDB best practices for performance referring to the documentation -- Use proper timestamp handling for time-series data -- Use correct data types and functions specific to QuestDB referring to the documentation. Do not use any word that is not in the documentation. - -## When Fixing Queries -- DO NOT return any content before completing your tool calls including documentation and validation tools. You should NOT CALL any tool after starting to return a response. -- Always validate the query in "sql" field using the validate_query tool before returning an explanation or a fixed SQL query -- Analyze the error message carefully to understand what went wrong -- Generate only valid QuestDB SQL syntax by always referring to the documentation about functions, operators, and SQL keywords -- Preserve the original intent of the query while fixing the error -- Follow QuestDB best practices and syntax rules referring to the documentation -- Consider common issues like: - - Missing or incorrect column names - - Invalid syntax for time-series operations - - Data type mismatches - - Incorrect function usage - -## Response Guidelines -- You are working as a coding assistant inside an IDE. Every time you return a query in "sql" field, you provide a suggestion to the user to accept or reject. When the user accepts the suggestion, you are informed and the query in the editor is updated with your suggestion. -- Modify a query by returning "sql" field only if the user asks you to generate, fix, or make changes to the query. If the user does not ask for fixing/changing/generating a query, return null in the "sql" field. Every time you provide a SQL query, the current SQL is updated. -- Provide the "explanation" field if you haven't provided it yet. Explanation should be in GFM (GitHub Flavored Markdown) format. Explanation field is cumulative, every time you provide an explanation, it is added to the previous explanations. - -## Tools -- Use the validate_query tool to validate the query in "sql" field before returning a response only if the user asks you to generate, fix, or make changes to the query. -` - const schemaAccess = grantSchemaAccess - ? `- Use the get_tables tool to retrieve all tables and materialized views in the database instance -- Use the get_table_schema tool to get detailed schema information for a specific table or a materialized view -- Use the get_table_details tool to get detailed information for a specific table or a materialized view. Each property is described in meta functions docs. -` - : "" - return base + schemaAccess + DOCS_INSTRUCTION_ANTHROPIC -} - -export const getExplainSchemaPrompt = ( - tableName: string, - schema: string, - kindLabel: string, -) => `You are a SQL expert assistant specializing in QuestDB, a high-performance time-series database. -Explain the following ${kindLabel} schema. Include: -- The purpose of the ${kindLabel} -- What each column represents and its data type -- Any important properties like WAL enablement, partitioning strategy, designated timestamps -- Any performance or storage considerations - -${kindLabel} Name: ${tableName} - -Schema: -\`\`\`sql -${schema} -\`\`\` - -**IMPORTANT: Format your response in markdown exactly as follows:** - -1. Start with a brief paragraph explaining the purpose and general characteristics of this ${kindLabel}. - -2. Add a "## Columns" section with a markdown table: -| Column | Type | Description | -|--------|------|-------------| -| column_name | \`data_type\` | Brief description | - -3. If this is a table or materialized view (not a view), add a "## Storage Details" section with bullet points about: -- WAL enablement -- Partitioning strategy -- Designated timestamp column -- Any other storage considerations - -For views, skip the Storage Details section.` - -export type HealthIssuePromptData = { - tableName: string - issue: { - id: string - field: string - message: string - currentValue?: string - } - tableDetails: string - monitoringDocs: string - trendSamples?: Array<{ value: number; timestamp: number }> -} - -export const getHealthIssuePrompt = (data: HealthIssuePromptData): string => { - const { tableName, issue, tableDetails, monitoringDocs, trendSamples } = data - - let trendSection = "" - if (trendSamples && trendSamples.length > 0) { - const recentSamples = trendSamples.slice(-30) - trendSection = ` - -### Trend Data (Recent Samples) -| Timestamp | Value | -|-----------|-------| -${recentSamples.map((s) => `| ${new Date(s.timestamp).toISOString()} | ${s.value.toLocaleString()} |`).join("\n")} -` - } - - return `You are a QuestDB expert assistant helping diagnose and resolve table health issues. - -A user is viewing the health monitoring panel for their table and has asked for help with a detected issue. - -## Table: ${tableName} - -## Health Issue Detected -- **Issue ID**: ${issue.id} -- **Field**: ${issue.field} -- **Message**: ${issue.message} -${issue.currentValue ? `- **Current Value**: ${issue.currentValue}` : ""}${trendSection} - -## Table Details (from tables() function) -\`\`\`json -${tableDetails} -\`\`\` - -## QuestDB Monitoring Documentation -${monitoringDocs} - ---- - -**Your Task:** -1. Explain what this health issue means in the context of this specific table -2. Analyze the table details to identify potential root causes -3. Provide specific, actionable recommendations to resolve or mitigate the issue -4. If there is a clear SQL command that can help fix the issue (like \`ALTER TABLE ... RESUME WAL\`), include it in the "sql" field of your response. **Only provide SQL if it directly addresses the root cause** - do not provide SQL just to inspect the problem. - -**IMPORTANT: Be concise and thorough, format your response in markdown with clear sections:** -- Use ## headings for main sections -- Use bullet points for lists -- Use \`code\` for configuration values and SQL -` -} - const MAX_RETRIES = 2 const RETRY_DELAY = 1000 @@ -650,451 +260,9 @@ const handleRateLimit = async () => { lastRequestTime = Date.now() } -const isNonRetryableError = (error: unknown) => { - if (error instanceof StreamingError) { - return error.errorType === "interrupted" || error.errorType === "failed" - } - return ( - error instanceof RefusalError || - error instanceof MaxTokensError || - error instanceof Anthropic.AuthenticationError || - (typeof OpenAI !== "undefined" && - error instanceof OpenAI.AuthenticationError) || - // @ts-expect-error no proper rate limit error type - ("status" in error && error.status === 429) || - error instanceof OpenAI.APIUserAbortError || - error instanceof Anthropic.APIUserAbortError - ) -} - -const executeTool = async ( - toolName: string, - input: unknown, - modelToolsClient: ModelToolsClient, - setStatus: StatusCallback, -): Promise<{ content: string; is_error?: boolean }> => { - try { - switch (toolName) { - case "get_tables": { - setStatus(AIOperationStatus.RetrievingTables) - if (!modelToolsClient.getTables) { - return { - content: - "Error: Schema access is not granted. This tool is not available.", - is_error: true, - } - } - const result = await modelToolsClient.getTables() - const MAX_TABLES = 1000 - if (result.length > MAX_TABLES) { - const truncated = result.slice(0, MAX_TABLES) - return { - content: JSON.stringify( - { - tables: truncated, - total_count: result.length, - truncated: true, - message: `Showing ${MAX_TABLES} of ${result.length} tables. Use get_table_schema with a specific table name to get details if you are interested in a specific table.`, - }, - null, - 2, - ), - } - } - return { content: JSON.stringify(result, null, 2) } - } - case "get_table_schema": { - const tableName = (input as { table_name: string })?.table_name - if (!modelToolsClient.getTableSchema) { - return { - content: - "Error: Schema access is not granted. This tool is not available.", - is_error: true, - } - } - if (!tableName) { - return { - content: "Error: table_name parameter is required", - is_error: true, - } - } - setStatus(AIOperationStatus.InvestigatingTable, { - name: tableName, - tableOpType: "schema", - }) - const result = await modelToolsClient.getTableSchema(tableName) - return { - content: - result || `Table '${tableName}' not found or schema unavailable`, - } - } - case "get_table_details": { - const tableName = (input as { table_name: string })?.table_name - if (!modelToolsClient.getTableDetails) { - return { - content: - "Error: Schema access is not granted. This tool is not available.", - is_error: true, - } - } - if (!tableName) { - return { - content: "Error: table_name parameter is required", - is_error: true, - } - } - setStatus(AIOperationStatus.InvestigatingTable, { - name: tableName, - tableOpType: "details", - }) - const result = await modelToolsClient.getTableDetails(tableName) - return { - content: result - ? JSON.stringify(result, null, 2) - : "Table details not found", - is_error: !result, - } - } - case "validate_query": { - setStatus(AIOperationStatus.ValidatingQuery) - const query = (input as { query: string })?.query - if (!query) { - return { - content: "Error: query parameter is required", - is_error: true, - } - } - const result = await modelToolsClient.validateQuery(query) - const content = { - valid: result.valid, - error: result.valid ? undefined : result.error, - position: result.valid ? undefined : result.position, - } - return { content: JSON.stringify(content, null, 2) } - } - case "get_questdb_toc": { - setStatus(AIOperationStatus.RetrievingDocumentation) - const tocContent = await getQuestDBTableOfContents() - return { content: tocContent } - } - case "get_questdb_documentation": { - const { category, items } = - (input as { category: string; items: string[] }) || {} - if (!category || !items || !Array.isArray(items)) { - return { - content: "Error: category and items parameters are required", - is_error: true, - } - } - const parsedItems = parseDocItems(items) - - if (parsedItems.length > 0) { - setStatus(AIOperationStatus.InvestigatingDocs, { items: parsedItems }) - } else { - setStatus(AIOperationStatus.InvestigatingDocs) - } - const documentation = await getSpecificDocumentation( - category as DocCategory, - items, - ) - return { content: documentation } - } - default: - return { content: `Unknown tool: ${toolName}`, is_error: true } - } - } catch (error) { - return { - content: `Tool execution error: ${error instanceof Error ? error.message : "Unknown error"}`, - is_error: true, - } - } -} - -interface AnthropicToolCallResult { - message: Anthropic.Messages.Message - accumulatedTokens: TokenUsage -} - -async function handleToolCalls( - message: Anthropic.Messages.Message, - anthropic: Anthropic, - modelToolsClient: ModelToolsClient, - conversationHistory: Array, - model: string, - setStatus: StatusCallback, - responseFormat: ResponseTextConfig, - abortSignal?: AbortSignal, - accumulatedTokens: TokenUsage = { inputTokens: 0, outputTokens: 0 }, - streaming?: StreamingCallback, -): Promise { - const toolUseBlocks = message.content.filter( - (block) => block.type === "tool_use", - ) - const toolResults = [] - - if (abortSignal?.aborted) { - return { - type: "aborted", - message: "Operation was cancelled", - } as AiAssistantAPIError - } - - for (const toolUse of toolUseBlocks) { - if ("name" in toolUse) { - const exec = await executeTool( - toolUse.name, - toolUse.input, - modelToolsClient, - setStatus, - ) - toolResults.push({ - type: "tool_result" as const, - tool_use_id: toolUse.id, - content: exec.content, - is_error: exec.is_error, - }) - } - } - - const updatedHistory = [ - ...conversationHistory, - { - role: "assistant" as const, - content: message.content, - }, - { - role: "user" as const, - content: toolResults, - }, - ] - - const criticalTokenUsage = - message.usage.input_tokens >= COMPACTION_THRESHOLDS["anthropic"] && - toolResults.length > 0 - if (criticalTokenUsage) { - updatedHistory.push({ - role: "user" as const, - content: - "**CRITICAL TOKEN USAGE: The conversation is getting too long to fit the context window. If you are planning to use more tools, summarize your findings to the user first, and wait for user confirmation to continue working on the task.**", - }) - } - - const followUpParams: Parameters[1] = { - model, - tools: modelToolsClient.getTables ? ALL_TOOLS : REFERENCE_TOOLS, - messages: updatedHistory, - temperature: 0.3, - } - - const format = responseFormat.format as { type: string; schema?: object } - if (format.type === "json_schema" && format.schema) { - // @ts-expect-error - output_format is a new field not yet in the type definitions - followUpParams.output_format = { - type: "json_schema", - schema: format.schema, - } - } - - const followUpMessage = streaming - ? await createAnthropicMessageStreaming( - anthropic, - followUpParams, - streaming, - abortSignal, - ) - : await createAnthropicMessage(anthropic, followUpParams, abortSignal) - - // Accumulate tokens from this response - const newAccumulatedTokens: TokenUsage = { - inputTokens: - accumulatedTokens.inputTokens + - (followUpMessage.usage?.input_tokens || 0), - outputTokens: - accumulatedTokens.outputTokens + - (followUpMessage.usage?.output_tokens || 0), - } - - if (followUpMessage.stop_reason === "tool_use") { - return handleToolCalls( - followUpMessage, - anthropic, - modelToolsClient, - updatedHistory, - model, - setStatus, - responseFormat, - abortSignal, - newAccumulatedTokens, - streaming, - ) - } - - return { - message: followUpMessage, - accumulatedTokens: newAccumulatedTokens, - } -} - -async function createOpenAIResponseStreaming( - openai: OpenAI, - params: OpenAI.Responses.ResponseCreateParamsNonStreaming, - streamCallback: StreamingCallback, - abortSignal?: AbortSignal, -): Promise { - let accumulatedText = "" - let lastExplanation = "" - let finalResponse: OpenAI.Responses.Response | null = null - - try { - const stream = await openai.responses.create({ - ...params, - stream: true, - } as OpenAI.Responses.ResponseCreateParamsStreaming) - - for await (const event of stream) { - if (abortSignal?.aborted) { - throw new StreamingError("Operation aborted", "interrupted") - } - - if (event.type === "error") { - const errorEvent = event as { error?: { message?: string } } - throw new StreamingError( - errorEvent.error?.message || "Stream error occurred", - "failed", - event, - ) - } - - if (event.type === "response.failed") { - const failedEvent = event as { - response?: { error?: { message?: string } } - } - throw new StreamingError( - failedEvent.response?.error?.message || - "Provider failed to return a response", - "failed", - event, - ) - } - - if (event.type === "response.output_text.delta") { - accumulatedText += event.delta - const explanation = extractPartialExplanation(accumulatedText) - if (explanation !== lastExplanation) { - const chunk = explanation.slice(lastExplanation.length) - lastExplanation = explanation - streamCallback.onTextChunk(chunk, explanation) - } - } - - if (event.type === "response.completed") { - finalResponse = event.response - } - } - } catch (error) { - if (error instanceof StreamingError) { - throw error - } - if (abortSignal?.aborted || error instanceof OpenAI.APIUserAbortError) { - throw new StreamingError("Operation aborted", "interrupted") - } - throw new StreamingError( - error instanceof Error ? error.message : "Stream interrupted", - "network", - error, - ) - } - - if (!finalResponse) { - throw new StreamingError("Provider failed to return a response", "failed") - } - - return finalResponse -} - -const extractOpenAIToolCalls = ( - response: OpenAI.Responses.Response, -): { id?: string; name: string; arguments: unknown; call_id: string }[] => { - const calls = [] - for (const item of response.output) { - if (item?.type === "function_call") { - const args = - typeof item.arguments === "string" - ? safeJsonParse(item.arguments) - : item.arguments || {} - calls.push({ - id: item.id, - name: item.name, - arguments: args, - call_id: item.call_id, - }) - } - } - return calls -} - -const getOpenAIText = ( - response: OpenAI.Responses.Response, -): { type: "refusal" | "text"; message: string } => { - const out = response.output || [] - if ( - out.find( - (item: ResponseOutputItem) => - item.type === "message" && - item.content.some((c) => c.type === "refusal"), - ) - ) { - return { - type: "refusal", - message: "The model refused to generate a response for this request.", - } - } - - for (const item of out) { - if (item.type === "message" && item.content) { - for (const content of item.content) { - if (content.type === "output_text" && "text" in content) { - return { type: "text", message: content.text } - } - } - } - } - - return { type: "text", message: "" } -} - -const safeJsonParse = (text: string): T | object => { - try { - return JSON.parse(text) as T - } catch { - return {} - } -} - -/** - * Extracts partial explanation text from incomplete JSON during streaming. - * Handles JSON escape sequences and partial content. - */ -function extractPartialExplanation(partialJson: string): string { - // Match "explanation": "content... where content may be incomplete - const explanationMatch = partialJson.match( - /"explanation"\s*:\s*"((?:[^"\\]|\\.)*)/, - ) - if (!explanationMatch) { - return "" - } - - // Unescape JSON string escape sequences - return explanationMatch[1] - .replace(/\\n/g, "\n") - .replace(/\\r/g, "\r") - .replace(/\\t/g, "\t") - .replace(/\\"/g, '"') - .replace(/\\\\/g, "\\") -} - const tryWithRetries = async ( fn: () => Promise, + provider: AIProvider, setStatus: StatusCallback, abortSignal?: AbortSignal, ): Promise => { @@ -1113,13 +281,13 @@ const tryWithRetries = async ( console.error( "AI Assistant error:", error instanceof Error ? error.message : String(error), - isNonRetryableError(error) + provider.isNonRetryableError(error) ? "Non-retryable error." : "Remaining retries: " + (MAX_RETRIES - retries) + ".", ) retries++ - if (retries > MAX_RETRIES || isNonRetryableError(error)) { - return handleAiAssistantError(error, setStatus) + if (retries > MAX_RETRIES || provider.isNonRetryableError(error)) { + return provider.classifyError(error, setStatus) } await new Promise((resolve) => setTimeout(resolve, RETRY_DELAY * retries)) @@ -1133,709 +301,28 @@ const tryWithRetries = async ( } } -interface OpenAIFlowConfig { - systemInstructions: string - initialUserContent: string - conversationHistory?: Array<{ role: "user" | "assistant"; content: string }> - responseFormat: ResponseTextConfig - postProcess?: (formatted: T) => T +export const testApiKey = async ( + apiKey: string, + model: string, + providerId: Provider, +): Promise<{ valid: boolean; error?: string }> => { + const provider = createProvider(providerId, apiKey) + return provider.testConnection({ apiKey, model }) } -interface AnthropicFlowConfig { - systemInstructions: string - initialUserContent: string - conversationHistory?: Array<{ role: "user" | "assistant"; content: string }> - responseFormat: ResponseTextConfig - postProcess?: (formatted: T) => T -} +export const generateChatTitle = async ({ + firstUserMessage, + settings, +}: { + firstUserMessage: string + settings: ActiveProviderSettings +}): Promise => { + if (!settings.apiKey || !settings.model) { + return null + } -interface ExecuteAnthropicFlowParams { - anthropic: Anthropic - model: string - config: AnthropicFlowConfig - modelToolsClient: ModelToolsClient - setStatus: StatusCallback - abortSignal?: AbortSignal - streaming?: StreamingCallback -} - -interface ExecuteOpenAIFlowParams { - openai: OpenAI - model: string - config: OpenAIFlowConfig - modelToolsClient: ModelToolsClient - setStatus: StatusCallback - abortSignal?: AbortSignal - streaming?: StreamingCallback -} - -const executeOpenAIFlow = async ({ - openai, - model, - config, - modelToolsClient, - setStatus, - abortSignal, - streaming, -}: ExecuteOpenAIFlowParams): Promise => { - let input: OpenAI.Responses.ResponseInput = [] - if (config.conversationHistory && config.conversationHistory.length > 0) { - const validMessages = config.conversationHistory.filter( - (msg) => msg.content && msg.content.trim() !== "", - ) - for (const msg of validMessages) { - input.push({ - role: msg.role, - content: msg.content, - }) - } - } - - input.push({ - role: "user", - content: config.initialUserContent, - }) - - const grantSchemaAccess = !!modelToolsClient.getTables - const openaiTools = toOpenAIFunctions( - grantSchemaAccess ? ALL_TOOLS : REFERENCE_TOOLS, - ) - - // Accumulate tokens across all iterations - let totalInputTokens = 0 - let totalOutputTokens = 0 - - const requestParams = { - ...getModelProps(model), - instructions: config.systemInstructions, - input, - tools: openaiTools, - text: config.responseFormat, - } as OpenAI.Responses.ResponseCreateParamsNonStreaming - - // Use streaming for the initial call if callback provided - let lastResponse = streaming - ? await createOpenAIResponseStreaming( - openai, - requestParams, - streaming, - abortSignal, - ) - : await openai.responses.create(requestParams) - input = [...input, ...lastResponse.output] - - // Add tokens from first response - totalInputTokens += lastResponse.usage?.input_tokens ?? 0 - totalOutputTokens += lastResponse.usage?.output_tokens ?? 0 - - while (true) { - if (abortSignal?.aborted) { - return { - type: "aborted", - message: "Operation was cancelled", - } as AiAssistantAPIError - } - - const toolCalls = extractOpenAIToolCalls(lastResponse) - if (!toolCalls.length) break - const tool_outputs: OpenAI.Responses.ResponseFunctionToolCallOutputItem[] = - [] - for (const tc of toolCalls) { - const exec = await executeTool( - tc.name, - tc.arguments, - modelToolsClient, - setStatus, - ) - tool_outputs.push({ - type: "function_call_output", - call_id: tc.call_id, - output: exec.content, - } as OpenAI.Responses.ResponseFunctionToolCallOutputItem) - } - input = [...input, ...tool_outputs] - - if ( - (lastResponse.usage?.input_tokens ?? 0) >= - COMPACTION_THRESHOLDS["openai"] && - tool_outputs.length > 0 - ) { - input.push({ - role: "user" as const, - content: - "**CRITICAL TOKEN USAGE: The conversation is getting too long to fit the context window. If you are planning to use more tools, summarize your findings to the user first, and wait for user confirmation to continue working on the task.**", - }) - } - const loopRequestParams = { - ...getModelProps(model), - instructions: config.systemInstructions, - input, - tools: openaiTools, - text: config.responseFormat, - } as OpenAI.Responses.ResponseCreateParamsNonStreaming - - // Use streaming for follow-up calls if callback provided - lastResponse = streaming - ? await createOpenAIResponseStreaming( - openai, - loopRequestParams, - streaming, - abortSignal, - ) - : await openai.responses.create(loopRequestParams) - input = [...input, ...lastResponse.output] - - // Accumulate tokens from each iteration - totalInputTokens += lastResponse.usage?.input_tokens ?? 0 - totalOutputTokens += lastResponse.usage?.output_tokens ?? 0 - } - - if (abortSignal?.aborted) { - return { - type: "aborted", - message: "Operation was cancelled", - } as AiAssistantAPIError - } - - const text = getOpenAIText(lastResponse) - if (text.type === "refusal") { - return { - type: "unknown", - message: text.message, - } as AiAssistantAPIError - } - - const rawOutput = text.message - - try { - const json = JSON.parse(rawOutput) as T - setStatus(null) - - const resultWithTokens = { - ...json, - tokenUsage: { - inputTokens: totalInputTokens, - outputTokens: totalOutputTokens, - }, - } as T & { tokenUsage: TokenUsage } - - if (config.postProcess) { - const processed = config.postProcess(json) - return { - ...processed, - tokenUsage: { - inputTokens: totalInputTokens, - outputTokens: totalOutputTokens, - }, - } as T & { tokenUsage: TokenUsage } - } - return resultWithTokens - } catch (error) { - setStatus(null) - return { - type: "unknown", - message: "Failed to parse assistant response.", - } as AiAssistantAPIError - } -} - -const executeAnthropicFlow = async ({ - anthropic, - model, - config, - modelToolsClient, - setStatus, - abortSignal, - streaming, -}: ExecuteAnthropicFlowParams): Promise => { - const initialMessages: MessageParam[] = [] - if (config.conversationHistory && config.conversationHistory.length > 0) { - const validMessages = config.conversationHistory.filter( - (msg) => msg.content && msg.content.trim() !== "", - ) - for (const msg of validMessages) { - initialMessages.push({ - role: msg.role, - content: msg.content, - }) - } - } - - initialMessages.push({ - role: "user" as const, - content: config.initialUserContent, - }) - - const grantSchemaAccess = !!modelToolsClient.getTables - - const messageParams: Parameters[1] = { - model, - system: config.systemInstructions, - tools: grantSchemaAccess ? ALL_TOOLS : REFERENCE_TOOLS, - messages: initialMessages, - temperature: 0.3, - } - - if (config.responseFormat?.format) { - const format = config.responseFormat.format as { - type: string - schema?: object - } - if (format.type === "json_schema" && format.schema) { - // @ts-expect-error - output_format is a new field not yet in the type definitions - messageParams.output_format = { - type: "json_schema", - schema: format.schema, - } - } - } - - // Use streaming for the initial call if callback provided - const message = streaming - ? await createAnthropicMessageStreaming( - anthropic, - messageParams, - streaming, - abortSignal, - ) - : await createAnthropicMessage(anthropic, messageParams, abortSignal) - - let totalInputTokens = message.usage?.input_tokens || 0 - let totalOutputTokens = message.usage?.output_tokens || 0 - - let responseMessage: Anthropic.Messages.Message - - if (message.stop_reason === "tool_use") { - const toolCallResult = await handleToolCalls( - message, - anthropic, - modelToolsClient, - initialMessages, - model, - setStatus, - config.responseFormat, - abortSignal, - { inputTokens: 0, outputTokens: 0 }, // Start fresh, we already counted initial message - streaming, - ) - - if ("type" in toolCallResult && "message" in toolCallResult) { - return toolCallResult - } - - const result = toolCallResult - responseMessage = result.message - totalInputTokens += result.accumulatedTokens.inputTokens - totalOutputTokens += result.accumulatedTokens.outputTokens - } else { - responseMessage = message - } - - if (abortSignal?.aborted) { - return { - type: "aborted", - message: "Operation was cancelled", - } as AiAssistantAPIError - } - - const textBlock = responseMessage.content.find( - (block) => block.type === "text", - ) - if (!textBlock || !("text" in textBlock)) { - setStatus(null) - return { - type: "unknown", - message: "No text response received from assistant.", - } as AiAssistantAPIError - } - - try { - const json = JSON.parse(textBlock.text) as T - setStatus(null) - - const resultWithTokens = { - ...json, - tokenUsage: { - inputTokens: totalInputTokens, - outputTokens: totalOutputTokens, - }, - } as T & { tokenUsage: TokenUsage } - - if (config.postProcess) { - const processed = config.postProcess(json) - return { - ...processed, - tokenUsage: { - inputTokens: totalInputTokens, - outputTokens: totalOutputTokens, - }, - } as T & { tokenUsage: TokenUsage } - } - return resultWithTokens - } catch (error) { - setStatus(null) - return { - type: "unknown", - message: "Failed to parse assistant response.", - } as AiAssistantAPIError - } -} - -class RefusalError extends Error { - constructor(message: string) { - super(message) - this.name = "RefusalError" - } -} - -class MaxTokensError extends Error { - constructor(message: string) { - super(message) - this.name = "MaxTokensError" - } -} - -class StreamingError extends Error { - constructor( - message: string, - public readonly errorType: "failed" | "network" | "interrupted" | "unknown", - public readonly originalError?: unknown, - ) { - super(message) - this.name = "StreamingError" - } -} - -async function createAnthropicMessage( - anthropic: Anthropic, - params: Omit & { - max_tokens?: number - }, - signal?: AbortSignal, -): Promise { - const message = await anthropic.messages.create( - { - ...params, - stream: false, - max_tokens: params.max_tokens ?? 8192, - }, - { - headers: { - "anthropic-beta": "structured-outputs-2025-11-13", - }, - signal, - }, - ) - - if (message.stop_reason === "refusal") { - throw new RefusalError( - "The model refused to generate a response for this request.", - ) - } - if (message.stop_reason === "max_tokens") { - throw new MaxTokensError( - "The response exceeded the maximum token limit. Please try again with a different prompt or model.", - ) - } - - return message -} - -async function createAnthropicMessageStreaming( - anthropic: Anthropic, - params: Omit & { - max_tokens?: number - }, - streamCallback: StreamingCallback, - abortSignal?: AbortSignal, -): Promise { - let accumulatedText = "" - let lastExplanation = "" - - const stream = anthropic.messages.stream( - { - ...params, - max_tokens: params.max_tokens ?? 8192, - }, - { - headers: { - "anthropic-beta": "structured-outputs-2025-11-13", - }, - signal: abortSignal, - }, - ) - - try { - for await (const event of stream) { - if (abortSignal?.aborted) { - throw new StreamingError("Operation aborted", "interrupted") - } - - const eventWithType = event as { type: string } - if (eventWithType.type === "error") { - const errorEvent = event as { - error?: { type?: string; message?: string } - } - const errorType = errorEvent.error?.type - const errorMessage = errorEvent.error?.message || "Stream error" - - if (errorType === "overloaded_error") { - throw new StreamingError( - "Service is temporarily overloaded. Please try again.", - "failed", - event, - ) - } - throw new StreamingError(errorMessage, "failed", event) - } - - if ( - event.type === "content_block_delta" && - event.delta.type === "text_delta" - ) { - accumulatedText += event.delta.text - const explanation = extractPartialExplanation(accumulatedText) - if (explanation !== lastExplanation) { - const chunk = explanation.slice(lastExplanation.length) - lastExplanation = explanation - streamCallback.onTextChunk(chunk, explanation) - } - } - } - } catch (error) { - if (error instanceof StreamingError) { - throw error - } - if (abortSignal?.aborted) { - throw new StreamingError("Operation aborted", "interrupted") - } - throw new StreamingError( - error instanceof Error ? error.message : "Stream interrupted", - "network", - error, - ) - } - - let finalMessage: Anthropic.Messages.Message - try { - finalMessage = await stream.finalMessage() - } catch (error) { - if (abortSignal?.aborted || error instanceof Anthropic.APIUserAbortError) { - throw new StreamingError("Operation aborted", "interrupted") - } - throw new StreamingError( - "Failed to get final message from the provider", - "network", - error, - ) - } - - if (finalMessage.stop_reason === "refusal") { - throw new RefusalError( - "The model refused to generate a response for this request.", - ) - } - if (finalMessage.stop_reason === "max_tokens") { - throw new MaxTokensError( - "The response exceeded the maximum token limit. Please try again with a different prompt or model.", - ) - } - - return finalMessage -} - -function handleAiAssistantError( - error: unknown, - setStatus: StatusCallback, -): AiAssistantAPIError { - if ( - error instanceof OpenAI.APIUserAbortError || - error instanceof Anthropic.APIUserAbortError || - (error instanceof StreamingError && error.errorType === "interrupted") - ) { - setStatus(AIOperationStatus.Aborted) - return { - type: "aborted", - message: "Operation was cancelled", - } - } - setStatus(null) - - if (error instanceof RefusalError) { - return { - type: "unknown", - message: "The model refused to generate a response for this request.", - details: error.message, - } - } - - if (error instanceof MaxTokensError) { - return { - type: "unknown", - message: - "The response exceeded the maximum token limit for the selected model. Please try again with a different prompt or model.", - details: error.message, - } - } - - if (error instanceof StreamingError) { - switch (error.errorType) { - case "network": - return { - type: "network", - message: - "Network error during streaming. Please check your connection.", - details: error.message, - } - case "failed": - default: - return { - type: "unknown", - message: error.message || "Stream failed unexpectedly.", - details: - error.originalError instanceof Error - ? error.originalError.message - : undefined, - } - } - } - - if (error instanceof Anthropic.AuthenticationError) { - return { - type: "invalid_key", - message: "Invalid API key. Please check your Anthropic API key.", - details: error.message, - } - } - - if (error instanceof Anthropic.RateLimitError) { - return { - type: "rate_limit", - message: "Rate limit exceeded. Please try again later.", - details: error.message, - } - } - - if (error instanceof Anthropic.APIConnectionError) { - return { - type: "network", - message: "Network error. Please check your internet connection.", - details: error.message, - } - } - - if (error instanceof Anthropic.APIError) { - return { - type: "unknown", - message: `Anthropic API error: ${error.message}`, - } - } - - if (error instanceof OpenAI.APIError) { - return { - type: "unknown", - message: `OpenAI API error: ${error.message}`, - } - } - - return { - type: "unknown", - message: "An unexpected error occurred. Please try again.", - details: error as string, - } -} - -export const testApiKey = async ( - apiKey: string, - model: string, -): Promise<{ valid: boolean; error?: string }> => { - try { - if (inferProviderFromModel(model) === "anthropic") { - const anthropic = new Anthropic({ - apiKey, - dangerouslyAllowBrowser: true, - }) - - await createAnthropicMessage(anthropic, { - model, - messages: [ - { - role: "user", - content: "ping", - }, - ], - }) - } else { - const openai = new OpenAI({ apiKey, dangerouslyAllowBrowser: true }) - await openai.responses.create({ - model: getModelProps(model).model, - input: [{ role: "user", content: "ping" }], - max_output_tokens: 16, - }) - } - - return { valid: true } - } catch (error: unknown) { - if (error instanceof Anthropic.AuthenticationError) { - return { - valid: false, - error: "Invalid API key", - } - } - - if (error instanceof Anthropic.RateLimitError) { - return { - valid: true, - } - } - - const status = - (error as { status?: number })?.status || - (error as { error?: { status?: number } })?.error?.status - if (status === 401) { - return { valid: false, error: "Invalid API key" } - } - if (status === 429) { - return { valid: true } - } - - return { - valid: false, - error: - error instanceof Error ? error.message : "Failed to validate API key", - } - } -} - -const ChatTitleFormat: ResponseTextConfig = { - format: { - type: "json_schema" as const, - name: "chat_title_format", - schema: { - type: "object", - properties: { - title: { type: "string" }, - }, - required: ["title"], - additionalProperties: false, - }, - strict: true, - }, -} - -export const generateChatTitle = async ({ - firstUserMessage, - settings, -}: { - firstUserMessage: string - settings: ActiveProviderSettings -}): Promise => { - if (!settings.apiKey || !settings.model) { - return null - } - - try { - const clients = createProviderClients(settings) + try { + const provider = createProvider(settings.provider, settings.apiKey) const prompt = `Generate a concise chat title (max 30 characters) for this conversation with QuestDB AI Assistant. The title should capture the main topic or intent. @@ -1844,54 +331,12 @@ ${firstUserMessage} Return a JSON object with the following structure: { "title": "Your title here" }` - if (clients.provider === "openai") { - const response = await clients.openai.responses.create({ - ...getModelProps(settings.model), - input: [{ role: "user", content: prompt }], - text: ChatTitleFormat, - max_output_tokens: 100, - }) - try { - const parsed = JSON.parse(response.output_text) as { title: string } - return parsed.title || null - } catch { - return null - } - } - - const messageParams: Parameters[1] = { + return await provider.generateTitle({ model: settings.model, - messages: [{ role: "user", content: prompt }], - max_tokens: 100, - temperature: 0.3, - } - const titleFormat = ChatTitleFormat.format as { - type: string - schema?: object - } - // @ts-expect-error - output_format is a new field not yet in the type definitions - messageParams.output_format = { - type: "json_schema", - schema: titleFormat.schema, - } - - const message = await createAnthropicMessage( - clients.anthropic, - messageParams, - ) - - const textBlock = message.content.find((block) => block.type === "text") - if (textBlock && "text" in textBlock) { - try { - const parsed = JSON.parse(textBlock.text) as { title: string } - return parsed.title?.slice(0, 40) || null - } catch { - return null - } - } - return null + prompt, + responseFormat: ChatTitleFormat, + }) } catch (error) { - // Silently fail - title generation is not critical console.warn("Failed to generate chat title:", error) return null } @@ -1953,9 +398,10 @@ export const continueConversation = async ({ health_issue: ConversationResponseFormat, }[operation] + const provider = createProvider(settings.provider, settings.apiKey) + return tryWithRetries( async () => { - const clients = createProviderClients(settings) const grantSchemaAccess = !!modelToolsClient.getTables const systemPrompt = getUnifiedPrompt(grantSchemaAccess) @@ -1966,17 +412,11 @@ export const continueConversation = async ({ if (conversationHistory.length > 0) { const compactionResult = await compactConversationIfNeeded( conversationHistory, - settings.provider, + provider, systemPrompt, userMessage, () => setStatus(AIOperationStatus.Compacting), - { - anthropicClient: - clients.provider === "anthropic" ? clients.anthropic : undefined, - openaiClient: - clients.provider === "openai" ? clients.openai : undefined, - model: settings.model, - }, + { model: settings.model }, ) if ("error" in compactionResult) { @@ -2025,57 +465,13 @@ export const continueConversation = async ({ } } - if (clients.provider === "openai") { - const result = await executeOpenAIFlow<{ - sql?: string | null - explanation: string - tokenUsage?: TokenUsage - }>({ - openai: clients.openai, - model: settings.model, - config: { - systemInstructions: getUnifiedPrompt(grantSchemaAccess), - initialUserContent: userMessage, - conversationHistory: workingConversationHistory.filter( - (m) => !m.isCompacted, - ), - responseFormat, - postProcess: (formatted) => { - const sql = - formatted?.sql === null - ? null - : formatted?.sql - ? normalizeSql(formatted.sql) - : currentSQL || "" - return { - sql, - explanation: formatted?.explanation || "", - tokenUsage: formatted.tokenUsage, - } - }, - }, - modelToolsClient, - setStatus, - abortSignal, - streaming, - }) - if (isAiAssistantError(result)) { - return result - } - return { - ...postProcess(result), - compactedConversationHistory: isCompacted - ? workingConversationHistory - : undefined, - } - } + const tools = grantSchemaAccess ? ALL_TOOLS : REFERENCE_TOOLS - const result = await executeAnthropicFlow<{ + const result = await provider.executeFlow<{ sql?: string | null explanation: string tokenUsage?: TokenUsage }>({ - anthropic: clients.anthropic, model: settings.model, config: { systemInstructions: getUnifiedPrompt(grantSchemaAccess), @@ -2084,25 +480,14 @@ export const continueConversation = async ({ (m) => !m.isCompacted, ), responseFormat, - postProcess: (formatted) => { - const sql = - formatted?.sql === null - ? null - : formatted?.sql - ? normalizeSql(formatted.sql) - : currentSQL || "" - return { - sql, - explanation: formatted?.explanation || "", - tokenUsage: formatted.tokenUsage, - } - }, }, modelToolsClient, + tools, setStatus, abortSignal, streaming, }) + if (isAiAssistantError(result)) { return result } @@ -2113,6 +498,7 @@ export const continueConversation = async ({ : undefined, } }, + provider, setStatus, abortSignal, ) diff --git a/src/utils/aiAssistantSettings.ts b/src/utils/aiAssistantSettings.ts index 7892e29b5..2c68f1e07 100644 --- a/src/utils/aiAssistantSettings.ts +++ b/src/utils/aiAssistantSettings.ts @@ -1,6 +1,7 @@ -import { ReasoningEffort } from "openai/resources/shared" import type { AiAssistantSettings } from "../providers/LocalStorageProvider/types" +type ReasoningEffort = "high" | "medium" | "low" + export type Provider = "anthropic" | "openai" export type ModelOption = { diff --git a/src/utils/contextCompaction.ts b/src/utils/contextCompaction.ts index 63c5dd90a..194489320 100644 --- a/src/utils/contextCompaction.ts +++ b/src/utils/contextCompaction.ts @@ -1,16 +1,6 @@ -import Anthropic from "@anthropic-ai/sdk" -import OpenAI from "openai" import type { ConversationMessage } from "../providers/AIConversationProvider/types" -import { - countTokens, - COMPACTION_THRESHOLDS, - type ConversationMessage as TokenConversationMessage, -} from "./tokenCounting" -import { - type Provider, - MODEL_OPTIONS, - getModelProps, -} from "./aiAssistantSettings" +import { MODEL_OPTIONS } from "./aiAssistantSettings" +import type { AIProvider } from "./ai" type CompactionResultSuccess = { compactedMessage: string @@ -73,7 +63,7 @@ ${summary} function toTokenMessages( messages: [...ConversationMessage[], Omit], -): TokenConversationMessage[] { +): Array<{ role: "user" | "assistant"; content: string }> { return messages .filter((m) => m.content && m.content.trim() !== "") .map((m) => ({ @@ -84,12 +74,10 @@ function toTokenMessages( async function generateSummary( middleMessages: ConversationMessage[], - provider: Provider, - anthropicClient?: Anthropic, - openaiClient?: OpenAI, + aiProvider: AIProvider, ): Promise { const testModel = MODEL_OPTIONS.find( - (m) => m.provider === provider && m.isTestModel, + (m) => m.provider === aiProvider.id && m.isTestModel, ) if (!testModel) { throw new Error("No test model found for provider") @@ -101,41 +89,22 @@ async function generateSummary( const userMessage = `Please summarize the following conversation:\n\n${conversationText}` - if (provider === "anthropic" && anthropicClient) { - const response = await anthropicClient.messages.create({ - ...getModelProps(testModel.value), - max_tokens: 8192, - messages: [{ role: "user", content: userMessage }], - system: SUMMARIZATION_PROMPT, - }) - - const textBlock = response.content.find((block) => block.type === "text") - return textBlock?.type === "text" ? textBlock.text : "" - } else if (provider === "openai" && openaiClient) { - const response = await openaiClient.responses.create({ - ...getModelProps(testModel.value), - instructions: SUMMARIZATION_PROMPT, - input: userMessage, - }) - - return response.output_text || "" - } - - throw new Error("No valid client provided for summarization") + return aiProvider.generateSummary({ + model: testModel.value, + systemPrompt: SUMMARIZATION_PROMPT, + userMessage, + }) } export async function compactConversationIfNeeded( conversationHistory: ConversationMessage[], - provider: Provider, + aiProvider: AIProvider, systemPrompt: string, userMessage: string, setStatusCompacting: () => void, - options: { - anthropicClient?: Anthropic - openaiClient?: OpenAI - model?: string - } = {}, + options: { model?: string } = {}, ): Promise { + const compactionThreshold = aiProvider.contextWindow - 50_000 const messages = [ ...conversationHistory, { @@ -144,33 +113,31 @@ export async function compactConversationIfNeeded( timestamp: Date.now(), } as Omit, ] as [...ConversationMessage[], Omit] + const totalChars = systemPrompt.length + messages.reduce((sum, m) => sum + m.content.length, 0) - if (totalChars < COMPACTION_THRESHOLDS[provider]) { + + if (totalChars < compactionThreshold) { return { wasCompacted: false } } const tokenMessages = toTokenMessages(messages) - const estimatedTokens = await countTokens( - provider, - tokenMessages, - systemPrompt, - { - anthropicClient: options.anthropicClient, - model: options.model, - }, - ) - if (estimatedTokens === -1) { + let estimatedTokens: number + try { + estimatedTokens = await aiProvider.countTokens({ + messages: tokenMessages, + systemPrompt, + model: options.model ?? "", + }) + } catch { console.error( "Failed to estimate tokens for conversation, using full messages list.", ) - return { - wasCompacted: false, - } + return { wasCompacted: false } } - if (estimatedTokens <= COMPACTION_THRESHOLDS[provider]) { + if (estimatedTokens <= compactionThreshold) { return { wasCompacted: false } } @@ -184,9 +151,8 @@ export async function compactConversationIfNeeded( const result = await compactConversationInternal( conversationHistory, - provider, + aiProvider, setStatusCompacting, - options, ) if (!result.wasCompacted) { @@ -202,13 +168,8 @@ export async function compactConversationIfNeeded( async function compactConversationInternal( messages: ConversationMessage[], - provider: Provider, + aiProvider: AIProvider, setStatusCompacting: () => void, - options: { - anthropicClient?: Anthropic - openaiClient?: OpenAI - model?: string - } = {}, ): Promise { if (messages.length === 0) { return { wasCompacted: false } @@ -217,12 +178,7 @@ async function compactConversationInternal( setStatusCompacting() try { - const summary = await generateSummary( - messages, - provider, - options.anthropicClient, - options.openaiClient, - ) + const summary = await generateSummary(messages, aiProvider) return { compactedMessage: buildContinuationPrompt(summary), diff --git a/src/utils/executeAIFlow.ts b/src/utils/executeAIFlow.ts index 5f244165b..ecfe4295f 100644 --- a/src/utils/executeAIFlow.ts +++ b/src/utils/executeAIFlow.ts @@ -18,14 +18,13 @@ import { createStreamingCallback, isAiAssistantError, generateChatTitle, - getExplainSchemaPrompt, - getHealthIssuePrompt, type ActiveProviderSettings, type GeneratedSQL, type AiAssistantExplanation, type AiAssistantAPIError, type AIOperation, } from "./aiAssistant" +import { getExplainSchemaPrompt, getHealthIssuePrompt } from "./ai" import { providerForModel, MODEL_OPTIONS } from "./aiAssistantSettings" import { eventBus } from "../modules/EventBus" import { EventType } from "../modules/EventBus/types" diff --git a/src/utils/tokenCounting.ts b/src/utils/tokenCounting.ts deleted file mode 100644 index 35600b38d..000000000 --- a/src/utils/tokenCounting.ts +++ /dev/null @@ -1,101 +0,0 @@ -import Anthropic from "@anthropic-ai/sdk" -import type { Provider } from "./aiAssistantSettings" -import type { Tiktoken, TiktokenBPE } from "js-tiktoken/lite" - -export interface ConversationMessage { - role: "user" | "assistant" - content: string -} - -export const CONTEXT_LIMITS: Record = { - anthropic: 200_000, - openai: 400_000, -} - -export const COMPACTION_THRESHOLDS: Record = { - anthropic: 150_000, - openai: 350_000, -} - -export async function countTokensAnthropic( - client: Anthropic, - messages: ConversationMessage[], - systemPrompt: string, - model: string, -): Promise { - const anthropicMessages: Anthropic.MessageParam[] = messages.map((m) => ({ - role: m.role, - content: m.content, - })) - - const response = await client.messages.countTokens({ - model, - system: systemPrompt, - messages: anthropicMessages, - }) - - return response.input_tokens -} -let tiktokenEncoder: Tiktoken | null = null - -export async function countTokensOpenAI( - messages: ConversationMessage[], - systemPrompt: string, -): Promise { - if (!tiktokenEncoder) { - const { Tiktoken } = await import("js-tiktoken/lite") - const o200k_base = await import("js-tiktoken/ranks/o200k_base").then( - (module: { default: TiktokenBPE }) => module.default, - ) - tiktokenEncoder = new Tiktoken(o200k_base) - } - - let totalTokens = 0 - - totalTokens += tiktokenEncoder.encode(systemPrompt).length - // Add overhead for system message formatting - totalTokens += 4 // <|start|>system<|end|> overhead - - for (const message of messages) { - // Each message has overhead for role markers - totalTokens += 4 // <|start|>{role}<|end|> overhead - totalTokens += tiktokenEncoder.encode(message.content).length - } - - // Add 2 tokens for assistant reply priming - totalTokens += 2 - - return totalTokens -} - -export async function countTokens( - provider: Provider, - messages: ConversationMessage[], - systemPrompt: string, - options: { - anthropicClient?: Anthropic - model?: string - } = {}, -): Promise { - try { - if (provider === "anthropic") { - if (!options.anthropicClient || !options.model) { - return -1 - } - return await countTokensAnthropic( - options.anthropicClient, - messages, - systemPrompt, - options.model, - ) - } else { - return countTokensOpenAI(messages, systemPrompt) - } - } catch (error) { - console.warn( - "Failed to estimate tokens for conversation, using full messages list.", - error, - ) - return -1 - } -} From 10469dd869373a9acf2216923bdc8a6f94636f06 Mon Sep 17 00:00:00 2001 From: emrberk Date: Thu, 5 Mar 2026 19:06:36 +0300 Subject: [PATCH 02/25] move settings --- src/components/AIStatusIndicator/index.tsx | 2 +- .../SetupAIAssistant/ConfigurationModal.tsx | 6 +----- src/components/SetupAIAssistant/ModelDropdown.tsx | 2 +- src/components/SetupAIAssistant/SettingsModal.tsx | 2 +- src/providers/AIStatusProvider/index.tsx | 2 +- src/utils/ai/anthropicProvider.ts | 2 +- src/utils/ai/index.ts | 12 ++++++++++++ src/utils/ai/openaiProvider.ts | 2 +- src/utils/ai/registry.ts | 2 +- src/utils/{aiAssistantSettings.ts => ai/settings.ts} | 2 +- src/utils/ai/types.ts | 2 +- src/utils/aiAssistant.ts | 2 +- src/utils/contextCompaction.ts | 2 +- src/utils/executeAIFlow.ts | 2 +- 14 files changed, 25 insertions(+), 17 deletions(-) rename src/utils/{aiAssistantSettings.ts => ai/settings.ts} (98%) diff --git a/src/components/AIStatusIndicator/index.tsx b/src/components/AIStatusIndicator/index.tsx index 646680160..9b8f7286c 100644 --- a/src/components/AIStatusIndicator/index.tsx +++ b/src/components/AIStatusIndicator/index.tsx @@ -15,7 +15,7 @@ import { color } from "../../utils" import { slideAnimation } from "../Animation" import { AISparkle } from "../AISparkle" import { pinkLinearGradientHorizontal } from "../../theme" -import { MODEL_OPTIONS } from "../../utils/aiAssistantSettings" +import { MODEL_OPTIONS } from "../../utils/ai" import { useAIConversation } from "../../providers/AIConversationProvider" import { Button } from "../../components/Button" import { BrainIcon } from "../SetupAIAssistant/BrainIcon" diff --git a/src/components/SetupAIAssistant/ConfigurationModal.tsx b/src/components/SetupAIAssistant/ConfigurationModal.tsx index d3435524a..4b44504ab 100644 --- a/src/components/SetupAIAssistant/ConfigurationModal.tsx +++ b/src/components/SetupAIAssistant/ConfigurationModal.tsx @@ -11,11 +11,7 @@ import { useLocalStorage } from "../../providers/LocalStorageProvider" import { testApiKey } from "../../utils/aiAssistant" import { StoreKey } from "../../utils/localStorage/types" import { toast } from "../Toast" -import { - MODEL_OPTIONS, - type ModelOption, - type Provider, -} from "../../utils/aiAssistantSettings" +import { MODEL_OPTIONS, type ModelOption, type Provider } from "../../utils/ai" import { useModalNavigation } from "../MultiStepModal" import { OpenAIIcon } from "./OpenAIIcon" import { AnthropicIcon } from "./AnthropicIcon" diff --git a/src/components/SetupAIAssistant/ModelDropdown.tsx b/src/components/SetupAIAssistant/ModelDropdown.tsx index 063d53845..d05f4ace1 100644 --- a/src/components/SetupAIAssistant/ModelDropdown.tsx +++ b/src/components/SetupAIAssistant/ModelDropdown.tsx @@ -6,7 +6,7 @@ import { PopperToggle } from "../PopperToggle" import { Box } from "../Box" import { Text } from "../Text" import { useLocalStorage } from "../../providers/LocalStorageProvider" -import { MODEL_OPTIONS } from "../../utils/aiAssistantSettings" +import { MODEL_OPTIONS } from "../../utils/ai" import { useAIStatus } from "../../providers/AIStatusProvider" import { StoreKey } from "../../utils/localStorage/types" import { OpenAIIcon } from "./OpenAIIcon" diff --git a/src/components/SetupAIAssistant/SettingsModal.tsx b/src/components/SetupAIAssistant/SettingsModal.tsx index 3fbc3a46b..860ccfd2a 100644 --- a/src/components/SetupAIAssistant/SettingsModal.tsx +++ b/src/components/SetupAIAssistant/SettingsModal.tsx @@ -24,7 +24,7 @@ import { type ModelOption, type Provider, getNextModel, -} from "../../utils/aiAssistantSettings" +} from "../../utils/ai" import type { AiAssistantSettings } from "../../providers/LocalStorageProvider/types" import { ForwardRef } from "../ForwardRef" import { Badge, BadgeType } from "../../components/Badge" diff --git a/src/providers/AIStatusProvider/index.tsx b/src/providers/AIStatusProvider/index.tsx index e9dd5a43f..c72ce8538 100644 --- a/src/providers/AIStatusProvider/index.tsx +++ b/src/providers/AIStatusProvider/index.tsx @@ -14,7 +14,7 @@ import { hasSchemaAccess, providerForModel, canUseAiAssistant, -} from "../../utils/aiAssistantSettings" +} from "../../utils/ai" import { useAIConversation } from "../AIConversationProvider" export const useAIStatus = () => { diff --git a/src/utils/ai/anthropicProvider.ts b/src/utils/ai/anthropicProvider.ts index 1b4915753..20cd34edc 100644 --- a/src/utils/ai/anthropicProvider.ts +++ b/src/utils/ai/anthropicProvider.ts @@ -9,7 +9,7 @@ import type { TokenUsage, } from "../aiAssistant" import { AIOperationStatus } from "../../providers/AIStatusProvider" -import { getModelProps } from "../aiAssistantSettings" +import { getModelProps } from "./settings" import type { AIProvider, FlowConfig, diff --git a/src/utils/ai/index.ts b/src/utils/ai/index.ts index c868d9ea3..4481e5459 100644 --- a/src/utils/ai/index.ts +++ b/src/utils/ai/index.ts @@ -27,3 +27,15 @@ export { getHealthIssuePrompt, } from "./prompts" export type { HealthIssuePromptData } from "./prompts" +export { + MODEL_OPTIONS, + providerForModel, + getModelProps, + getAllProviders, + getSelectedModel, + getNextModel, + isAiAssistantConfigured, + canUseAiAssistant, + hasSchemaAccess, +} from "./settings" +export type { Provider, ModelOption } from "./settings" diff --git a/src/utils/ai/openaiProvider.ts b/src/utils/ai/openaiProvider.ts index a525d5746..dc881d823 100644 --- a/src/utils/ai/openaiProvider.ts +++ b/src/utils/ai/openaiProvider.ts @@ -11,7 +11,7 @@ import type { TokenUsage, } from "../aiAssistant" import { AIOperationStatus } from "../../providers/AIStatusProvider" -import { getModelProps } from "../aiAssistantSettings" +import { getModelProps } from "./settings" import type { AIProvider, FlowConfig, diff --git a/src/utils/ai/registry.ts b/src/utils/ai/registry.ts index 6b43772b7..812d2d323 100644 --- a/src/utils/ai/registry.ts +++ b/src/utils/ai/registry.ts @@ -1,7 +1,7 @@ import type { AIProvider } from "./types" import { createOpenAIProvider } from "./openaiProvider" import { createAnthropicProvider } from "./anthropicProvider" -import type { Provider } from "../aiAssistantSettings" +import type { Provider } from "./settings" export function createProvider( providerId: Provider, diff --git a/src/utils/aiAssistantSettings.ts b/src/utils/ai/settings.ts similarity index 98% rename from src/utils/aiAssistantSettings.ts rename to src/utils/ai/settings.ts index 2c68f1e07..7972d7253 100644 --- a/src/utils/aiAssistantSettings.ts +++ b/src/utils/ai/settings.ts @@ -1,4 +1,4 @@ -import type { AiAssistantSettings } from "../providers/LocalStorageProvider/types" +import type { AiAssistantSettings } from "../../providers/LocalStorageProvider/types" type ReasoningEffort = "high" | "medium" | "low" diff --git a/src/utils/ai/types.ts b/src/utils/ai/types.ts index 3ad77eb63..2a89109ae 100644 --- a/src/utils/ai/types.ts +++ b/src/utils/ai/types.ts @@ -4,7 +4,7 @@ import type { StatusCallback, StreamingCallback, } from "../aiAssistant" -import type { Provider } from "../aiAssistantSettings" +import type { Provider } from "./settings" export interface ToolDefinition { name: string diff --git a/src/utils/aiAssistant.ts b/src/utils/aiAssistant.ts index 8fee84215..dd3383755 100644 --- a/src/utils/aiAssistant.ts +++ b/src/utils/aiAssistant.ts @@ -1,6 +1,6 @@ import { Client } from "./questdb/client" import { Type, Table } from "./questdb/types" -import type { Provider } from "./aiAssistantSettings" +import type { Provider } from "./ai" import { formatSql } from "./formatSql" import { AIOperationStatus, StatusArgs } from "../providers/AIStatusProvider" import type { diff --git a/src/utils/contextCompaction.ts b/src/utils/contextCompaction.ts index 194489320..42eaa7c0d 100644 --- a/src/utils/contextCompaction.ts +++ b/src/utils/contextCompaction.ts @@ -1,5 +1,5 @@ import type { ConversationMessage } from "../providers/AIConversationProvider/types" -import { MODEL_OPTIONS } from "./aiAssistantSettings" +import { MODEL_OPTIONS } from "./ai" import type { AIProvider } from "./ai" type CompactionResultSuccess = { diff --git a/src/utils/executeAIFlow.ts b/src/utils/executeAIFlow.ts index ecfe4295f..1a7e84e23 100644 --- a/src/utils/executeAIFlow.ts +++ b/src/utils/executeAIFlow.ts @@ -25,7 +25,7 @@ import { type AIOperation, } from "./aiAssistant" import { getExplainSchemaPrompt, getHealthIssuePrompt } from "./ai" -import { providerForModel, MODEL_OPTIONS } from "./aiAssistantSettings" +import { providerForModel, MODEL_OPTIONS } from "./ai" import { eventBus } from "../modules/EventBus" import { EventType } from "../modules/EventBus/types" From e75cbb609557bbb78a2ebeb17872bec54fce6805 Mon Sep 17 00:00:00 2001 From: emrberk Date: Fri, 6 Mar 2026 14:20:14 +0300 Subject: [PATCH 03/25] remove structured output beta, update anthropic sdk --- package.json | 2 +- src/utils/ai/anthropicProvider.ts | 46 +++++++++---------------------- yarn.lock | 10 +++---- 3 files changed, 19 insertions(+), 39 deletions(-) diff --git a/package.json b/package.json index 778044fb1..a62d31fa7 100644 --- a/package.json +++ b/package.json @@ -31,7 +31,7 @@ "prepare": "husky" }, "dependencies": { - "@anthropic-ai/sdk": "^0.71.2", + "@anthropic-ai/sdk": "^0.78.0", "@date-fns/tz": "^1.2.0", "@docsearch/css": "^3.5.2", "@docsearch/react": "^3.5.2", diff --git a/src/utils/ai/anthropicProvider.ts b/src/utils/ai/anthropicProvider.ts index 20cd34edc..b5ba1cfce 100644 --- a/src/utils/ai/anthropicProvider.ts +++ b/src/utils/ai/anthropicProvider.ts @@ -1,5 +1,6 @@ import Anthropic from "@anthropic-ai/sdk" import type { MessageParam } from "@anthropic-ai/sdk/resources/messages" +import type { OutputConfig } from "@anthropic-ai/sdk/resources/messages" import type { Tool as AnthropicTool } from "@anthropic-ai/sdk/resources/messages" import type { AiAssistantAPIError, @@ -36,14 +37,12 @@ function toAnthropicTools(tools: ToolDefinition[]): AnthropicTool[] { })) } -// Wraps ResponseFormatSchema into Anthropic's output_format parameter -function toAnthropicOutputFormat( - format: ResponseFormatSchema, -): { type: "json_schema"; schema: Record } | undefined { - if (!format.schema) return undefined +function toAnthropicOutputConfig(format: ResponseFormatSchema): OutputConfig { return { - type: "json_schema", - schema: format.schema, + format: { + type: "json_schema", + schema: format.schema, + }, } } @@ -61,9 +60,6 @@ async function createAnthropicMessage( max_tokens: params.max_tokens ?? 8192, }, { - headers: { - "anthropic-beta": "structured-outputs-2025-11-13", - }, signal, }, ) @@ -99,9 +95,6 @@ async function createAnthropicMessageStreaming( max_tokens: params.max_tokens ?? 8192, }, { - headers: { - "anthropic-beta": "structured-outputs-2025-11-13", - }, signal: abortSignal, }, ) @@ -198,7 +191,7 @@ async function handleToolCalls( model: string, systemPrompt: string, setStatus: StatusCallback, - outputFormat: ReturnType, + outputConfig: OutputConfig, tools: AnthropicTool[], contextWindow: number, abortSignal?: AbortSignal, @@ -263,11 +256,7 @@ async function handleToolCalls( tools, messages: updatedHistory, temperature: 0.3, - } - - if (outputFormat) { - // @ts-expect-error - output_format is a new field not yet in the type definitions - followUpParams.output_format = outputFormat + output_config: outputConfig, } const followUpMessage = streaming @@ -297,7 +286,7 @@ async function handleToolCalls( model, systemPrompt, setStatus, - outputFormat, + outputConfig, tools, contextWindow, abortSignal, @@ -360,7 +349,7 @@ export function createAnthropicProvider(apiKey: string): AIProvider { }) const anthropicTools = toAnthropicTools(tools) - const outputFormat = toAnthropicOutputFormat(config.responseFormat) + const outputConfig = toAnthropicOutputConfig(config.responseFormat) const messageParams: Parameters[1] = { model, @@ -368,11 +357,7 @@ export function createAnthropicProvider(apiKey: string): AIProvider { tools: anthropicTools, messages: initialMessages, temperature: 0.3, - } - - if (outputFormat) { - // @ts-expect-error - output_format is a new field not yet in the type definitions - messageParams.output_format = outputFormat + output_config: outputConfig, } const message = streaming @@ -398,7 +383,7 @@ export function createAnthropicProvider(apiKey: string): AIProvider { model, config.systemInstructions, setStatus, - outputFormat, + outputConfig, anthropicTools, contextWindow, abortSignal, @@ -475,13 +460,8 @@ export function createAnthropicProvider(apiKey: string): AIProvider { messages: [{ role: "user", content: prompt }], max_tokens: 100, temperature: 0.3, + output_config: toAnthropicOutputConfig(responseFormat), } - const outputFormat = toAnthropicOutputFormat(responseFormat) - if (outputFormat) { - // @ts-expect-error - output_format is a new field not yet in the type definitions - messageParams.output_format = outputFormat - } - const message = await createAnthropicMessage(anthropic, messageParams) const textBlock = message.content.find((block) => block.type === "text") diff --git a/yarn.lock b/yarn.lock index 5e869060a..dfef7b0d4 100644 --- a/yarn.lock +++ b/yarn.lock @@ -211,9 +211,9 @@ __metadata: languageName: node linkType: hard -"@anthropic-ai/sdk@npm:^0.71.2": - version: 0.71.2 - resolution: "@anthropic-ai/sdk@npm:0.71.2" +"@anthropic-ai/sdk@npm:^0.78.0": + version: 0.78.0 + resolution: "@anthropic-ai/sdk@npm:0.78.0" dependencies: json-schema-to-ts: "npm:^3.1.1" peerDependencies: @@ -223,7 +223,7 @@ __metadata: optional: true bin: anthropic-ai-sdk: bin/cli - checksum: 10/a8190f9e860079dd97a544a95f36bd4b0b3a9a941610d7e067c431dc47febe03e3e761fc371166b261af9629d832533eeb3d8e72298e9f73dd52994a61881a2c + checksum: 10/7cb34e36d4fc766f0765b2581596825996073b03eec97a1193f07c6ca4ab48a021310dae9df630d61550ae2aa7fb3a6cf54236f7418932b25ea1a0e32624fdf1 languageName: node linkType: hard @@ -2461,7 +2461,7 @@ __metadata: resolution: "@questdb/web-console@workspace:." dependencies: "@4tw/cypress-drag-drop": "npm:^2.2.5" - "@anthropic-ai/sdk": "npm:^0.71.2" + "@anthropic-ai/sdk": "npm:^0.78.0" "@babel/core": "npm:^7.28.5" "@babel/preset-env": "npm:^7.20.2" "@babel/preset-react": "npm:^7.17.12" From 6d7e488f31fd07862bbb3e9bc5d7ac214acda645 Mon Sep 17 00:00:00 2001 From: emrberk Date: Fri, 6 Mar 2026 18:24:13 +0300 Subject: [PATCH 04/25] add chat completions provider support --- e2e/questdb | 2 +- src/providers/AIStatusProvider/index.tsx | 1 + src/providers/LocalStorageProvider/index.tsx | 7 +- src/providers/LocalStorageProvider/types.ts | 5 +- src/utils/ai/anthropicProvider.ts | 8 +- src/utils/ai/index.ts | 8 +- src/utils/ai/openaiChatCompletionsProvider.ts | 598 ++++++++++++++++++ src/utils/ai/openaiProvider.ts | 35 +- src/utils/ai/registry.ts | 25 +- src/utils/ai/settings.test.ts | 375 +++++++++++ src/utils/ai/settings.ts | 141 +++-- src/utils/ai/types.ts | 4 +- src/utils/executeAIFlow.ts | 5 +- 13 files changed, 1149 insertions(+), 65 deletions(-) create mode 100644 src/utils/ai/openaiChatCompletionsProvider.ts create mode 100644 src/utils/ai/settings.test.ts diff --git a/e2e/questdb b/e2e/questdb index 42483f08e..2263b2adb 160000 --- a/e2e/questdb +++ b/e2e/questdb @@ -1 +1 @@ -Subproject commit 42483f08e40ce6ab69e6d85d61f0265ddc8d1d41 +Subproject commit 2263b2adb482cf7b8290306a7f736381670f76cc diff --git a/src/providers/AIStatusProvider/index.tsx b/src/providers/AIStatusProvider/index.tsx index c72ce8538..b4c7e8dcf 100644 --- a/src/providers/AIStatusProvider/index.tsx +++ b/src/providers/AIStatusProvider/index.tsx @@ -136,6 +136,7 @@ export const AIStatusProvider: React.FC = ({ const apiKey = useMemo(() => { if (!currentModel) return null const provider = providerForModel(currentModel) + if (!provider) return null return aiAssistantSettings.providers?.[provider]?.apiKey || null }, [currentModel, aiAssistantSettings]) diff --git a/src/providers/LocalStorageProvider/index.tsx b/src/providers/LocalStorageProvider/index.tsx index ec3ecf258..baeae1d89 100644 --- a/src/providers/LocalStorageProvider/index.tsx +++ b/src/providers/LocalStorageProvider/index.tsx @@ -33,6 +33,7 @@ import { LeftPanelState, LeftPanelType, } from "./types" +import { reconcileSettings } from "../../utils/ai/settings" export const DEFAULT_AI_ASSISTANT_SETTINGS: AiAssistantSettings = { providers: {}, @@ -139,10 +140,14 @@ export const LocalStorageProvider = ({ if (stored) { try { const parsed = JSON.parse(stored) as AiAssistantSettings - return { + const reconciled = reconcileSettings({ selectedModel: parsed.selectedModel, providers: parsed.providers || {}, + }) + if (JSON.stringify(reconciled) !== stored) { + setValue(StoreKey.AI_ASSISTANT_SETTINGS, JSON.stringify(reconciled)) } + return reconciled } catch (e) { return defaultConfig.aiAssistantSettings } diff --git a/src/providers/LocalStorageProvider/types.ts b/src/providers/LocalStorageProvider/types.ts index 7a7335d09..81b1a3d17 100644 --- a/src/providers/LocalStorageProvider/types.ts +++ b/src/providers/LocalStorageProvider/types.ts @@ -6,10 +6,7 @@ export type ProviderSettings = { export type AiAssistantSettings = { selectedModel?: string - providers: { - anthropic?: ProviderSettings - openai?: ProviderSettings - } + providers: Partial> } export type SettingsType = string | boolean | number | AiAssistantSettings diff --git a/src/utils/ai/anthropicProvider.ts b/src/utils/ai/anthropicProvider.ts index b5ba1cfce..5a9b14983 100644 --- a/src/utils/ai/anthropicProvider.ts +++ b/src/utils/ai/anthropicProvider.ts @@ -11,6 +11,7 @@ import type { } from "../aiAssistant" import { AIOperationStatus } from "../../providers/AIStatusProvider" import { getModelProps } from "./settings" +import type { ProviderId } from "./settings" import type { AIProvider, FlowConfig, @@ -301,7 +302,10 @@ async function handleToolCalls( } } -export function createAnthropicProvider(apiKey: string): AIProvider { +export function createAnthropicProvider( + apiKey: string, + providerId: ProviderId = "anthropic", +): AIProvider { const anthropic = new Anthropic({ apiKey, dangerouslyAllowBrowser: true, @@ -310,7 +314,7 @@ export function createAnthropicProvider(apiKey: string): AIProvider { const contextWindow = 200_000 return { - id: "anthropic", + id: providerId, contextWindow, async executeFlow({ diff --git a/src/utils/ai/index.ts b/src/utils/ai/index.ts index 4481e5459..2cee7f130 100644 --- a/src/utils/ai/index.ts +++ b/src/utils/ai/index.ts @@ -38,4 +38,10 @@ export { canUseAiAssistant, hasSchemaAccess, } from "./settings" -export type { Provider, ModelOption } from "./settings" +export type { + ProviderId, + Provider, + ProviderType, + ModelOption, +} from "./settings" +export { PROVIDER_TYPE } from "./settings" diff --git a/src/utils/ai/openaiChatCompletionsProvider.ts b/src/utils/ai/openaiChatCompletionsProvider.ts new file mode 100644 index 000000000..5f903d859 --- /dev/null +++ b/src/utils/ai/openaiChatCompletionsProvider.ts @@ -0,0 +1,598 @@ +import OpenAI from "openai" +import type { + ChatCompletionMessageParam, + ChatCompletionMessageToolCall, +} from "openai/resources/chat/completions" +import type { + AiAssistantAPIError, + ModelToolsClient, + StatusCallback, + StreamingCallback, + TokenUsage, +} from "../aiAssistant" +import { AIOperationStatus } from "../../providers/AIStatusProvider" +import { getModelProps } from "./settings" +import type { ProviderId } from "./settings" +import type { + AIProvider, + FlowConfig, + ResponseFormatSchema, + ToolDefinition, +} from "./types" +import { + StreamingError, + RefusalError, + MaxTokensError, + safeJsonParse, + extractPartialExplanation, + executeTool, +} from "./shared" +import type { Tiktoken, TiktokenBPE } from "js-tiktoken/lite" + +function toResponseFormat(format: ResponseFormatSchema) { + return { + type: "json_schema" as const, + json_schema: { + name: format.name, + schema: format.schema, + strict: format.strict, + }, + } +} + +function toOpenAITools( + tools: ToolDefinition[], +): OpenAI.Chat.Completions.ChatCompletionTool[] { + return tools.map((t) => ({ + type: "function" as const, + function: { + name: t.name, + description: t.description, + parameters: { ...t.inputSchema, additionalProperties: false }, + strict: true, + }, + })) +} + +interface RequestResult { + content: string + toolCalls: { id: string; name: string; arguments: unknown }[] + promptTokens: number + completionTokens: number + assistantMessage: ChatCompletionMessageParam +} + +async function createChatCompletionStreaming( + openai: OpenAI, + params: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming, + streamCallback: StreamingCallback, + abortSignal?: AbortSignal, +): Promise<{ + content: string + refusal: string | null + finishReason: string | null + toolCalls: ChatCompletionMessageToolCall[] + usage: { prompt_tokens: number; completion_tokens: number } | null +}> { + let accumulatedText = "" + let accumulatedRefusal = "" + let lastExplanation = "" + let finishReason: string | null = null + const toolCallAccumulator: Map< + number, + { id: string; name: string; arguments: string } + > = new Map() + let usage: { prompt_tokens: number; completion_tokens: number } | null = null + + try { + const stream = await openai.chat.completions.create({ + ...params, + stream: true, + stream_options: { include_usage: true }, + }) + + for await (const chunk of stream) { + if (abortSignal?.aborted) { + throw new StreamingError("Operation aborted", "interrupted") + } + + const choice = chunk.choices?.[0] + + if (choice?.delta?.content) { + accumulatedText += choice.delta.content + const explanation = extractPartialExplanation(accumulatedText) + if (explanation !== lastExplanation) { + const delta = explanation.slice(lastExplanation.length) + lastExplanation = explanation + streamCallback.onTextChunk(delta, explanation) + } + } + + if (choice?.delta?.refusal) { + accumulatedRefusal += choice.delta.refusal + } + + if (choice?.finish_reason) { + finishReason = choice.finish_reason + } + + if (choice?.delta?.tool_calls) { + for (const tc of choice.delta.tool_calls) { + const existing = toolCallAccumulator.get(tc.index) + if (existing) { + if (tc.id) existing.id = tc.id + if (tc.function?.name) existing.name = tc.function.name + existing.arguments += tc.function?.arguments ?? "" + } else { + toolCallAccumulator.set(tc.index, { + id: tc.id ?? "", + name: tc.function?.name ?? "", + arguments: tc.function?.arguments ?? "", + }) + } + } + } + + if (chunk.usage) { + usage = { + prompt_tokens: chunk.usage.prompt_tokens, + completion_tokens: chunk.usage.completion_tokens, + } + } + } + } catch (error) { + if (error instanceof StreamingError) { + throw error + } + if (abortSignal?.aborted || error instanceof OpenAI.APIUserAbortError) { + throw new StreamingError("Operation aborted", "interrupted") + } + throw new StreamingError( + error instanceof Error ? error.message : "Stream interrupted", + "network", + error, + ) + } + + const toolCalls: ChatCompletionMessageToolCall[] = Array.from( + toolCallAccumulator.values(), + ).map((tc) => ({ + id: tc.id, + type: "function" as const, + function: { name: tc.name, arguments: tc.arguments }, + })) + + return { + content: accumulatedText, + refusal: accumulatedRefusal || null, + finishReason, + toolCalls, + usage, + } +} + +function extractToolCallsFromMessage( + toolCalls: ChatCompletionMessageToolCall[], +): { id: string; name: string; arguments: unknown }[] { + return toolCalls + .filter((tc) => tc.type === "function") + .map((tc) => ({ + id: tc.id, + name: tc.function.name, + arguments: safeJsonParse(tc.function.arguments), + })) +} + +function buildAssistantMessage( + content: string | null, + toolCalls: ChatCompletionMessageToolCall[], +): ChatCompletionMessageParam { + return { + role: "assistant" as const, + content, + ...(toolCalls.length > 0 ? { tool_calls: toolCalls } : {}), + } +} + +async function executeRequest( + openai: OpenAI, + params: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming, + streaming?: StreamingCallback, + abortSignal?: AbortSignal, +): Promise { + if (streaming) { + const accumulated = await createChatCompletionStreaming( + openai, + params, + streaming, + abortSignal, + ) + + if (accumulated.refusal) { + throw new RefusalError(accumulated.refusal) + } + if (accumulated.finishReason === "length") { + throw new MaxTokensError( + "Response truncated: the model ran out of tokens.", + ) + } + + return { + content: accumulated.content, + toolCalls: extractToolCallsFromMessage(accumulated.toolCalls), + promptTokens: accumulated.usage?.prompt_tokens ?? 0, + completionTokens: accumulated.usage?.completion_tokens ?? 0, + assistantMessage: buildAssistantMessage( + accumulated.content || null, + accumulated.toolCalls, + ), + } + } + + const response = await openai.chat.completions.create(params) + const message = response.choices[0]?.message + + if (message?.refusal) { + throw new RefusalError(message.refusal) + } + if (response.choices[0]?.finish_reason === "length") { + throw new MaxTokensError("Response truncated: the model ran out of tokens.") + } + + const rawToolCalls = message?.tool_calls ?? [] + + return { + content: message?.content ?? "", + toolCalls: extractToolCallsFromMessage(rawToolCalls), + promptTokens: response.usage?.prompt_tokens ?? 0, + completionTokens: response.usage?.completion_tokens ?? 0, + assistantMessage: buildAssistantMessage( + message?.content ?? null, + rawToolCalls.filter( + (tc): tc is Extract => + tc.type === "function", + ), + ), + } +} + +let tiktokenEncoder: Tiktoken | null = null + +function toChatCompletionsAPIProps(model: string): { + model: string + reasoning_effort?: OpenAI.ReasoningEffort +} { + const props = getModelProps(model) + return { + model: props.model, + ...(props.reasoningEffort + ? { reasoning_effort: props.reasoningEffort as OpenAI.ReasoningEffort } + : {}), + } +} + +export function createOpenAIChatCompletionsProvider( + apiKey: string, + providerId: ProviderId = "openai", +): AIProvider { + const openai = new OpenAI({ + apiKey, + dangerouslyAllowBrowser: true, + }) + + const contextWindow = 400_000 + + return { + id: providerId, + contextWindow, + + async executeFlow({ + model, + config, + modelToolsClient, + tools, + setStatus, + abortSignal, + streaming, + }: { + model: string + config: FlowConfig + modelToolsClient: ModelToolsClient + tools: ToolDefinition[] + setStatus: StatusCallback + abortSignal?: AbortSignal + streaming?: StreamingCallback + }): Promise { + const messages: ChatCompletionMessageParam[] = [ + { role: "system", content: config.systemInstructions }, + ] + + if (config.conversationHistory && config.conversationHistory.length > 0) { + const validMessages = config.conversationHistory.filter( + (msg) => msg.content && msg.content.trim() !== "", + ) + for (const msg of validMessages) { + messages.push({ role: msg.role, content: msg.content }) + } + } + + messages.push({ role: "user", content: config.initialUserContent }) + + const openaiTools = toOpenAITools(tools) + let totalInputTokens = 0 + let totalOutputTokens = 0 + let lastPromptTokens = 0 + + const baseParams = { + ...toChatCompletionsAPIProps(model), + tools: openaiTools, + response_format: toResponseFormat(config.responseFormat), + } + + let result = await executeRequest( + openai, + { + ...baseParams, + messages, + } as OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming, + streaming, + abortSignal, + ) + totalInputTokens += result.promptTokens + totalOutputTokens += result.completionTokens + lastPromptTokens = result.promptTokens + messages.push(result.assistantMessage) + + while (true) { + if (abortSignal?.aborted) { + return { + type: "aborted", + message: "Operation was cancelled", + } as AiAssistantAPIError + } + + if (!result.toolCalls.length) break + + for (const tc of result.toolCalls) { + const exec = await executeTool( + tc.name, + tc.arguments, + modelToolsClient, + setStatus, + ) + messages.push({ + role: "tool", + tool_call_id: tc.id, + content: exec.content, + }) + } + + if ( + lastPromptTokens >= contextWindow - 50_000 && + result.toolCalls.length > 0 + ) { + messages.push({ + role: "user" as const, + content: + "**CRITICAL TOKEN USAGE: The conversation is getting too long to fit the context window. If you are planning to use more tools, summarize your findings to the user first, and wait for user confirmation to continue working on the task.**", + }) + } + + result = await executeRequest( + openai, + { + ...baseParams, + messages, + } as OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming, + streaming, + abortSignal, + ) + totalInputTokens += result.promptTokens + totalOutputTokens += result.completionTokens + lastPromptTokens = result.promptTokens + messages.push(result.assistantMessage) + } + + if (abortSignal?.aborted) { + return { + type: "aborted", + message: "Operation was cancelled", + } as AiAssistantAPIError + } + + try { + const json = JSON.parse(result.content) as T + setStatus(null) + + const resultWithTokens = { + ...json, + tokenUsage: { + inputTokens: totalInputTokens, + outputTokens: totalOutputTokens, + }, + } as T & { tokenUsage: TokenUsage } + + if (config.postProcess) { + const processed = config.postProcess(json) + return { + ...processed, + tokenUsage: { + inputTokens: totalInputTokens, + outputTokens: totalOutputTokens, + }, + } as T & { tokenUsage: TokenUsage } + } + return resultWithTokens + } catch { + setStatus(null) + return { + type: "unknown", + message: "Failed to parse assistant response.", + } as AiAssistantAPIError + } + }, + + async generateTitle({ model, prompt, responseFormat }) { + try { + const response = await openai.chat.completions.create({ + ...toChatCompletionsAPIProps(model), + messages: [{ role: "user", content: prompt }], + response_format: toResponseFormat(responseFormat), + max_completion_tokens: 100, + }) + const content = response.choices[0]?.message?.content || "" + const parsed = JSON.parse(content) as { title: string } + return parsed.title || null + } catch { + return null + } + }, + + async generateSummary({ model, systemPrompt, userMessage }) { + const response = await openai.chat.completions.create({ + ...toChatCompletionsAPIProps(model), + messages: [ + { role: "system", content: systemPrompt }, + { role: "user", content: userMessage }, + ], + }) + return response.choices[0]?.message?.content || "" + }, + + async testConnection({ apiKey: testApiKey, model }) { + try { + const testClient = new OpenAI({ + apiKey: testApiKey, + dangerouslyAllowBrowser: true, + }) + await testClient.chat.completions.create({ + model: getModelProps(model).model, + messages: [{ role: "user", content: "ping" }], + max_completion_tokens: 16, + }) + return { valid: true } + } catch (error: unknown) { + const status = + (error as { status?: number })?.status || + (error as { error?: { status?: number } })?.error?.status + if (status === 401) { + return { valid: false, error: "Invalid API key" } + } + if (status === 429) { + return { valid: true } + } + return { + valid: false, + error: + error instanceof Error + ? error.message + : "Failed to validate API key", + } + } + }, + + async countTokens({ messages, systemPrompt }) { + if (!tiktokenEncoder) { + const { Tiktoken } = await import("js-tiktoken/lite") + const o200k_base = await import("js-tiktoken/ranks/o200k_base").then( + (module: { default: TiktokenBPE }) => module.default, + ) + tiktokenEncoder = new Tiktoken(o200k_base) + } + + let totalTokens = 0 + totalTokens += tiktokenEncoder.encode(systemPrompt).length + totalTokens += 4 // system message formatting overhead + + for (const message of messages) { + totalTokens += 4 // role markers overhead + totalTokens += tiktokenEncoder.encode(message.content).length + } + + totalTokens += 2 // assistant reply priming + return totalTokens + }, + + classifyError( + error: unknown, + setStatus: StatusCallback, + ): AiAssistantAPIError { + if ( + error instanceof OpenAI.APIUserAbortError || + (error instanceof StreamingError && error.errorType === "interrupted") + ) { + setStatus(AIOperationStatus.Aborted) + return { type: "aborted", message: "Operation was cancelled" } + } + setStatus(null) + + if (error instanceof RefusalError) { + return { + type: "unknown", + message: "The model refused to generate a response for this request.", + details: error.message, + } + } + + if (error instanceof MaxTokensError) { + return { + type: "unknown", + message: + "The response exceeded the maximum token limit for the selected model. Please try again with a different prompt or model.", + details: error.message, + } + } + + if (error instanceof StreamingError) { + switch (error.errorType) { + case "network": + return { + type: "network", + message: + "Network error during streaming. Please check your connection.", + details: error.message, + } + case "failed": + default: + return { + type: "unknown", + message: error.message || "Stream failed unexpectedly.", + details: + error.originalError instanceof Error + ? error.originalError.message + : undefined, + } + } + } + + if (error instanceof OpenAI.APIError) { + return { + type: "unknown", + message: `OpenAI API error: ${error.message}`, + } + } + + return { + type: "unknown", + message: "An unexpected error occurred. Please try again.", + details: error as string, + } + }, + + isNonRetryableError(error: unknown): boolean { + if (error instanceof StreamingError) { + return error.errorType === "interrupted" || error.errorType === "failed" + } + return ( + error instanceof RefusalError || + error instanceof MaxTokensError || + error instanceof OpenAI.AuthenticationError || + (error != null && + typeof error === "object" && + "status" in error && + error.status === 429) || + error instanceof OpenAI.APIUserAbortError + ) + }, + } +} diff --git a/src/utils/ai/openaiProvider.ts b/src/utils/ai/openaiProvider.ts index dc881d823..7450c69a2 100644 --- a/src/utils/ai/openaiProvider.ts +++ b/src/utils/ai/openaiProvider.ts @@ -12,6 +12,7 @@ import type { } from "../aiAssistant" import { AIOperationStatus } from "../../providers/AIStatusProvider" import { getModelProps } from "./settings" +import type { ProviderId } from "./settings" import type { AIProvider, FlowConfig, @@ -182,7 +183,27 @@ function getOpenAIText(response: OpenAI.Responses.Response): { let tiktokenEncoder: Tiktoken | null = null -export function createOpenAIProvider(apiKey: string): AIProvider { +function toResponsesAPIProps(model: string): { + model: string + reasoning?: OpenAI.Reasoning +} { + const props = getModelProps(model) + return { + model: props.model, + ...(props.reasoningEffort + ? { + reasoning: { + effort: props.reasoningEffort as OpenAI.ReasoningEffort, + }, + } + : {}), + } +} + +export function createOpenAIProvider( + apiKey: string, + providerId: ProviderId = "openai", +): AIProvider { const openai = new OpenAI({ apiKey, dangerouslyAllowBrowser: true, @@ -191,7 +212,7 @@ export function createOpenAIProvider(apiKey: string): AIProvider { const contextWindow = 400_000 return { - id: "openai", + id: providerId, contextWindow, async executeFlow({ @@ -235,7 +256,7 @@ export function createOpenAIProvider(apiKey: string): AIProvider { let totalOutputTokens = 0 const requestParams = { - ...getModelProps(model), + ...toResponsesAPIProps(model), instructions: config.systemInstructions, input, tools: openaiTools, @@ -293,7 +314,7 @@ export function createOpenAIProvider(apiKey: string): AIProvider { }) } const loopRequestParams = { - ...getModelProps(model), + ...toResponsesAPIProps(model), instructions: config.systemInstructions, input, tools: openaiTools, @@ -366,7 +387,7 @@ export function createOpenAIProvider(apiKey: string): AIProvider { async generateTitle({ model, prompt, responseFormat }) { try { const response = await openai.responses.create({ - ...getModelProps(model), + ...toResponsesAPIProps(model), input: [{ role: "user", content: prompt }], text: toResponseTextConfig(responseFormat), max_output_tokens: 100, @@ -380,7 +401,7 @@ export function createOpenAIProvider(apiKey: string): AIProvider { async generateSummary({ model, systemPrompt, userMessage }) { const response = await openai.responses.create({ - ...getModelProps(model), + ...toResponsesAPIProps(model), instructions: systemPrompt, input: userMessage, }) @@ -394,7 +415,7 @@ export function createOpenAIProvider(apiKey: string): AIProvider { dangerouslyAllowBrowser: true, }) await testClient.responses.create({ - model: getModelProps(model).model, + model: getModelProps(model).model, // testConnection only needs model name input: [{ role: "user", content: "ping" }], max_output_tokens: 16, }) diff --git a/src/utils/ai/registry.ts b/src/utils/ai/registry.ts index 812d2d323..0b4f6e5c1 100644 --- a/src/utils/ai/registry.ts +++ b/src/utils/ai/registry.ts @@ -1,18 +1,31 @@ import type { AIProvider } from "./types" import { createOpenAIProvider } from "./openaiProvider" +import { createOpenAIChatCompletionsProvider } from "./openaiChatCompletionsProvider" import { createAnthropicProvider } from "./anthropicProvider" -import type { Provider } from "./settings" +import { PROVIDER_TYPE } from "./settings" +import type { ProviderId, ProviderType } from "./settings" export function createProvider( - providerId: Provider, + providerId: ProviderId, apiKey: string, ): AIProvider { - switch (providerId) { + const providerType = PROVIDER_TYPE[providerId] + return createProviderByType(providerType, providerId, apiKey) +} + +export function createProviderByType( + providerType: ProviderType, + providerId: ProviderId, + apiKey: string, +): AIProvider { + switch (providerType) { case "openai": - return createOpenAIProvider(apiKey) + return createOpenAIProvider(apiKey, providerId) + case "openai-chat-completions": + return createOpenAIChatCompletionsProvider(apiKey, providerId) case "anthropic": - return createAnthropicProvider(apiKey) + return createAnthropicProvider(apiKey, providerId) default: - throw new Error(`Unknown provider: ${providerId}`) + throw new Error(`Unknown provider type: ${providerType}`) } } diff --git a/src/utils/ai/settings.test.ts b/src/utils/ai/settings.test.ts new file mode 100644 index 000000000..fdbe5b7dc --- /dev/null +++ b/src/utils/ai/settings.test.ts @@ -0,0 +1,375 @@ +import { describe, it, expect, afterEach } from "vitest" +import { reconcileSettings, getSelectedModel, MODEL_OPTIONS } from "./settings" +import type { ModelOption } from "./settings" + +import type { AiAssistantSettings } from "../../providers/LocalStorageProvider/types" + +const makeSettings = ( + overrides: Partial = {}, +): AiAssistantSettings => ({ + providers: {}, + ...overrides, +}) + +describe("reconcileSettings", () => { + it("removes stale model IDs from enabledModels", () => { + const settings = makeSettings({ + providers: { + openai: { + apiKey: "sk-test", + enabledModels: ["gpt-5-mini", "removed-model", "also-removed"], + grantSchemaAccess: false, + }, + }, + }) + const result = reconcileSettings(settings) + expect(result.providers.openai!.enabledModels).toEqual(["gpt-5-mini"]) + }) + + it("does not add defaultEnabled models when user has valid models", () => { + const settings = makeSettings({ + providers: { + anthropic: { + apiKey: "sk-test", + enabledModels: ["claude-sonnet-4-5"], + grantSchemaAccess: false, + }, + }, + }) + const result = reconcileSettings(settings) + expect(result.providers.anthropic!.enabledModels).toEqual([ + "claude-sonnet-4-5", + ]) + }) + + it("leaves enabledModels empty when all previous models were removed", () => { + const settings = makeSettings({ + providers: { + anthropic: { + apiKey: "sk-test", + enabledModels: ["removed-model-1", "removed-model-2"], + grantSchemaAccess: false, + }, + }, + }) + const result = reconcileSettings(settings) + expect(result.providers.anthropic!.enabledModels).toEqual([]) + }) + + it("does not add defaults for unconfigured providers", () => { + const settings = makeSettings({ + providers: { + anthropic: { + apiKey: "sk-test", + enabledModels: ["claude-sonnet-4-5"], + grantSchemaAccess: false, + }, + }, + }) + const result = reconcileSettings(settings) + expect(result.providers.openai).toBeUndefined() + }) + + it("is idempotent", () => { + const settings = makeSettings({ + selectedModel: "claude-sonnet-4-5", + providers: { + anthropic: { + apiKey: "sk-test", + enabledModels: ["claude-sonnet-4-5", "stale-model"], + grantSchemaAccess: true, + }, + openai: { + apiKey: "sk-test", + enabledModels: ["gpt-5-mini"], + grantSchemaAccess: false, + }, + }, + }) + const once = reconcileSettings(settings) + const twice = reconcileSettings(once) + expect(twice).toEqual(once) + }) + + it("preserves unknown fields (forward compat)", () => { + const settings = makeSettings({ + providers: { + anthropic: { + apiKey: "sk-test", + enabledModels: ["claude-sonnet-4-5"], + grantSchemaAccess: false, + }, + }, + }) + const settingsWithFutureField = settings as unknown as Record< + string, + string + > + settingsWithFutureField.futureField = "preserved" + const result = reconcileSettings(settings) + expect((result as unknown as Record).futureField).toBe( + "preserved", + ) + }) + + it("clears selectedModel if not in any enabledModels", () => { + const settings = makeSettings({ + selectedModel: "removed-model", + providers: { + openai: { + apiKey: "sk-test", + enabledModels: ["gpt-5-mini"], + grantSchemaAccess: false, + }, + }, + }) + const result = reconcileSettings(settings) + expect(result.selectedModel).toEqual("gpt-5-mini") + }) + + it("preserves selectedModel if it is in enabledModels", () => { + const settings = makeSettings({ + selectedModel: "gpt-5-mini", + providers: { + openai: { + apiKey: "sk-test", + enabledModels: ["gpt-5-mini"], + grantSchemaAccess: false, + }, + }, + }) + const result = reconcileSettings(settings) + expect(result.selectedModel).toBe("gpt-5-mini") + }) + + it("handles empty providers gracefully", () => { + const settings = makeSettings({ providers: {} }) + const result = reconcileSettings(settings) + expect(result.providers).toEqual({}) + }) + + it("does not mutate the input settings", () => { + const settings = makeSettings({ + providers: { + openai: { + apiKey: "sk-test", + enabledModels: ["gpt-5-mini", "stale-model"], + grantSchemaAccess: false, + }, + }, + }) + const originalModels = [...settings.providers.openai!.enabledModels] + reconcileSettings(settings) + expect(settings.providers.openai!.enabledModels).toEqual(originalModels) + }) +}) + +describe("getSelectedModel", () => { + it("returns selectedModel when it is in enabledModels", () => { + const settings = makeSettings({ + selectedModel: "gpt-5-mini", + providers: { + openai: { + apiKey: "sk-test", + enabledModels: ["gpt-5-mini", "gpt-5"], + grantSchemaAccess: false, + }, + }, + }) + expect(getSelectedModel(settings)).toBe("gpt-5-mini") + }) + + it("does not return selectedModel if not in enabledModels", () => { + const settings = makeSettings({ + selectedModel: "claude-sonnet-4-5", + providers: { + openai: { + apiKey: "sk-test", + enabledModels: ["gpt-5-mini"], + grantSchemaAccess: false, + }, + }, + }) + expect(getSelectedModel(settings)).not.toBe("claude-sonnet-4-5") + expect(getSelectedModel(settings)).toBe("gpt-5-mini") + }) + + it("returns null when no models are enabled", () => { + const settings = makeSettings({ providers: {} }) + expect(getSelectedModel(settings)).toBeNull() + }) +}) + +/** + * Simulates version upgrades by temporarily replacing MODEL_OPTIONS contents. + * Tests verify that user settings from a previous version are handled correctly + * when the app is updated with a different model list. + */ +describe("version compatibility scenarios", () => { + let originalOptions: ModelOption[] + + function setModelOptions(options: ModelOption[]) { + originalOptions = [...MODEL_OPTIONS] + MODEL_OPTIONS.length = 0 + MODEL_OPTIONS.push(...options) + } + + afterEach(() => { + MODEL_OPTIONS.length = 0 + MODEL_OPTIONS.push(...originalOptions) + }) + + it("upgrade: model removed, selectedModel was that model", () => { + // v1: user had model-A and model-B, selected model-A + setModelOptions([ + { label: "A", value: "model-a", provider: "openai" }, + { label: "B", value: "model-b", provider: "openai" }, + ]) + + const v1Settings = makeSettings({ + selectedModel: "model-a", + providers: { + openai: { + apiKey: "sk-test", + enabledModels: ["model-a", "model-b"], + grantSchemaAccess: false, + }, + }, + }) + + // v2: model-A removed, model-C added + setModelOptions([ + { label: "B", value: "model-b", provider: "openai" }, + { + label: "C", + value: "model-c", + provider: "openai", + defaultEnabled: true, + }, + ]) + + const reconciled = reconcileSettings(v1Settings) + expect(reconciled.providers.openai!.enabledModels).toEqual(["model-b"]) + expect(reconciled.selectedModel).toBe("model-b") + }) + + it("upgrade: all models removed for a provider", () => { + setModelOptions([{ label: "A", value: "model-a", provider: "openai" }]) + + const v1Settings = makeSettings({ + selectedModel: "model-a", + providers: { + openai: { + apiKey: "sk-test", + enabledModels: ["model-a"], + grantSchemaAccess: false, + }, + }, + }) + + // v2: provider's models completely replaced + setModelOptions([ + { + label: "X", + value: "model-x", + provider: "openai", + defaultEnabled: true, + }, + { label: "Y", value: "model-y", provider: "openai" }, + ]) + + const reconciled = reconcileSettings(v1Settings) + // all old models gone, empty list — user must re-enable in settings + expect(reconciled.providers.openai!.enabledModels).toEqual([]) + expect(reconciled.selectedModel).toBeUndefined() + expect(getSelectedModel(reconciled)).toBeNull() + }) + + it("upgrade: new models added, user keeps their selection", () => { + setModelOptions([ + { label: "A", value: "model-a", provider: "anthropic", default: true }, + { label: "B", value: "model-b", provider: "anthropic" }, + ]) + + const v1Settings = makeSettings({ + selectedModel: "model-b", + providers: { + anthropic: { + apiKey: "sk-test", + enabledModels: ["model-a", "model-b"], + grantSchemaAccess: true, + }, + }, + }) + + // v2: model-C added + setModelOptions([ + { label: "A", value: "model-a", provider: "anthropic", default: true }, + { label: "B", value: "model-b", provider: "anthropic" }, + { + label: "C", + value: "model-c", + provider: "anthropic", + defaultEnabled: true, + }, + ]) + + const reconciled = reconcileSettings(v1Settings) + // existing models preserved, new model NOT auto-added + expect(reconciled.providers.anthropic!.enabledModels).toEqual([ + "model-a", + "model-b", + ]) + expect(reconciled.selectedModel).toBe("model-b") + expect(getSelectedModel(reconciled)).toBe("model-b") + }) + + it("upgrade: selected model survives but some enabled models removed", () => { + setModelOptions([ + { label: "A", value: "model-a", provider: "openai" }, + { label: "B", value: "model-b", provider: "openai" }, + { label: "C", value: "model-c", provider: "openai" }, + ]) + + const v1Settings = makeSettings({ + selectedModel: "model-b", + providers: { + openai: { + apiKey: "sk-test", + enabledModels: ["model-a", "model-b", "model-c"], + grantSchemaAccess: false, + }, + }, + }) + + // v2: model-A and model-C removed + setModelOptions([ + { label: "B", value: "model-b", provider: "openai" }, + { label: "D", value: "model-d", provider: "openai" }, + ]) + + const reconciled = reconcileSettings(v1Settings) + expect(reconciled.providers.openai!.enabledModels).toEqual(["model-b"]) + expect(reconciled.selectedModel).toBe("model-b") + expect(getSelectedModel(reconciled)).toBe("model-b") + }) + + it("downgrade: user has models from a newer version", () => { + setModelOptions([{ label: "A", value: "model-a", provider: "openai" }]) + + const futureSettings = makeSettings({ + selectedModel: "model-future", + providers: { + openai: { + apiKey: "sk-test", + enabledModels: ["model-a", "model-future"], + grantSchemaAccess: false, + }, + }, + }) + + const reconciled = reconcileSettings(futureSettings) + expect(reconciled.providers.openai!.enabledModels).toEqual(["model-a"]) + expect(reconciled.selectedModel).toBe("model-a") + }) +}) diff --git a/src/utils/ai/settings.ts b/src/utils/ai/settings.ts index 7972d7253..81290b3a1 100644 --- a/src/utils/ai/settings.ts +++ b/src/utils/ai/settings.ts @@ -1,13 +1,21 @@ import type { AiAssistantSettings } from "../../providers/LocalStorageProvider/types" -type ReasoningEffort = "high" | "medium" | "low" +export type ProviderType = "anthropic" | "openai" | "openai-chat-completions" -export type Provider = "anthropic" | "openai" +export type ProviderId = "anthropic" | "openai" + +/** @deprecated Use ProviderId instead */ +export type Provider = ProviderId + +export const PROVIDER_TYPE: Record = { + anthropic: "anthropic", + openai: "openai", +} export type ModelOption = { label: string value: string - provider: Provider + provider: ProviderId isSlow?: boolean isTestModel?: boolean default?: boolean @@ -16,22 +24,22 @@ export type ModelOption = { export const MODEL_OPTIONS: ModelOption[] = [ { - label: "Claude Sonnet 4.5", - value: "claude-sonnet-4-5", + label: "Claude Opus 4.6", + value: "claude-opus-4-6", provider: "anthropic", - default: true, + isSlow: true, defaultEnabled: true, }, { - label: "Claude Opus 4.5", - value: "claude-opus-4-5", + label: "Claude Sonnet 4.6", + value: "claude-sonnet-4-6", provider: "anthropic", - isSlow: true, + default: true, defaultEnabled: true, }, { - label: "Claude Sonnet 4", - value: "claude-sonnet-4", + label: "Claude Sonnet 4.5", + value: "claude-sonnet-4-5", provider: "anthropic", }, { @@ -41,50 +49,52 @@ export const MODEL_OPTIONS: ModelOption[] = [ isTestModel: true, }, { - label: "GPT-5.1 (High Reasoning)", - value: "gpt-5.1@reasoning=high", + label: "GPT-5.4 (High Reasoning)", + value: "gpt-5.4@reasoning=high", provider: "openai", - isSlow: true, }, { - label: "GPT-5.1 (Medium Reasoning)", - value: "gpt-5.1@reasoning=medium", + label: "GPT-5.4 (Medium Reasoning)", + value: "gpt-5.4@reasoning=medium", provider: "openai", - isSlow: true, defaultEnabled: true, }, { - label: "GPT-5.1 (No Reasoning)", - value: "gpt-5.1", + label: "GPT-5.4 (Low Reasoning)", + value: "gpt-5.4@reasoning=low", provider: "openai", defaultEnabled: true, - isTestModel: true, + default: true, }, { - label: "GPT-5", - value: "gpt-5", + label: "GPT-5 mini", + value: "gpt-5-mini", provider: "openai", defaultEnabled: true, }, { - label: "GPT-5 mini", - value: "gpt-5-mini", + label: "GPT-5 nano", + value: "gpt-5-nano", provider: "openai", - default: true, defaultEnabled: true, + isTestModel: true, }, ] -export const providerForModel = (model: ModelOption["value"]): Provider => { - return MODEL_OPTIONS.find((m) => m.value === model)!.provider +export type ReasoningEffort = "high" | "medium" | "low" + +export type ModelProps = { + model: string + reasoningEffort?: ReasoningEffort } -export const getModelProps = ( +export const providerForModel = ( model: ModelOption["value"], -): { - model: string - reasoning?: { effort: ReasoningEffort } -} => { +): ProviderId | null => { + return MODEL_OPTIONS.find((m) => m.value === model)?.provider ?? null +} + +export const getModelProps = (model: ModelOption["value"]): ModelProps => { const modelOption = MODEL_OPTIONS.find((m) => m.value === model) if (!modelOption) { return { model } @@ -99,15 +109,15 @@ export const getModelProps = ( if (paramName === "reasoning" && paramValue) { return { model: modelName, - reasoning: { effort: paramValue as ReasoningEffort }, + reasoningEffort: paramValue as ReasoningEffort, } } } return { model: modelName } } -export const getAllProviders = (): Provider[] => { - const providers = new Set() +export const getAllProviders = (): ProviderId[] => { + const providers = new Set() MODEL_OPTIONS.forEach((model) => { providers.add(model.provider) }) @@ -117,26 +127,45 @@ export const getAllProviders = (): Provider[] => { export const getSelectedModel = ( settings: AiAssistantSettings, ): string | null => { + const enabledModels = getAllEnabledModels(settings) const selectedModel = settings.selectedModel if ( selectedModel && typeof selectedModel === "string" && - MODEL_OPTIONS.find((m) => m.value === selectedModel) + enabledModels.includes(selectedModel) ) { return selectedModel } - return MODEL_OPTIONS.find((m) => m.default)?.value ?? null + // Fall back to first enabled default model, then first enabled model + return ( + enabledModels.find( + (id) => MODEL_OPTIONS.find((m) => m.value === id)?.default, + ) ?? + enabledModels[0] ?? + null + ) +} + +const getAllEnabledModels = (settings: AiAssistantSettings): string[] => { + const models: string[] = [] + for (const provider of getAllProviders()) { + const providerModels = settings.providers?.[provider]?.enabledModels + if (providerModels) { + models.push(...providerModels) + } + } + return models } export const getNextModel = ( currentModel: string | undefined, - enabledModels: Record, + enabledModels: Record, ): string | null => { let nextModel: string | null | undefined = currentModel const modelProvider = currentModel ? providerForModel(currentModel) : null - if (modelProvider && enabledModels[modelProvider].length > 0) { + if (modelProvider && enabledModels[modelProvider]?.length > 0) { // Current model is still enabled, so we can use it if (currentModel && enabledModels[modelProvider].includes(currentModel)) { return currentModel @@ -175,6 +204,40 @@ export const canUseAiAssistant = (settings: AiAssistantSettings): boolean => { return isAiAssistantConfigured(settings) && !!settings.selectedModel } +/** + * Reconciles persisted AI assistant settings against the current MODEL_OPTIONS. + * Removes model IDs not present in MODEL_OPTIONS from enabledModels. + * + * Pure function — does not write to localStorage. + * Idempotent: applying it multiple times produces the same result. + */ +export const reconcileSettings = ( + settings: AiAssistantSettings, +): AiAssistantSettings => { + const validModelIds = new Set(MODEL_OPTIONS.map((m) => m.value)) + const result = { + ...settings, + providers: { ...settings.providers }, + } + + for (const providerKey of Object.keys(result.providers) as ProviderId[]) { + const providerSettings = result.providers[providerKey] + if (!providerSettings?.enabledModels) continue + + const models = providerSettings.enabledModels.filter((id) => + validModelIds.has(id), + ) + result.providers[providerKey] = { + ...providerSettings, + enabledModels: models, + } + } + + result.selectedModel = getSelectedModel(result) ?? undefined + + return result +} + export const hasSchemaAccess = (settings: AiAssistantSettings): boolean => { const selectedModel = getSelectedModel(settings) if (!selectedModel) return false diff --git a/src/utils/ai/types.ts b/src/utils/ai/types.ts index 2a89109ae..109db4a4b 100644 --- a/src/utils/ai/types.ts +++ b/src/utils/ai/types.ts @@ -4,7 +4,7 @@ import type { StatusCallback, StreamingCallback, } from "../aiAssistant" -import type { Provider } from "./settings" +import type { ProviderId } from "./settings" export interface ToolDefinition { name: string @@ -31,7 +31,7 @@ export interface FlowConfig { } export interface AIProvider { - readonly id: Provider + readonly id: ProviderId readonly contextWindow: number executeFlow(params: { diff --git a/src/utils/executeAIFlow.ts b/src/utils/executeAIFlow.ts index 1a7e84e23..5df165af3 100644 --- a/src/utils/executeAIFlow.ts +++ b/src/utils/executeAIFlow.ts @@ -359,7 +359,8 @@ async function generateChatTitleIfNeeded( return } - const provider = providerForModel(config.settings.model) + const provider = providerForModel(config.settings.model)! + const testModel = MODEL_OPTIONS.find( (m) => m.isTestModel && m.provider === provider, ) @@ -447,7 +448,7 @@ export async function executeAIFlow( eventBus.publish(EventType.AI_QUERY_HIGHLIGHT, conversationId) } - const provider = providerForModel(settings.model) + const provider = providerForModel(settings.model)! const providerSettings: ActiveProviderSettings = { model: settings.model, provider, From 557a1b30d33a3da97e28ca8d2225869ef5f16240 Mon Sep 17 00:00:00 2001 From: emrberk Date: Fri, 6 Mar 2026 18:29:19 +0300 Subject: [PATCH 05/25] provider -> providerId --- .../SetupAIAssistant/ConfigurationModal.tsx | 20 ++++---- .../SetupAIAssistant/SettingsModal.tsx | 48 ++++++++++--------- src/utils/ai/index.ts | 7 +-- src/utils/ai/settings.ts | 3 -- src/utils/aiAssistant.ts | 6 +-- 5 files changed, 41 insertions(+), 43 deletions(-) diff --git a/src/components/SetupAIAssistant/ConfigurationModal.tsx b/src/components/SetupAIAssistant/ConfigurationModal.tsx index 4b44504ab..8417e8176 100644 --- a/src/components/SetupAIAssistant/ConfigurationModal.tsx +++ b/src/components/SetupAIAssistant/ConfigurationModal.tsx @@ -11,7 +11,11 @@ import { useLocalStorage } from "../../providers/LocalStorageProvider" import { testApiKey } from "../../utils/aiAssistant" import { StoreKey } from "../../utils/localStorage/types" import { toast } from "../Toast" -import { MODEL_OPTIONS, type ModelOption, type Provider } from "../../utils/ai" +import { + MODEL_OPTIONS, + type ModelOption, + type ProviderId, +} from "../../utils/ai" import { useModalNavigation } from "../MultiStepModal" import { OpenAIIcon } from "./OpenAIIcon" import { AnthropicIcon } from "./AnthropicIcon" @@ -407,22 +411,22 @@ type ConfigurationModalProps = { onOpenChange?: (open: boolean) => void } -const getProviderName = (provider: Provider | null) => { +const getProviderName = (provider: ProviderId | null) => { if (!provider) return "" return provider === "openai" ? "OpenAI" : "Anthropic" } type StepOneContentProps = { - selectedProvider: Provider | null + selectedProvider: ProviderId | null apiKey: string error: string | null providerName: string - onProviderSelect: (provider: Provider) => void + onProviderSelect: (provider: ProviderId) => void onApiKeyChange: (value: string) => void } type StepTwoContentProps = { - selectedProvider: Provider | null + selectedProvider: ProviderId | null enabledModels: string[] grantSchemaAccess: boolean modelsByProvider: { anthropic: ModelOption[]; openai: ModelOption[] } @@ -567,7 +571,7 @@ const StepTwoContent = ({ const handleClose: () => void = navigation.handleClose const currentProvider = selectedProvider - const getModelsForProvider = (provider: Provider) => { + const getModelsForProvider = (provider: ProviderId) => { return provider === "openai" ? modelsByProvider.openai : modelsByProvider.anthropic @@ -696,7 +700,7 @@ export const ConfigurationModal = ({ onOpenChange, }: ConfigurationModalProps) => { const { aiAssistantSettings, updateSettings } = useLocalStorage() - const [selectedProvider, setSelectedProvider] = useState( + const [selectedProvider, setSelectedProvider] = useState( null, ) const providerName = useMemo( @@ -722,7 +726,7 @@ export const ConfigurationModal = ({ return { anthropic, openai } }, []) - const handleProviderSelect = useCallback((provider: Provider) => { + const handleProviderSelect = useCallback((provider: ProviderId) => { setSelectedProvider(provider) setError(null) setApiKey("") diff --git a/src/components/SetupAIAssistant/SettingsModal.tsx b/src/components/SetupAIAssistant/SettingsModal.tsx index 860ccfd2a..7e2172edb 100644 --- a/src/components/SetupAIAssistant/SettingsModal.tsx +++ b/src/components/SetupAIAssistant/SettingsModal.tsx @@ -22,7 +22,7 @@ import { getAllProviders, MODEL_OPTIONS, type ModelOption, - type Provider, + type ProviderId, getNextModel, } from "../../utils/ai" import type { AiAssistantSettings } from "../../providers/LocalStorageProvider/types" @@ -500,16 +500,18 @@ type SettingsModalProps = { onOpenChange?: (open: boolean) => void } -const getProviderName = (provider: Provider) => { +const getProviderName = (provider: ProviderId) => { return provider === "openai" ? "OpenAI" : "Anthropic" } -const getModelsForProvider = (provider: Provider): ModelOption[] => { +const getModelsForProvider = (provider: ProviderId): ModelOption[] => { return MODEL_OPTIONS.filter((m) => m.provider === provider) } -const getProvidersWithApiKeys = (settings: AiAssistantSettings): Provider[] => { - const providers: Provider[] = [] +const getProvidersWithApiKeys = ( + settings: AiAssistantSettings, +): ProviderId[] => { + const providers: ProviderId[] = [] const allProviders = getAllProviders() for (const provider of allProviders) { if (settings.providers?.[provider]?.apiKey) { @@ -523,11 +525,11 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { const { aiAssistantSettings, updateSettings } = useLocalStorage() const initializeProviderState = useCallback( ( - getValue: (provider: Provider) => T, + getValue: (provider: ProviderId) => T, defaultValue: T, - ): Record => { + ): Record => { const allProviders = getAllProviders() - const state = {} as Record + const state = {} as Record for (const provider of allProviders) { state[provider] = getValue(provider) ?? defaultValue } @@ -536,18 +538,18 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { [], ) - const [selectedProvider, setSelectedProvider] = useState(() => { + const [selectedProvider, setSelectedProvider] = useState(() => { const providersWithKeys = getProvidersWithApiKeys(aiAssistantSettings) return providersWithKeys[0] || getAllProviders()[0] }) - const [apiKeys, setApiKeys] = useState>(() => + const [apiKeys, setApiKeys] = useState>(() => initializeProviderState( (provider) => aiAssistantSettings.providers?.[provider]?.apiKey || "", "", ), ) const [enabledModels, setEnabledModels] = useState< - Record + Record >(() => initializeProviderState( (provider) => @@ -556,7 +558,7 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { ), ) const [grantSchemaAccess, setGrantSchemaAccess] = useState< - Record + Record >(() => initializeProviderState( (provider) => @@ -565,7 +567,7 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { ), ) const [validatedApiKeys, setValidatedApiKeys] = useState< - Record + Record >(() => initializeProviderState( (provider) => !!aiAssistantSettings.providers?.[provider]?.apiKey, @@ -573,23 +575,23 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { ), ) const [validationState, setValidationState] = useState< - Record + Record >(() => initializeProviderState(() => "idle" as const, "idle" as const)) const [validationErrors, setValidationErrors] = useState< - Record + Record >(() => initializeProviderState(() => null, null)) const [isInputFocused, setIsInputFocused] = useState< - Record + Record >(() => initializeProviderState(() => false, false)) const inputRef = useRef(null) - const handleProviderSelect = useCallback((provider: Provider) => { + const handleProviderSelect = useCallback((provider: ProviderId) => { setSelectedProvider(provider) setValidationErrors((prev) => ({ ...prev, [provider]: null })) }, []) const handleApiKeyChange = useCallback( - (provider: Provider, value: string) => { + (provider: ProviderId, value: string) => { setApiKeys((prev) => ({ ...prev, [provider]: value })) setValidationErrors((prev) => ({ ...prev, [provider]: null })) // If API key changes, mark as not validated @@ -601,7 +603,7 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { ) const handleValidateApiKey = useCallback( - async (provider: Provider) => { + async (provider: ProviderId) => { const apiKey = apiKeys[provider] if (!apiKey) { setValidationErrors((prev) => ({ @@ -656,7 +658,7 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { [apiKeys], ) - const handleRemoveApiKey = useCallback((provider: Provider) => { + const handleRemoveApiKey = useCallback((provider: ProviderId) => { // Remove API key from local state only // Settings will be persisted when Save Settings is clicked setApiKeys((prev) => ({ ...prev, [provider]: "" })) @@ -667,7 +669,7 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { }, []) const handleModelToggle = useCallback( - (provider: Provider, modelValue: string) => { + (provider: ProviderId, modelValue: string) => { setEnabledModels((prev) => { const current = prev[provider] const isEnabled = current.includes(modelValue) @@ -683,7 +685,7 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { ) const handleSchemaAccessChange = useCallback( - (provider: Provider, checked: boolean) => { + (provider: ProviderId, checked: boolean) => { setGrantSchemaAccess((prev) => ({ ...prev, [provider]: checked })) }, [], @@ -751,7 +753,7 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { const allProviders = useMemo(() => getAllProviders(), []) - const renderProviderIcon = (provider: Provider, isActive: boolean) => { + const renderProviderIcon = (provider: ProviderId, isActive: boolean) => { const color = isActive ? "#f8f8f2" : "#9ca3af" if (provider === "openai") { return diff --git a/src/utils/ai/index.ts b/src/utils/ai/index.ts index 2cee7f130..c886a9d38 100644 --- a/src/utils/ai/index.ts +++ b/src/utils/ai/index.ts @@ -38,10 +38,5 @@ export { canUseAiAssistant, hasSchemaAccess, } from "./settings" -export type { - ProviderId, - Provider, - ProviderType, - ModelOption, -} from "./settings" +export type { ProviderId, ProviderType, ModelOption } from "./settings" export { PROVIDER_TYPE } from "./settings" diff --git a/src/utils/ai/settings.ts b/src/utils/ai/settings.ts index 81290b3a1..e5f43283a 100644 --- a/src/utils/ai/settings.ts +++ b/src/utils/ai/settings.ts @@ -4,9 +4,6 @@ export type ProviderType = "anthropic" | "openai" | "openai-chat-completions" export type ProviderId = "anthropic" | "openai" -/** @deprecated Use ProviderId instead */ -export type Provider = ProviderId - export const PROVIDER_TYPE: Record = { anthropic: "anthropic", openai: "openai", diff --git a/src/utils/aiAssistant.ts b/src/utils/aiAssistant.ts index dd3383755..090e920fa 100644 --- a/src/utils/aiAssistant.ts +++ b/src/utils/aiAssistant.ts @@ -1,6 +1,6 @@ import { Client } from "./questdb/client" import { Type, Table } from "./questdb/types" -import type { Provider } from "./ai" +import type { ProviderId } from "./ai" import { formatSql } from "./formatSql" import { AIOperationStatus, StatusArgs } from "../providers/AIStatusProvider" import type { @@ -22,7 +22,7 @@ import type { AIProvider } from "./ai" export type ActiveProviderSettings = { model: string - provider: Provider + provider: ProviderId apiKey: string } @@ -304,7 +304,7 @@ const tryWithRetries = async ( export const testApiKey = async ( apiKey: string, model: string, - providerId: Provider, + providerId: ProviderId, ): Promise<{ valid: boolean; error?: string }> => { const provider = createProvider(providerId, apiKey) return provider.testConnection({ apiKey, model }) From 30886ac376e11200258b9555fbfe78f31c5274e0 Mon Sep 17 00:00:00 2001 From: emrberk Date: Fri, 6 Mar 2026 19:25:49 +0300 Subject: [PATCH 06/25] refactor: dynamic provider system with PlugsIcon for custom providers - Replace PROVIDER_TYPE with PROVIDERS map containing type + name - Centralize getProviderName, remove local duplicates - Make provider lists dynamic (getAllProviders) instead of hardcoded - Use PlugsIcon from @phosphor-icons/react for custom provider fallback - Support comma-separated extra params in model value parsing - Simplify hasSchemaAccess with providerForModel lookup Co-Authored-By: Claude Opus 4.6 --- .../SetupAIAssistant/ConfigurationModal.tsx | 35 +++++++--------- .../SetupAIAssistant/ModelDropdown.tsx | 13 ++++-- .../SetupAIAssistant/SettingsModal.tsx | 16 ++++---- src/providers/AIStatusProvider/index.tsx | 11 ++--- src/utils/ai/index.ts | 3 +- src/utils/ai/registry.ts | 4 +- src/utils/ai/settings.ts | 40 ++++++++++--------- 7 files changed, 64 insertions(+), 58 deletions(-) diff --git a/src/components/SetupAIAssistant/ConfigurationModal.tsx b/src/components/SetupAIAssistant/ConfigurationModal.tsx index 8417e8176..f75a9eb53 100644 --- a/src/components/SetupAIAssistant/ConfigurationModal.tsx +++ b/src/components/SetupAIAssistant/ConfigurationModal.tsx @@ -15,11 +15,13 @@ import { MODEL_OPTIONS, type ModelOption, type ProviderId, + getProviderName, } from "../../utils/ai" import { useModalNavigation } from "../MultiStepModal" import { OpenAIIcon } from "./OpenAIIcon" import { AnthropicIcon } from "./AnthropicIcon" import { BrainIcon } from "./BrainIcon" +import { Plugs as PlugsIcon } from "@phosphor-icons/react" import { theme } from "../../theme" const ModalContent = styled.div` @@ -411,11 +413,6 @@ type ConfigurationModalProps = { onOpenChange?: (open: boolean) => void } -const getProviderName = (provider: ProviderId | null) => { - if (!provider) return "" - return provider === "openai" ? "OpenAI" : "Anthropic" -} - type StepOneContentProps = { selectedProvider: ProviderId | null apiKey: string @@ -429,7 +426,7 @@ type StepTwoContentProps = { selectedProvider: ProviderId | null enabledModels: string[] grantSchemaAccess: boolean - modelsByProvider: { anthropic: ModelOption[]; openai: ModelOption[] } + modelsByProvider: Record onModelToggle: (modelValue: string) => void onSchemaAccessChange: (checked: boolean) => void } @@ -504,7 +501,7 @@ const StepOneContent = ({ height="40" color={theme.color.foreground} /> - OpenAI + {getProviderName("openai")} - Anthropic + {getProviderName("anthropic")} @@ -572,9 +569,7 @@ const StepTwoContent = ({ const currentProvider = selectedProvider const getModelsForProvider = (provider: ProviderId) => { - return provider === "openai" - ? modelsByProvider.openai - : modelsByProvider.anthropic + return modelsByProvider[provider] || [] } return ( @@ -600,10 +595,12 @@ const StepTwoContent = ({ Enable Models - {currentProvider === "openai" ? ( + {currentProvider === "anthropic" ? ( + + ) : currentProvider === "openai" ? ( ) : ( - + )} {getProviderName(currentProvider)} @@ -714,16 +711,14 @@ export const ConfigurationModal = ({ const [grantSchemaAccess, setGrantSchemaAccess] = useState(true) const modelsByProvider = useMemo(() => { - const anthropic: ModelOption[] = [] - const openai: ModelOption[] = [] + const result: Record = {} MODEL_OPTIONS.forEach((model) => { - if (model.provider === "anthropic") { - anthropic.push(model) - } else { - openai.push(model) + if (!result[model.provider]) { + result[model.provider] = [] } + result[model.provider].push(model) }) - return { anthropic, openai } + return result }, []) const handleProviderSelect = useCallback((provider: ProviderId) => { diff --git a/src/components/SetupAIAssistant/ModelDropdown.tsx b/src/components/SetupAIAssistant/ModelDropdown.tsx index d05f4ace1..bd2650983 100644 --- a/src/components/SetupAIAssistant/ModelDropdown.tsx +++ b/src/components/SetupAIAssistant/ModelDropdown.tsx @@ -12,6 +12,7 @@ import { StoreKey } from "../../utils/localStorage/types" import { OpenAIIcon } from "./OpenAIIcon" import { AnthropicIcon } from "./AnthropicIcon" import { BrainIcon } from "./BrainIcon" +import { PlugsIcon } from "@phosphor-icons/react" import { Tooltip } from "../Tooltip" const ExpandUpDown = () => ( @@ -221,10 +222,12 @@ export const ModelDropdown = () => { ]} trigger={ - {displayModel.provider === "openai" ? ( + {displayModel.provider === "anthropic" ? ( + + ) : displayModel.provider === "openai" ? ( ) : ( - + )} {displayModel.label} @@ -246,10 +249,12 @@ export const ModelDropdown = () => { $selected={isSelected} > - {model.provider === "openai" ? ( + {model.provider === "anthropic" ? ( + + ) : model.provider === "openai" ? ( ) : ( - + )} {model.label} diff --git a/src/components/SetupAIAssistant/SettingsModal.tsx b/src/components/SetupAIAssistant/SettingsModal.tsx index 7e2172edb..69dba40a0 100644 --- a/src/components/SetupAIAssistant/SettingsModal.tsx +++ b/src/components/SetupAIAssistant/SettingsModal.tsx @@ -16,6 +16,7 @@ import { Edit } from "@styled-icons/remix-line" import { OpenAIIcon } from "./OpenAIIcon" import { AnthropicIcon } from "./AnthropicIcon" import { BrainIcon } from "./BrainIcon" +import { PlugsIcon } from "@phosphor-icons/react" import { LoadingSpinner } from "../LoadingSpinner" import { Overlay } from "../Overlay" import { @@ -24,6 +25,7 @@ import { type ModelOption, type ProviderId, getNextModel, + getProviderName, } from "../../utils/ai" import type { AiAssistantSettings } from "../../providers/LocalStorageProvider/types" import { ForwardRef } from "../ForwardRef" @@ -500,10 +502,6 @@ type SettingsModalProps = { onOpenChange?: (open: boolean) => void } -const getProviderName = (provider: ProviderId) => { - return provider === "openai" ? "OpenAI" : "Anthropic" -} - const getModelsForProvider = (provider: ProviderId): ModelOption[] => { return MODEL_OPTIONS.filter((m) => m.provider === provider) } @@ -755,10 +753,14 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { const renderProviderIcon = (provider: ProviderId, isActive: boolean) => { const color = isActive ? "#f8f8f2" : "#9ca3af" - if (provider === "openai") { - return + switch (provider) { + case "openai": + return + case "anthropic": + return + default: + return } - return } return ( diff --git a/src/providers/AIStatusProvider/index.tsx b/src/providers/AIStatusProvider/index.tsx index b4c7e8dcf..3575dcd27 100644 --- a/src/providers/AIStatusProvider/index.tsx +++ b/src/providers/AIStatusProvider/index.tsx @@ -14,6 +14,7 @@ import { hasSchemaAccess, providerForModel, canUseAiAssistant, + getAllProviders, } from "../../utils/ai" import { useAIConversation } from "../AIConversationProvider" @@ -142,11 +143,11 @@ export const AIStatusProvider: React.FC = ({ const models = useMemo(() => { const allModels: string[] = [] - const anthropicModels = - aiAssistantSettings.providers?.anthropic?.enabledModels || [] - const openaiModels = - aiAssistantSettings.providers?.openai?.enabledModels || [] - allModels.push(...anthropicModels, ...openaiModels) + for (const provider of getAllProviders()) { + const providerModels = + aiAssistantSettings.providers?.[provider]?.enabledModels || [] + allModels.push(...providerModels) + } return allModels }, [aiAssistantSettings]) diff --git a/src/utils/ai/index.ts b/src/utils/ai/index.ts index c886a9d38..7288a9da7 100644 --- a/src/utils/ai/index.ts +++ b/src/utils/ai/index.ts @@ -29,8 +29,10 @@ export { export type { HealthIssuePromptData } from "./prompts" export { MODEL_OPTIONS, + PROVIDERS, providerForModel, getModelProps, + getProviderName, getAllProviders, getSelectedModel, getNextModel, @@ -39,4 +41,3 @@ export { hasSchemaAccess, } from "./settings" export type { ProviderId, ProviderType, ModelOption } from "./settings" -export { PROVIDER_TYPE } from "./settings" diff --git a/src/utils/ai/registry.ts b/src/utils/ai/registry.ts index 0b4f6e5c1..2f8a5505f 100644 --- a/src/utils/ai/registry.ts +++ b/src/utils/ai/registry.ts @@ -2,14 +2,14 @@ import type { AIProvider } from "./types" import { createOpenAIProvider } from "./openaiProvider" import { createOpenAIChatCompletionsProvider } from "./openaiChatCompletionsProvider" import { createAnthropicProvider } from "./anthropicProvider" -import { PROVIDER_TYPE } from "./settings" +import { PROVIDERS } from "./settings" import type { ProviderId, ProviderType } from "./settings" export function createProvider( providerId: ProviderId, apiKey: string, ): AIProvider { - const providerType = PROVIDER_TYPE[providerId] + const providerType = PROVIDERS[providerId].type return createProviderByType(providerType, providerId, apiKey) } diff --git a/src/utils/ai/settings.ts b/src/utils/ai/settings.ts index e5f43283a..bf53b0e7e 100644 --- a/src/utils/ai/settings.ts +++ b/src/utils/ai/settings.ts @@ -4,9 +4,19 @@ export type ProviderType = "anthropic" | "openai" | "openai-chat-completions" export type ProviderId = "anthropic" | "openai" -export const PROVIDER_TYPE: Record = { - anthropic: "anthropic", - openai: "openai", +export type ProviderDefinition = { + type: ProviderType + name: string +} + +export const PROVIDERS: Record = { + anthropic: { type: "anthropic", name: "Anthropic" }, + openai: { type: "openai", name: "OpenAI" }, +} + +export const getProviderName = (providerId: ProviderId | null): string => { + if (!providerId) return "" + return PROVIDERS[providerId]?.name ?? providerId } export type ModelOption = { @@ -99,14 +109,14 @@ export const getModelProps = (model: ModelOption["value"]): ModelProps => { const parts = modelOption.value.split("@") const modelName = parts[0] const extraParams = parts[1] + ?.split(",") + ?.map((p) => ({ key: p.split("=")[0], value: p.split("=")[1] })) if (extraParams) { - const params = extraParams.split("=") - const paramName = params[0] - const paramValue = params[1] - if (paramName === "reasoning" && paramValue) { + const reasoningParam = extraParams.find((p) => p.key === "reasoning") + if (reasoningParam && reasoningParam.value) { return { model: modelName, - reasoningEffort: paramValue as ReasoningEffort, + reasoningEffort: reasoningParam.value as ReasoningEffort, } } } @@ -239,16 +249,8 @@ export const hasSchemaAccess = (settings: AiAssistantSettings): boolean => { const selectedModel = getSelectedModel(settings) if (!selectedModel) return false - const anthropicModels = settings.providers?.anthropic?.enabledModels || [] - const openaiModels = settings.providers?.openai?.enabledModels || [] - - if (anthropicModels.includes(selectedModel)) { - return settings.providers?.anthropic?.grantSchemaAccess === true - } - - if (openaiModels.includes(selectedModel)) { - return settings.providers?.openai?.grantSchemaAccess === true - } + const provider = providerForModel(selectedModel) + if (!provider) return false - return false + return settings.providers?.[provider]?.grantSchemaAccess === true } From 657fea97038ee4d787f89f164297a39a1e806511 Mon Sep 17 00:00:00 2001 From: emrberk Date: Sat, 7 Mar 2026 00:12:44 +0300 Subject: [PATCH 07/25] custom provider configuration support --- src/components/AIStatusIndicator/index.tsx | 9 +- src/components/ExplainQueryButton/index.tsx | 2 + src/components/FixQueryButton/index.tsx | 11 +- .../SetupAIAssistant/ConfigurationModal.tsx | 16 ++- .../SetupAIAssistant/ModelDropdown.tsx | 6 +- .../SetupAIAssistant/SettingsModal.tsx | 35 +++-- src/hooks/useAIQuickActions.ts | 5 + src/providers/AIStatusProvider/index.tsx | 11 +- src/providers/LocalStorageProvider/types.ts | 11 ++ src/scenes/Editor/AIChatWindow/index.tsx | 3 + src/utils/ai/anthropicProvider.ts | 4 +- src/utils/ai/index.ts | 13 +- src/utils/ai/openaiChatCompletionsProvider.ts | 14 +- src/utils/ai/openaiProvider.ts | 4 +- src/utils/ai/registry.ts | 34 ++++- src/utils/ai/settings.ts | 133 ++++++++++++++---- src/utils/aiAssistant.ts | 22 ++- src/utils/contextCompaction.ts | 18 +-- src/utils/executeAIFlow.ts | 24 ++-- 19 files changed, 286 insertions(+), 89 deletions(-) diff --git a/src/components/AIStatusIndicator/index.tsx b/src/components/AIStatusIndicator/index.tsx index 9b8f7286c..bd26a25c3 100644 --- a/src/components/AIStatusIndicator/index.tsx +++ b/src/components/AIStatusIndicator/index.tsx @@ -15,7 +15,7 @@ import { color } from "../../utils" import { slideAnimation } from "../Animation" import { AISparkle } from "../AISparkle" import { pinkLinearGradientHorizontal } from "../../theme" -import { MODEL_OPTIONS } from "../../utils/ai" +import { getAllModelOptions } from "../../utils/ai" import { useAIConversation } from "../../providers/AIConversationProvider" import { Button } from "../../components/Button" import { BrainIcon } from "../SetupAIAssistant/BrainIcon" @@ -310,6 +310,7 @@ export const AIStatusIndicator: React.FC = () => { currentModel, abortOperation, clearOperation, + aiAssistantSettings, } = useAIStatus() const { chatWindowState, openChatWindow } = useAIConversation() const [expanded, setExpanded] = useState(false) @@ -320,8 +321,10 @@ export const AIStatusIndicator: React.FC = () => { const activeSidebar = useSelector(selectors.console.getActiveSidebar) const statusRef = useRef(null) const hasExtendedThinking = useMemo(() => { - return MODEL_OPTIONS.find((model) => model.value === currentModel)?.isSlow - }, [currentModel]) + return getAllModelOptions(aiAssistantSettings).find( + (model) => model.value === currentModel, + )?.isSlow + }, [currentModel, aiAssistantSettings]) const operationSections = useMemo( () => buildOperationSections(currentOperation, status, true), diff --git a/src/components/ExplainQueryButton/index.tsx b/src/components/ExplainQueryButton/index.tsx index 74ffc5d25..ada476df1 100644 --- a/src/components/ExplainQueryButton/index.tsx +++ b/src/components/ExplainQueryButton/index.tsx @@ -44,6 +44,7 @@ export const ExplainQueryButton = ({ hasSchemaAccess, currentModel: currentModelValue, apiKey: apiKeyValue, + aiAssistantSettings, } = useAIStatus() const { addMessage, @@ -62,6 +63,7 @@ export const ExplainQueryButton = ({ conversationId, queryText, settings: { model: currentModel, apiKey }, + aiAssistantSettings, questClient: quest, tables, hasSchemaAccess, diff --git a/src/components/FixQueryButton/index.tsx b/src/components/FixQueryButton/index.tsx index b79b552e9..37ebe80c9 100644 --- a/src/components/FixQueryButton/index.tsx +++ b/src/components/FixQueryButton/index.tsx @@ -21,8 +21,14 @@ export const FixQueryButton = () => { const { quest } = useContext(QuestContext) const { editorRef, executionRefs } = useEditor() const tables = useSelector(selectors.query.getTables) - const { setStatus, abortController, hasSchemaAccess, currentModel, apiKey } = - useAIStatus() + const { + setStatus, + abortController, + hasSchemaAccess, + currentModel, + apiKey, + aiAssistantSettings, + } = useAIStatus() const { chatWindowState, getConversationMeta, @@ -53,6 +59,7 @@ export const FixQueryButton = () => { errorMessage, errorWord: word ?? undefined, settings: { model: currentModel!, apiKey: apiKey! }, + aiAssistantSettings, questClient: quest, tables, hasSchemaAccess, diff --git a/src/components/SetupAIAssistant/ConfigurationModal.tsx b/src/components/SetupAIAssistant/ConfigurationModal.tsx index f75a9eb53..da21d58d1 100644 --- a/src/components/SetupAIAssistant/ConfigurationModal.tsx +++ b/src/components/SetupAIAssistant/ConfigurationModal.tsx @@ -12,7 +12,7 @@ import { testApiKey } from "../../utils/aiAssistant" import { StoreKey } from "../../utils/localStorage/types" import { toast } from "../Toast" import { - MODEL_OPTIONS, + getAllModelOptions, type ModelOption, type ProviderId, getProviderName, @@ -712,7 +712,7 @@ export const ConfigurationModal = ({ const modelsByProvider = useMemo(() => { const result: Record = {} - MODEL_OPTIONS.forEach((model) => { + getAllModelOptions(aiAssistantSettings).forEach((model) => { if (!result[model.provider]) { result[model.provider] = [] } @@ -750,7 +750,9 @@ export const ConfigurationModal = ({ const selectedModel = enabledModels.find( - (m) => MODEL_OPTIONS.find((mo) => mo.value === m)?.default, + (m) => + getAllModelOptions(aiAssistantSettings).find((mo) => mo.value === m) + ?.default, ) ?? enabledModels[0] const newSettings = { @@ -789,7 +791,7 @@ export const ConfigurationModal = ({ } const testModel = - MODEL_OPTIONS.find( + getAllModelOptions(aiAssistantSettings).find( (m) => m.isTestModel && m.provider === selectedProvider, )?.value ?? modelsByProvider[selectedProvider][0].value @@ -800,9 +802,9 @@ export const ConfigurationModal = ({ setError(errorMsg) return errorMsg } - const defaultModels = MODEL_OPTIONS.filter( - (m) => m.defaultEnabled && m.provider === selectedProvider, - ).map((m) => m.value) + const defaultModels = getAllModelOptions(aiAssistantSettings) + .filter((m) => m.defaultEnabled && m.provider === selectedProvider) + .map((m) => m.value) if (defaultModels.length > 0) { setEnabledModels(defaultModels) } diff --git a/src/components/SetupAIAssistant/ModelDropdown.tsx b/src/components/SetupAIAssistant/ModelDropdown.tsx index bd2650983..3b1f77148 100644 --- a/src/components/SetupAIAssistant/ModelDropdown.tsx +++ b/src/components/SetupAIAssistant/ModelDropdown.tsx @@ -6,7 +6,7 @@ import { PopperToggle } from "../PopperToggle" import { Box } from "../Box" import { Text } from "../Text" import { useLocalStorage } from "../../providers/LocalStorageProvider" -import { MODEL_OPTIONS } from "../../utils/ai" +import { getAllModelOptions } from "../../utils/ai" import { useAIStatus } from "../../providers/AIStatusProvider" import { StoreKey } from "../../utils/localStorage/types" import { OpenAIIcon } from "./OpenAIIcon" @@ -165,10 +165,10 @@ export const ModelDropdown = () => { const [dropdownActive, setDropdownActive] = useState(false) const enabledModels = useMemo(() => { - return MODEL_OPTIONS.filter((model) => + return getAllModelOptions(aiAssistantSettings).filter((model) => enabledModelValues.includes(model.value), ) - }, [enabledModelValues]) + }, [enabledModelValues, aiAssistantSettings]) const handleModelSelect = (modelValue: string) => { updateSettings(StoreKey.AI_ASSISTANT_SETTINGS, { diff --git a/src/components/SetupAIAssistant/SettingsModal.tsx b/src/components/SetupAIAssistant/SettingsModal.tsx index 69dba40a0..07f13dc72 100644 --- a/src/components/SetupAIAssistant/SettingsModal.tsx +++ b/src/components/SetupAIAssistant/SettingsModal.tsx @@ -21,7 +21,8 @@ import { LoadingSpinner } from "../LoadingSpinner" import { Overlay } from "../Overlay" import { getAllProviders, - MODEL_OPTIONS, + getAllModelOptions, + getApiKey, type ModelOption, type ProviderId, getNextModel, @@ -502,17 +503,20 @@ type SettingsModalProps = { onOpenChange?: (open: boolean) => void } -const getModelsForProvider = (provider: ProviderId): ModelOption[] => { - return MODEL_OPTIONS.filter((m) => m.provider === provider) +const getModelsForProvider = ( + provider: ProviderId, + settings?: AiAssistantSettings, +): ModelOption[] => { + return getAllModelOptions(settings).filter((m) => m.provider === provider) } const getProvidersWithApiKeys = ( settings: AiAssistantSettings, ): ProviderId[] => { const providers: ProviderId[] = [] - const allProviders = getAllProviders() + const allProviders = getAllProviders(settings) for (const provider of allProviders) { - if (settings.providers?.[provider]?.apiKey) { + if (getApiKey(provider, settings)) { providers.push(provider) } } @@ -526,7 +530,7 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { getValue: (provider: ProviderId) => T, defaultValue: T, ): Record => { - const allProviders = getAllProviders() + const allProviders = getAllProviders(aiAssistantSettings) const state = {} as Record for (const provider of allProviders) { state[provider] = getValue(provider) ?? defaultValue @@ -538,7 +542,7 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { const [selectedProvider, setSelectedProvider] = useState(() => { const providersWithKeys = getProvidersWithApiKeys(aiAssistantSettings) - return providersWithKeys[0] || getAllProviders()[0] + return providersWithKeys[0] || getAllProviders(aiAssistantSettings)[0] }) const [apiKeys, setApiKeys] = useState>(() => initializeProviderState( @@ -614,7 +618,7 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { setValidationState((prev) => ({ ...prev, [provider]: "validating" })) setValidationErrors((prev) => ({ ...prev, [provider]: null })) - const providerModels = getModelsForProvider(provider) + const providerModels = getModelsForProvider(provider, aiAssistantSettings) if (providerModels.length === 0) { setValidationState((prev) => ({ ...prev, [provider]: "error" })) setValidationErrors((prev) => ({ @@ -636,9 +640,9 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { [provider]: result.error || "Invalid API key", })) } else { - const defaultModels = MODEL_OPTIONS.filter( - (m) => m.defaultEnabled && m.provider === provider, - ).map((m) => m.value) + const defaultModels = getAllModelOptions(aiAssistantSettings) + .filter((m) => m.defaultEnabled && m.provider === provider) + .map((m) => m.value) if (defaultModels.length > 0) { setEnabledModels((prev) => ({ ...prev, [provider]: defaultModels })) } @@ -691,7 +695,7 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { const handleSave = useCallback(() => { const updatedProviders = { ...aiAssistantSettings.providers } - const allProviders = getAllProviders() + const allProviders = getAllProviders(aiAssistantSettings) for (const provider of allProviders) { if (validatedApiKeys[provider]) { @@ -740,7 +744,7 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { const maskInput = !!(currentProviderApiKey && !currentProviderIsFocused) const modelsForProvider = useMemo( - () => getModelsForProvider(selectedProvider), + () => getModelsForProvider(selectedProvider, aiAssistantSettings), [selectedProvider], ) @@ -749,7 +753,10 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { [enabledModels, selectedProvider], ) - const allProviders = useMemo(() => getAllProviders(), []) + const allProviders = useMemo( + () => getAllProviders(aiAssistantSettings), + [aiAssistantSettings], + ) const renderProviderIcon = (provider: ProviderId, isActive: boolean) => { const color = isActive ? "#f8f8f2" : "#9ca3af" diff --git a/src/hooks/useAIQuickActions.ts b/src/hooks/useAIQuickActions.ts index 1fdd31f29..7c2c1aab1 100644 --- a/src/hooks/useAIQuickActions.ts +++ b/src/hooks/useAIQuickActions.ts @@ -49,6 +49,7 @@ export const useAIQuickActions = () => { hasSchemaAccess, currentModel, apiKey, + aiAssistantSettings, } = useAIStatus() const { @@ -130,6 +131,7 @@ export const useAIQuickActions = () => { ...schemaDisplayData, }, settings: { model: currentModel, apiKey }, + aiAssistantSettings, questClient: quest, tables, hasSchemaAccess, @@ -171,6 +173,7 @@ export const useAIQuickActions = () => { ...schemaDisplayData, }, settings: { model: currentModel, apiKey }, + aiAssistantSettings, questClient: quest, tables, hasSchemaAccess, @@ -248,6 +251,7 @@ export const useAIQuickActions = () => { monitoringDocs, trendSamples, settings: { model: currentModel, apiKey }, + aiAssistantSettings, questClient: quest, tables, hasSchemaAccess, @@ -295,6 +299,7 @@ export const useAIQuickActions = () => { monitoringDocs, trendSamples, settings: { model: currentModel, apiKey }, + aiAssistantSettings, questClient: quest, tables, hasSchemaAccess, diff --git a/src/providers/AIStatusProvider/index.tsx b/src/providers/AIStatusProvider/index.tsx index 3575dcd27..e736abad1 100644 --- a/src/providers/AIStatusProvider/index.tsx +++ b/src/providers/AIStatusProvider/index.tsx @@ -15,7 +15,9 @@ import { providerForModel, canUseAiAssistant, getAllProviders, + getApiKey, } from "../../utils/ai" +import type { AiAssistantSettings } from "../LocalStorageProvider/types" import { useAIConversation } from "../AIConversationProvider" export const useAIStatus = () => { @@ -79,6 +81,7 @@ type BaseAIStatusContextType = { models: string[] currentOperation: OperationHistory clearOperation: () => void + aiAssistantSettings: AiAssistantSettings } export type AIStatusContextType = @@ -136,14 +139,14 @@ export const AIStatusProvider: React.FC = ({ const apiKey = useMemo(() => { if (!currentModel) return null - const provider = providerForModel(currentModel) + const provider = providerForModel(currentModel, aiAssistantSettings) if (!provider) return null - return aiAssistantSettings.providers?.[provider]?.apiKey || null + return getApiKey(provider, aiAssistantSettings) }, [currentModel, aiAssistantSettings]) const models = useMemo(() => { const allModels: string[] = [] - for (const provider of getAllProviders()) { + for (const provider of getAllProviders(aiAssistantSettings)) { const providerModels = aiAssistantSettings.providers?.[provider]?.enabledModels || [] allModels.push(...providerModels) @@ -253,6 +256,7 @@ export const AIStatusProvider: React.FC = ({ apiKey: apiKey!, models, currentOperation, + aiAssistantSettings, } : { status, @@ -267,6 +271,7 @@ export const AIStatusProvider: React.FC = ({ apiKey, models, currentOperation, + aiAssistantSettings, } return ( diff --git a/src/providers/LocalStorageProvider/types.ts b/src/providers/LocalStorageProvider/types.ts index 81b1a3d17..dca97feff 100644 --- a/src/providers/LocalStorageProvider/types.ts +++ b/src/providers/LocalStorageProvider/types.ts @@ -4,9 +4,20 @@ export type ProviderSettings = { grantSchemaAccess: boolean } +export type CustomProviderDefinition = { + type: "anthropic" | "openai" | "openai-chat-completions" + name: string + baseURL: string + apiKey?: string + contextWindow: number + testModel?: string + models: string[] +} + export type AiAssistantSettings = { selectedModel?: string providers: Partial> + customProviders?: Record } export type SettingsType = string | boolean | number | AiAssistantSettings diff --git a/src/scenes/Editor/AIChatWindow/index.tsx b/src/scenes/Editor/AIChatWindow/index.tsx index c4a202b3d..b0a28e0ad 100644 --- a/src/scenes/Editor/AIChatWindow/index.tsx +++ b/src/scenes/Editor/AIChatWindow/index.tsx @@ -215,6 +215,7 @@ const AIChatWindow: React.FC = () => { hasSchemaAccess, currentModel, apiKey, + aiAssistantSettings, } = useAIStatus() const tables = useSelector(selectors.query.getTables) const running = useSelector(selectors.query.getRunning) @@ -384,6 +385,7 @@ const AIChatWindow: React.FC = () => { conversationHistory: conversation.messages, isFirstMessage: !hasAssistantMessages, settings: { model: currentModel, apiKey }, + aiAssistantSettings, questClient: quest, tables, hasSchemaAccess, @@ -560,6 +562,7 @@ const AIChatWindow: React.FC = () => { const settings = { model: currentModel, apiKey } const commonConfig = { settings, + aiAssistantSettings, questClient: quest, tables, hasSchemaAccess, diff --git a/src/utils/ai/anthropicProvider.ts b/src/utils/ai/anthropicProvider.ts index 5a9b14983..b42ead59d 100644 --- a/src/utils/ai/anthropicProvider.ts +++ b/src/utils/ai/anthropicProvider.ts @@ -305,13 +305,15 @@ async function handleToolCalls( export function createAnthropicProvider( apiKey: string, providerId: ProviderId = "anthropic", + options?: { baseURL?: string; contextWindow?: number }, ): AIProvider { const anthropic = new Anthropic({ apiKey, dangerouslyAllowBrowser: true, + ...(options?.baseURL ? { baseURL: options.baseURL } : {}), }) - const contextWindow = 200_000 + const contextWindow = options?.contextWindow ?? 200_000 return { id: providerId, diff --git a/src/utils/ai/index.ts b/src/utils/ai/index.ts index 7288a9da7..262cf4400 100644 --- a/src/utils/ai/index.ts +++ b/src/utils/ai/index.ts @@ -29,15 +29,24 @@ export { export type { HealthIssuePromptData } from "./prompts" export { MODEL_OPTIONS, - PROVIDERS, + BUILTIN_PROVIDERS, providerForModel, getModelProps, getProviderName, getAllProviders, + getAllModelOptions, getSelectedModel, getNextModel, + getTestModel, + getProviderContextWindow, + getApiKey, isAiAssistantConfigured, canUseAiAssistant, hasSchemaAccess, } from "./settings" -export type { ProviderId, ProviderType, ModelOption } from "./settings" +export type { + ProviderId, + ProviderType, + ModelOption, + CustomProviderDefinition, +} from "./settings" diff --git a/src/utils/ai/openaiChatCompletionsProvider.ts b/src/utils/ai/openaiChatCompletionsProvider.ts index 5f903d859..c3f7b123a 100644 --- a/src/utils/ai/openaiChatCompletionsProvider.ts +++ b/src/utils/ai/openaiChatCompletionsProvider.ts @@ -274,13 +274,15 @@ function toChatCompletionsAPIProps(model: string): { export function createOpenAIChatCompletionsProvider( apiKey: string, providerId: ProviderId = "openai", + options?: { baseURL?: string; contextWindow?: number }, ): AIProvider { const openai = new OpenAI({ apiKey, dangerouslyAllowBrowser: true, + ...(options?.baseURL ? { baseURL: options.baseURL } : {}), }) - const contextWindow = 400_000 + const contextWindow = options?.contextWindow ?? 400_000 return { id: providerId, @@ -492,6 +494,16 @@ export function createOpenAIChatCompletionsProvider( }, async countTokens({ messages, systemPrompt }) { + // Custom providers (non-default baseURL) use chars/3.5 estimation + // because the actual tokenizer is unknown and tiktoken underestimates + // Claude tokens by 15-25% (dangerous for compaction). + if (options?.baseURL) { + const totalChars = + systemPrompt.length + + messages.reduce((sum, m) => sum + m.content.length, 0) + return Math.ceil(totalChars / 3.5) + } + if (!tiktokenEncoder) { const { Tiktoken } = await import("js-tiktoken/lite") const o200k_base = await import("js-tiktoken/ranks/o200k_base").then( diff --git a/src/utils/ai/openaiProvider.ts b/src/utils/ai/openaiProvider.ts index 7450c69a2..ce56ac8a8 100644 --- a/src/utils/ai/openaiProvider.ts +++ b/src/utils/ai/openaiProvider.ts @@ -203,13 +203,15 @@ function toResponsesAPIProps(model: string): { export function createOpenAIProvider( apiKey: string, providerId: ProviderId = "openai", + options?: { baseURL?: string; contextWindow?: number }, ): AIProvider { const openai = new OpenAI({ apiKey, dangerouslyAllowBrowser: true, + ...(options?.baseURL ? { baseURL: options.baseURL } : {}), }) - const contextWindow = 400_000 + const contextWindow = options?.contextWindow ?? 400_000 return { id: providerId, diff --git a/src/utils/ai/registry.ts b/src/utils/ai/registry.ts index 2f8a5505f..9c8e1994f 100644 --- a/src/utils/ai/registry.ts +++ b/src/utils/ai/registry.ts @@ -1,30 +1,52 @@ import type { AIProvider } from "./types" +import type { AiAssistantSettings } from "../../providers/LocalStorageProvider/types" import { createOpenAIProvider } from "./openaiProvider" import { createOpenAIChatCompletionsProvider } from "./openaiChatCompletionsProvider" import { createAnthropicProvider } from "./anthropicProvider" -import { PROVIDERS } from "./settings" +import { BUILTIN_PROVIDERS } from "./settings" import type { ProviderId, ProviderType } from "./settings" +type ProviderOptions = { + baseURL?: string + contextWindow?: number +} + export function createProvider( providerId: ProviderId, apiKey: string, + settings?: AiAssistantSettings, ): AIProvider { - const providerType = PROVIDERS[providerId].type - return createProviderByType(providerType, providerId, apiKey) + // Check built-in providers first + const builtin = BUILTIN_PROVIDERS[providerId] + if (builtin) { + return createProviderByType(builtin.type, providerId, apiKey) + } + + // Check custom providers + const custom = settings?.customProviders?.[providerId] + if (custom) { + return createProviderByType(custom.type, providerId, apiKey, { + baseURL: custom.baseURL, + contextWindow: custom.contextWindow, + }) + } + + throw new Error(`Unknown provider: ${providerId}`) } export function createProviderByType( providerType: ProviderType, providerId: ProviderId, apiKey: string, + options?: ProviderOptions, ): AIProvider { switch (providerType) { case "openai": - return createOpenAIProvider(apiKey, providerId) + return createOpenAIProvider(apiKey, providerId, options) case "openai-chat-completions": - return createOpenAIChatCompletionsProvider(apiKey, providerId) + return createOpenAIChatCompletionsProvider(apiKey, providerId, options) case "anthropic": - return createAnthropicProvider(apiKey, providerId) + return createAnthropicProvider(apiKey, providerId, options) default: throw new Error(`Unknown provider type: ${providerType}`) } diff --git a/src/utils/ai/settings.ts b/src/utils/ai/settings.ts index bf53b0e7e..5bac3975c 100644 --- a/src/utils/ai/settings.ts +++ b/src/utils/ai/settings.ts @@ -1,22 +1,34 @@ -import type { AiAssistantSettings } from "../../providers/LocalStorageProvider/types" +import type { + AiAssistantSettings, + CustomProviderDefinition, +} from "../../providers/LocalStorageProvider/types" export type ProviderType = "anthropic" | "openai" | "openai-chat-completions" -export type ProviderId = "anthropic" | "openai" +/** Provider ID — built-in ("anthropic", "openai") or user-defined string for custom providers. */ +export type ProviderId = string export type ProviderDefinition = { type: ProviderType name: string } -export const PROVIDERS: Record = { +export { type CustomProviderDefinition } + +export const BUILTIN_PROVIDERS: Record = { anthropic: { type: "anthropic", name: "Anthropic" }, openai: { type: "openai", name: "OpenAI" }, } -export const getProviderName = (providerId: ProviderId | null): string => { +export const getProviderName = ( + providerId: ProviderId | null, + settings?: AiAssistantSettings, +): string => { if (!providerId) return "" - return PROVIDERS[providerId]?.name ?? providerId + if (BUILTIN_PROVIDERS[providerId]) return BUILTIN_PROVIDERS[providerId].name + const custom = settings?.customProviders?.[providerId] + if (custom) return custom.name + return providerId } export type ModelOption = { @@ -95,18 +107,35 @@ export type ModelProps = { reasoningEffort?: ReasoningEffort } +export const getAllModelOptions = ( + settings?: AiAssistantSettings, +): ModelOption[] => { + if (!settings?.customProviders) return MODEL_OPTIONS + const customModels: ModelOption[] = [] + for (const [providerId, def] of Object.entries(settings.customProviders)) { + for (const modelId of def.models) { + customModels.push({ + label: modelId, + value: modelId, + provider: providerId, + }) + } + } + return [...MODEL_OPTIONS, ...customModels] +} + export const providerForModel = ( model: ModelOption["value"], + settings?: AiAssistantSettings, ): ProviderId | null => { - return MODEL_OPTIONS.find((m) => m.value === model)?.provider ?? null + return ( + getAllModelOptions(settings).find((m) => m.value === model)?.provider ?? + null + ) } export const getModelProps = (model: ModelOption["value"]): ModelProps => { - const modelOption = MODEL_OPTIONS.find((m) => m.value === model) - if (!modelOption) { - return { model } - } - const parts = modelOption.value.split("@") + const parts = model.split("@") const modelName = parts[0] const extraParams = parts[1] ?.split(",") @@ -123,11 +152,18 @@ export const getModelProps = (model: ModelOption["value"]): ModelProps => { return { model: modelName } } -export const getAllProviders = (): ProviderId[] => { +export const getAllProviders = ( + settings?: AiAssistantSettings, +): ProviderId[] => { const providers = new Set() MODEL_OPTIONS.forEach((model) => { providers.add(model.provider) }) + if (settings?.customProviders) { + for (const id of Object.keys(settings.customProviders)) { + providers.add(id) + } + } return Array.from(providers) } @@ -144,10 +180,11 @@ export const getSelectedModel = ( return selectedModel } + const allModels = getAllModelOptions(settings) // Fall back to first enabled default model, then first enabled model return ( enabledModels.find( - (id) => MODEL_OPTIONS.find((m) => m.value === id)?.default, + (id) => allModels.find((m) => m.value === id)?.default, ) ?? enabledModels[0] ?? null @@ -156,7 +193,7 @@ export const getSelectedModel = ( const getAllEnabledModels = (settings: AiAssistantSettings): string[] => { const models: string[] = [] - for (const provider of getAllProviders()) { + for (const provider of getAllProviders(settings)) { const providerModels = settings.providers?.[provider]?.enabledModels if (providerModels) { models.push(...providerModels) @@ -168,10 +205,14 @@ const getAllEnabledModels = (settings: AiAssistantSettings): string[] => { export const getNextModel = ( currentModel: string | undefined, enabledModels: Record, + settings?: AiAssistantSettings, ): string | null => { let nextModel: string | null | undefined = currentModel - const modelProvider = currentModel ? providerForModel(currentModel) : null + const allModels = getAllModelOptions(settings) + const modelProvider = currentModel + ? providerForModel(currentModel, settings) + : null if (modelProvider && enabledModels[modelProvider]?.length > 0) { // Current model is still enabled, so we can use it if (currentModel && enabledModels[modelProvider].includes(currentModel)) { @@ -180,17 +221,17 @@ export const getNextModel = ( // Take the default model of this provider, otherwise the first enabled model of this provider nextModel = enabledModels[modelProvider].find( - (m) => MODEL_OPTIONS.find((mo) => mo.value === m)?.default, + (m) => allModels.find((mo) => mo.value === m)?.default, ) ?? enabledModels[modelProvider][0] } else { // No other enabled models for this provider, we have to choose from another provider if exists - const otherProviderWithEnabledModel = getAllProviders().find( - (p) => enabledModels[p].length > 0, + const otherProviderWithEnabledModel = getAllProviders(settings).find( + (p) => enabledModels[p]?.length > 0, ) if (otherProviderWithEnabledModel) { nextModel = enabledModels[otherProviderWithEnabledModel].find( - (m) => MODEL_OPTIONS.find((mo) => mo.value === m)?.default, + (m) => allModels.find((mo) => mo.value === m)?.default, ) ?? enabledModels[otherProviderWithEnabledModel][0] } else { nextModel = null @@ -202,18 +243,48 @@ export const getNextModel = ( export const isAiAssistantConfigured = ( settings: AiAssistantSettings, ): boolean => { - return getAllProviders().some( + const builtinConfigured = Object.keys(BUILTIN_PROVIDERS).some( (provider) => !!settings.providers?.[provider]?.apiKey, ) + if (builtinConfigured) return true + return Object.keys(settings.customProviders ?? {}).length > 0 } export const canUseAiAssistant = (settings: AiAssistantSettings): boolean => { return isAiAssistantConfigured(settings) && !!settings.selectedModel } +export const getTestModel = ( + providerId: ProviderId, + settings?: AiAssistantSettings, +): string | null => { + const custom = settings?.customProviders?.[providerId] + if (custom) { + return custom.testModel ?? custom.models[0] ?? null + } + return ( + MODEL_OPTIONS.find((m) => m.provider === providerId && m.isTestModel) + ?.value ?? null + ) +} + +/** + * Returns the context window for a given provider. + * For custom providers, returns the configured value. + * For built-in providers, returns null (factory uses its own default). + */ +export const getProviderContextWindow = ( + providerId: ProviderId, + settings?: AiAssistantSettings, +): number | null => { + const custom = settings?.customProviders?.[providerId] + return custom?.contextWindow ?? null +} + /** - * Reconciles persisted AI assistant settings against the current MODEL_OPTIONS. - * Removes model IDs not present in MODEL_OPTIONS from enabledModels. + * Reconciles persisted AI assistant settings against current model options. + * Removes stale model IDs from built-in providers' enabledModels. + * Preserves custom provider models (validated against customProviders definitions). * * Pure function — does not write to localStorage. * Idempotent: applying it multiple times produces the same result. @@ -221,18 +292,18 @@ export const canUseAiAssistant = (settings: AiAssistantSettings): boolean => { export const reconcileSettings = ( settings: AiAssistantSettings, ): AiAssistantSettings => { - const validModelIds = new Set(MODEL_OPTIONS.map((m) => m.value)) + const allValidIds = new Set(getAllModelOptions(settings).map((m) => m.value)) const result = { ...settings, providers: { ...settings.providers }, } - for (const providerKey of Object.keys(result.providers) as ProviderId[]) { + for (const providerKey of Object.keys(result.providers)) { const providerSettings = result.providers[providerKey] if (!providerSettings?.enabledModels) continue const models = providerSettings.enabledModels.filter((id) => - validModelIds.has(id), + allValidIds.has(id), ) result.providers[providerKey] = { ...providerSettings, @@ -245,11 +316,21 @@ export const reconcileSettings = ( return result } +export const getApiKey = ( + providerId: ProviderId, + settings: AiAssistantSettings, +): string | null => { + const builtinKey = settings.providers?.[providerId]?.apiKey + if (builtinKey) return builtinKey + const custom = settings.customProviders?.[providerId] + return custom?.apiKey || null +} + export const hasSchemaAccess = (settings: AiAssistantSettings): boolean => { const selectedModel = getSelectedModel(settings) if (!selectedModel) return false - const provider = providerForModel(selectedModel) + const provider = providerForModel(selectedModel, settings) if (!provider) return false return settings.providers?.[provider]?.grantSchemaAccess === true diff --git a/src/utils/aiAssistant.ts b/src/utils/aiAssistant.ts index 090e920fa..ef3f6e23d 100644 --- a/src/utils/aiAssistant.ts +++ b/src/utils/aiAssistant.ts @@ -1,6 +1,7 @@ import { Client } from "./questdb/client" import { Type, Table } from "./questdb/types" import type { ProviderId } from "./ai" +import type { AiAssistantSettings } from "../providers/LocalStorageProvider/types" import { formatSql } from "./formatSql" import { AIOperationStatus, StatusArgs } from "../providers/AIStatusProvider" import type { @@ -24,6 +25,7 @@ export type ActiveProviderSettings = { model: string provider: ProviderId apiKey: string + aiAssistantSettings?: AiAssistantSettings } export interface AiAssistantAPIError { @@ -305,8 +307,9 @@ export const testApiKey = async ( apiKey: string, model: string, providerId: ProviderId, + settings?: AiAssistantSettings, ): Promise<{ valid: boolean; error?: string }> => { - const provider = createProvider(providerId, apiKey) + const provider = createProvider(providerId, apiKey, settings) return provider.testConnection({ apiKey, model }) } @@ -322,7 +325,11 @@ export const generateChatTitle = async ({ } try { - const provider = createProvider(settings.provider, settings.apiKey) + const provider = createProvider( + settings.provider, + settings.apiKey, + settings.aiAssistantSettings, + ) const prompt = `Generate a concise chat title (max 30 characters) for this conversation with QuestDB AI Assistant. The title should capture the main topic or intent. @@ -398,7 +405,11 @@ export const continueConversation = async ({ health_issue: ConversationResponseFormat, }[operation] - const provider = createProvider(settings.provider, settings.apiKey) + const provider = createProvider( + settings.provider, + settings.apiKey, + settings.aiAssistantSettings, + ) return tryWithRetries( async () => { @@ -416,7 +427,10 @@ export const continueConversation = async ({ systemPrompt, userMessage, () => setStatus(AIOperationStatus.Compacting), - { model: settings.model }, + { + model: settings.model, + aiAssistantSettings: settings.aiAssistantSettings, + }, ) if ("error" in compactionResult) { diff --git a/src/utils/contextCompaction.ts b/src/utils/contextCompaction.ts index 42eaa7c0d..c8e36ae41 100644 --- a/src/utils/contextCompaction.ts +++ b/src/utils/contextCompaction.ts @@ -1,6 +1,7 @@ import type { ConversationMessage } from "../providers/AIConversationProvider/types" -import { MODEL_OPTIONS } from "./ai" +import { getTestModel } from "./ai" import type { AIProvider } from "./ai" +import type { AiAssistantSettings } from "../providers/LocalStorageProvider/types" type CompactionResultSuccess = { compactedMessage: string @@ -75,11 +76,10 @@ function toTokenMessages( async function generateSummary( middleMessages: ConversationMessage[], aiProvider: AIProvider, + settings?: AiAssistantSettings, ): Promise { - const testModel = MODEL_OPTIONS.find( - (m) => m.provider === aiProvider.id && m.isTestModel, - ) - if (!testModel) { + const testModelValue = getTestModel(aiProvider.id, settings) + if (!testModelValue) { throw new Error("No test model found for provider") } @@ -90,7 +90,7 @@ async function generateSummary( const userMessage = `Please summarize the following conversation:\n\n${conversationText}` return aiProvider.generateSummary({ - model: testModel.value, + model: testModelValue, systemPrompt: SUMMARIZATION_PROMPT, userMessage, }) @@ -102,7 +102,7 @@ export async function compactConversationIfNeeded( systemPrompt: string, userMessage: string, setStatusCompacting: () => void, - options: { model?: string } = {}, + options: { model?: string; aiAssistantSettings?: AiAssistantSettings } = {}, ): Promise { const compactionThreshold = aiProvider.contextWindow - 50_000 const messages = [ @@ -153,6 +153,7 @@ export async function compactConversationIfNeeded( conversationHistory, aiProvider, setStatusCompacting, + options.aiAssistantSettings, ) if (!result.wasCompacted) { @@ -170,6 +171,7 @@ async function compactConversationInternal( messages: ConversationMessage[], aiProvider: AIProvider, setStatusCompacting: () => void, + settings?: AiAssistantSettings, ): Promise { if (messages.length === 0) { return { wasCompacted: false } @@ -178,7 +180,7 @@ async function compactConversationInternal( setStatusCompacting() try { - const summary = await generateSummary(messages, aiProvider) + const summary = await generateSummary(messages, aiProvider, settings) return { compactedMessage: buildContinuationPrompt(summary), diff --git a/src/utils/executeAIFlow.ts b/src/utils/executeAIFlow.ts index 5df165af3..8ecbabc00 100644 --- a/src/utils/executeAIFlow.ts +++ b/src/utils/executeAIFlow.ts @@ -25,7 +25,8 @@ import { type AIOperation, } from "./aiAssistant" import { getExplainSchemaPrompt, getHealthIssuePrompt } from "./ai" -import { providerForModel, MODEL_OPTIONS } from "./ai" +import { providerForModel, getTestModel } from "./ai" +import type { AiAssistantSettings } from "../providers/LocalStorageProvider/types" import { eventBus } from "../modules/EventBus" import { EventType } from "../modules/EventBus/types" @@ -35,6 +36,7 @@ type BaseFlowConfig = { model: string apiKey: string } + aiAssistantSettings?: AiAssistantSettings questClient: Client tables?: Array hasSchemaAccess: boolean @@ -359,21 +361,23 @@ async function generateChatTitleIfNeeded( return } - const provider = providerForModel(config.settings.model)! - - const testModel = MODEL_OPTIONS.find( - (m) => m.isTestModel && m.provider === provider, + const provider = providerForModel( + config.settings.model, + config.aiAssistantSettings, ) + if (!provider) return - if (!testModel) return + const testModelValue = getTestModel(provider, config.aiAssistantSettings) + if (!testModelValue) return try { const title = await generateChatTitle({ firstUserMessage: userMessageContent, settings: { - model: testModel.value, + model: testModelValue, provider, apiKey: config.settings.apiKey, + aiAssistantSettings: config.aiAssistantSettings, }, }) @@ -448,11 +452,15 @@ export async function executeAIFlow( eventBus.publish(EventType.AI_QUERY_HIGHLIGHT, conversationId) } - const provider = providerForModel(settings.model)! + const provider = providerForModel(settings.model, config.aiAssistantSettings) + if (!provider) { + throw new Error(`No provider found for model: ${settings.model}`) + } const providerSettings: ActiveProviderSettings = { model: settings.model, provider, apiKey: settings.apiKey, + aiAssistantSettings: config.aiAssistantSettings, } const modelToolsClient = createModelToolsClient( From b37f4d660ebb8d57d55bbc792879dc8ac1c18a45 Mon Sep 17 00:00:00 2001 From: emrberk Date: Sat, 7 Mar 2026 00:39:06 +0300 Subject: [PATCH 08/25] small fixes --- src/components/SetupAIAssistant/ConfigurationModal.tsx | 2 +- src/components/SetupAIAssistant/SettingsModal.tsx | 8 ++++++-- src/providers/LocalStorageProvider/index.tsx | 3 +++ src/providers/LocalStorageProvider/types.ts | 1 + src/utils/ai/anthropicProvider.ts | 1 + src/utils/ai/openaiChatCompletionsProvider.ts | 1 + src/utils/ai/openaiProvider.ts | 1 + src/utils/ai/settings.ts | 5 ++++- 8 files changed, 18 insertions(+), 4 deletions(-) diff --git a/src/components/SetupAIAssistant/ConfigurationModal.tsx b/src/components/SetupAIAssistant/ConfigurationModal.tsx index da21d58d1..ee98609e7 100644 --- a/src/components/SetupAIAssistant/ConfigurationModal.tsx +++ b/src/components/SetupAIAssistant/ConfigurationModal.tsx @@ -719,7 +719,7 @@ export const ConfigurationModal = ({ result[model.provider].push(model) }) return result - }, []) + }, [aiAssistantSettings]) const handleProviderSelect = useCallback((provider: ProviderId) => { setSelectedProvider(provider) diff --git a/src/components/SetupAIAssistant/SettingsModal.tsx b/src/components/SetupAIAssistant/SettingsModal.tsx index 07f13dc72..d222554cf 100644 --- a/src/components/SetupAIAssistant/SettingsModal.tsx +++ b/src/components/SetupAIAssistant/SettingsModal.tsx @@ -716,7 +716,11 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { providers: updatedProviders, } - const nextModel = getNextModel(updatedSettings.selectedModel, enabledModels) + const nextModel = getNextModel( + updatedSettings.selectedModel, + enabledModels, + updatedSettings, + ) updatedSettings.selectedModel = nextModel || undefined updateSettings(StoreKey.AI_ASSISTANT_SETTINGS, updatedSettings) @@ -745,7 +749,7 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { const modelsForProvider = useMemo( () => getModelsForProvider(selectedProvider, aiAssistantSettings), - [selectedProvider], + [selectedProvider, aiAssistantSettings], ) const enabledModelsForProvider = useMemo( diff --git a/src/providers/LocalStorageProvider/index.tsx b/src/providers/LocalStorageProvider/index.tsx index baeae1d89..7f53aa549 100644 --- a/src/providers/LocalStorageProvider/index.tsx +++ b/src/providers/LocalStorageProvider/index.tsx @@ -143,6 +143,9 @@ export const LocalStorageProvider = ({ const reconciled = reconcileSettings({ selectedModel: parsed.selectedModel, providers: parsed.providers || {}, + ...(parsed.customProviders && { + customProviders: parsed.customProviders, + }), }) if (JSON.stringify(reconciled) !== stored) { setValue(StoreKey.AI_ASSISTANT_SETTINGS, JSON.stringify(reconciled)) diff --git a/src/providers/LocalStorageProvider/types.ts b/src/providers/LocalStorageProvider/types.ts index dca97feff..6e79ac040 100644 --- a/src/providers/LocalStorageProvider/types.ts +++ b/src/providers/LocalStorageProvider/types.ts @@ -12,6 +12,7 @@ export type CustomProviderDefinition = { contextWindow: number testModel?: string models: string[] + grantSchemaAccess?: boolean } export type AiAssistantSettings = { diff --git a/src/utils/ai/anthropicProvider.ts b/src/utils/ai/anthropicProvider.ts index b42ead59d..5ef4a20d8 100644 --- a/src/utils/ai/anthropicProvider.ts +++ b/src/utils/ai/anthropicProvider.ts @@ -498,6 +498,7 @@ export function createAnthropicProvider( const testClient = new Anthropic({ apiKey: testApiKey, dangerouslyAllowBrowser: true, + ...(options?.baseURL ? { baseURL: options.baseURL } : {}), }) await createAnthropicMessage(testClient, { diff --git a/src/utils/ai/openaiChatCompletionsProvider.ts b/src/utils/ai/openaiChatCompletionsProvider.ts index c3f7b123a..486d208be 100644 --- a/src/utils/ai/openaiChatCompletionsProvider.ts +++ b/src/utils/ai/openaiChatCompletionsProvider.ts @@ -466,6 +466,7 @@ export function createOpenAIChatCompletionsProvider( const testClient = new OpenAI({ apiKey: testApiKey, dangerouslyAllowBrowser: true, + ...(options?.baseURL ? { baseURL: options.baseURL } : {}), }) await testClient.chat.completions.create({ model: getModelProps(model).model, diff --git a/src/utils/ai/openaiProvider.ts b/src/utils/ai/openaiProvider.ts index ce56ac8a8..61569cebd 100644 --- a/src/utils/ai/openaiProvider.ts +++ b/src/utils/ai/openaiProvider.ts @@ -415,6 +415,7 @@ export function createOpenAIProvider( const testClient = new OpenAI({ apiKey: testApiKey, dangerouslyAllowBrowser: true, + ...(options?.baseURL ? { baseURL: options.baseURL } : {}), }) await testClient.responses.create({ model: getModelProps(model).model, // testConnection only needs model name diff --git a/src/utils/ai/settings.ts b/src/utils/ai/settings.ts index 5bac3975c..3df59f9b1 100644 --- a/src/utils/ai/settings.ts +++ b/src/utils/ai/settings.ts @@ -333,5 +333,8 @@ export const hasSchemaAccess = (settings: AiAssistantSettings): boolean => { const provider = providerForModel(selectedModel, settings) if (!provider) return false - return settings.providers?.[provider]?.grantSchemaAccess === true + return ( + settings.providers?.[provider]?.grantSchemaAccess === true || + settings.customProviders?.[provider]?.grantSchemaAccess === true + ) } From cb33880df161e1794f449c26946f4418fcdd65a1 Mon Sep 17 00:00:00 2001 From: emrberk Date: Sat, 7 Mar 2026 01:07:09 +0300 Subject: [PATCH 09/25] fixes --- .../SetupAIAssistant/ConfigurationModal.tsx | 7 +++- .../SetupAIAssistant/SettingsModal.tsx | 27 ++++++++----- src/utils/ai/index.ts | 2 + src/utils/ai/settings.ts | 40 +++++++++++++++---- 4 files changed, 56 insertions(+), 20 deletions(-) diff --git a/src/components/SetupAIAssistant/ConfigurationModal.tsx b/src/components/SetupAIAssistant/ConfigurationModal.tsx index ee98609e7..a5f04c738 100644 --- a/src/components/SetupAIAssistant/ConfigurationModal.tsx +++ b/src/components/SetupAIAssistant/ConfigurationModal.tsx @@ -796,7 +796,12 @@ export const ConfigurationModal = ({ )?.value ?? modelsByProvider[selectedProvider][0].value try { - const result = await testApiKey(apiKey, testModel, selectedProvider) + const result = await testApiKey( + apiKey, + testModel, + selectedProvider, + aiAssistantSettings, + ) if (!result.valid) { const errorMsg = result.error || "Invalid API key" setError(errorMsg) diff --git a/src/components/SetupAIAssistant/SettingsModal.tsx b/src/components/SetupAIAssistant/SettingsModal.tsx index d222554cf..63c049dbb 100644 --- a/src/components/SetupAIAssistant/SettingsModal.tsx +++ b/src/components/SetupAIAssistant/SettingsModal.tsx @@ -546,7 +546,7 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { }) const [apiKeys, setApiKeys] = useState>(() => initializeProviderState( - (provider) => aiAssistantSettings.providers?.[provider]?.apiKey || "", + (provider) => getApiKey(provider, aiAssistantSettings) || "", "", ), ) @@ -562,17 +562,19 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { const [grantSchemaAccess, setGrantSchemaAccess] = useState< Record >(() => - initializeProviderState( - (provider) => - aiAssistantSettings.providers?.[provider]?.grantSchemaAccess !== false, - true, - ), + initializeProviderState((provider) => { + const custom = aiAssistantSettings.customProviders?.[provider] + if (custom) return custom.grantSchemaAccess !== false + return ( + aiAssistantSettings.providers?.[provider]?.grantSchemaAccess !== false + ) + }, true), ) const [validatedApiKeys, setValidatedApiKeys] = useState< Record >(() => initializeProviderState( - (provider) => !!aiAssistantSettings.providers?.[provider]?.apiKey, + (provider) => !!getApiKey(provider, aiAssistantSettings), false, ), ) @@ -632,7 +634,12 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { providerModels.find((m) => m.isTestModel) ?? providerModels[0] ).value try { - const result = await testApiKey(apiKey, testModel, provider) + const result = await testApiKey( + apiKey, + testModel, + provider, + aiAssistantSettings, + ) if (!result.valid) { setValidationState((prev) => ({ ...prev, [provider]: "error" })) setValidationErrors((prev) => ({ @@ -1029,9 +1036,7 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { handleSchemaAccessChange( diff --git a/src/utils/ai/index.ts b/src/utils/ai/index.ts index 262cf4400..0c3e889e5 100644 --- a/src/utils/ai/index.ts +++ b/src/utils/ai/index.ts @@ -40,6 +40,8 @@ export { getTestModel, getProviderContextWindow, getApiKey, + makeCustomModelValue, + parseModelValue, isAiAssistantConfigured, canUseAiAssistant, hasSchemaAccess, diff --git a/src/utils/ai/settings.ts b/src/utils/ai/settings.ts index 3df59f9b1..567e1d568 100644 --- a/src/utils/ai/settings.ts +++ b/src/utils/ai/settings.ts @@ -107,6 +107,27 @@ export type ModelProps = { reasoningEffort?: ReasoningEffort } +const CUSTOM_MODEL_SEP = ":" + +export const makeCustomModelValue = ( + providerId: ProviderId, + modelId: string, +): string => `${providerId}${CUSTOM_MODEL_SEP}${modelId}` + +export const parseModelValue = ( + value: string, +): { customProviderId: string; rawModel: string } | { rawModel: string } => { + const sepIndex = value.indexOf(CUSTOM_MODEL_SEP) + if (sepIndex === -1) return { rawModel: value } + const candidateProvider = value.slice(0, sepIndex) + // Only treat as namespaced if the prefix is NOT a built-in provider. + if (BUILTIN_PROVIDERS[candidateProvider]) return { rawModel: value } + return { + customProviderId: candidateProvider, + rawModel: value.slice(sepIndex + 1), + } +} + export const getAllModelOptions = ( settings?: AiAssistantSettings, ): ModelOption[] => { @@ -116,7 +137,7 @@ export const getAllModelOptions = ( for (const modelId of def.models) { customModels.push({ label: modelId, - value: modelId, + value: makeCustomModelValue(providerId, modelId), provider: providerId, }) } @@ -126,16 +147,18 @@ export const getAllModelOptions = ( export const providerForModel = ( model: ModelOption["value"], - settings?: AiAssistantSettings, + _settings?: AiAssistantSettings, ): ProviderId | null => { - return ( - getAllModelOptions(settings).find((m) => m.value === model)?.provider ?? - null - ) + // Check for namespaced custom model value (providerId:modelId) + const parsed = parseModelValue(model) + if ("customProviderId" in parsed) return parsed.customProviderId + // Fall back to built-in model lookup + return MODEL_OPTIONS.find((m) => m.value === model)?.provider ?? null } export const getModelProps = (model: ModelOption["value"]): ModelProps => { - const parts = model.split("@") + const { rawModel } = parseModelValue(model) + const parts = rawModel.split("@") const modelName = parts[0] const extraParams = parts[1] ?.split(",") @@ -260,7 +283,8 @@ export const getTestModel = ( ): string | null => { const custom = settings?.customProviders?.[providerId] if (custom) { - return custom.testModel ?? custom.models[0] ?? null + const rawModel = custom.testModel ?? custom.models[0] ?? null + return rawModel ? makeCustomModelValue(providerId, rawModel) : null } return ( MODEL_OPTIONS.find((m) => m.provider === providerId && m.isTestModel) From fe2bb06b60e1cf0b6d26d49a242b2e67b4d86fcf Mon Sep 17 00:00:00 2001 From: emrberk Date: Sat, 7 Mar 2026 01:19:34 +0300 Subject: [PATCH 10/25] handle errors --- .../SetupAIAssistant/SettingsModal.tsx | 6 +++--- src/utils/aiAssistant.ts | 21 ++++++++++++++----- src/utils/executeAIFlow.ts | 8 ++++++- 3 files changed, 26 insertions(+), 9 deletions(-) diff --git a/src/components/SetupAIAssistant/SettingsModal.tsx b/src/components/SetupAIAssistant/SettingsModal.tsx index 63c049dbb..1aab474bd 100644 --- a/src/components/SetupAIAssistant/SettingsModal.tsx +++ b/src/components/SetupAIAssistant/SettingsModal.tsx @@ -563,11 +563,11 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { Record >(() => initializeProviderState((provider) => { + const providerSettings = aiAssistantSettings.providers?.[provider] + if (providerSettings) return providerSettings.grantSchemaAccess !== false const custom = aiAssistantSettings.customProviders?.[provider] if (custom) return custom.grantSchemaAccess !== false - return ( - aiAssistantSettings.providers?.[provider]?.grantSchemaAccess !== false - ) + return true }, true), ) const [validatedApiKeys, setValidatedApiKeys] = useState< diff --git a/src/utils/aiAssistant.ts b/src/utils/aiAssistant.ts index ef3f6e23d..255232209 100644 --- a/src/utils/aiAssistant.ts +++ b/src/utils/aiAssistant.ts @@ -405,11 +405,22 @@ export const continueConversation = async ({ health_issue: ConversationResponseFormat, }[operation] - const provider = createProvider( - settings.provider, - settings.apiKey, - settings.aiAssistantSettings, - ) + let provider: ReturnType + try { + provider = createProvider( + settings.provider, + settings.apiKey, + settings.aiAssistantSettings, + ) + } catch (error) { + return { + type: "unknown", + message: + error instanceof Error + ? error.message + : "Failed to initialize provider", + } + } return tryWithRetries( async () => { diff --git a/src/utils/executeAIFlow.ts b/src/utils/executeAIFlow.ts index 8ecbabc00..4d82bc4aa 100644 --- a/src/utils/executeAIFlow.ts +++ b/src/utils/executeAIFlow.ts @@ -454,7 +454,13 @@ export async function executeAIFlow( const provider = providerForModel(settings.model, config.aiAssistantSettings) if (!provider) { - throw new Error(`No provider found for model: ${settings.model}`) + callbacks.updateMessage(conversationId, assistantMessageId, { + error: `No provider found for model: ${settings.model}`, + }) + return { + success: false, + error: `No provider found for model: ${settings.model}`, + } } const providerSettings: ActiveProviderSettings = { model: settings.model, From 4a64a5fbc3ff3d389f8b6a852e3cc089c2ec36d7 Mon Sep 17 00:00:00 2001 From: emrberk Date: Sat, 7 Mar 2026 01:39:26 +0300 Subject: [PATCH 11/25] provider/model name, api key removal fxes --- .../SetupAIAssistant/ConfigurationModal.tsx | 4 +-- .../SetupAIAssistant/SettingsModal.tsx | 33 ++++++++++++++++--- src/providers/AIStatusProvider/index.tsx | 15 +++------ src/utils/ai/anthropicProvider.ts | 16 ++++++--- src/utils/ai/index.ts | 1 + src/utils/ai/settings.ts | 10 +++++- 6 files changed, 57 insertions(+), 22 deletions(-) diff --git a/src/components/SetupAIAssistant/ConfigurationModal.tsx b/src/components/SetupAIAssistant/ConfigurationModal.tsx index a5f04c738..767deded3 100644 --- a/src/components/SetupAIAssistant/ConfigurationModal.tsx +++ b/src/components/SetupAIAssistant/ConfigurationModal.tsx @@ -701,8 +701,8 @@ export const ConfigurationModal = ({ null, ) const providerName = useMemo( - () => getProviderName(selectedProvider), - [selectedProvider], + () => getProviderName(selectedProvider, aiAssistantSettings), + [selectedProvider, aiAssistantSettings], ) const [apiKey, setApiKey] = useState("") const [error, setError] = useState(null) diff --git a/src/components/SetupAIAssistant/SettingsModal.tsx b/src/components/SetupAIAssistant/SettingsModal.tsx index 1aab474bd..f0897e5fb 100644 --- a/src/components/SetupAIAssistant/SettingsModal.tsx +++ b/src/components/SetupAIAssistant/SettingsModal.tsx @@ -718,9 +718,28 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { } } + // Sync API keys into customProviders so getApiKey() stays consistent + const updatedCustomProviders = aiAssistantSettings.customProviders + ? { ...aiAssistantSettings.customProviders } + : undefined + if (updatedCustomProviders) { + for (const provider of allProviders) { + if (updatedCustomProviders[provider]) { + updatedCustomProviders[provider] = { + ...updatedCustomProviders[provider], + apiKey: validatedApiKeys[provider] ? apiKeys[provider] : undefined, + grantSchemaAccess: grantSchemaAccess[provider], + } + } + } + } + const updatedSettings: AiAssistantSettings = { ...aiAssistantSettings, providers: updatedProviders, + ...(updatedCustomProviders && { + customProviders: updatedCustomProviders, + }), } const nextModel = getNextModel( @@ -832,7 +851,7 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { {renderProviderIcon(provider, isActive)} - {getProviderName(provider)} + {getProviderName(provider, aiAssistantSettings)} @@ -878,7 +897,10 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { target="_blank" rel="noopener noreferrer" > - {getProviderName(selectedProvider)} + {getProviderName( + selectedProvider, + aiAssistantSettings, + )} . @@ -896,7 +918,7 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { onChange={(e) => { handleApiKeyChange(selectedProvider, e.target.value) }} - placeholder={`Enter ${getProviderName(selectedProvider)} API key`} + placeholder={`Enter ${getProviderName(selectedProvider, aiAssistantSettings)} API key`} $hasError={!!currentProviderError} $showEditButton={maskInput} readOnly={maskInput} @@ -1051,7 +1073,10 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { Grant schema access to{" "} - {getProviderName(selectedProvider)} + {getProviderName( + selectedProvider, + aiAssistantSettings, + )} When enabled, the AI assistant can access your diff --git a/src/providers/AIStatusProvider/index.tsx b/src/providers/AIStatusProvider/index.tsx index e736abad1..fdb479632 100644 --- a/src/providers/AIStatusProvider/index.tsx +++ b/src/providers/AIStatusProvider/index.tsx @@ -14,7 +14,7 @@ import { hasSchemaAccess, providerForModel, canUseAiAssistant, - getAllProviders, + getAllEnabledModels, getApiKey, } from "../../utils/ai" import type { AiAssistantSettings } from "../LocalStorageProvider/types" @@ -144,15 +144,10 @@ export const AIStatusProvider: React.FC = ({ return getApiKey(provider, aiAssistantSettings) }, [currentModel, aiAssistantSettings]) - const models = useMemo(() => { - const allModels: string[] = [] - for (const provider of getAllProviders(aiAssistantSettings)) { - const providerModels = - aiAssistantSettings.providers?.[provider]?.enabledModels || [] - allModels.push(...providerModels) - } - return allModels - }, [aiAssistantSettings]) + const models = useMemo( + () => getAllEnabledModels(aiAssistantSettings), + [aiAssistantSettings], + ) const setStatus = useCallback( ( diff --git a/src/utils/ai/anthropicProvider.ts b/src/utils/ai/anthropicProvider.ts index 5ef4a20d8..ef2219525 100644 --- a/src/utils/ai/anthropicProvider.ts +++ b/src/utils/ai/anthropicProvider.ts @@ -38,6 +38,10 @@ function toAnthropicTools(tools: ToolDefinition[]): AnthropicTool[] { })) } +function toAnthropicModel(model: string): string { + return getModelProps(model).model +} + function toAnthropicOutputConfig(format: ResponseFormatSchema): OutputConfig { return { format: { @@ -357,8 +361,10 @@ export function createAnthropicProvider( const anthropicTools = toAnthropicTools(tools) const outputConfig = toAnthropicOutputConfig(config.responseFormat) + const resolvedModel = toAnthropicModel(model) + const messageParams: Parameters[1] = { - model, + model: resolvedModel, system: config.systemInstructions, tools: anthropicTools, messages: initialMessages, @@ -386,7 +392,7 @@ export function createAnthropicProvider( anthropic, modelToolsClient, initialMessages, - model, + resolvedModel, config.systemInstructions, setStatus, outputConfig, @@ -462,7 +468,7 @@ export function createAnthropicProvider( async generateTitle({ model, prompt, responseFormat }) { try { const messageParams: Parameters[1] = { - model, + model: toAnthropicModel(model), messages: [{ role: "user", content: prompt }], max_tokens: 100, temperature: 0.3, @@ -502,7 +508,7 @@ export function createAnthropicProvider( }) await createAnthropicMessage(testClient, { - model, + model: toAnthropicModel(model), messages: [{ role: "user", content: "ping" }], }) return { valid: true } @@ -539,7 +545,7 @@ export function createAnthropicProvider( })) const response = await anthropic.messages.countTokens({ - model, + model: toAnthropicModel(model), system: systemPrompt, messages: anthropicMessages, }) diff --git a/src/utils/ai/index.ts b/src/utils/ai/index.ts index 0c3e889e5..b4f41913a 100644 --- a/src/utils/ai/index.ts +++ b/src/utils/ai/index.ts @@ -35,6 +35,7 @@ export { getProviderName, getAllProviders, getAllModelOptions, + getAllEnabledModels, getSelectedModel, getNextModel, getTestModel, diff --git a/src/utils/ai/settings.ts b/src/utils/ai/settings.ts index 567e1d568..e18771a8f 100644 --- a/src/utils/ai/settings.ts +++ b/src/utils/ai/settings.ts @@ -214,12 +214,20 @@ export const getSelectedModel = ( ) } -const getAllEnabledModels = (settings: AiAssistantSettings): string[] => { +export const getAllEnabledModels = ( + settings: AiAssistantSettings, +): string[] => { const models: string[] = [] for (const provider of getAllProviders(settings)) { const providerModels = settings.providers?.[provider]?.enabledModels if (providerModels) { models.push(...providerModels) + } else if (settings.customProviders?.[provider]) { + models.push( + ...settings.customProviders[provider].models.map((m) => + makeCustomModelValue(provider, m), + ), + ) } } return models From 76c074335ace2211827e1a1419435169082f168f Mon Sep 17 00:00:00 2001 From: emrberk Date: Sat, 7 Mar 2026 23:08:37 +0300 Subject: [PATCH 12/25] feat: add listModels to AIProvider interface All three providers (Anthropic, OpenAI Responses, OpenAI Chat Completions) implement model listing via their respective SDKs. Used by custom provider setup UI to fetch available models. Co-Authored-By: Claude Opus 4.6 --- src/utils/ai/anthropicProvider.ts | 8 ++++++++ src/utils/ai/openaiChatCompletionsProvider.ts | 8 ++++++++ src/utils/ai/openaiProvider.ts | 8 ++++++++ src/utils/ai/types.ts | 2 ++ 4 files changed, 26 insertions(+) diff --git a/src/utils/ai/anthropicProvider.ts b/src/utils/ai/anthropicProvider.ts index ef2219525..fc74e53c0 100644 --- a/src/utils/ai/anthropicProvider.ts +++ b/src/utils/ai/anthropicProvider.ts @@ -553,6 +553,14 @@ export function createAnthropicProvider( return response.input_tokens }, + async listModels(): Promise { + const models: string[] = [] + for await (const model of anthropic.models.list()) { + models.push(model.id) + } + return models + }, + classifyError( error: unknown, setStatus: StatusCallback, diff --git a/src/utils/ai/openaiChatCompletionsProvider.ts b/src/utils/ai/openaiChatCompletionsProvider.ts index 486d208be..1a8380ac8 100644 --- a/src/utils/ai/openaiChatCompletionsProvider.ts +++ b/src/utils/ai/openaiChatCompletionsProvider.ts @@ -526,6 +526,14 @@ export function createOpenAIChatCompletionsProvider( return totalTokens }, + async listModels(): Promise { + const models: string[] = [] + for await (const model of openai.models.list()) { + models.push(model.id) + } + return models + }, + classifyError( error: unknown, setStatus: StatusCallback, diff --git a/src/utils/ai/openaiProvider.ts b/src/utils/ai/openaiProvider.ts index 61569cebd..554ab0dfa 100644 --- a/src/utils/ai/openaiProvider.ts +++ b/src/utils/ai/openaiProvider.ts @@ -465,6 +465,14 @@ export function createOpenAIProvider( return totalTokens }, + async listModels(): Promise { + const models: string[] = [] + for await (const model of openai.models.list()) { + models.push(model.id) + } + return models + }, + classifyError( error: unknown, setStatus: StatusCallback, diff --git a/src/utils/ai/types.ts b/src/utils/ai/types.ts index 109db4a4b..9667e0704 100644 --- a/src/utils/ai/types.ts +++ b/src/utils/ai/types.ts @@ -67,6 +67,8 @@ export interface AIProvider { model: string }): Promise + listModels(): Promise + classifyError(error: unknown, setStatus: StatusCallback): AiAssistantAPIError isNonRetryableError(error: unknown): boolean } From 66e4766f9c5d088b71cbc9e7975af5dc9faafdef Mon Sep 17 00:00:00 2001 From: emrberk Date: Mon, 9 Mar 2026 03:46:07 +0300 Subject: [PATCH 13/25] custom provider ui initial --- .../SetupAIAssistant/CustomProviderModal.tsx | 1082 +++++++++++++++++ .../SetupAIAssistant/SettingsModal.tsx | 868 +++++++------ src/utils/ai/anthropicProvider.ts | 2 +- src/utils/ai/openaiChatCompletionsProvider.ts | 2 +- src/utils/ai/openaiProvider.ts | 2 +- src/utils/ai/settings.ts | 3 +- src/utils/aiAssistant.ts | 7 +- 7 files changed, 1609 insertions(+), 357 deletions(-) create mode 100644 src/components/SetupAIAssistant/CustomProviderModal.tsx diff --git a/src/components/SetupAIAssistant/CustomProviderModal.tsx b/src/components/SetupAIAssistant/CustomProviderModal.tsx new file mode 100644 index 000000000..b63e4547c --- /dev/null +++ b/src/components/SetupAIAssistant/CustomProviderModal.tsx @@ -0,0 +1,1082 @@ +import React, { useState, useMemo, useCallback, useEffect, useRef } from "react" +import styled, { useTheme } from "styled-components" +import { MultiStepModal } from "../MultiStepModal" +import type { Step } from "../MultiStepModal" +import { useModalNavigation } from "../MultiStepModal" +import { Box } from "../Box" +import { Input } from "../Input" +import { Checkbox } from "../Checkbox" +import { Text } from "../Text" +import { Dialog } from "../Dialog" +import { createProviderByType } from "../../utils/ai/registry" +import type { + ProviderType, + CustomProviderDefinition, +} from "../../utils/ai/settings" +import { Select } from "../Select" +import { WarningIcon, XIcon } from "@phosphor-icons/react" + +const ModalContent = styled.div` + display: flex; + flex-direction: column; + width: 100%; +` + +const HeaderSection = styled(Box).attrs({ + flexDirection: "column", + gap: "1.6rem", +})` + padding: 2.4rem; + padding-top: 0; + width: 100%; +` + +const HeaderTitleRow = styled(Box).attrs({ + justifyContent: "space-between", + align: "flex-start", + gap: "1rem", +})` + width: 100%; +` + +const HeaderText = styled(Box).attrs({ + flexDirection: "column", + gap: "1.2rem", + align: "flex-start", +})` + flex: 1; +` + +const ModalTitle = styled(Dialog.Title)` + font-size: 2.4rem; + font-weight: 600; + margin: 0; + padding: 0; + color: ${({ theme }) => theme.color.foreground}; + border: 0; +` + +const ModalSubtitle = styled(Dialog.Description)` + color: ${({ theme }) => theme.color.gray2}; + margin: 0; + padding: 0; +` + +const StyledCloseButton = styled.button` + background: transparent; + border: none; + cursor: pointer; + padding: 0; + display: flex; + align-items: center; + justify-content: center; + color: ${({ theme }) => theme.color.gray1}; + border-radius: 0.4rem; + flex-shrink: 0; + width: 2.2rem; + height: 2.2rem; + + &:hover { + color: ${({ theme }) => theme.color.foreground}; + } +` + +const Separator = styled.div` + height: 0.1rem; + width: 100%; + background: ${({ theme }) => theme.color.selection}; +` + +const ContentSection = styled(Box).attrs({ + flexDirection: "column", + gap: "2rem", +})` + padding: 2.4rem; + width: 100%; +` + +const InputSection = styled(Box).attrs({ + flexDirection: "column", + gap: "1.2rem", +})` + width: 100%; +` + +const InputLabel = styled(Text)` + font-size: 1.6rem; + font-weight: 600; + color: ${({ theme }) => theme.color.gray2}; +` + +const StyledInput = styled(Input)<{ $hasError?: boolean }>` + width: 100%; + background: #262833; + border: 0.1rem solid + ${({ theme, $hasError }) => ($hasError ? theme.color.red : "#6b7280")}; + border-radius: 0.8rem; + font-size: 1.4rem; + min-height: 3rem; + + &::placeholder { + color: ${({ theme }) => theme.color.gray2}; + font-family: inherit; + } +` + +const PasswordInput = styled(StyledInput)` + text-security: disc; + -webkit-text-security: disc; + -moz-text-security: disc; +` + +const StyledSelect = styled(Select)` + width: 100%; + background: #262833; + color: ${({ theme }) => theme.color.foreground}; + border: 0.1rem solid #6b7280; + border-radius: 0.8rem; + min-height: 3.2rem; + padding: 0 0.75rem; + cursor: pointer; + + &:focus { + border-color: ${({ theme }) => theme.color.pink}; + outline: none; + } + + option { + background: ${({ theme }) => theme.color.backgroundDarker}; + color: ${({ theme }) => theme.color.foreground}; + } +` + +const HelperText = styled(Text)` + font-size: 1.3rem; + font-weight: 300; + color: ${({ theme }) => theme.color.gray2}; +` + +const WarningBanner = styled(Box).attrs({ + flexDirection: "row", + gap: "0.6rem", + align: "center", +})` + width: 100%; + background: rgba(255, 165, 0, 0.08); + border: 0.1rem solid ${({ theme }) => theme.color.orange}; + border-radius: 0.8rem; + padding: 0.75rem; +` + +const WarningText = styled(Text)` + font-size: 1.3rem; + color: ${({ theme }) => theme.color.orange}; +` + +const ModelListContainer = styled.div` + max-height: 30rem; + overflow-y: auto; + display: flex; + flex-direction: column; + gap: 0.25rem; + border: 0.1rem solid #6b7280; + border-radius: 0.4rem; + width: 100%; +` + +const ModelRow = styled.label` + display: flex; + align-items: center; + gap: 0.8rem; + padding: 0.6rem 0.8rem; + cursor: pointer; + font-size: 1.4rem; + color: ${({ theme }) => theme.color.foreground}; + + &:hover { + background: ${({ theme }) => theme.color.selection}; + } +` + +const ModelChipsContainer = styled.div` + display: flex; + flex-wrap: wrap; + gap: 0.6rem; +` + +const ModelChip = styled.div` + display: inline-flex; + align-items: center; + gap: 0.5rem; + background: ${({ theme }) => theme.color.selection}; + border-radius: 0.4rem; + padding: 0.4rem 0.8rem; + font-size: 1.3rem; + color: ${({ theme }) => theme.color.foreground}; +` + +const ChipRemoveButton = styled.button` + background: none; + border: none; + cursor: pointer; + padding: 0; + display: flex; + justify-content: center; + align-items: center; + color: ${({ theme }) => theme.color.gray2}; + + &:hover { + color: ${({ theme }) => theme.color.foreground}; + } +` + +const AddModelRow = styled(Box).attrs({ + gap: "0.8rem", + align: "center", +})` + width: 100%; +` + +const AddModelButton = styled.button` + height: 3rem; + border: 0.1rem solid ${({ theme }) => theme.color.pinkDarker}; + background: ${({ theme }) => theme.color.background}; + color: ${({ theme }) => theme.color.foreground}; + border-radius: 0.4rem; + padding: 0 1.2rem; + font-size: 1.4rem; + font-weight: 500; + cursor: pointer; + white-space: nowrap; + + &:hover:not(:disabled) { + background: ${({ theme }) => theme.color.pinkDarker}; + } + + &:disabled { + opacity: 0.6; + cursor: not-allowed; + } +` + +const SelectAllRow = styled(Box).attrs({ + gap: "2rem", + align: "center", +})` + display: inline-flex; + margin-left: auto; +` + +const SelectAllLink = styled.button` + background: none; + border: none; + cursor: pointer; + color: ${({ theme }) => theme.color.cyan}; + font-size: 1.4rem; + padding: 0; + + &:hover { + text-decoration: underline; + } +` + +const SchemaAccessSection = styled(Box).attrs({ + flexDirection: "column", + gap: "1.6rem", + align: "flex-start", +})` + width: 100%; +` + +const SchemaAccessTitle = styled(Text)` + font-size: 1.6rem; + font-weight: 600; + color: ${({ theme }) => theme.color.gray2}; + flex: 1; +` + +const SchemaCheckboxContainer = styled(Box).attrs({ + gap: "1.5rem", + align: "flex-start", +})` + background: rgba(68, 71, 90, 0.56); + padding: 0.75rem; + border-radius: 0.4rem; + width: 100%; +` + +const SchemaCheckboxInner = styled(Box).attrs({ + gap: "1.5rem", + align: "center", +})` + flex: 1; + padding: 0.75rem; + border-radius: 0.5rem; +` + +const SchemaCheckboxWrapper = styled.div` + flex-shrink: 0; + display: flex; + align-items: center; +` + +const SchemaCheckboxContent = styled(Box).attrs({ + flexDirection: "column", + gap: "0.6rem", +})` + flex: 1; +` + +const SchemaCheckboxLabel = styled(Text)` + font-size: 1.4rem; + font-weight: 500; + color: ${({ theme }) => theme.color.foreground}; +` + +const SchemaCheckboxDescription = styled(Text)` + font-size: 1.3rem; + font-weight: 400; + color: ${({ theme }) => theme.color.gray2}; +` + +const SchemaCheckboxDescriptionBold = styled.span` + font-weight: 500; + color: ${({ theme }) => theme.color.foreground}; +` + +const CloseButton = ({ onClick }: { onClick: () => void }) => ( + + + + + +) + +const SchemaAccessToggle = ({ + checked, + onChange, + providerName, +}: { + checked: boolean + onChange: (checked: boolean) => void + providerName: string +}) => ( + + Schema Access + + + + onChange(e.target.checked)} + /> + + + + Grant schema access to {providerName} + + + When enabled, the AI assistant can access your database schema + information to provide more accurate suggestions and explanations. + Schema information helps the AI understand your table structures, + column names, and relationships.{" "} + + The AI model will not have access to your data. + + + + + + +) + +type StepOneProps = { + name: string + providerType: ProviderType + baseURL: string + apiKey: string + onNameChange: (v: string) => void + onProviderTypeChange: (v: ProviderType) => void + onBaseURLChange: (v: string) => void + onApiKeyChange: (v: string) => void +} + +const StepOneContent = ({ + name, + providerType, + baseURL, + apiKey, + onNameChange, + onProviderTypeChange, + onBaseURLChange, + onApiKeyChange, +}: StepOneProps) => { + const navigation = useModalNavigation() + + return ( + + + + + Add Custom Provider + + Configure a custom AI provider endpoint. Supports + OpenAI-compatible, Anthropic-compatible, and local providers like + Ollama. + + + + + + + + + Provider Name + onNameChange(e.target.value)} + placeholder="e.g., My Ollama, Azure GPT" + /> + + + Provider Type + + onProviderTypeChange(e.target.value as ProviderType) + } + options={[ + { + label: "OpenAI Chat Completions API", + value: "openai-chat-completions", + }, + { + label: "OpenAI Responses API", + value: "openai", + }, + { + label: "Anthropic Messages API", + value: "anthropic", + }, + ]} + /> + + Most third-party providers and local models use the OpenAI Chat + Completions format. + + + + Base URL + onBaseURLChange(e.target.value)} + placeholder="e.g., http://localhost:11434/v1" + /> + + The base URL of your provider's API endpoint. + + + + API Key + onApiKeyChange(e.target.value)} + placeholder="Optional for local providers" + /> + + Stored locally in your browser. Optional for local providers like + Ollama. + + + + + ) +} + +type StepTwoAutoProps = { + fetchedModels: string[] + selectedModels: string[] + contextWindow: number + grantSchemaAccess: boolean + providerName: string + manualModelInput: string + onToggleModel: (model: string) => void + onSelectAll: () => void + onDeselectAll: () => void + onContextWindowChange: (v: number) => void + onSchemaAccessChange: (v: boolean) => void + onManualModelInputChange: (v: string) => void + onAddManualModel: () => void +} + +const StepTwoAutoContent = ({ + fetchedModels, + selectedModels, + contextWindow, + grantSchemaAccess, + providerName, + manualModelInput, + onToggleModel, + onSelectAll, + onDeselectAll, + onContextWindowChange, + onSchemaAccessChange, + onManualModelInputChange, + onAddManualModel, +}: StepTwoAutoProps) => { + const navigation = useModalNavigation() + + return ( + + + + + Configure Settings + + Configure the settings for your custom provider. + + + + + + + + + + Select Models + + + Select All + + + Deselect All + + + + + {fetchedModels.map((model) => ( + + onToggleModel(model)} + /> + {model} + + ))} + + + + Don't see your model? Add it manually: + + onManualModelInputChange(e.target.value)} + placeholder="e.g., llama3" + onKeyDown={(e) => { + if (e.key === "Enter") { + e.preventDefault() + onAddManualModel() + } + }} + /> + + Add + + + {selectedModels.filter((m) => !fetchedModels.includes(m)).length > + 0 && ( + + {selectedModels + .filter((m) => !fetchedModels.includes(m)) + .map((model) => ( + + {model} + onToggleModel(model)} + > + + + + ))} + + )} + + + + + + Context Window + onContextWindowChange(Number(e.target.value))} + min={1} + /> + + Maximum number of tokens the model can process. + + + + + + + + ) +} + +type StepTwoManualProps = { + manualModels: string[] + manualModelInput: string + contextWindow: number + grantSchemaAccess: boolean + providerName: string + onManualModelInputChange: (v: string) => void + onAddManualModel: () => void + onRemoveManualModel: (model: string) => void + onContextWindowChange: (v: number) => void + onSchemaAccessChange: (v: boolean) => void +} + +const StepTwoManualContent = ({ + manualModels, + manualModelInput, + contextWindow, + grantSchemaAccess, + providerName, + onManualModelInputChange, + onAddManualModel, + onRemoveManualModel, + onContextWindowChange, + onSchemaAccessChange, +}: StepTwoManualProps) => { + const theme = useTheme() + const navigation = useModalNavigation() + + return ( + + + + + Models & Settings + + Configure the models and settings for your custom provider. + + + + + + + + + + + Could not fetch models automatically from this provider. Please + enter model IDs manually. + + + + Add Models + + onManualModelInputChange(e.target.value)} + placeholder="e.g., llama3, gpt-4o, claude-sonnet-4-20250514" + onKeyDown={(e) => { + if (e.key === "Enter") { + e.preventDefault() + onAddManualModel() + } + }} + /> + + Add + + + {manualModels.length > 0 && ( + + {manualModels.map((model) => ( + + {model} + onRemoveManualModel(model)} + title={`Remove ${model}`} + > + + + + ))} + + )} + + + + + + Context Window + onContextWindowChange(Number(e.target.value))} + min={1} + /> + + Maximum number of tokens the model can process. + + + + + + + + + ) +} + +export type CustomProviderModalProps = { + open: boolean + onOpenChange: (open: boolean) => void + onSave: (providerId: string, provider: CustomProviderDefinition) => void + existingProviderIds: string[] +} + +const generateProviderId = (name: string): string => + name + .toLowerCase() + .replace(/[^a-z0-9]+/g, "-") + .replace(/^-|-$/g, "") || "custom-provider" + +export const CustomProviderModal = ({ + open, + onOpenChange, + onSave, + existingProviderIds, +}: CustomProviderModalProps) => { + const [name, setName] = useState("") + const [providerType, setProviderType] = useState( + "openai-chat-completions", + ) + const [baseURL, setBaseURL] = useState("") + const [apiKey, setApiKey] = useState("") + + const [contextWindow, setContextWindow] = useState(128_000) + const [fetchedModels, setFetchedModels] = useState(null) + const [selectedModels, setSelectedModels] = useState([]) + const [manualModels, setManualModels] = useState([]) + const [manualModelInput, setManualModelInput] = useState("") + + const [flowPath, setFlowPath] = useState<"auto" | "manual">("manual") + + const [grantSchemaAccess, setGrantSchemaAccess] = useState(false) + + const abortControllerRef = useRef(null) + + useEffect(() => { + return () => { + abortControllerRef.current?.abort() + } + }, []) + + const handleToggleModel = useCallback((model: string) => { + setSelectedModels((prev) => + prev.includes(model) ? prev.filter((m) => m !== model) : [...prev, model], + ) + }, []) + + const handleSelectAll = useCallback(() => { + if (fetchedModels) { + setSelectedModels((prev) => { + const manual = prev.filter((m) => !fetchedModels.includes(m)) + return [...fetchedModels, ...manual] + }) + } + }, [fetchedModels]) + + const handleDeselectAll = useCallback(() => { + setSelectedModels((prev) => + fetchedModels ? prev.filter((m) => !fetchedModels.includes(m)) : [], + ) + }, [fetchedModels]) + + const handleAddManualModel = useCallback(() => { + const trimmed = manualModelInput.trim() + if (!trimmed) return + + if (flowPath === "auto") { + setSelectedModels((prev) => + prev.includes(trimmed) ? prev : [...prev, trimmed], + ) + } else { + setManualModels((prev) => + prev.includes(trimmed) ? prev : [...prev, trimmed], + ) + } + setManualModelInput("") + }, [manualModelInput, flowPath]) + + const handleRemoveManualModel = useCallback((model: string) => { + setManualModels((prev) => prev.filter((m) => m !== model)) + }, []) + + const connectionValidate = useCallback(async (): Promise< + string | boolean + > => { + if (!name.trim()) return "Provider name is required" + if (!baseURL.trim()) return "Base URL is required" + if (!baseURL.startsWith("http://") && !baseURL.startsWith("https://")) + return "Base URL must start with http:// or https://" + + const providerId = generateProviderId(name) + if (existingProviderIds.includes(providerId)) + return `A provider with a similar name already exists` + + // First, check that the URL is reachable with a simple fetch + try { + abortControllerRef.current?.abort() + abortControllerRef.current = new AbortController() + + const normalizedURL = baseURL.replace(/\/+$/, "") + await fetch(normalizedURL, { + method: "GET", + signal: abortControllerRef.current.signal, + }) + } catch (err) { + if (err instanceof DOMException && err.name === "AbortError") { + return "Connection check was cancelled" + } + return `Could not connect to ${baseURL}. Please check the URL and make sure the server is running.` + } + + // URL is reachable — try to fetch models + try { + const tempProvider = createProviderByType( + providerType, + "temp", + apiKey || "", + { baseURL, contextWindow }, + ) + const models = await tempProvider.listModels() + if (models && models.length > 0) { + setFetchedModels(models) + setFlowPath("auto") + } else { + setFetchedModels(null) + setFlowPath("manual") + } + } catch (err) { + const message = err instanceof Error ? err.message : "Unknown error" + return `Could not connect to provider: ${message}` + } + + return true + }, [name, baseURL, providerType, apiKey, contextWindow, existingProviderIds]) + + const modelsValidate = useCallback((): string | boolean => { + if (flowPath === "auto") { + if (selectedModels.length === 0) return "Select at least one model" + } else { + if (manualModels.length === 0 && !manualModelInput.trim()) + return "Add at least one model" + } + return true + }, [flowPath, selectedModels, manualModels, manualModelInput]) + + const handleComplete = useCallback(() => { + const providerId = generateProviderId(name) + + // Auto-add any pending manual model input + const pendingModel = manualModelInput.trim() + let models: string[] + + if (flowPath === "auto") { + models = + pendingModel && !selectedModels.includes(pendingModel) + ? [...selectedModels, pendingModel] + : selectedModels + } else { + models = + pendingModel && !manualModels.includes(pendingModel) + ? [...manualModels, pendingModel] + : manualModels + } + + const definition: CustomProviderDefinition = { + type: providerType, + name: name.trim(), + baseURL: baseURL.trim(), + apiKey: apiKey || undefined, + contextWindow, + models, + grantSchemaAccess, + } + + onSave(providerId, definition) + }, [ + name, + flowPath, + selectedModels, + manualModels, + manualModelInput, + providerType, + baseURL, + apiKey, + contextWindow, + grantSchemaAccess, + onSave, + ]) + + const steps: Step[] = useMemo(() => { + const connectionStep: Step = { + id: "connection", + title: "Add Custom Provider", + stepName: "Connection", + content: ( + + ), + validate: connectionValidate, + } + + if (flowPath === "auto" && fetchedModels !== null) { + return [ + connectionStep, + { + id: "select-models", + title: "Add Custom Provider", + stepName: "Models & Settings", + content: ( + + ), + validate: modelsValidate, + }, + ] + } + + return [ + connectionStep, + { + id: "manual-models", + title: "Add Custom Provider", + stepName: "Models & Settings", + content: ( + + ), + validate: modelsValidate, + }, + ] + }, [ + name, + providerType, + baseURL, + apiKey, + connectionValidate, + flowPath, + fetchedModels, + selectedModels, + contextWindow, + grantSchemaAccess, + manualModelInput, + handleToggleModel, + handleSelectAll, + handleDeselectAll, + handleAddManualModel, + modelsValidate, + manualModels, + handleRemoveManualModel, + ]) + + const canProceed = useCallback( + (stepIndex: number): boolean => { + if (stepIndex === 0) { + return !!name.trim() && !!baseURL.trim() + } + return true + }, + [name, baseURL], + ) + + return ( + + ) +} diff --git a/src/components/SetupAIAssistant/SettingsModal.tsx b/src/components/SetupAIAssistant/SettingsModal.tsx index f0897e5fb..6c99c5453 100644 --- a/src/components/SetupAIAssistant/SettingsModal.tsx +++ b/src/components/SetupAIAssistant/SettingsModal.tsx @@ -13,25 +13,31 @@ import { testApiKey } from "../../utils/aiAssistant" import { StoreKey } from "../../utils/localStorage/types" import { toast } from "../Toast" import { Edit } from "@styled-icons/remix-line" +import { TrashIcon, PlugsIcon, PlusIcon } from "@phosphor-icons/react" import { OpenAIIcon } from "./OpenAIIcon" import { AnthropicIcon } from "./AnthropicIcon" import { BrainIcon } from "./BrainIcon" -import { PlugsIcon } from "@phosphor-icons/react" import { LoadingSpinner } from "../LoadingSpinner" import { Overlay } from "../Overlay" import { getAllProviders, getAllModelOptions, getApiKey, + makeCustomModelValue, + BUILTIN_PROVIDERS, type ModelOption, type ProviderId, getNextModel, getProviderName, } from "../../utils/ai" -import type { AiAssistantSettings } from "../../providers/LocalStorageProvider/types" +import type { + AiAssistantSettings, + CustomProviderDefinition, +} from "../../providers/LocalStorageProvider/types" import { ForwardRef } from "../ForwardRef" import { Badge, BadgeType } from "../../components/Badge" import { CheckboxCircle } from "@styled-icons/remix-fill" +import { CustomProviderModal } from "./CustomProviderModal" const ModalContent = styled.div` display: flex; @@ -460,6 +466,18 @@ const SchemaCheckboxDescriptionBold = styled.span` color: ${({ theme }) => theme.color.foreground}; ` +const RemoveProviderButton = styled(Button)` + border: 0.1rem solid ${({ theme }) => theme.color.red}; + background: ${({ theme }) => theme.color.backgroundDarker}; + color: ${({ theme }) => theme.color.foreground}; + + &:hover:not(:disabled) { + background: ${({ theme }) => theme.color.background}; + border: 0.1rem solid ${({ theme }) => theme.color.red}; + color: ${({ theme }) => theme.color.foreground}; + } +` + const FooterSection = styled(Box).attrs({ flexDirection: "column", gap: "2rem", @@ -498,6 +516,27 @@ const SaveButton = styled(Button)` width: 100%; ` +const AddProviderButton = styled.button` + display: flex; + flex-direction: column; + align-items: center; + gap: 0.5rem; + padding: 0.8rem 1.6rem; + background: none; + border: 0.1rem dashed ${({ theme }) => theme.color.gray2}; + border-radius: 0.4rem; + color: ${({ theme }) => theme.color.gray2}; + cursor: pointer; + font-size: 1.3rem; + justify-content: center; + margin: 0 1rem; + + &:hover { + border-color: ${({ theme }) => theme.color.foreground}; + color: ${({ theme }) => theme.color.foreground}; + } +` + type SettingsModalProps = { open?: boolean onOpenChange?: (open: boolean) => void @@ -544,6 +583,7 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { const providersWithKeys = getProvidersWithApiKeys(aiAssistantSettings) return providersWithKeys[0] || getAllProviders(aiAssistantSettings)[0] }) + const isCustomProvider = !BUILTIN_PROVIDERS[selectedProvider] const [apiKeys, setApiKeys] = useState>(() => initializeProviderState( (provider) => getApiKey(provider, aiAssistantSettings) || "", @@ -574,7 +614,9 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { Record >(() => initializeProviderState( - (provider) => !!getApiKey(provider, aiAssistantSettings), + (provider) => + !BUILTIN_PROVIDERS[provider] || + !!getApiKey(provider, aiAssistantSettings), false, ), ) @@ -589,6 +631,23 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { >(() => initializeProviderState(() => false, false)) const inputRef = useRef(null) + const [customProviderModalOpen, setCustomProviderModalOpen] = useState(false) + + const [localCustomProviders, setLocalCustomProviders] = useState< + Record + >(() => ({ ...(aiAssistantSettings.customProviders ?? {}) })) + + const localSettings = useMemo( + () => ({ + ...aiAssistantSettings, + customProviders: + Object.keys(localCustomProviders).length > 0 + ? localCustomProviders + : undefined, + }), + [aiAssistantSettings, localCustomProviders], + ) + const handleProviderSelect = useCallback((provider: ProviderId) => { setSelectedProvider(provider) setValidationErrors((prev) => ({ ...prev, [provider]: null })) @@ -620,7 +679,7 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { setValidationState((prev) => ({ ...prev, [provider]: "validating" })) setValidationErrors((prev) => ({ ...prev, [provider]: null })) - const providerModels = getModelsForProvider(provider, aiAssistantSettings) + const providerModels = getModelsForProvider(provider, localSettings) if (providerModels.length === 0) { setValidationState((prev) => ({ ...prev, [provider]: "error" })) setValidationErrors((prev) => ({ @@ -638,7 +697,7 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { apiKey, testModel, provider, - aiAssistantSettings, + localSettings, ) if (!result.valid) { setValidationState((prev) => ({ ...prev, [provider]: "error" })) @@ -647,7 +706,7 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { [provider]: result.error || "Invalid API key", })) } else { - const defaultModels = getAllModelOptions(aiAssistantSettings) + const defaultModels = getAllModelOptions(localSettings) .filter((m) => m.defaultEnabled && m.provider === provider) .map((m) => m.value) if (defaultModels.length > 0) { @@ -667,16 +726,6 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { [apiKeys], ) - const handleRemoveApiKey = useCallback((provider: ProviderId) => { - // Remove API key from local state only - // Settings will be persisted when Save Settings is clicked - setApiKeys((prev) => ({ ...prev, [provider]: "" })) - setValidatedApiKeys((prev) => ({ ...prev, [provider]: false })) - setValidationState((prev) => ({ ...prev, [provider]: "idle" })) - setValidationErrors((prev) => ({ ...prev, [provider]: null })) - setIsInputFocused((prev) => ({ ...prev, [provider]: false })) - }, []) - const handleModelToggle = useCallback( (provider: ProviderId, modelValue: string) => { setEnabledModels((prev) => { @@ -702,34 +751,39 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { const handleSave = useCallback(() => { const updatedProviders = { ...aiAssistantSettings.providers } - const allProviders = getAllProviders(aiAssistantSettings) + const allProviderIds = getAllProviders(localSettings) - for (const provider of allProviders) { - if (validatedApiKeys[provider]) { - // Only save providers with validated API keys + for (const provider of allProviderIds) { + const isCustom = !BUILTIN_PROVIDERS[provider] + if (validatedApiKeys[provider] || isCustom) { updatedProviders[provider] = { - apiKey: apiKeys[provider], + apiKey: apiKeys[provider] ?? "", enabledModels: enabledModels[provider], grantSchemaAccess: grantSchemaAccess[provider], } } else { - // Remove provider entry if no validated API key delete updatedProviders[provider] } } - // Sync API keys into customProviders so getApiKey() stays consistent - const updatedCustomProviders = aiAssistantSettings.customProviders - ? { ...aiAssistantSettings.customProviders } - : undefined + // Remove provider entries for deleted custom providers + for (const provider of Object.keys(updatedProviders)) { + if (!BUILTIN_PROVIDERS[provider] && !localCustomProviders[provider]) { + delete updatedProviders[provider] + } + } + + // Sync API keys and schema access into custom provider definitions + const updatedCustomProviders = + Object.keys(localCustomProviders).length > 0 + ? { ...localCustomProviders } + : undefined if (updatedCustomProviders) { - for (const provider of allProviders) { - if (updatedCustomProviders[provider]) { - updatedCustomProviders[provider] = { - ...updatedCustomProviders[provider], - apiKey: validatedApiKeys[provider] ? apiKeys[provider] : undefined, - grantSchemaAccess: grantSchemaAccess[provider], - } + for (const provider of Object.keys(updatedCustomProviders)) { + updatedCustomProviders[provider] = { + ...updatedCustomProviders[provider], + apiKey: apiKeys[provider] || undefined, + grantSchemaAccess: grantSchemaAccess[provider], } } } @@ -737,9 +791,7 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { const updatedSettings: AiAssistantSettings = { ...aiAssistantSettings, providers: updatedProviders, - ...(updatedCustomProviders && { - customProviders: updatedCustomProviders, - }), + customProviders: updatedCustomProviders, } const nextModel = getNextModel( @@ -754,6 +806,8 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { onOpenChange?.(false) }, [ aiAssistantSettings, + localSettings, + localCustomProviders, apiKeys, enabledModels, grantSchemaAccess, @@ -766,6 +820,72 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { onOpenChange?.(false) }, [onOpenChange]) + const handleRemoveProvider = useCallback( + (providerId: ProviderId) => { + const isCustom = !BUILTIN_PROVIDERS[providerId] + + if (isCustom) { + setLocalCustomProviders((prev) => { + const { [providerId]: _, ...rest } = prev + return rest + }) + } + + setApiKeys((prev) => ({ ...prev, [providerId]: "" })) + setGrantSchemaAccess((prev) => ({ ...prev, [providerId]: false })) + setValidatedApiKeys((prev) => ({ ...prev, [providerId]: false })) + setValidationState((prev) => ({ ...prev, [providerId]: "idle" })) + setValidationErrors((prev) => ({ ...prev, [providerId]: null })) + setEnabledModels((prev) => ({ ...prev, [providerId]: [] })) + setIsInputFocused((prev) => ({ ...prev, [providerId]: false })) + + // Switch to first remaining active provider + const updatedCustomProviders = isCustom + ? (() => { + const { [providerId]: _, ...rest } = localCustomProviders + return Object.keys(rest).length > 0 ? rest : undefined + })() + : localSettings.customProviders + const remaining = getAllProviders({ + ...localSettings, + customProviders: updatedCustomProviders, + }).filter((p) => p !== providerId || BUILTIN_PROVIDERS[p]) + setSelectedProvider(remaining[0] ?? "openai") + }, + [localSettings, localCustomProviders], + ) + + const handleCustomProviderSave = useCallback( + (providerId: string, definition: CustomProviderDefinition) => { + setLocalCustomProviders((prev) => ({ + ...prev, + [providerId]: definition, + })) + setApiKeys((prev) => ({ + ...prev, + [providerId]: definition.apiKey ?? "", + })) + setGrantSchemaAccess((prev) => ({ + ...prev, + [providerId]: definition.grantSchemaAccess ?? false, + })) + setValidatedApiKeys((prev) => ({ + ...prev, + [providerId]: true, + })) + setEnabledModels((prev) => ({ + ...prev, + [providerId]: definition.models.map((m) => + makeCustomModelValue(providerId, m), + ), + })) + + setSelectedProvider(providerId) + setCustomProviderModalOpen(false) + }, + [], + ) + const currentProviderValidated = validatedApiKeys[selectedProvider] const currentProviderApiKey = apiKeys[selectedProvider] const currentProviderValidationState = validationState[selectedProvider] @@ -774,8 +894,8 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { const maskInput = !!(currentProviderApiKey && !currentProviderIsFocused) const modelsForProvider = useMemo( - () => getModelsForProvider(selectedProvider, aiAssistantSettings), - [selectedProvider, aiAssistantSettings], + () => getModelsForProvider(selectedProvider, localSettings), + [selectedProvider, localSettings], ) const enabledModelsForProvider = useMemo( @@ -784,8 +904,8 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { ) const allProviders = useMemo( - () => getAllProviders(aiAssistantSettings), - [aiAssistantSettings], + () => getAllProviders(localSettings), + [localSettings], ) const renderProviderIcon = (provider: ProviderId, isActive: boolean) => { @@ -801,322 +921,368 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { } return ( - - - - - - - - - - - Assistant Settings - - Modify settings for your AI assistant, set up new providers, - and review access. - - - - - - - - - - - - - {allProviders.map((provider) => { - const isActive = selectedProvider === provider - return ( - handleProviderSelect(provider)} - data-hook={`ai-settings-provider-${provider}`} - > - - {renderProviderIcon(provider, isActive)} - - {getProviderName(provider, aiAssistantSettings)} - - - - - - {validatedApiKeys[provider] ? "Enabled" : "Inactive"} - - - - ) - })} - - - - - - + + + + + + + + + + + Assistant Settings + + Modify settings for your AI assistant, set up new + providers, and review access. + + + + - API Key - {validatedApiKeys[selectedProvider] && ( - } - data-hook="ai-settings-validated-badge" - > - Validated - - )} - - Get your API key from{" "} - - {getProviderName( - selectedProvider, - aiAssistantSettings, - )} - - . - - - - { - handleApiKeyChange(selectedProvider, e.target.value) - }} - placeholder={`Enter ${getProviderName(selectedProvider, aiAssistantSettings)} API key`} - $hasError={!!currentProviderError} - $showEditButton={maskInput} - readOnly={maskInput} - onFocus={() => { - setIsInputFocused((prev) => ({ - ...prev, - [selectedProvider]: true, - })) - }} - onBlur={() => { - setIsInputFocused((prev) => ({ - ...prev, - [selectedProvider]: false, - })) - if (inputRef.current) { - inputRef.current.blur() - } - }} - onMouseDown={(e) => { - if (maskInput) { - e.preventDefault() - } - }} - tabIndex={maskInput ? -1 : 0} - style={{ - cursor: maskInput ? "default" : "text", - }} - data-hook="ai-settings-api-key" + - {maskInput && ( - { - inputRef.current?.focus() - }} - title="Edit API key" - > - - - )} - - {currentProviderError && ( - {currentProviderError} - )} - {!currentProviderError && ( - - Stored locally in your browser and never sent to QuestDB - servers. This API key is used to authenticate your - requests to the model provider. - - )} - - currentProviderValidated - ? handleRemoveApiKey(selectedProvider) - : handleValidateApiKey(selectedProvider) - } - disabled={ - currentProviderValidationState === "validating" || - (!currentProviderValidated && !currentProviderApiKey) - } - data-hook="ai-settings-test-api" - > - {currentProviderValidationState === "validating" ? ( - - - Validating... - - ) : currentProviderValidated ? ( - "Remove API Key" + + + + + + + + {allProviders.map((provider) => { + const isActive = selectedProvider === provider + return ( + handleProviderSelect(provider)} + data-hook={`ai-settings-provider-${provider}`} + > + + {renderProviderIcon(provider, isActive)} + + {getProviderName(provider, localSettings)} + + + + + + {validatedApiKeys[provider] + ? "Enabled" + : "Inactive"} + + + + ) + })} + { + setCustomProviderModalOpen(true) + }} + > + Add custom provider + + + + + + + {isCustomProvider && !currentProviderApiKey ? ( + <> + API Key + + This provider does not have an API key. + + ) : ( - "Validate API Key" - )} - - - - - - Enable Models - {currentProviderValidated ? ( - - {modelsForProvider.map((model) => { - const isEnabled = enabledModelsForProvider.includes( - model.value, - ) - return ( - - - {model.label} - {model.isSlow && ( - - - - Due to advanced reasoning & thinking - capabilities, responses using this model - can be slow. - - - )} - - - handleModelToggle( + <> + + API Key + {validatedApiKeys[selectedProvider] && ( + } + data-hook="ai-settings-validated-badge" + > + Validated + + )} + {!isCustomProvider && ( + + Get your API key from{" "} + + {getProviderName( selectedProvider, - model.value, - ) + localSettings, + )} + + . + + )} + + + { + handleApiKeyChange( + selectedProvider, + e.target.value, + ) + }} + placeholder={`Enter ${getProviderName(selectedProvider, localSettings)} API key`} + $hasError={!!currentProviderError} + $showEditButton={maskInput} + readOnly={maskInput} + onFocus={() => { + setIsInputFocused((prev) => ({ + ...prev, + [selectedProvider]: true, + })) + }} + onBlur={() => { + setIsInputFocused((prev) => ({ + ...prev, + [selectedProvider]: false, + })) + if (inputRef.current) { + inputRef.current.blur() + } + }} + onMouseDown={(e) => { + if (maskInput) { + e.preventDefault() } - /> - - ) - })} - - ) : ( - - - When you've entered and validated your API key, - you'll be able to select and enable available - models. - - - )} - - - - - - Schema Access - - - - - - handleSchemaAccessChange( - selectedProvider, - e.target.checked, - ) - } - disabled={!currentProviderValidated} - data-hook="ai-settings-schema-access" - /> - - - - Grant schema access to{" "} - {getProviderName( - selectedProvider, - aiAssistantSettings, + }} + tabIndex={maskInput ? -1 : 0} + style={{ + cursor: maskInput ? "default" : "text", + }} + data-hook="ai-settings-api-key" + /> + {maskInput && ( + { + inputRef.current?.focus() + }} + title="Edit API key" + > + + )} - - - When enabled, the AI assistant can access your - database schema information to provide more accurate - suggestions and explanations. Schema information - helps the AI understand your table structures, - column names, and relationships.{" "} - - The AI model will not have access to your data. - - - - - - - - - - - - - - Cancel - - - Save Settings - - - - - - - + + {currentProviderError && ( + {currentProviderError} + )} + {!currentProviderError && ( + + Stored locally in your browser and never sent to + QuestDB servers. This API key is used to + authenticate your requests to the model provider. + + )} + {!currentProviderValidated && + currentProviderApiKey && ( + + handleValidateApiKey(selectedProvider) + } + disabled={ + currentProviderValidationState === + "validating" + } + data-hook="ai-settings-test-api" + > + {currentProviderValidationState === + "validating" ? ( + + + Validating... + + ) : ( + "Validate API Key" + )} + + )} + + )} + + + + + Enable Models + {currentProviderValidated ? ( + + {modelsForProvider.map((model) => { + const isEnabled = enabledModelsForProvider.includes( + model.value, + ) + return ( + + + {model.label} + {model.isSlow && ( + + + + Due to advanced reasoning & thinking + capabilities, responses using this model + can be slow. + + + )} + + + handleModelToggle( + selectedProvider, + model.value, + ) + } + /> + + ) + })} + + ) : ( + + + When you've entered and validated your API key, + you'll be able to select and enable available + models. + + + )} + + + + + + Schema Access + + + + + + handleSchemaAccessChange( + selectedProvider, + e.target.checked, + ) + } + disabled={!currentProviderValidated} + data-hook="ai-settings-schema-access" + /> + + + + Grant schema access to{" "} + {getProviderName(selectedProvider, localSettings)} + + + When enabled, the AI assistant can access your + database schema information to provide more + accurate suggestions and explanations. Schema + information helps the AI understand your table + structures, column names, and relationships.{" "} + + The AI model will not have access to your data. + + + + + + + + + } + type="button" + onClick={() => handleRemoveProvider(selectedProvider)} + > + {isCustomProvider ? "Remove Provider" : "Reset Provider"} + + + + + + + + + Cancel + + + Save Settings + + + + + + + + {customProviderModalOpen && ( + + )} + ) } diff --git a/src/utils/ai/anthropicProvider.ts b/src/utils/ai/anthropicProvider.ts index fc74e53c0..b4d3faf1d 100644 --- a/src/utils/ai/anthropicProvider.ts +++ b/src/utils/ai/anthropicProvider.ts @@ -558,7 +558,7 @@ export function createAnthropicProvider( for await (const model of anthropic.models.list()) { models.push(model.id) } - return models + return models.sort((a, b) => a.localeCompare(b)) }, classifyError( diff --git a/src/utils/ai/openaiChatCompletionsProvider.ts b/src/utils/ai/openaiChatCompletionsProvider.ts index 1a8380ac8..554facc8e 100644 --- a/src/utils/ai/openaiChatCompletionsProvider.ts +++ b/src/utils/ai/openaiChatCompletionsProvider.ts @@ -531,7 +531,7 @@ export function createOpenAIChatCompletionsProvider( for await (const model of openai.models.list()) { models.push(model.id) } - return models + return models.sort((a, b) => a.localeCompare(b)) }, classifyError( diff --git a/src/utils/ai/openaiProvider.ts b/src/utils/ai/openaiProvider.ts index 554ab0dfa..c8849f4a6 100644 --- a/src/utils/ai/openaiProvider.ts +++ b/src/utils/ai/openaiProvider.ts @@ -470,7 +470,7 @@ export function createOpenAIProvider( for await (const model of openai.models.list()) { models.push(model.id) } - return models + return models.sort((a, b) => a.localeCompare(b)) }, classifyError( diff --git a/src/utils/ai/settings.ts b/src/utils/ai/settings.ts index e18771a8f..513737689 100644 --- a/src/utils/ai/settings.ts +++ b/src/utils/ai/settings.ts @@ -355,7 +355,8 @@ export const getApiKey = ( const builtinKey = settings.providers?.[providerId]?.apiKey if (builtinKey) return builtinKey const custom = settings.customProviders?.[providerId] - return custom?.apiKey || null + if (custom) return custom.apiKey || "" + return null } export const hasSchemaAccess = (settings: AiAssistantSettings): boolean => { diff --git a/src/utils/aiAssistant.ts b/src/utils/aiAssistant.ts index 255232209..514cf169b 100644 --- a/src/utils/aiAssistant.ts +++ b/src/utils/aiAssistant.ts @@ -18,6 +18,7 @@ import { ALL_TOOLS, REFERENCE_TOOLS, getUnifiedPrompt, + BUILTIN_PROVIDERS, } from "./ai" import type { AIProvider } from "./ai" @@ -320,7 +321,8 @@ export const generateChatTitle = async ({ firstUserMessage: string settings: ActiveProviderSettings }): Promise => { - if (!settings.apiKey || !settings.model) { + const isCustom = !BUILTIN_PROVIDERS[settings.provider] + if ((!isCustom && !settings.apiKey) || !settings.model) { return null } @@ -382,7 +384,8 @@ export const continueConversation = async ({ compactedConversationHistory?: Array } > => { - if (!settings.apiKey || !settings.model) { + const isCustom = !BUILTIN_PROVIDERS[settings.provider] + if ((!isCustom && !settings.apiKey) || !settings.model) { return { type: "invalid_key", message: "API key or model is missing", From 86ee3055666e85d2733e9a5396285e76d628d20a Mon Sep 17 00:00:00 2001 From: emrberk Date: Mon, 9 Mar 2026 03:47:04 +0300 Subject: [PATCH 14/25] submodule --- e2e/questdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/e2e/questdb b/e2e/questdb index 2263b2adb..834d0fecb 160000 --- a/e2e/questdb +++ b/e2e/questdb @@ -1 +1 @@ -Subproject commit 2263b2adb482cf7b8290306a7f736381670f76cc +Subproject commit 834d0fecbefe8e72a05c488ca75f8976cd4cc0f9 From 29c9ff4568ed4d8d291c3ce00df8755cd97e2606 Mon Sep 17 00:00:00 2001 From: emrberk Date: Mon, 9 Mar 2026 17:14:54 +0300 Subject: [PATCH 15/25] custom provider configuration in configuration modal, improve error handling, fix context window input --- .../SetupAIAssistant/ConfigurationModal.tsx | 282 ++++++++++-------- .../SetupAIAssistant/CustomProviderModal.tsx | 26 +- .../SetupAIAssistant/SettingsModal.tsx | 31 +- src/utils/ai/anthropicProvider.ts | 27 +- src/utils/ai/openaiChatCompletionsProvider.ts | 46 ++- src/utils/ai/openaiProvider.ts | 46 ++- src/utils/executeAIFlow.ts | 12 +- 7 files changed, 304 insertions(+), 166 deletions(-) diff --git a/src/components/SetupAIAssistant/ConfigurationModal.tsx b/src/components/SetupAIAssistant/ConfigurationModal.tsx index 767deded3..92a75ff0e 100644 --- a/src/components/SetupAIAssistant/ConfigurationModal.tsx +++ b/src/components/SetupAIAssistant/ConfigurationModal.tsx @@ -1,4 +1,4 @@ -import React, { useState, useMemo, useCallback } from "react" +import React, { useState, useMemo, useCallback, useEffect } from "react" import styled, { css } from "styled-components" import { Dialog } from "../Dialog" import { MultiStepModal, Step } from "../MultiStepModal" @@ -10,9 +10,12 @@ import { Text } from "../Text" import { useLocalStorage } from "../../providers/LocalStorageProvider" import { testApiKey } from "../../utils/aiAssistant" import { StoreKey } from "../../utils/localStorage/types" +import type { CustomProviderDefinition } from "../../providers/LocalStorageProvider/types" import { toast } from "../Toast" import { getAllModelOptions, + getAllProviders, + makeCustomModelValue, type ModelOption, type ProviderId, getProviderName, @@ -21,8 +24,9 @@ import { useModalNavigation } from "../MultiStepModal" import { OpenAIIcon } from "./OpenAIIcon" import { AnthropicIcon } from "./AnthropicIcon" import { BrainIcon } from "./BrainIcon" -import { Plugs as PlugsIcon } from "@phosphor-icons/react" +import { PlusIcon, Plugs as PlugsIcon } from "@phosphor-icons/react" import { theme } from "../../theme" +import { CustomProviderModal } from "./CustomProviderModal" const ModalContent = styled.div` display: flex; @@ -115,17 +119,12 @@ const SectionDescription = styled(Text)` color: ${({ theme }) => theme.color.gray2}; ` -const ProviderSelectionContainer = styled(Box).attrs({ - gap: "4rem", - align: "center", -})` - width: 100%; -` - const ProviderCardsContainer = styled(Box).attrs({ gap: "2rem", + alignItems: "flex-start", })` height: 8.5rem; + width: 100%; ` const ProviderCard = styled.button<{ $selected: boolean }>` @@ -167,34 +166,6 @@ const ProviderName = styled(Text)` text-align: center; ` -const ComingSoonContainer = styled(Box).attrs({ - flexDirection: "column", - gap: "0.6rem", - align: "flex-start", -})` - width: 13.2rem; -` - -const ComingSoonIcons = styled(Box).attrs({ - align: "center", -})` - width: 100%; - padding-left: 0; - padding-right: 1.2rem; -` - -const ComingSoonIcon = styled.img` - width: 100%; - height: auto; - object-fit: contain; -` - -const ComingSoonText = styled(Text)` - font-size: 1.3rem; - font-weight: 300; - color: ${({ theme }) => theme.color.gray2}; -` - const InputSection = styled(Box).attrs({ flexDirection: "column", gap: "1.2rem", @@ -408,6 +379,33 @@ const WarningText = styled(Text)` text-align: left; ` +const AddCustomProviderCard = styled.button` + background: transparent; + border: 0.1rem dashed ${({ theme }) => theme.color.gray2}; + border-radius: 0.8rem; + cursor: pointer; + display: flex; + flex-direction: column; + align-items: center; + justify-content: center; + gap: 0.6rem; + padding: 1.2rem 2rem; + width: 10rem; + height: 8.5rem; + transition: all 0.2s; + color: ${({ theme }) => theme.color.gray2}; + + &:hover { + border-color: ${({ theme }) => theme.color.foreground}; + color: ${({ theme }) => theme.color.foreground}; + } + + &:focus-visible { + outline: 0.2rem solid ${({ theme }) => theme.color.foreground}; + outline-offset: 0.2rem; + } +` + type ConfigurationModalProps = { open?: boolean onOpenChange?: (open: boolean) => void @@ -420,6 +418,7 @@ type StepOneContentProps = { providerName: string onProviderSelect: (provider: ProviderId) => void onApiKeyChange: (value: string) => void + onAddCustomProvider: () => void } type StepTwoContentProps = { @@ -460,6 +459,7 @@ const StepOneContent = ({ providerName, onProviderSelect, onApiKeyChange, + onAddCustomProvider, }: StepOneContentProps) => { const navigation = useModalNavigation() const handleClose: () => void = navigation.handleClose @@ -484,74 +484,73 @@ const StepOneContent = ({ Select Provider - We currently only support two model providers, with support for - more coming soon. + Choose a built-in provider or add your own custom provider. + You'll be able to configure and switch between multiple + providers later. - - - onProviderSelect("openai")} - type="button" - data-hook="ai-settings-provider-openai" - > - - {getProviderName("openai")} - - onProviderSelect("anthropic")} - type="button" - data-hook="ai-settings-provider-anthropic" - > - - {getProviderName("anthropic")} - - - - - - - Coming soon... - - + + onProviderSelect("openai")} + type="button" + data-hook="ai-settings-provider-openai" + > + + {getProviderName("openai")} + + onProviderSelect("anthropic")} + type="button" + data-hook="ai-settings-provider-anthropic" + > + + {getProviderName("anthropic")} + + + + Custom + + - - - - API Key - onApiKeyChange(e.target.value)} - placeholder={`Enter${providerName ? ` ${providerName}` : ""} API key`} - $hasError={!!error} - disabled={!selectedProvider} - data-hook="ai-settings-api-key" - /> - {error && ( - {error} - )} - - Stored locally in your browser and never sent to QuestDB servers. - This API key is used to authenticate your requests to the model - provider. - - - + {selectedProvider && ( + <> + + + + API Key + onApiKeyChange(e.target.value)} + placeholder={`Enter${providerName ? ` ${providerName}` : ""} API key`} + $hasError={!!error} + data-hook="ai-settings-api-key" + /> + {error && ( + + {error} + + )} + + Stored locally in your browser and never sent to QuestDB + servers. This API key is used to authenticate your requests to + the model provider. + + + + + )} ) } @@ -706,6 +705,13 @@ export const ConfigurationModal = ({ ) const [apiKey, setApiKey] = useState("") const [error, setError] = useState(null) + const [customProviderModalOpen, setCustomProviderModalOpen] = useState(false) + + useEffect(() => { + if (!open) { + setCustomProviderModalOpen(false) + } + }, [open]) const [enabledModels, setEnabledModels] = useState([]) const [grantSchemaAccess, setGrantSchemaAccess] = useState(true) @@ -850,6 +856,37 @@ export const ConfigurationModal = ({ setGrantSchemaAccess(true) }, []) + const handleCustomProviderSave = useCallback( + (providerId: string, definition: CustomProviderDefinition) => { + const newEnabledModels = definition.models.map((m) => + makeCustomModelValue(providerId, m), + ) + + const newSettings = { + ...aiAssistantSettings, + selectedModel: newEnabledModels[0], + customProviders: { + ...(aiAssistantSettings.customProviders ?? {}), + [providerId]: definition, + }, + providers: { + ...aiAssistantSettings.providers, + [providerId]: { + apiKey: definition.apiKey ?? "", + enabledModels: newEnabledModels, + grantSchemaAccess: definition.grantSchemaAccess ?? false, + }, + }, + } + + updateSettings(StoreKey.AI_ASSISTANT_SETTINGS, newSettings) + toast.success("AI Assistant activated successfully") + setCustomProviderModalOpen(false) + onOpenChange?.(false) + }, + [aiAssistantSettings, updateSettings, onOpenChange], + ) + const steps: Step[] = useMemo( () => [ { @@ -864,6 +901,7 @@ export const ConfigurationModal = ({ providerName={providerName} onProviderSelect={handleProviderSelect} onApiKeyChange={handleApiKeyChange} + onAddCustomProvider={() => setCustomProviderModalOpen(true)} /> ), validate: validateStepOne, @@ -903,21 +941,31 @@ export const ConfigurationModal = ({ ) return ( - { - if (!isOpen) { - handleModalClose() - } - onOpenChange?.(isOpen) - }} - onStepChange={handleStepChange} - steps={steps} - maxWidth="64rem" - onComplete={handleComplete} - canProceed={canProceed} - completeButtonText="Activate Assistant" - showValidationError={false} - /> + <> + { + if (!isOpen) { + handleModalClose() + } + onOpenChange?.(isOpen) + }} + onStepChange={handleStepChange} + steps={steps} + maxWidth="64rem" + onComplete={handleComplete} + canProceed={canProceed} + completeButtonText="Activate Assistant" + showValidationError={false} + /> + {customProviderModalOpen && ( + + )} + ) } diff --git a/src/components/SetupAIAssistant/CustomProviderModal.tsx b/src/components/SetupAIAssistant/CustomProviderModal.tsx index b63e4547c..0ca04e264 100644 --- a/src/components/SetupAIAssistant/CustomProviderModal.tsx +++ b/src/components/SetupAIAssistant/CustomProviderModal.tsx @@ -455,7 +455,7 @@ const StepOneContent = ({ Provider Type onProviderTypeChange(e.target.value as ProviderType) } @@ -637,10 +637,10 @@ const StepTwoAutoContent = ({ type="number" value={contextWindow} onChange={(e) => onContextWindowChange(Number(e.target.value))} - min={1} /> - Maximum number of tokens the model can process. + Maximum number of tokens the model can process. AI assistant + requires a minimum of 100,000 tokens. @@ -754,10 +754,10 @@ const StepTwoManualContent = ({ type="number" value={contextWindow} onChange={(e) => onContextWindowChange(Number(e.target.value))} - min={1} /> - Maximum number of tokens the model can process. + Maximum number of tokens the model can process. AI assistant + requires a minimum of 100,000 tokens. @@ -843,16 +843,14 @@ export const CustomProviderModal = ({ if (!trimmed) return if (flowPath === "auto") { - setSelectedModels((prev) => - prev.includes(trimmed) ? prev : [...prev, trimmed], - ) + if (selectedModels.includes(trimmed)) return + setSelectedModels((prev) => [...prev, trimmed]) } else { - setManualModels((prev) => - prev.includes(trimmed) ? prev : [...prev, trimmed], - ) + if (manualModels.includes(trimmed)) return + setManualModels((prev) => [...prev, trimmed]) } setManualModelInput("") - }, [manualModelInput, flowPath]) + }, [manualModelInput, flowPath, selectedModels, manualModels]) const handleRemoveManualModel = useCallback((model: string) => { setManualModels((prev) => prev.filter((m) => m !== model)) @@ -918,8 +916,10 @@ export const CustomProviderModal = ({ if (manualModels.length === 0 && !manualModelInput.trim()) return "Add at least one model" } + if (!contextWindow || contextWindow < 100_000) + return "Context window must be at least 100,000 tokens" return true - }, [flowPath, selectedModels, manualModels, manualModelInput]) + }, [flowPath, selectedModels, manualModels, manualModelInput, contextWindow]) const handleComplete = useCallback(() => { const providerId = generateProviderId(name) diff --git a/src/components/SetupAIAssistant/SettingsModal.tsx b/src/components/SetupAIAssistant/SettingsModal.tsx index 6c99c5453..dd74f513d 100644 --- a/src/components/SetupAIAssistant/SettingsModal.tsx +++ b/src/components/SetupAIAssistant/SettingsModal.tsx @@ -857,6 +857,10 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { const handleCustomProviderSave = useCallback( (providerId: string, definition: CustomProviderDefinition) => { + const newEnabledModels = definition.models.map((m) => + makeCustomModelValue(providerId, m), + ) + setLocalCustomProviders((prev) => ({ ...prev, [providerId]: definition, @@ -875,15 +879,31 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { })) setEnabledModels((prev) => ({ ...prev, - [providerId]: definition.models.map((m) => - makeCustomModelValue(providerId, m), - ), + [providerId]: newEnabledModels, })) + const updatedCustomProviders = { + ...(aiAssistantSettings.customProviders ?? {}), + [providerId]: definition, + } + const updatedProviders = { + ...aiAssistantSettings.providers, + [providerId]: { + apiKey: definition.apiKey ?? "", + enabledModels: newEnabledModels, + grantSchemaAccess: definition.grantSchemaAccess ?? false, + }, + } + updateSettings(StoreKey.AI_ASSISTANT_SETTINGS, { + ...aiAssistantSettings, + customProviders: updatedCustomProviders, + providers: updatedProviders, + }) + setSelectedProvider(providerId) setCustomProviderModalOpen(false) }, - [], + [aiAssistantSettings, updateSettings], ) const currentProviderValidated = validatedApiKeys[selectedProvider] @@ -1005,7 +1025,8 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { - {isCustomProvider && !currentProviderApiKey ? ( + {isCustomProvider && + !localCustomProviders[selectedProvider]?.apiKey ? ( <> API Key diff --git a/src/utils/ai/anthropicProvider.ts b/src/utils/ai/anthropicProvider.ts index b4d3faf1d..e8f59d8ba 100644 --- a/src/utils/ai/anthropicProvider.ts +++ b/src/utils/ai/anthropicProvider.ts @@ -121,7 +121,7 @@ async function createAnthropicMessageStreaming( if (errorType === "overloaded_error") { throw new StreamingError( "Service is temporarily overloaded. Please try again.", - "failed", + "network", event, ) } @@ -148,6 +148,9 @@ async function createAnthropicMessageStreaming( if (abortSignal?.aborted) { throw new StreamingError("Operation aborted", "interrupted") } + if (error instanceof Anthropic.APIError) { + throw error + } throw new StreamingError( error instanceof Error ? error.message : "Stream interrupted", "network", @@ -162,6 +165,9 @@ async function createAnthropicMessageStreaming( if (abortSignal?.aborted || error instanceof Anthropic.APIUserAbortError) { throw new StreamingError("Operation aborted", "interrupted") } + if (error instanceof Anthropic.APIError) { + throw error + } throw new StreamingError( "Failed to get final message from the provider", "network", @@ -640,14 +646,15 @@ export function createAnthropicProvider( if (error instanceof Anthropic.APIError) { return { type: "unknown", - message: `Anthropic API error: ${error.message}`, + message: error.message, + details: `Status ${error.status}`, } } return { type: "unknown", message: "An unexpected error occurred. Please try again.", - details: error as string, + details: error instanceof Error ? error.message : String(error), } }, @@ -655,14 +662,18 @@ export function createAnthropicProvider( if (error instanceof StreamingError) { return error.errorType === "interrupted" || error.errorType === "failed" } + if ( + error instanceof Anthropic.APIError && + error.status != null && + error.status >= 400 && + error.status < 500 && + error.status !== 429 + ) { + return true + } return ( error instanceof RefusalError || error instanceof MaxTokensError || - error instanceof Anthropic.AuthenticationError || - (error != null && - typeof error === "object" && - "status" in error && - error.status === 429) || error instanceof Anthropic.APIUserAbortError ) }, diff --git a/src/utils/ai/openaiChatCompletionsProvider.ts b/src/utils/ai/openaiChatCompletionsProvider.ts index 554facc8e..6512380a1 100644 --- a/src/utils/ai/openaiChatCompletionsProvider.ts +++ b/src/utils/ai/openaiChatCompletionsProvider.ts @@ -147,6 +147,9 @@ async function createChatCompletionStreaming( if (abortSignal?.aborted || error instanceof OpenAI.APIUserAbortError) { throw new StreamingError("Operation aborted", "interrupted") } + if (error instanceof OpenAI.APIError) { + throw error + } throw new StreamingError( error instanceof Error ? error.message : "Stream interrupted", "network", @@ -586,17 +589,42 @@ export function createOpenAIChatCompletionsProvider( } } + if (error instanceof OpenAI.AuthenticationError) { + return { + type: "invalid_key", + message: "Invalid API key. Please check your OpenAI API key.", + details: error.message, + } + } + + if (error instanceof OpenAI.RateLimitError) { + return { + type: "rate_limit", + message: "Rate limit exceeded. Please try again later.", + details: error.message, + } + } + + if (error instanceof OpenAI.APIConnectionError) { + return { + type: "network", + message: "Network error. Please check your internet connection.", + details: error.message, + } + } + if (error instanceof OpenAI.APIError) { return { type: "unknown", - message: `OpenAI API error: ${error.message}`, + message: error.message, + details: `Status ${error.status}`, } } return { type: "unknown", message: "An unexpected error occurred. Please try again.", - details: error as string, + details: error instanceof Error ? error.message : String(error), } }, @@ -604,14 +632,18 @@ export function createOpenAIChatCompletionsProvider( if (error instanceof StreamingError) { return error.errorType === "interrupted" || error.errorType === "failed" } + if ( + error instanceof OpenAI.APIError && + error.status != null && + error.status >= 400 && + error.status < 500 && + error.status !== 429 + ) { + return true + } return ( error instanceof RefusalError || error instanceof MaxTokensError || - error instanceof OpenAI.AuthenticationError || - (error != null && - typeof error === "object" && - "status" in error && - error.status === 429) || error instanceof OpenAI.APIUserAbortError ) }, diff --git a/src/utils/ai/openaiProvider.ts b/src/utils/ai/openaiProvider.ts index c8849f4a6..2f4c2b915 100644 --- a/src/utils/ai/openaiProvider.ts +++ b/src/utils/ai/openaiProvider.ts @@ -115,6 +115,9 @@ async function createOpenAIResponseStreaming( if (abortSignal?.aborted || error instanceof OpenAI.APIUserAbortError) { throw new StreamingError("Operation aborted", "interrupted") } + if (error instanceof OpenAI.APIError) { + throw error + } throw new StreamingError( error instanceof Error ? error.message : "Stream interrupted", "network", @@ -525,17 +528,42 @@ export function createOpenAIProvider( } } + if (error instanceof OpenAI.AuthenticationError) { + return { + type: "invalid_key", + message: "Invalid API key. Please check your OpenAI API key.", + details: error.message, + } + } + + if (error instanceof OpenAI.RateLimitError) { + return { + type: "rate_limit", + message: "Rate limit exceeded. Please try again later.", + details: error.message, + } + } + + if (error instanceof OpenAI.APIConnectionError) { + return { + type: "network", + message: "Network error. Please check your internet connection.", + details: error.message, + } + } + if (error instanceof OpenAI.APIError) { return { type: "unknown", - message: `OpenAI API error: ${error.message}`, + message: error.message, + details: `Status ${error.status}`, } } return { type: "unknown", message: "An unexpected error occurred. Please try again.", - details: error as string, + details: error instanceof Error ? error.message : String(error), } }, @@ -543,14 +571,18 @@ export function createOpenAIProvider( if (error instanceof StreamingError) { return error.errorType === "interrupted" || error.errorType === "failed" } + if ( + error instanceof OpenAI.APIError && + error.status != null && + error.status >= 400 && + error.status < 500 && + error.status !== 429 + ) { + return true + } return ( error instanceof RefusalError || error instanceof MaxTokensError || - error instanceof OpenAI.AuthenticationError || - (error != null && - typeof error === "object" && - "status" in error && - error.status === 429) || error instanceof OpenAI.APIUserAbortError ) }, diff --git a/src/utils/executeAIFlow.ts b/src/utils/executeAIFlow.ts index 4d82bc4aa..677a418ea 100644 --- a/src/utils/executeAIFlow.ts +++ b/src/utils/executeAIFlow.ts @@ -212,16 +212,10 @@ function buildUserMessage(config: AIFlowConfig): AIFlowUserMessage { } function formatErrorMessage(error: AiAssistantAPIError): string { - switch (error.type) { - case "aborted": - return "Operation has been cancelled" - case "network": - return "Connection interrupted. Please check your network and try again." - case "rate_limit": - return "Rate limit reached. Please wait a moment and try again." - default: - return error.message || "An unexpected error occurred" + if (error.type === "aborted") { + return "Operation has been cancelled" } + return error.message || "An unexpected error occurred" } type ProcessResultConfig = { From e221f0511e8a6cb8ddc59cb2236d0ec0918fc9bc Mon Sep 17 00:00:00 2001 From: emrberk Date: Tue, 10 Mar 2026 00:30:11 +0300 Subject: [PATCH 16/25] dont require fixed test model for custom provider --- src/providers/LocalStorageProvider/types.ts | 1 - src/utils/ai/settings.ts | 6 ++---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/providers/LocalStorageProvider/types.ts b/src/providers/LocalStorageProvider/types.ts index 6e79ac040..48e55986d 100644 --- a/src/providers/LocalStorageProvider/types.ts +++ b/src/providers/LocalStorageProvider/types.ts @@ -10,7 +10,6 @@ export type CustomProviderDefinition = { baseURL: string apiKey?: string contextWindow: number - testModel?: string models: string[] grantSchemaAccess?: boolean } diff --git a/src/utils/ai/settings.ts b/src/utils/ai/settings.ts index 513737689..55fc42bb3 100644 --- a/src/utils/ai/settings.ts +++ b/src/utils/ai/settings.ts @@ -289,10 +289,8 @@ export const getTestModel = ( providerId: ProviderId, settings?: AiAssistantSettings, ): string | null => { - const custom = settings?.customProviders?.[providerId] - if (custom) { - const rawModel = custom.testModel ?? custom.models[0] ?? null - return rawModel ? makeCustomModelValue(providerId, rawModel) : null + if (settings?.customProviders?.[providerId]) { + return settings.selectedModel ?? null } return ( MODEL_OPTIONS.find((m) => m.provider === providerId && m.isTestModel) From 2670cc3da4e9c9e94dc52a643a83f3c8efe2bd06 Mon Sep 17 00:00:00 2001 From: emrberk Date: Wed, 11 Mar 2026 15:26:42 +0300 Subject: [PATCH 17/25] flexible response handling, fixes --- package.json | 1 + .../SetupAIAssistant/CustomProviderModal.tsx | 27 +- .../SetupAIAssistant/ModelDropdown.tsx | 2 + .../SetupAIAssistant/SettingsModal.tsx | 4 + src/utils/ai/anthropicProvider.ts | 65 +- src/utils/ai/index.ts | 2 + src/utils/ai/openaiChatCompletionsProvider.ts | 64 +- src/utils/ai/openaiProvider.ts | 69 +- src/utils/ai/registry.ts | 2 + src/utils/ai/shared.test.ts | 969 ++++++++++++++++++ src/utils/ai/shared.ts | 107 +- src/utils/executeAIFlow.ts | 2 +- yarn.lock | 10 + 13 files changed, 1275 insertions(+), 49 deletions(-) create mode 100644 src/utils/ai/shared.test.ts diff --git a/package.json b/package.json index a62d31fa7..6e2187973 100644 --- a/package.json +++ b/package.json @@ -79,6 +79,7 @@ "js-base64": "^3.7.7", "js-sha256": "^0.11.0", "js-tiktoken": "^1.0.21", + "jsonrepair": "^3.13.3", "lodash.isequal": "^4.5.0", "lodash.merge": "^4.6.2", "monaco-editor": "^0.52.2", diff --git a/src/components/SetupAIAssistant/CustomProviderModal.tsx b/src/components/SetupAIAssistant/CustomProviderModal.tsx index 0ca04e264..832ffa5e8 100644 --- a/src/components/SetupAIAssistant/CustomProviderModal.tsx +++ b/src/components/SetupAIAssistant/CustomProviderModal.tsx @@ -448,7 +448,7 @@ const StepOneContent = ({ type="text" value={name} onChange={(e) => onNameChange(e.target.value)} - placeholder="e.g., My Ollama, Azure GPT" + placeholder="e.g., OpenRouter, Ollama" /> @@ -868,24 +868,7 @@ export const CustomProviderModal = ({ if (existingProviderIds.includes(providerId)) return `A provider with a similar name already exists` - // First, check that the URL is reachable with a simple fetch - try { - abortControllerRef.current?.abort() - abortControllerRef.current = new AbortController() - - const normalizedURL = baseURL.replace(/\/+$/, "") - await fetch(normalizedURL, { - method: "GET", - signal: abortControllerRef.current.signal, - }) - } catch (err) { - if (err instanceof DOMException && err.name === "AbortError") { - return "Connection check was cancelled" - } - return `Could not connect to ${baseURL}. Please check the URL and make sure the server is running.` - } - - // URL is reachable — try to fetch models + // Try to fetch models, fallback to manual entry try { const tempProvider = createProviderByType( providerType, @@ -901,9 +884,9 @@ export const CustomProviderModal = ({ setFetchedModels(null) setFlowPath("manual") } - } catch (err) { - const message = err instanceof Error ? err.message : "Unknown error" - return `Could not connect to provider: ${message}` + } catch { + setFetchedModels(null) + setFlowPath("manual") } return true diff --git a/src/components/SetupAIAssistant/ModelDropdown.tsx b/src/components/SetupAIAssistant/ModelDropdown.tsx index 3b1f77148..3487fb689 100644 --- a/src/components/SetupAIAssistant/ModelDropdown.tsx +++ b/src/components/SetupAIAssistant/ModelDropdown.tsx @@ -98,6 +98,8 @@ const DropdownContent = styled.div` min-width: 22.8rem; gap: 0.4rem; z-index: 9999; + max-height: 50vh; + overflow-y: auto; ` const Title = styled(Text)` diff --git a/src/components/SetupAIAssistant/SettingsModal.tsx b/src/components/SetupAIAssistant/SettingsModal.tsx index dd74f513d..af59f672c 100644 --- a/src/components/SetupAIAssistant/SettingsModal.tsx +++ b/src/components/SetupAIAssistant/SettingsModal.tsx @@ -168,6 +168,10 @@ const ProviderTabTitle = styled(Box).attrs({ align: "center", })` width: 100%; + + svg { + flex-shrink: 0; + } ` const ProviderTabName = styled(Text)<{ $active: boolean }>` diff --git a/src/utils/ai/anthropicProvider.ts b/src/utils/ai/anthropicProvider.ts index e8f59d8ba..571982347 100644 --- a/src/utils/ai/anthropicProvider.ts +++ b/src/utils/ai/anthropicProvider.ts @@ -24,6 +24,8 @@ import { MaxTokensError, extractPartialExplanation, executeTool, + parseCustomProviderResponse, + responseFormatToPromptInstruction, } from "./shared" function toAnthropicTools(tools: ToolDefinition[]): AnthropicTool[] { @@ -202,7 +204,7 @@ async function handleToolCalls( model: string, systemPrompt: string, setStatus: StatusCallback, - outputConfig: OutputConfig, + outputConfig: OutputConfig | undefined, tools: AnthropicTool[], contextWindow: number, abortSignal?: AbortSignal, @@ -267,7 +269,7 @@ async function handleToolCalls( tools, messages: updatedHistory, temperature: 0.3, - output_config: outputConfig, + ...(outputConfig ? { output_config: outputConfig } : {}), } const followUpMessage = streaming @@ -315,7 +317,7 @@ async function handleToolCalls( export function createAnthropicProvider( apiKey: string, providerId: ProviderId = "anthropic", - options?: { baseURL?: string; contextWindow?: number }, + options?: { baseURL?: string; contextWindow?: number; isCustom?: boolean }, ): AIProvider { const anthropic = new Anthropic({ apiKey, @@ -324,6 +326,7 @@ export function createAnthropicProvider( }) const contextWindow = options?.contextWindow ?? 200_000 + const isCustom = options?.isCustom ?? false return { id: providerId, @@ -365,17 +368,24 @@ export function createAnthropicProvider( }) const anthropicTools = toAnthropicTools(tools) - const outputConfig = toAnthropicOutputConfig(config.responseFormat) + const outputConfig = isCustom + ? undefined + : toAnthropicOutputConfig(config.responseFormat) + + const systemPrompt = isCustom + ? config.systemInstructions + + responseFormatToPromptInstruction(config.responseFormat) + : config.systemInstructions const resolvedModel = toAnthropicModel(model) const messageParams: Parameters[1] = { model: resolvedModel, - system: config.systemInstructions, + system: systemPrompt, tools: anthropicTools, messages: initialMessages, temperature: 0.3, - output_config: outputConfig, + ...(outputConfig ? { output_config: outputConfig } : {}), } const message = streaming @@ -399,7 +409,7 @@ export function createAnthropicProvider( modelToolsClient, initialMessages, resolvedModel, - config.systemInstructions, + systemPrompt, setStatus, outputConfig, anthropicTools, @@ -439,6 +449,26 @@ export function createAnthropicProvider( } as AiAssistantAPIError } + if (isCustom) { + const json = parseCustomProviderResponse( + textBlock.text, + (config.responseFormat.schema.required as string[]) || [], + (raw) => ({ explanation: raw }) as unknown as T, + ) + setStatus(null) + + const tokenUsage = { + inputTokens: totalInputTokens, + outputTokens: totalOutputTokens, + } + + if (config.postProcess) { + const processed = config.postProcess(json) + return { ...processed, tokenUsage } as T & { tokenUsage: TokenUsage } + } + return { ...json, tokenUsage } as T & { tokenUsage: TokenUsage } + } + try { const json = JSON.parse(textBlock.text) as T setStatus(null) @@ -473,17 +503,34 @@ export function createAnthropicProvider( async generateTitle({ model, prompt, responseFormat }) { try { + const userContent = isCustom + ? prompt + responseFormatToPromptInstruction(responseFormat) + : prompt + + const titleOutputConfig = isCustom + ? undefined + : toAnthropicOutputConfig(responseFormat) + const messageParams: Parameters[1] = { model: toAnthropicModel(model), - messages: [{ role: "user", content: prompt }], + messages: [{ role: "user", content: userContent }], max_tokens: 100, temperature: 0.3, - output_config: toAnthropicOutputConfig(responseFormat), + ...(titleOutputConfig ? { output_config: titleOutputConfig } : {}), } const message = await createAnthropicMessage(anthropic, messageParams) const textBlock = message.content.find((block) => block.type === "text") if (textBlock && "text" in textBlock) { + if (isCustom) { + const parsed = parseCustomProviderResponse<{ title: string }>( + textBlock.text, + (responseFormat.schema.required as string[]) || [], + (raw) => ({ title: raw.trim().slice(0, 40) }), + ) + return parsed.title || null + } + const parsed = JSON.parse(textBlock.text) as { title: string } return parsed.title?.slice(0, 40) || null } diff --git a/src/utils/ai/index.ts b/src/utils/ai/index.ts index b4f41913a..a7eeb3d37 100644 --- a/src/utils/ai/index.ts +++ b/src/utils/ai/index.ts @@ -19,6 +19,8 @@ export { safeJsonParse, extractPartialExplanation, executeTool, + parseCustomProviderResponse, + responseFormatToPromptInstruction, } from "./shared" export { DOCS_INSTRUCTION, diff --git a/src/utils/ai/openaiChatCompletionsProvider.ts b/src/utils/ai/openaiChatCompletionsProvider.ts index 6512380a1..b613aa3fc 100644 --- a/src/utils/ai/openaiChatCompletionsProvider.ts +++ b/src/utils/ai/openaiChatCompletionsProvider.ts @@ -26,6 +26,8 @@ import { safeJsonParse, extractPartialExplanation, executeTool, + parseCustomProviderResponse, + responseFormatToPromptInstruction, } from "./shared" import type { Tiktoken, TiktokenBPE } from "js-tiktoken/lite" @@ -277,7 +279,7 @@ function toChatCompletionsAPIProps(model: string): { export function createOpenAIChatCompletionsProvider( apiKey: string, providerId: ProviderId = "openai", - options?: { baseURL?: string; contextWindow?: number }, + options?: { baseURL?: string; contextWindow?: number; isCustom?: boolean }, ): AIProvider { const openai = new OpenAI({ apiKey, @@ -286,6 +288,7 @@ export function createOpenAIChatCompletionsProvider( }) const contextWindow = options?.contextWindow ?? 400_000 + const isCustom = options?.isCustom ?? false return { id: providerId, @@ -308,8 +311,13 @@ export function createOpenAIChatCompletionsProvider( abortSignal?: AbortSignal streaming?: StreamingCallback }): Promise { + const systemContent = isCustom + ? config.systemInstructions + + responseFormatToPromptInstruction(config.responseFormat) + : config.systemInstructions + const messages: ChatCompletionMessageParam[] = [ - { role: "system", content: config.systemInstructions }, + { role: "system", content: systemContent }, ] if (config.conversationHistory && config.conversationHistory.length > 0) { @@ -328,10 +336,14 @@ export function createOpenAIChatCompletionsProvider( let totalOutputTokens = 0 let lastPromptTokens = 0 + const response_format = isCustom + ? undefined + : toResponseFormat(config.responseFormat) + const baseParams = { ...toChatCompletionsAPIProps(model), tools: openaiTools, - response_format: toResponseFormat(config.responseFormat), + response_format, } let result = await executeRequest( @@ -405,6 +417,26 @@ export function createOpenAIChatCompletionsProvider( } as AiAssistantAPIError } + if (isCustom) { + const json = parseCustomProviderResponse( + result.content, + (config.responseFormat.schema.required as string[]) || [], + (raw) => ({ explanation: raw }) as unknown as T, + ) + setStatus(null) + + const tokenUsage = { + inputTokens: totalInputTokens, + outputTokens: totalOutputTokens, + } + + if (config.postProcess) { + const processed = config.postProcess(json) + return { ...processed, tokenUsage } as T & { tokenUsage: TokenUsage } + } + return { ...json, tokenUsage } as T & { tokenUsage: TokenUsage } + } + try { const json = JSON.parse(result.content) as T setStatus(null) @@ -439,13 +471,31 @@ export function createOpenAIChatCompletionsProvider( async generateTitle({ model, prompt, responseFormat }) { try { + const userContent = isCustom + ? prompt + responseFormatToPromptInstruction(responseFormat) + : prompt + + const response_format = isCustom + ? undefined + : toResponseFormat(responseFormat) + const response = await openai.chat.completions.create({ - ...toChatCompletionsAPIProps(model), - messages: [{ role: "user", content: prompt }], - response_format: toResponseFormat(responseFormat), - max_completion_tokens: 100, + model: toChatCompletionsAPIProps(model).model, + messages: [{ role: "user", content: userContent }], + response_format, + ...(isCustom ? {} : { max_completion_tokens: 100 }), }) const content = response.choices[0]?.message?.content || "" + + if (isCustom) { + const parsed = parseCustomProviderResponse<{ title: string }>( + content, + (responseFormat.schema.required as string[]) || [], + (raw) => ({ title: raw.trim().slice(0, 40) }), + ) + return parsed.title || null + } + const parsed = JSON.parse(content) as { title: string } return parsed.title || null } catch { diff --git a/src/utils/ai/openaiProvider.ts b/src/utils/ai/openaiProvider.ts index 2f4c2b915..7cb392aaa 100644 --- a/src/utils/ai/openaiProvider.ts +++ b/src/utils/ai/openaiProvider.ts @@ -26,6 +26,8 @@ import { safeJsonParse, extractPartialExplanation, executeTool, + parseCustomProviderResponse, + responseFormatToPromptInstruction, } from "./shared" import type { Tiktoken, TiktokenBPE } from "js-tiktoken/lite" @@ -206,7 +208,7 @@ function toResponsesAPIProps(model: string): { export function createOpenAIProvider( apiKey: string, providerId: ProviderId = "openai", - options?: { baseURL?: string; contextWindow?: number }, + options?: { baseURL?: string; contextWindow?: number; isCustom?: boolean }, ): AIProvider { const openai = new OpenAI({ apiKey, @@ -215,6 +217,7 @@ export function createOpenAIProvider( }) const contextWindow = options?.contextWindow ?? 400_000 + const isCustom = options?.isCustom ?? false return { id: providerId, @@ -260,12 +263,21 @@ export function createOpenAIProvider( let totalInputTokens = 0 let totalOutputTokens = 0 + const systemInstructions = isCustom + ? config.systemInstructions + + responseFormatToPromptInstruction(config.responseFormat) + : config.systemInstructions + + const textConfig = isCustom + ? undefined + : toResponseTextConfig(config.responseFormat) + const requestParams = { ...toResponsesAPIProps(model), - instructions: config.systemInstructions, + instructions: systemInstructions, input, tools: openaiTools, - text: toResponseTextConfig(config.responseFormat), + ...(textConfig ? { text: textConfig } : {}), } as OpenAI.Responses.ResponseCreateParamsNonStreaming let lastResponse = streaming @@ -320,10 +332,10 @@ export function createOpenAIProvider( } const loopRequestParams = { ...toResponsesAPIProps(model), - instructions: config.systemInstructions, + instructions: systemInstructions, input, tools: openaiTools, - text: toResponseTextConfig(config.responseFormat), + ...(textConfig ? { text: textConfig } : {}), } as OpenAI.Responses.ResponseCreateParamsNonStreaming lastResponse = streaming @@ -357,6 +369,26 @@ export function createOpenAIProvider( const rawOutput = text.message + if (isCustom) { + const json = parseCustomProviderResponse( + rawOutput, + (config.responseFormat.schema.required as string[]) || [], + (raw) => ({ explanation: raw }) as unknown as T, + ) + setStatus(null) + + const tokenUsage = { + inputTokens: totalInputTokens, + outputTokens: totalOutputTokens, + } + + if (config.postProcess) { + const processed = config.postProcess(json) + return { ...processed, tokenUsage } as T & { tokenUsage: TokenUsage } + } + return { ...json, tokenUsage } as T & { tokenUsage: TokenUsage } + } + try { const json = JSON.parse(rawOutput) as T setStatus(null) @@ -391,13 +423,32 @@ export function createOpenAIProvider( async generateTitle({ model, prompt, responseFormat }) { try { + const userContent = isCustom + ? prompt + responseFormatToPromptInstruction(responseFormat) + : prompt + + const titleTextConfig = isCustom + ? undefined + : toResponseTextConfig(responseFormat) + const response = await openai.responses.create({ - ...toResponsesAPIProps(model), - input: [{ role: "user", content: prompt }], - text: toResponseTextConfig(responseFormat), + model: toResponsesAPIProps(model).model, + input: [{ role: "user", content: userContent }], + ...(titleTextConfig ? { text: titleTextConfig } : {}), max_output_tokens: 100, }) - const parsed = JSON.parse(response.output_text) as { title: string } + const rawText = response.output_text + + if (isCustom) { + const parsed = parseCustomProviderResponse<{ title: string }>( + rawText, + (responseFormat.schema.required as string[]) || [], + (raw) => ({ title: raw.trim().slice(0, 40) }), + ) + return parsed.title || null + } + + const parsed = JSON.parse(rawText) as { title: string } return parsed.title || null } catch { return null diff --git a/src/utils/ai/registry.ts b/src/utils/ai/registry.ts index 9c8e1994f..947b42e37 100644 --- a/src/utils/ai/registry.ts +++ b/src/utils/ai/registry.ts @@ -9,6 +9,7 @@ import type { ProviderId, ProviderType } from "./settings" type ProviderOptions = { baseURL?: string contextWindow?: number + isCustom?: boolean } export function createProvider( @@ -28,6 +29,7 @@ export function createProvider( return createProviderByType(custom.type, providerId, apiKey, { baseURL: custom.baseURL, contextWindow: custom.contextWindow, + isCustom: true, }) } diff --git a/src/utils/ai/shared.test.ts b/src/utils/ai/shared.test.ts new file mode 100644 index 000000000..b73d923d8 --- /dev/null +++ b/src/utils/ai/shared.test.ts @@ -0,0 +1,969 @@ +import { describe, it, expect } from "vitest" +import { + extractJsonWithExpectedFields, + parseCustomProviderResponse, + safeJsonParse, +} from "./shared" + +type SqlResponse = { sql: string | null; explanation: string } +type TitleResponse = { title: string } +type ExplainResponse = { explanation: string } + +const sqlFields = ["sql", "explanation"] +const titleFields = ["title"] +const explainFields = ["explanation"] + +const sqlFallback = (raw: string): SqlResponse => ({ + explanation: raw, + sql: null, +}) +const titleFallback = (raw: string): TitleResponse => ({ + title: raw.trim().slice(0, 40), +}) +const explainFallback = (raw: string): ExplainResponse => ({ + explanation: raw, +}) + +describe("parseCustomProviderResponse", () => { + // ─── Step 1: Direct JSON.parse ─────────────────────────────────── + + describe("step 1: valid JSON string", () => { + it("parses valid JSON with sql and explanation", () => { + const text = JSON.stringify({ + sql: "SELECT * FROM trades", + explanation: "Fetches all trades", + }) + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result).toEqual({ + sql: "SELECT * FROM trades", + explanation: "Fetches all trades", + }) + }) + + it("parses valid JSON with title", () => { + const text = JSON.stringify({ title: "My Chat" }) + const result = parseCustomProviderResponse( + text, + titleFields, + titleFallback, + ) + expect(result).toEqual({ title: "My Chat" }) + }) + + it("parses valid JSON with explanation only", () => { + const text = JSON.stringify({ explanation: "This is an explanation" }) + const result = parseCustomProviderResponse( + text, + explainFields, + explainFallback, + ) + expect(result).toEqual({ explanation: "This is an explanation" }) + }) + + it("parses valid JSON with null sql", () => { + const text = JSON.stringify({ + sql: null, + explanation: "No SQL needed", + }) + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result).toEqual({ sql: null, explanation: "No SQL needed" }) + }) + + it("returns full parsed object even with extra fields in step 1", () => { + const text = JSON.stringify({ + sql: "SELECT 1", + explanation: "test", + extra: "ignored by caller but present", + }) + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + // Step 1 returns JSON.parse(text) as-is, including extra fields + expect(result.sql).toBe("SELECT 1") + expect(result.explanation).toBe("test") + }) + }) + + // ─── Step 2: JSON in ```json ``` block ─────────────────────────── + + describe("step 2: JSON in ```json block", () => { + it("extracts JSON from ```json block", () => { + const text = + 'Here is the result:\n\n```json\n{"sql": "SELECT 1", "explanation": "Returns 1"}\n```' + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result).toEqual({ sql: "SELECT 1", explanation: "Returns 1" }) + }) + + it("extracts JSON from ```json block with pretty-printed JSON", () => { + const text = `Some preamble text. + +\`\`\`json +{ + "sql": "SELECT * FROM t", + "explanation": "Gets all rows" +} +\`\`\`` + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result).toEqual({ + sql: "SELECT * FROM t", + explanation: "Gets all rows", + }) + }) + + it("handles ```json block with nested markdown code blocks in explanation", () => { + const explanation = + "# ASOF JOIN\n\n```sql\nSELECT * FROM t1 ASOF JOIN t2\n```\n\nMore text\n\n```sql\nSELECT 1\n```" + const json = JSON.stringify({ + sql: "SELECT * FROM t1 ASOF JOIN t2", + explanation, + }) + const text = `Here is the response:\n\n\`\`\`json\n${json}\n\`\`\`` + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result.sql).toBe("SELECT * FROM t1 ASOF JOIN t2") + expect(result.explanation).toBe(explanation) + }) + + it("extracts title from ```json block", () => { + const text = + 'Generated title:\n\n```json\n{"title": "Trade Analysis"}\n```' + const result = parseCustomProviderResponse( + text, + titleFields, + titleFallback, + ) + expect(result).toEqual({ title: "Trade Analysis" }) + }) + + it("only includes expected fields from ```json block", () => { + const text = + '```json\n{"sql": "SELECT 1", "explanation": "test", "confidence": 0.9}\n```' + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result).toEqual({ sql: "SELECT 1", explanation: "test" }) + expect((result as Record).confidence).toBeUndefined() + }) + }) + + // ─── Step 2: Bare JSON with preamble ───────────────────────────── + + describe("step 2: bare JSON without ```json wrapper", () => { + it("extracts bare JSON after preamble text", () => { + const text = + 'Excellent! Here is the response:\n\n{"sql": "SELECT 1", "explanation": "Returns one"}' + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result).toEqual({ sql: "SELECT 1", explanation: "Returns one" }) + }) + + it("extracts bare JSON with preamble and epilogue", () => { + const text = + 'Here:\n{"sql": null, "explanation": "No query needed"}\nHope this helps!' + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result).toEqual({ sql: null, explanation: "No query needed" }) + }) + + it("handles preamble text that contains curly braces", () => { + const text = + 'Using {ASOF JOIN} syntax:\n\n{"sql": "SELECT 1", "explanation": "test"}' + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result).toEqual({ sql: "SELECT 1", explanation: "test" }) + }) + + it("extracts bare pretty-printed JSON after preamble", () => { + const text = `Let me provide the final response: + +{ + "sql": "SELECT * FROM trades LIMIT 10", + "explanation": "Fetches recent trades" +}` + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result).toEqual({ + sql: "SELECT * FROM trades LIMIT 10", + explanation: "Fetches recent trades", + }) + }) + }) + + // ─── Step 2: JSON with complex content ─────────────────────────── + + describe("step 2: complex content in JSON values", () => { + it("handles explanation with curly braces inside strings", () => { + const text = + 'Result:\n\n{"sql": "SELECT 1", "explanation": "Use {curly braces} in templates"}' + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result.explanation).toBe("Use {curly braces} in templates") + }) + + it("handles explanation with escaped quotes", () => { + const json = JSON.stringify({ + sql: 'SELECT * FROM "my table"', + explanation: 'Use "double quotes" for identifiers', + }) + const text = `Here:\n${json}` + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result.sql).toBe('SELECT * FROM "my table"') + expect(result.explanation).toBe('Use "double quotes" for identifiers') + }) + + it("handles SQL with complex nested queries", () => { + const sql = + "SELECT t.*, m.bids[1,1] FROM trades t ASOF JOIN market_data m ON (t.symbol = m.symbol) WHERE t.timestamp IN yesterday()" + const json = JSON.stringify({ sql, explanation: "Complex join query" }) + const text = `\`\`\`json\n${json}\n\`\`\`` + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result.sql).toBe(sql) + }) + + it("handles real-world DeepSeek response with nested markdown code blocks", () => { + // Simulates the actual failing case from DeepSeek via OpenRouter + const explanation = + "# ASOF JOIN\n\n## Basic Syntax\n\n```sql\nSELECT columns\nFROM left_table\nASOF JOIN right_table ON (matching_columns)\n```\n\n## Example\n\n```sql\nSELECT t.*, m.bids[1,1]\nFROM trades t\nASOF JOIN market_data m ON (symbol)\n```\n\nMore details here." + const innerJson = JSON.stringify({ + sql: "SELECT t.*, m.bids[1,1]\nFROM trades t\nASOF JOIN market_data m ON (symbol)", + explanation, + }) + const text = `Perfect! Now I'll provide you with a comprehensive response.\n\n\`\`\`json\n${innerJson}\n\`\`\`` + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result.sql).toBe( + "SELECT t.*, m.bids[1,1]\nFROM trades t\nASOF JOIN market_data m ON (symbol)", + ) + expect(result.explanation).toBe(explanation) + }) + + it("handles real-world DeepSeek response with bare JSON (no ```json wrapper)", () => { + // Simulates the actual failing case from DeepSeek via DeepInfra + const explanation = + "# ASOF JOIN in QuestDB\n\n## Example\n\n```sql\nSELECT t.*, m.bids[1,1]\nFROM trades t\nASOF JOIN market_data m ON (symbol)\n```" + const innerJson = JSON.stringify({ + sql: "SELECT t.*, m.bids[1,1]\nFROM trades t\nASOF JOIN market_data m ON (symbol)", + explanation, + }) + const text = `Excellent! Now let me provide the final response with examples:\n\n${innerJson}` + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result.sql).toBe( + "SELECT t.*, m.bids[1,1]\nFROM trades t\nASOF JOIN market_data m ON (symbol)", + ) + expect(result.explanation).toBe(explanation) + }) + }) + + // ─── Step 2: Missing expected fields ───────────────────────────── + + describe("step 2: missing expected fields", () => { + it("falls back when JSON in ```json block has wrong fields", () => { + const text = '```json\n{"query": "SELECT 1", "description": "test"}\n```' + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + // Falls to fallback because "sql" and "explanation" are missing + expect(result).toEqual({ + explanation: text, + sql: null, + }) + }) + + it("falls back when bare JSON has wrong fields", () => { + const text = 'Here:\n{"answer": "SELECT 1", "reasoning": "because"}' + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result).toEqual({ + explanation: text, + sql: null, + }) + }) + + it("falls back when JSON has only some expected fields", () => { + const text = '{"explanation": "test"}' + // Step 1 parses it successfully — this is valid JSON + // But wait, step 1 returns JSON.parse as-is without field checking + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + // Step 1 returns it as-is (no field validation in step 1) + expect(result.explanation).toBe("test") + expect(result.sql).toBeUndefined() + }) + }) + + // ─── Step 2: Invalid JSON repaired by jsonrepair ──────────────── + + describe("step 2: malformed JSON repaired by jsonrepair", () => { + it("repairs trailing commas in ```json block", () => { + const text = '```json\n{"sql": "SELECT 1", "explanation": "test",}\n```' + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result).toEqual({ sql: "SELECT 1", explanation: "test" }) + }) + + it("repairs single-quoted strings in bare JSON", () => { + const text = "Here:\n{'sql': 'SELECT 1', 'explanation': 'test'}" + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result).toEqual({ sql: "SELECT 1", explanation: "test" }) + }) + + it("repairs unquoted keys", () => { + const text = '{sql: "SELECT 1", explanation: "test"}' + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result).toEqual({ sql: "SELECT 1", explanation: "test" }) + }) + + it("repairs Python boolean/null constants (True, False, None)", () => { + const text = '{"sql": None, "explanation": "No query needed"}' + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result.sql).toBeNull() + expect(result.explanation).toBe("No query needed") + }) + + it("repairs trailing comma + unquoted keys combined", () => { + const text = "Preamble:\n{sql: 'SELECT 1', explanation: 'works',}" + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result).toEqual({ sql: "SELECT 1", explanation: "works" }) + }) + + it("repairs missing closing brace (truncated JSON)", () => { + // jsonrepair can add the missing brace + const text = '{"sql": "SELECT 1", "explanation": "truncated' + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result.sql).toBe("SELECT 1") + expect(result.explanation).toBe("truncated") + }) + }) + + // ─── Step 2: Unrepairable invalid JSON ───────────────────────── + + describe("step 2: unrepairable invalid JSON", () => { + it("falls back when ```json block contains gibberish", () => { + const text = "```json\n{invalid json here}\n```" + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result).toEqual({ explanation: text, sql: null }) + }) + }) + + // ─── Step 3: jsonrepair on full text ─────────────────────────── + + describe("step 3: jsonrepair on full text", () => { + it("repairs full-text malformed JSON (no preamble)", () => { + const text = "{'title': 'My Chat',}" + const result = parseCustomProviderResponse( + text, + titleFields, + titleFallback, + ) + expect(result).toEqual({ title: "My Chat" }) + }) + + it("repairs JSON with comments", () => { + // jsonrepair strips JS comments + const text = + '{\n "sql": "SELECT 1", // the query\n "explanation": "test"\n}' + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result.sql).toBe("SELECT 1") + expect(result.explanation).toBe("test") + }) + }) + + // ─── Step 4: Fallback to raw text ──────────────────────────────── + + describe("step 4: fallback", () => { + it("returns fallback for plain text response", () => { + const text = "I can help you write a query for that." + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result).toEqual({ explanation: text, sql: null }) + }) + + it("returns fallback for empty string", () => { + const result = parseCustomProviderResponse( + "", + sqlFields, + sqlFallback, + ) + expect(result).toEqual({ explanation: "", sql: null }) + }) + + it("returns title fallback for plain text", () => { + const text = "Trade Analysis Overview" + const result = parseCustomProviderResponse( + text, + titleFields, + titleFallback, + ) + expect(result).toEqual({ title: "Trade Analysis Overview" }) + }) + + it("truncates long title in fallback", () => { + const text = + "This is a very long title that should be truncated to forty characters maximum" + const result = parseCustomProviderResponse( + text, + titleFields, + titleFallback, + ) + expect(result.title.length).toBeLessThanOrEqual(40) + }) + + it("returns fallback for markdown without JSON", () => { + const text = + "# Query Help\n\nYou can use `SELECT * FROM trades` to get all trades." + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result).toEqual({ explanation: text, sql: null }) + }) + + it("returns fallback when text has braces but no valid JSON", () => { + const text = + "Use the following syntax: if (x > 0) { return x; } else { return -x; }" + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result).toEqual({ explanation: text, sql: null }) + }) + }) + + // ─── Edge cases ────────────────────────────────────────────────── + + describe("edge cases", () => { + it("handles empty expectedFields (matches any valid JSON object)", () => { + const text = 'Preamble\n{"foo": "bar"}' + const result = parseCustomProviderResponse>( + text, + [], + (raw) => ({ raw }), + ) + // Empty expectedFields means every() is vacuously true, + // but only expected fields are extracted → empty object + expect(result).toEqual({}) + }) + + it("handles JSON array (not object) — falls back", () => { + const text = '[{"sql": "SELECT 1"}]' + // Step 1: JSON.parse succeeds and returns the array + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + // Arrays are returned as-is from step 1 + expect(Array.isArray(result)).toBe(true) + }) + + it("handles multiple JSON objects in text — picks first with expected fields", () => { + const text = + '{"wrong": true}\n\n{"sql": "SELECT 1", "explanation": "test"}' + // Step 1: JSON.parse fails (two objects aren't valid single JSON) + // Step 2: first { → {"wrong": true} → valid but wrong fields → next { + // second { → finds the correct one + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result).toEqual({ sql: "SELECT 1", explanation: "test" }) + }) + + it("handles deeply nested JSON braces", () => { + const sql = "SELECT CASE WHEN x > 0 THEN 'pos' ELSE 'neg' END FROM t" + const text = `Result:\n${JSON.stringify({ sql, explanation: "CASE expression" })}` + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result.sql).toBe(sql) + }) + + it("handles unicode content", () => { + const text = JSON.stringify({ + sql: "SELECT * FROM données", + explanation: "Récupère toutes les données 日本語テスト", + }) + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result.sql).toBe("SELECT * FROM données") + expect(result.explanation).toContain("日本語テスト") + }) + + it("handles whitespace-only content", () => { + const result = parseCustomProviderResponse( + " \n\n ", + sqlFields, + sqlFallback, + ) + expect(result).toEqual({ explanation: " \n\n ", sql: null }) + }) + + it("handles ``` block without json language tag", () => { + const text = + 'Here:\n\n```\n{"sql": "SELECT 1", "explanation": "test"}\n```' + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result).toEqual({ sql: "SELECT 1", explanation: "test" }) + }) + + it("handles JSON with very long explanation containing multiple code blocks", () => { + const explanation = [ + "# Complex Query", + "", + "First, create the table:", + "```sql", + "CREATE TABLE trades (symbol STRING, price DOUBLE, ts TIMESTAMP) timestamp(ts);", + "```", + "", + "Then insert data:", + "```sql", + "INSERT INTO trades VALUES('BTC', 50000, now());", + "```", + "", + "Finally, query it:", + "```sql", + "SELECT * FROM trades WHERE symbol = 'BTC';", + "```", + ].join("\n") + const json = JSON.stringify({ + sql: "SELECT * FROM trades WHERE symbol = 'BTC'", + explanation, + }) + const text = `\`\`\`json\n${json}\n\`\`\`` + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result.sql).toBe("SELECT * FROM trades WHERE symbol = 'BTC'") + expect(result.explanation).toBe(explanation) + }) + + it("handles JSON where explanation contains JSON-like text", () => { + const json = JSON.stringify({ + sql: "SELECT 1", + explanation: + 'The response format is {"key": "value"} and you can nest {objects: {inside}} as needed.', + }) + const text = `Response:\n${json}` + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result.sql).toBe("SELECT 1") + expect(result.explanation).toContain('{"key": "value"}') + }) + + it("handles JSON with escaped backslashes and special chars", () => { + const text = JSON.stringify({ + sql: "SELECT '\\n' FROM t", + explanation: "Selects a backslash-n string\ttab", + }) + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result.sql).toBe("SELECT '\\n' FROM t") + expect(result.explanation).toBe("Selects a backslash-n string\ttab") + }) + }) +}) + +describe("extractJsonWithExpectedFields", () => { + // ─── Basic extraction ─────────────────────────────────────────── + + it("extracts valid JSON with all expected fields", () => { + const text = '{"sql": "SELECT 1", "explanation": "test"}' + const result = extractJsonWithExpectedFields(text, ["sql", "explanation"]) + expect(result).toEqual({ sql: "SELECT 1", explanation: "test" }) + }) + + it("returns only expected fields, stripping extras", () => { + const text = '{"sql": "SELECT 1", "explanation": "test", "confidence": 0.9}' + const result = extractJsonWithExpectedFields(text, ["sql", "explanation"]) + expect(result).toEqual({ sql: "SELECT 1", explanation: "test" }) + expect(result).not.toHaveProperty("confidence") + }) + + it("returns null when no JSON found", () => { + const result = extractJsonWithExpectedFields("plain text", ["sql"]) + expect(result).toBeNull() + }) + + it("returns null for empty string", () => { + const result = extractJsonWithExpectedFields("", ["sql"]) + expect(result).toBeNull() + }) + + it("returns null when no opening brace exists", () => { + const result = extractJsonWithExpectedFields("no braces here", ["sql"]) + expect(result).toBeNull() + }) + + // ─── Field matching ───────────────────────────────────────────── + + it("returns null when JSON is valid but missing expected fields", () => { + const text = '{"query": "SELECT 1", "description": "test"}' + const result = extractJsonWithExpectedFields(text, ["sql", "explanation"]) + expect(result).toBeNull() + }) + + it("returns null when only some expected fields present", () => { + const text = '{"sql": "SELECT 1"}' + const result = extractJsonWithExpectedFields(text, ["sql", "explanation"]) + expect(result).toBeNull() + }) + + it("matches with empty expectedFields (vacuously true)", () => { + const text = '{"anything": "works"}' + const result = extractJsonWithExpectedFields(text, []) + expect(result).toEqual({}) + }) + + it("includes fields with null values", () => { + const text = '{"sql": null, "explanation": "No query needed"}' + const result = extractJsonWithExpectedFields(text, ["sql", "explanation"]) + expect(result).toEqual({ sql: null, explanation: "No query needed" }) + }) + + it("includes fields with falsy values (0, false, empty string)", () => { + const text = '{"count": 0, "active": false, "name": ""}' + const result = extractJsonWithExpectedFields(text, [ + "count", + "active", + "name", + ]) + expect(result).toEqual({ count: 0, active: false, name: "" }) + }) + + // ─── Preamble and epilogue ────────────────────────────────────── + + it("extracts JSON after preamble text", () => { + const text = + 'Here is the result:\n\n{"sql": "SELECT 1", "explanation": "test"}' + const result = extractJsonWithExpectedFields(text, ["sql", "explanation"]) + expect(result).toEqual({ sql: "SELECT 1", explanation: "test" }) + }) + + it("extracts JSON with both preamble and epilogue", () => { + const text = 'Result:\n{"title": "My Chat"}\nHope that helps!' + const result = extractJsonWithExpectedFields(text, ["title"]) + expect(result).toEqual({ title: "My Chat" }) + }) + + it("extracts JSON from ```json code block", () => { + const text = '```json\n{"sql": "SELECT 1", "explanation": "test"}\n```' + const result = extractJsonWithExpectedFields(text, ["sql", "explanation"]) + expect(result).toEqual({ sql: "SELECT 1", explanation: "test" }) + }) + + // ─── Multiple JSON objects ────────────────────────────────────── + + it("skips first JSON object if it lacks expected fields", () => { + const text = '{"wrong": true}\n\n{"sql": "SELECT 1", "explanation": "test"}' + const result = extractJsonWithExpectedFields(text, ["sql", "explanation"]) + expect(result).toEqual({ sql: "SELECT 1", explanation: "test" }) + }) + + it("when multiple matching objects exist, jsonrepair merges them (last wins)", () => { + const text = + '{"sql": "SELECT 1", "explanation": "first"}\n{"sql": "SELECT 2", "explanation": "second"}' + const result = extractJsonWithExpectedFields(text, ["sql", "explanation"]) + // jsonrepair merges concatenated objects, later keys overwrite earlier ones + expect(result).toEqual({ sql: "SELECT 2", explanation: "second" }) + }) + + it("skips non-JSON brace text to find real JSON", () => { + const text = + 'Use {curly braces} for templates\n{"sql": "SELECT 1", "explanation": "found"}' + const result = extractJsonWithExpectedFields(text, ["sql", "explanation"]) + expect(result).toEqual({ sql: "SELECT 1", explanation: "found" }) + }) + + // ─── jsonrepair fallback ──────────────────────────────────────── + + it("repairs trailing commas via jsonrepair", () => { + const text = '{"sql": "SELECT 1", "explanation": "test",}' + const result = extractJsonWithExpectedFields(text, ["sql", "explanation"]) + expect(result).toEqual({ sql: "SELECT 1", explanation: "test" }) + }) + + it("repairs single-quoted strings via jsonrepair", () => { + const text = "{'sql': 'SELECT 1', 'explanation': 'test'}" + const result = extractJsonWithExpectedFields(text, ["sql", "explanation"]) + expect(result).toEqual({ sql: "SELECT 1", explanation: "test" }) + }) + + it("repairs unquoted keys via jsonrepair", () => { + const text = '{sql: "SELECT 1", explanation: "test"}' + const result = extractJsonWithExpectedFields(text, ["sql", "explanation"]) + expect(result).toEqual({ sql: "SELECT 1", explanation: "test" }) + }) + + it("repairs Python None to null via jsonrepair", () => { + const text = '{"sql": None, "explanation": "No query"}' + const result = extractJsonWithExpectedFields(text, ["sql", "explanation"]) + expect(result).toEqual({ sql: null, explanation: "No query" }) + }) + + // ─── Complex content ──────────────────────────────────────────── + + it("handles nested braces inside string values", () => { + const text = JSON.stringify({ + sql: "SELECT 1", + explanation: "Use {curly braces} in {templates}", + }) + const result = extractJsonWithExpectedFields(text, ["sql", "explanation"]) + expect(result!.explanation).toBe("Use {curly braces} in {templates}") + }) + + it("handles escaped quotes inside string values", () => { + const text = JSON.stringify({ + sql: 'SELECT * FROM "my table"', + explanation: 'Use "double quotes" for identifiers', + }) + const result = extractJsonWithExpectedFields(text, ["sql", "explanation"]) + expect(result!.sql).toBe('SELECT * FROM "my table"') + }) + + it("handles newlines inside string values", () => { + const text = JSON.stringify({ + sql: "SELECT *\nFROM trades\nLIMIT 10", + explanation: "Multi-line query", + }) + const result = extractJsonWithExpectedFields(text, ["sql", "explanation"]) + expect(result!.sql).toBe("SELECT *\nFROM trades\nLIMIT 10") + }) + + it("handles pretty-printed JSON", () => { + const text = `{ + "sql": "SELECT 1", + "explanation": "test" +}` + const result = extractJsonWithExpectedFields(text, ["sql", "explanation"]) + expect(result).toEqual({ sql: "SELECT 1", explanation: "test" }) + }) + + it("handles unicode content", () => { + const text = JSON.stringify({ + sql: "SELECT * FROM données", + explanation: "日本語テスト", + }) + const result = extractJsonWithExpectedFields(text, ["sql", "explanation"]) + expect(result!.sql).toBe("SELECT * FROM données") + expect(result!.explanation).toBe("日本語テスト") + }) + + it("returns null when text has braces but no valid JSON", () => { + const text = "if (x > 0) { return x; } else { return -x; }" + const result = extractJsonWithExpectedFields(text, ["sql", "explanation"]) + expect(result).toBeNull() + }) + + it("returns null for nested non-JSON braces", () => { + const text = "function() { if (true) { console.log('hi') } }" + const result = extractJsonWithExpectedFields(text, ["sql"]) + expect(result).toBeNull() + }) + + it("handles single expected field", () => { + const text = '{"title": "My Chat"}' + const result = extractJsonWithExpectedFields(text, ["title"]) + expect(result).toEqual({ title: "My Chat" }) + }) + + it("handles many expected fields", () => { + const text = JSON.stringify({ a: 1, b: 2, c: 3, d: 4, e: 5 }) + const result = extractJsonWithExpectedFields(text, [ + "a", + "b", + "c", + "d", + "e", + ]) + expect(result).toEqual({ a: 1, b: 2, c: 3, d: 4, e: 5 }) + }) + + it("handles JSON embedded deep in markdown", () => { + const text = `# Response + +Here is the analysis: + +Some preamble with {random} braces. + +\`\`\`json +{"sql": "SELECT count() FROM trades", "explanation": "Counts all trades"} +\`\`\` + +And some epilogue text.` + const result = extractJsonWithExpectedFields(text, ["sql", "explanation"]) + expect(result).toEqual({ + sql: "SELECT count() FROM trades", + explanation: "Counts all trades", + }) + }) +}) + +describe("safeJsonParse", () => { + it("parses valid JSON", () => { + const result = safeJsonParse<{ a: number }>('{"a": 1}') + expect(result).toEqual({ a: 1 }) + }) + + it("returns jsonrepair result for non-JSON text", () => { + // jsonrepair turns plain text into a JSON string + const result = safeJsonParse("not json at all") + expect(result).toBe("not json at all") + }) + + it("returns empty object for empty input", () => { + const result = safeJsonParse("{") + expect(result).toEqual({}) + }) + + it("repairs truncated JSON (missing closing brace)", () => { + // Real-world Qwen case: tool call arguments with missing } + const text = + '{"category": "functions", "items": ["today", "tomorrow", "yesterday"]' + const result = safeJsonParse<{ + category: string + items: string[] + }>(text) + expect(result).toEqual({ + category: "functions", + items: ["today", "tomorrow", "yesterday"], + }) + }) + + it("repairs trailing commas", () => { + const result = safeJsonParse<{ table_name: string }>( + '{"table_name": "trades",}', + ) + expect(result).toEqual({ table_name: "trades" }) + }) + + it("repairs single-quoted strings", () => { + const result = safeJsonParse<{ query: string }>("{'query': 'SELECT 1'}") + expect(result).toEqual({ query: "SELECT 1" }) + }) + + it("repairs unquoted keys", () => { + const result = safeJsonParse<{ table_name: string }>( + '{table_name: "trades"}', + ) + expect(result).toEqual({ table_name: "trades" }) + }) + + it("handles empty string arguments", () => { + const result = safeJsonParse("") + expect(result).toEqual({}) + }) +}) diff --git a/src/utils/ai/shared.ts b/src/utils/ai/shared.ts index 8a211d46a..fbda68515 100644 --- a/src/utils/ai/shared.ts +++ b/src/utils/ai/shared.ts @@ -6,6 +6,8 @@ import { parseDocItems, DocCategory, } from "../questdbDocsRetrieval" +import type { ResponseFormatSchema } from "./types" +import { jsonrepair } from "jsonrepair" export class RefusalError extends Error { constructor(message: string) { @@ -36,7 +38,11 @@ export const safeJsonParse = (text: string): T | object => { try { return JSON.parse(text) as T } catch { - return {} + try { + return JSON.parse(jsonrepair(text)) as T + } catch { + return {} + } } } @@ -198,3 +204,102 @@ export const executeTool = async ( } } } + +export function extractJsonWithExpectedFields( + text: string, + expectedFields: string[], +): Record | null { + let searchStart = 0 + while (true) { + const braceStart = text.indexOf("{", searchStart) + if (braceStart === -1) break + + const textFromBrace = text.slice(braceStart) + let endIdx = textFromBrace.lastIndexOf("}") + while (endIdx > 0) { + const candidate = textFromBrace.slice(0, endIdx + 1) + // Try direct JSON.parse first, then jsonrepair as fallback + let parsed: Record | null = null + try { + parsed = JSON.parse(candidate) as Record + } catch { + try { + parsed = JSON.parse(jsonrepair(candidate)) as Record + } catch { + // jsonrepair couldn't fix it either + } + } + + if (parsed !== null) { + if (expectedFields.every((field) => field in parsed)) { + const result: Record = {} + for (const field of expectedFields) { + result[field] = parsed[field] + } + return result + } + break // Valid JSON but missing expected fields — try next { + } + endIdx = textFromBrace.lastIndexOf("}", endIdx - 1) + } + searchStart = braceStart + 1 + } + return null +} + +export function parseCustomProviderResponse( + text: string, + expectedFields: string[], + fallback: (rawText: string) => T, +): T { + try { + return JSON.parse(text) as T + } catch { + // not valid JSON as-is + } + + const extracted = extractJsonWithExpectedFields(text, expectedFields) + if (extracted) { + return extracted as T + } + + try { + const repaired = JSON.parse(jsonrepair(text)) as Record + if ( + repaired !== null && + typeof repaired === "object" && + !Array.isArray(repaired) && + (expectedFields.length === 0 || + expectedFields.every((field) => field in repaired)) + ) { + return repaired as T + } + } catch { + // jsonrepair couldn't salvage it + } + + // Fallback — caller decides the shape + return fallback(text) +} + +export function responseFormatToPromptInstruction( + format: ResponseFormatSchema, +): string { + const properties = format.schema.properties as Record< + string, + { type: unknown } + > + const required = (format.schema.required as string[]) || [] + + const fields = Object.entries(properties) + .map(([key, value]) => { + const typeStr = Array.isArray(value.type) + ? value.type.join(" | ") + : String(value.type) + const isRequired = required.includes(key) + return ` "${key}": ${typeStr}${isRequired ? " (required)" : " (optional)"}` + }) + .join(",\n") + + return `\nAlways respond with a valid JSON object with the following fields:\n{\n${fields}\n}` +} diff --git a/src/utils/executeAIFlow.ts b/src/utils/executeAIFlow.ts index 677a418ea..8b9b1e786 100644 --- a/src/utils/executeAIFlow.ts +++ b/src/utils/executeAIFlow.ts @@ -327,7 +327,7 @@ function processSQLResult( } } - let assistantContent = result.explanation || "Response received" + let assistantContent = result.explanation || "No explanation received" if (hasSQLInResult) { assistantContent = `SQL Query:\n\`\`\`sql\n${result.sql}\n\`\`\`\n\nExplanation:\n${result.explanation || ""}` } diff --git a/yarn.lock b/yarn.lock index dfef7b0d4..0c53075c3 100644 --- a/yarn.lock +++ b/yarn.lock @@ -2545,6 +2545,7 @@ __metadata: js-base64: "npm:^3.7.7" js-sha256: "npm:^0.11.0" js-tiktoken: "npm:^1.0.21" + jsonrepair: "npm:^3.13.3" lint-staged: "npm:^16.2.6" lodash.isequal: "npm:^4.5.0" lodash.merge: "npm:^4.6.2" @@ -8271,6 +8272,15 @@ __metadata: languageName: node linkType: hard +"jsonrepair@npm:^3.13.3": + version: 3.13.3 + resolution: "jsonrepair@npm:3.13.3" + bin: + jsonrepair: bin/cli.js + checksum: 10/cd1d42516e3e03ccc44498c328f87f4ec05b24afe190becced0babf5d608e81b375e8d2040494142760556c1d6583b395073b5253626907e4df968d8cf01115c + languageName: node + linkType: hard + "jsprim@npm:^2.0.2": version: 2.0.2 resolution: "jsprim@npm:2.0.2" From 3cee5b39811aed33c0f329349921b76e22bba7fa Mon Sep 17 00:00:00 2001 From: emrberk Date: Wed, 11 Mar 2026 17:48:09 +0300 Subject: [PATCH 18/25] error wording --- src/utils/ai/anthropicProvider.ts | 2 +- src/utils/ai/openaiChatCompletionsProvider.ts | 2 +- src/utils/ai/openaiProvider.ts | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/utils/ai/anthropicProvider.ts b/src/utils/ai/anthropicProvider.ts index 571982347..4e1b763db 100644 --- a/src/utils/ai/anthropicProvider.ts +++ b/src/utils/ai/anthropicProvider.ts @@ -685,7 +685,7 @@ export function createAnthropicProvider( if (error instanceof Anthropic.APIConnectionError) { return { type: "network", - message: "Network error. Please check your internet connection.", + message: "Network error. Please check your connection.", details: error.message, } } diff --git a/src/utils/ai/openaiChatCompletionsProvider.ts b/src/utils/ai/openaiChatCompletionsProvider.ts index b613aa3fc..0c2ac1cd7 100644 --- a/src/utils/ai/openaiChatCompletionsProvider.ts +++ b/src/utils/ai/openaiChatCompletionsProvider.ts @@ -658,7 +658,7 @@ export function createOpenAIChatCompletionsProvider( if (error instanceof OpenAI.APIConnectionError) { return { type: "network", - message: "Network error. Please check your internet connection.", + message: "Network error. Please check your connection.", details: error.message, } } diff --git a/src/utils/ai/openaiProvider.ts b/src/utils/ai/openaiProvider.ts index 7cb392aaa..82ec7bfb1 100644 --- a/src/utils/ai/openaiProvider.ts +++ b/src/utils/ai/openaiProvider.ts @@ -598,7 +598,7 @@ export function createOpenAIProvider( if (error instanceof OpenAI.APIConnectionError) { return { type: "network", - message: "Network error. Please check your internet connection.", + message: "Network error. Please check your connection.", details: error.message, } } From 0583f6789ce02f34c8129a9eb5789ca43ae2ffab Mon Sep 17 00:00:00 2001 From: emrberk Date: Thu, 12 Mar 2026 23:55:51 +0300 Subject: [PATCH 19/25] strip unnecessary tracking headers from custom provider --- e2e/questdb | 2 +- .../SetupAIAssistant/CustomProviderModal.tsx | 2 +- src/utils/ai/anthropicProvider.ts | 16 +++++- src/utils/ai/fetchWithFilteredHeaders.ts | 53 +++++++++++++++++++ src/utils/ai/openaiChatCompletionsProvider.ts | 16 +++++- src/utils/ai/openaiProvider.ts | 16 +++++- 6 files changed, 100 insertions(+), 5 deletions(-) create mode 100644 src/utils/ai/fetchWithFilteredHeaders.ts diff --git a/e2e/questdb b/e2e/questdb index 834d0fecb..c5dc7a1eb 160000 --- a/e2e/questdb +++ b/e2e/questdb @@ -1 +1 @@ -Subproject commit 834d0fecbefe8e72a05c488ca75f8976cd4cc0f9 +Subproject commit c5dc7a1eb1533a6971322e69515278e711ec4ec7 diff --git a/src/components/SetupAIAssistant/CustomProviderModal.tsx b/src/components/SetupAIAssistant/CustomProviderModal.tsx index 832ffa5e8..99395d958 100644 --- a/src/components/SetupAIAssistant/CustomProviderModal.tsx +++ b/src/components/SetupAIAssistant/CustomProviderModal.tsx @@ -874,7 +874,7 @@ export const CustomProviderModal = ({ providerType, "temp", apiKey || "", - { baseURL, contextWindow }, + { baseURL, contextWindow, isCustom: true }, ) const models = await tempProvider.listModels() if (models && models.length > 0) { diff --git a/src/utils/ai/anthropicProvider.ts b/src/utils/ai/anthropicProvider.ts index 4e1b763db..d57a09ee2 100644 --- a/src/utils/ai/anthropicProvider.ts +++ b/src/utils/ai/anthropicProvider.ts @@ -27,6 +27,10 @@ import { parseCustomProviderResponse, responseFormatToPromptInstruction, } from "./shared" +import { + createHeaderFilteredFetch, + ANTHROPIC_ALLOWED_HEADERS, +} from "./fetchWithFilteredHeaders" function toAnthropicTools(tools: ToolDefinition[]): AnthropicTool[] { return tools.map((t) => ({ @@ -319,14 +323,19 @@ export function createAnthropicProvider( providerId: ProviderId = "anthropic", options?: { baseURL?: string; contextWindow?: number; isCustom?: boolean }, ): AIProvider { + const isCustom = options?.isCustom ?? false const anthropic = new Anthropic({ apiKey, dangerouslyAllowBrowser: true, ...(options?.baseURL ? { baseURL: options.baseURL } : {}), + ...(isCustom + ? { + fetch: createHeaderFilteredFetch(ANTHROPIC_ALLOWED_HEADERS), + } + : {}), }) const contextWindow = options?.contextWindow ?? 200_000 - const isCustom = options?.isCustom ?? false return { id: providerId, @@ -558,6 +567,11 @@ export function createAnthropicProvider( apiKey: testApiKey, dangerouslyAllowBrowser: true, ...(options?.baseURL ? { baseURL: options.baseURL } : {}), + ...(isCustom + ? { + fetch: createHeaderFilteredFetch(ANTHROPIC_ALLOWED_HEADERS), + } + : {}), }) await createAnthropicMessage(testClient, { diff --git a/src/utils/ai/fetchWithFilteredHeaders.ts b/src/utils/ai/fetchWithFilteredHeaders.ts new file mode 100644 index 000000000..5851fde98 --- /dev/null +++ b/src/utils/ai/fetchWithFilteredHeaders.ts @@ -0,0 +1,53 @@ +function filterHeaders(raw: HeadersInit, allowSet: Set): Headers { + const source = new Headers(raw) + const filtered = new Headers() + source.forEach((value, key) => { + if (allowSet.has(key.toLowerCase())) { + filtered.set(key, value) + } + }) + return filtered +} + +export function createHeaderFilteredFetch( + allowedHeaders: string[], +): typeof globalThis.fetch { + const allowSet = new Set(allowedHeaders.map((h) => h.toLowerCase())) + return (input, init) => { + // Collect headers from both the Request object and init + const merged = new Headers() + if (input instanceof Request) { + input.headers.forEach((v, k) => merged.set(k, v)) + } + if (init?.headers) { + new Headers(init.headers).forEach((v, k) => merged.set(k, v)) + } + + const filtered = filterHeaders(merged, allowSet) + + // Normalize to plain URL + init to avoid Request carrying extra headers + const url = input instanceof Request ? input.url : input + const method = + init?.method ?? (input instanceof Request ? input.method : undefined) + const body = + init?.body ?? (input instanceof Request ? input.body : undefined) + const signal = + init?.signal ?? (input instanceof Request ? input.signal : undefined) + + return globalThis.fetch(url, { + ...init, + method, + headers: filtered, + body, + signal, + }) + } +} + +export const OPENAI_ALLOWED_HEADERS = ["content-type", "authorization"] + +export const ANTHROPIC_ALLOWED_HEADERS = [ + "content-type", + "x-api-key", + "anthropic-version", +] diff --git a/src/utils/ai/openaiChatCompletionsProvider.ts b/src/utils/ai/openaiChatCompletionsProvider.ts index 0c2ac1cd7..a673945a2 100644 --- a/src/utils/ai/openaiChatCompletionsProvider.ts +++ b/src/utils/ai/openaiChatCompletionsProvider.ts @@ -30,6 +30,10 @@ import { responseFormatToPromptInstruction, } from "./shared" import type { Tiktoken, TiktokenBPE } from "js-tiktoken/lite" +import { + createHeaderFilteredFetch, + OPENAI_ALLOWED_HEADERS, +} from "./fetchWithFilteredHeaders" function toResponseFormat(format: ResponseFormatSchema) { return { @@ -281,14 +285,19 @@ export function createOpenAIChatCompletionsProvider( providerId: ProviderId = "openai", options?: { baseURL?: string; contextWindow?: number; isCustom?: boolean }, ): AIProvider { + const isCustom = options?.isCustom ?? false const openai = new OpenAI({ apiKey, dangerouslyAllowBrowser: true, ...(options?.baseURL ? { baseURL: options.baseURL } : {}), + ...(isCustom + ? { + fetch: createHeaderFilteredFetch(OPENAI_ALLOWED_HEADERS), + } + : {}), }) const contextWindow = options?.contextWindow ?? 400_000 - const isCustom = options?.isCustom ?? false return { id: providerId, @@ -520,6 +529,11 @@ export function createOpenAIChatCompletionsProvider( apiKey: testApiKey, dangerouslyAllowBrowser: true, ...(options?.baseURL ? { baseURL: options.baseURL } : {}), + ...(isCustom + ? { + fetch: createHeaderFilteredFetch(OPENAI_ALLOWED_HEADERS), + } + : {}), }) await testClient.chat.completions.create({ model: getModelProps(model).model, diff --git a/src/utils/ai/openaiProvider.ts b/src/utils/ai/openaiProvider.ts index 82ec7bfb1..bd349a668 100644 --- a/src/utils/ai/openaiProvider.ts +++ b/src/utils/ai/openaiProvider.ts @@ -30,6 +30,10 @@ import { responseFormatToPromptInstruction, } from "./shared" import type { Tiktoken, TiktokenBPE } from "js-tiktoken/lite" +import { + createHeaderFilteredFetch, + OPENAI_ALLOWED_HEADERS, +} from "./fetchWithFilteredHeaders" function toResponseTextConfig( format: ResponseFormatSchema, @@ -210,14 +214,19 @@ export function createOpenAIProvider( providerId: ProviderId = "openai", options?: { baseURL?: string; contextWindow?: number; isCustom?: boolean }, ): AIProvider { + const isCustom = options?.isCustom ?? false const openai = new OpenAI({ apiKey, dangerouslyAllowBrowser: true, ...(options?.baseURL ? { baseURL: options.baseURL } : {}), + ...(isCustom + ? { + fetch: createHeaderFilteredFetch(OPENAI_ALLOWED_HEADERS), + } + : {}), }) const contextWindow = options?.contextWindow ?? 400_000 - const isCustom = options?.isCustom ?? false return { id: providerId, @@ -470,6 +479,11 @@ export function createOpenAIProvider( apiKey: testApiKey, dangerouslyAllowBrowser: true, ...(options?.baseURL ? { baseURL: options.baseURL } : {}), + ...(isCustom + ? { + fetch: createHeaderFilteredFetch(OPENAI_ALLOWED_HEADERS), + } + : {}), }) await testClient.responses.create({ model: getModelProps(model).model, // testConnection only needs model name From b484648ce4b8bb3c9a81880a3d9e2a219384e6f9 Mon Sep 17 00:00:00 2001 From: emrberk Date: Fri, 13 Mar 2026 03:41:38 +0300 Subject: [PATCH 20/25] update tests --- e2e/commands.js | 6 + e2e/tests/console/aiAssistant.spec.js | 890 +++++++++++++++++- e2e/utils/aiAssistant.js | 434 ++++++--- .../SetupAIAssistant/ConfigurationModal.tsx | 6 +- .../SetupAIAssistant/CustomProviderModal.tsx | 35 +- .../SetupAIAssistant/SettingsModal.tsx | 2 + 6 files changed, 1228 insertions(+), 145 deletions(-) diff --git a/e2e/commands.js b/e2e/commands.js index 6e2559af0..5dd0de036 100644 --- a/e2e/commands.js +++ b/e2e/commands.js @@ -628,3 +628,9 @@ Cypress.Commands.add("waitForAIResponse", (alias) => { cy.wait(alias) cy.waitForStreamingComplete() }) + +Cypress.Commands.add("setupCustomProvider", (config = {}) => { + const { getCustomProviderConfiguredSettings } = require("./utils/aiAssistant") + const settings = getCustomProviderConfiguredSettings(config) + cy.loadConsoleWithAuth(false, settings) +}) diff --git a/e2e/tests/console/aiAssistant.spec.js b/e2e/tests/console/aiAssistant.spec.js index 72d8e733e..ef7b181e1 100644 --- a/e2e/tests/console/aiAssistant.spec.js +++ b/e2e/tests/console/aiAssistant.spec.js @@ -2,6 +2,7 @@ const { PROVIDERS, + CUSTOM_PROVIDER_DEFAULTS, getOpenAIConfiguredSettings, getAnthropicConfiguredSettings, createToolCallFlow, @@ -10,6 +11,8 @@ const { createFinalResponseData, createChatTitleResponse, isTitleRequest, + getCustomProviderConfiguredSettings, + getCustomProviderEndpoint, } = require("../../utils/aiAssistant") /** @@ -266,28 +269,24 @@ describe("ai assistant", () => { // Then cy.getByDataHook("ai-settings-modal-step-one").should("be.visible") - cy.getByDataHook("ai-settings-api-key") - .should("be.visible") - .should("have.attr", "placeholder", "Enter API key") - .should("be.disabled") + // API key input is hidden until a provider is selected + cy.getByDataHook("ai-settings-api-key").should("not.exist") - // When + // When - select Anthropic cy.getByDataHook("ai-settings-provider-anthropic").click() - // Then + // Then - API key input appears cy.getByDataHook("ai-settings-api-key") .should("be.visible") .should("have.attr", "placeholder", "Enter Anthropic API key") - .should("not.be.disabled") - // When + // When - switch to OpenAI cy.getByDataHook("ai-settings-provider-openai").click() // Then cy.getByDataHook("ai-settings-api-key") .should("be.visible") .should("have.attr", "placeholder", "Enter OpenAI API key") - .should("not.be.disabled") ;["anthropic", "openai"].forEach((provider) => { // Given interceptTokenValidation(provider, false) @@ -303,7 +302,6 @@ describe("ai assistant", () => { "placeholder", `Enter ${provider === "anthropic" ? "Anthropic" : "OpenAI"} API key`, ) - .should("not.be.disabled") .should("be.empty") // When @@ -415,9 +413,9 @@ describe("ai assistant", () => { .should("contain", "Inactive") // When - cy.getByDataHook("ai-settings-test-api") + cy.getByDataHook("ai-settings-remove-provider").scrollIntoView() + cy.getByDataHook("ai-settings-remove-provider") .should("be.visible") - .should("contain", "Remove API Key") .click() // Then @@ -2536,3 +2534,871 @@ Syntax: \`avg(column)\` }) }) }) + +describe("custom providers", () => { + beforeEach(() => { + cy.intercept("POST", PROVIDERS.openai.endpoint, (req) => { + throw new Error( + `Unhandled OpenAI request detected! Request body: ${JSON.stringify(req.body).slice(0, 200)}...`, + ) + }).as("unhandledOpenAI") + + cy.intercept("POST", PROVIDERS.anthropic.endpoint, (req) => { + throw new Error( + `Unhandled Anthropic request detected! Request body: ${JSON.stringify(req.body).slice(0, 200)}...`, + ) + }).as("unhandledAnthropic") + }) + + it("should configure provider with auto-fetched models, select/deselect, and verify localStorage", () => { + cy.loadConsoleWithAuth() + + cy.intercept("GET", "**/models*", { + statusCode: 200, + body: { + object: "list", + data: [ + { id: "llama3", object: "model" }, + { id: "mistral", object: "model" }, + { id: "codellama", object: "model" }, + ], + }, + }).as("modelListRequest") + + cy.getByDataHook("ai-assistant-settings-button") + .should("be.visible") + .click() + cy.getByDataHook("ai-promo-continue").should("be.visible").click() + cy.getByDataHook("ai-settings-modal-step-one").should("be.visible") + cy.getByDataHook("ai-settings-provider-custom").should("be.visible").click() + + cy.getByDataHook("custom-provider-name-input") + .should("be.visible") + .type("Ollama") + cy.getByDataHook("custom-provider-type-select").should( + "have.value", + "openai-chat-completions", + ) + cy.getByDataHook("custom-provider-base-url-input").type( + "http://localhost:11434/v1", + ) + + cy.getByDataHook("multi-step-modal-next-button").click() + cy.wait("@modelListRequest") + + cy.getByDataHook("custom-provider-model-row").should("have.length", 3) + cy.getByDataHook("custom-provider-context-window-input").should( + "have.value", + "200000", + ) + + cy.getByDataHook("custom-provider-select-all").click() + cy.getByDataHook("custom-provider-model-row") + .find('input[type="checkbox"]') + .each(($checkbox) => { + cy.wrap($checkbox).should("be.checked") + }) + + cy.getByDataHook("custom-provider-deselect-all").click() + cy.getByDataHook("custom-provider-model-row") + .find('input[type="checkbox"]') + .each(($checkbox) => { + cy.wrap($checkbox).should("not.be.checked") + }) + + cy.getByDataHook("custom-provider-model-row").contains("llama3").click() + cy.getByDataHook("custom-provider-model-row").contains("mistral").click() + + cy.getByDataHook("custom-provider-manual-model-input").type( + "custom-finetune", + ) + cy.getByDataHook("custom-provider-add-model-button").click() + cy.getByDataHook("custom-provider-model-chip").should("have.length", 1) + cy.getByDataHook("custom-provider-model-chip").should( + "contain", + "custom-finetune", + ) + + cy.getByDataHook("custom-provider-remove-model").click() + cy.getByDataHook("custom-provider-model-chip").should("not.exist") + + cy.getByDataHook("custom-provider-schema-access").check() + cy.getByDataHook("multi-step-modal-next-button").click() + + cy.contains("AI Assistant activated successfully").should("be.visible") + cy.getByDataHook("ai-chat-button").should("be.visible") + + cy.window().then((win) => { + const settings = JSON.parse( + win.localStorage.getItem("ai.assistant.settings"), + ) + expect(settings.customProviders.ollama).to.exist + expect(settings.customProviders.ollama.type).to.equal( + "openai-chat-completions", + ) + expect(settings.customProviders.ollama.name).to.equal("Ollama") + expect(settings.customProviders.ollama.baseURL).to.equal( + "http://localhost:11434/v1", + ) + expect(settings.customProviders.ollama.models).to.deep.equal([ + "llama3", + "mistral", + ]) + expect(settings.customProviders.ollama.contextWindow).to.equal(200000) + expect(settings.customProviders.ollama.grantSchemaAccess).to.be.true + + expect(settings.providers.ollama.enabledModels).to.deep.equal([ + "ollama:llama3", + "ollama:mistral", + ]) + expect(settings.selectedModel).to.equal("ollama:llama3") + }) + }) + + it("should reject invalid URL, require models, enforce context window minimum, and prevent duplicates", () => { + cy.loadConsoleWithAuth() + + cy.getByDataHook("ai-assistant-settings-button") + .should("be.visible") + .click() + cy.getByDataHook("ai-promo-continue").should("be.visible").click() + cy.getByDataHook("ai-settings-provider-custom").should("be.visible").click() + + cy.getByDataHook("multi-step-modal-next-button").should("be.disabled") + + cy.getByDataHook("custom-provider-name-input").type("OpenRouter") + cy.getByDataHook("custom-provider-base-url-input").type("ftp://invalid") + cy.getByDataHook("multi-step-modal-next-button").click() + cy.contains("Base URL must start with http:// or https://").should( + "be.visible", + ) + + cy.getByDataHook("custom-provider-base-url-input") + .clear() + .type("https://openrouter.ai/api/v1") + cy.getByDataHook("custom-provider-api-key-input").type("sk-test") + + cy.intercept("GET", "**/models*", { + statusCode: 500, + body: { error: "Internal Server Error" }, + }).as("modelListFail") + + cy.getByDataHook("multi-step-modal-next-button").click() + cy.wait("@modelListFail") + + cy.getByDataHook("custom-provider-warning-banner").should("be.visible") + + cy.getByDataHook("multi-step-modal-next-button").click() + cy.contains("Add at least one model").should("be.visible") + + cy.getByDataHook("custom-provider-manual-model-input").type("gpt-4o{enter}") + cy.getByDataHook("custom-provider-model-chip") + .should("have.length", 1) + .should("contain", "gpt-4o") + + // Duplicate model should not create a second chip + cy.getByDataHook("custom-provider-manual-model-input").type("gpt-4o") + cy.getByDataHook("custom-provider-add-model-button").click() + cy.getByDataHook("custom-provider-model-chip").should("have.length", 1) + + // Input not cleared on duplicate, clear manually + cy.getByDataHook("custom-provider-manual-model-input") + .clear() + .type("claude-3.5-sonnet") + cy.getByDataHook("custom-provider-add-model-button").click() + cy.getByDataHook("custom-provider-model-chip").should("have.length", 2) + + cy.getByDataHook("custom-provider-add-model-button").should("be.disabled") + + cy.getByDataHook("custom-provider-context-window-input").type( + "{selectall}50000", + ) + cy.getByDataHook("custom-provider-context-window-input").should( + "have.value", + "50000", + ) + cy.getByDataHook("multi-step-modal-next-button").click() + cy.contains("Context window must be at least 100,000 tokens").should( + "be.visible", + ) + + cy.getByDataHook("custom-provider-context-window-input").type( + "{selectall}100000", + ) + cy.getByDataHook("multi-step-modal-next-button").click() + + cy.contains("AI Assistant activated successfully").should("be.visible") + cy.getByDataHook("ai-chat-button").should("be.visible") + + cy.window().then((win) => { + const settings = JSON.parse( + win.localStorage.getItem("ai.assistant.settings"), + ) + expect(settings.customProviders.openrouter).to.exist + expect(settings.customProviders.openrouter.models).to.deep.equal([ + "gpt-4o", + "claude-3.5-sonnet", + ]) + expect(settings.customProviders.openrouter.apiKey).to.equal("sk-test") + expect(settings.customProviders.openrouter.contextWindow).to.equal(100000) + }) + }) + + it("should send chat with tool call through custom endpoint and accept SQL suggestion", () => { + const customBaseURL = CUSTOM_PROVIDER_DEFAULTS.baseURL + const customEndpoint = getCustomProviderEndpoint( + customBaseURL, + "openai-chat-completions", + ) + + cy.setupCustomProvider() + cy.createTable("btc_trades") + cy.refreshSchema() + + const assistantResponse = + "Here are the tables in your database. Let me write a query for btc_trades." + const sql = "SELECT * FROM btc_trades LIMIT 10;" + + const flow = createToolCallFlow({ + provider: "openai-chat-completions", + streaming: true, + question: "What tables are in the database?", + endpoint: customEndpoint, + steps: [ + { toolCall: { name: "get_tables", args: {} } }, + { + finalResponse: { + explanation: assistantResponse, + sql: sql, + }, + expectToolResult: { includes: ["btc_trades"] }, + }, + ], + }) + + flow.intercept() + + cy.getByDataHook("ai-chat-button").click() + cy.getByDataHook("chat-input-textarea").should("be.visible") + cy.getByDataHook("chat-input-textarea").type(flow.question) + cy.getByDataHook("chat-send-button").click() + + flow.waitForCompletion() + + cy.getByDataHook("chat-message-assistant") + .should("be.visible") + .should("contain", assistantResponse) + + cy.getByDataHook("assistant-mode-processing-collapsed").click() + cy.getByDataHook("assistant-mode-reviewing-tables").should("exist") + + cy.getByDataHook("message-action-accept").should("be.visible") + cy.getByDataHook("message-action-accept").click() + + cy.getByDataHook("diff-status-accepted").should("contain", "Accepted") + cy.getByDataHook("chat-context-badge").should( + "contain", + "SELECT * FROM btc_trades", + ) + + cy.dropTableIfExists("btc_trades") + }) + + it("should toggle models, add second provider, remove first, and update model dropdown", () => { + cy.loadConsoleWithAuth( + false, + getCustomProviderConfiguredSettings({ + providerId: "ollama", + name: "Ollama", + models: ["llama3", "mistral", "codellama"], + }), + ) + + cy.getByDataHook("ai-settings-model-dropdown").should("be.visible").click() + cy.getByDataHook("ai-settings-model-item").should("have.length", 3) + cy.contains("llama3").should("be.visible") + cy.contains("mistral").should("be.visible") + cy.contains("codellama").should("be.visible") + cy.getByDataHook("ai-settings-model-dropdown").click() + + cy.getByDataHook("ai-assistant-settings-button").click() + cy.getByDataHook("ai-settings-provider-ollama").should("be.visible").click() + + cy.get("[data-model='llama3']").should("exist") + cy.get("[data-model='mistral']").should("exist") + cy.get("[data-model='codellama']").should("exist") + + cy.get("[data-model='mistral']").find("button[role='switch']").click() + cy.get("[data-model='mistral'][data-enabled='false']").should("exist") + cy.getByDataHook("ai-settings-save").click() + + cy.getByDataHook("ai-settings-model-dropdown").should("be.visible").click() + cy.getByDataHook("ai-settings-model-item").should("have.length", 2) + cy.contains("llama3").should("be.visible") + cy.contains("codellama").should("be.visible") + cy.getByDataHook("ai-settings-model-dropdown").click() + + cy.getByDataHook("ai-assistant-settings-button").click() + cy.getByDataHook("ai-settings-provider-ollama").click() + cy.get("[data-model='mistral'][data-enabled='false']").should("exist") + + cy.get("[data-model='mistral']").find("button[role='switch']").click() + cy.get("[data-model='mistral'][data-enabled='true']").should("exist") + cy.getByDataHook("ai-settings-save").click() + + cy.getByDataHook("ai-settings-model-dropdown").should("be.visible").click() + cy.getByDataHook("ai-settings-model-item").should("have.length", 3) + cy.getByDataHook("ai-settings-model-dropdown").click() + + cy.getByDataHook("ai-assistant-settings-button").click() + cy.getByDataHook("ai-settings-add-custom-provider") + .should("be.visible") + .click() + + cy.getByDataHook("custom-provider-name-input").type("OpenRouter") + cy.getByDataHook("custom-provider-base-url-input").type( + "https://openrouter.ai/api/v1", + ) + + cy.intercept("GET", "https://openrouter.ai/api/v1/models", { + statusCode: 500, + body: { error: "Server error" }, + }) + + cy.getByDataHook("multi-step-modal-next-button").click() + + cy.getByDataHook("custom-provider-warning-banner").should("be.visible") + cy.getByDataHook("custom-provider-manual-model-input").type("gpt-4o") + cy.getByDataHook("custom-provider-add-model-button").click() + cy.getByDataHook("multi-step-modal-next-button").click() + + cy.getByDataHook("ai-settings-provider-ollama").should("be.visible") + cy.getByDataHook("ai-settings-provider-openrouter").should("be.visible") + cy.getByDataHook("ai-settings-save").click() + + cy.getByDataHook("ai-settings-model-dropdown").should("be.visible").click() + cy.getByDataHook("ai-settings-model-item").should("have.length", 4) + cy.contains("gpt-4o").should("be.visible") + cy.getByDataHook("ai-settings-model-dropdown").click() + + cy.getByDataHook("ai-assistant-settings-button").click() + cy.getByDataHook("ai-settings-provider-ollama").click() + cy.getByDataHook("ai-settings-remove-provider").click() + + cy.getByDataHook("ai-settings-provider-ollama").should("not.exist") + cy.getByDataHook("ai-settings-provider-openrouter").should("be.visible") + cy.getByDataHook("ai-settings-save").click() + + cy.getByDataHook("ai-settings-model-dropdown").should("be.visible").click() + cy.getByDataHook("ai-settings-model-item").should("have.length", 1) + cy.contains("gpt-4o").should("be.visible") + cy.getByDataHook("ai-settings-model-dropdown").click() + + cy.window().then((win) => { + const settings = JSON.parse( + win.localStorage.getItem("ai.assistant.settings"), + ) + expect(settings.customProviders.ollama).to.not.exist + expect(settings.providers.ollama).to.not.exist + expect(settings.customProviders.openrouter).to.exist + expect(settings.providers.openrouter).to.exist + }) + }) + + it("should show error on 401, retry successfully, and show error on network failure", () => { + const customEndpoint = getCustomProviderEndpoint( + CUSTOM_PROVIDER_DEFAULTS.baseURL, + "openai-chat-completions", + ) + + cy.setupCustomProvider() + + cy.intercept("POST", customEndpoint, (req) => { + if (isTitleRequest("openai-chat-completions", req.body)) { + req.reply( + createChatTitleResponse("openai-chat-completions", "Test Chat"), + ) + return + } + req.reply({ + statusCode: 401, + body: { + error: { + type: "authentication_error", + message: "Invalid API key", + }, + }, + }) + }).as("errorRequest") + + cy.getByDataHook("ai-chat-button").click() + cy.getByDataHook("chat-input-textarea").should("be.visible") + cy.getByDataHook("chat-input-textarea").type("Test error handling") + cy.getByDataHook("chat-send-button").click() + + cy.wait("@errorRequest") + cy.getByDataHook("chat-message-error").should("be.visible") + cy.getByDataHook("retry-button").should("be.visible") + + cy.intercept("POST", customEndpoint, (req) => { + if (isTitleRequest("openai-chat-completions", req.body)) { + req.reply( + createChatTitleResponse("openai-chat-completions", "Test Chat"), + ) + return + } + req.reply( + createResponse( + "openai-chat-completions", + createFinalResponseData( + "openai-chat-completions", + "Successful response after retry", + null, + ), + { streaming: true }, + ), + ) + }).as("successRequest") + + cy.getByDataHook("retry-button").click() + + cy.wait("@successRequest") + cy.waitForStreamingComplete() + cy.getByDataHook("chat-message-error").should("not.exist") + cy.getByDataHook("chat-message-assistant") + .should("be.visible") + .should("contain", "Successful response after retry") + + cy.getByDataHook("chat-window-new").click() + + cy.intercept("POST", customEndpoint, (req) => { + if (isTitleRequest("openai-chat-completions", req.body)) { + req.reply( + createChatTitleResponse("openai-chat-completions", "Test Chat"), + ) + return + } + req.destroy() + }).as("networkError") + + cy.getByDataHook("chat-input-textarea").should("be.visible") + cy.getByDataHook("chat-input-textarea").type("Test network error") + cy.getByDataHook("chat-send-button").click() + + cy.wait("@networkError") + cy.getByDataHook("chat-message-error").should("be.visible") + cy.getByDataHook("retry-button").should("be.visible") + }) + + it("should reject duplicate names against custom and built-in providers, and sanitize special characters in IDs", () => { + cy.loadConsoleWithAuth( + false, + getCustomProviderConfiguredSettings({ + providerId: "my-provider", + name: "My Provider", + models: ["test-model"], + }), + ) + + cy.getByDataHook("ai-assistant-settings-button").click() + cy.getByDataHook("ai-settings-add-custom-provider").click() + + cy.getByDataHook("custom-provider-name-input").type("My Provider") + cy.getByDataHook("custom-provider-base-url-input").type( + "http://localhost:1234", + ) + + cy.intercept("GET", "http://localhost:1234/models", { + statusCode: 500, + body: { error: "not needed" }, + }) + + cy.getByDataHook("multi-step-modal-next-button").click() + cy.contains("A provider with a similar name already exists").should( + "be.visible", + ) + + // "OpenAI" collides with built-in provider ID "openai" + cy.getByDataHook("custom-provider-name-input").clear().type("OpenAI") + + cy.intercept("GET", "http://localhost:1234/models", { + statusCode: 500, + body: { error: "not needed" }, + }) + + cy.getByDataHook("multi-step-modal-next-button").click() + cy.contains("A provider with a similar name already exists").should( + "be.visible", + ) + + // Special characters should be stripped from the generated ID + cy.getByDataHook("custom-provider-name-input") + .clear() + .type("My Provider (v2.0)!") + + cy.intercept("GET", "http://localhost:1234/models", { + statusCode: 500, + body: { error: "not needed" }, + }) + + cy.getByDataHook("multi-step-modal-next-button").click() + cy.getByDataHook("custom-provider-warning-banner").should("be.visible") + + cy.getByDataHook("custom-provider-manual-model-input").type("test-model-2") + cy.getByDataHook("custom-provider-add-model-button").click() + cy.getByDataHook("multi-step-modal-next-button").click() + + cy.getByDataHook("ai-settings-provider-my-provider").should("be.visible") + cy.getByDataHook("ai-settings-provider-my-provider-v2-0").should( + "be.visible", + ) + + cy.getByDataHook("ai-settings-save").click() + cy.window().then((win) => { + const settings = JSON.parse( + win.localStorage.getItem("ai.assistant.settings"), + ) + expect(settings.customProviders["my-provider-v2-0"]).to.exist + expect(settings.customProviders["my-provider-v2-0"].name).to.equal( + "My Provider (v2.0)!", + ) + }) + }) + + it("should route Anthropic-type provider requests to custom base URL", () => { + const anthropicBaseURL = "http://localhost:8080" + + cy.loadConsoleWithAuth( + false, + getCustomProviderConfiguredSettings({ + providerId: "custom-anthropic", + name: "Custom Anthropic", + type: "anthropic", + baseURL: anthropicBaseURL, + apiKey: "test-anthropic-key", + models: ["claude-custom"], + }), + ) + + // Anthropic SDK appends /v1/messages to baseURL + cy.intercept("POST", "http://localhost:8080/v1/messages", (req) => { + if (isTitleRequest("anthropic", req.body)) { + req.reply(createChatTitleResponse("anthropic", "Test Chat")) + return + } + req.reply( + createResponse( + "anthropic", + createFinalResponseData( + "anthropic", + "Response from custom Anthropic provider", + null, + ), + { streaming: true }, + ), + ) + }).as("anthropicRequest") + + cy.intercept("POST", "https://api.anthropic.com/**", () => { + throw new Error( + "Request should not go to api.anthropic.com for custom provider", + ) + }) + + cy.getByDataHook("ai-chat-button").click() + cy.getByDataHook("chat-input-textarea").should("be.visible") + cy.getByDataHook("chat-input-textarea").type("Test Anthropic custom") + cy.getByDataHook("chat-send-button").click() + + cy.wait("@anthropicRequest") + cy.waitForStreamingComplete() + + cy.getByDataHook("chat-message-assistant") + .should("be.visible") + .should("contain", "Response from custom Anthropic provider") + }) + + it("should route requests to correct endpoint when switching between built-in and custom models", () => { + const customBaseURL = CUSTOM_PROVIDER_DEFAULTS.baseURL + const customEndpoint = getCustomProviderEndpoint( + customBaseURL, + "openai-chat-completions", + ) + + const openaiSettings = getOpenAIConfiguredSettings() + cy.loadConsoleWithAuth( + false, + getCustomProviderConfiguredSettings( + { + providerId: "ollama", + name: "Ollama", + models: ["llama3"], + }, + openaiSettings, + ), + ) + + cy.intercept("POST", customEndpoint, (req) => { + if (isTitleRequest("openai-chat-completions", req.body)) { + req.reply( + createChatTitleResponse("openai-chat-completions", "Ollama Chat"), + ) + return + } + req.reply( + createResponse( + "openai-chat-completions", + createFinalResponseData( + "openai-chat-completions", + "Response from Ollama", + null, + ), + { streaming: true }, + ), + ) + }).as("ollamaRequest") + + cy.intercept("POST", PROVIDERS.openai.endpoint, (req) => { + if (isTitleRequest("openai", req.body)) { + req.reply(createChatTitleResponse("openai", "OpenAI Chat")) + return + } + req.reply( + createResponse( + "openai", + createFinalResponseData("openai", "Response from OpenAI", null), + { streaming: true }, + ), + ) + }).as("openaiRequest") + + cy.getByDataHook("ai-chat-button").click() + cy.getByDataHook("chat-input-textarea").should("be.visible") + cy.getByDataHook("chat-input-textarea").type("Test with Ollama") + cy.getByDataHook("chat-send-button").click() + + cy.wait("@ollamaRequest") + cy.waitForStreamingComplete() + cy.getByDataHook("chat-message-assistant") + .should("be.visible") + .should("contain", "Response from Ollama") + + cy.getByDataHook("ai-settings-model-dropdown").click() + cy.getByDataHook("ai-settings-model-item-label") + .contains("GPT-5 mini") + .click() + + cy.getByDataHook("chat-window-new").click() + cy.getByDataHook("chat-input-textarea").should("be.visible") + cy.getByDataHook("chat-input-textarea").type("Test with OpenAI") + cy.getByDataHook("chat-send-button").click() + + cy.wait("@openaiRequest") + cy.waitForStreamingComplete() + cy.getByDataHook("chat-message-assistant") + .should("be.visible") + .should("contain", "Response from OpenAI") + + cy.getByDataHook("ai-assistant-settings-button").click() + cy.getByDataHook("ai-settings-provider-openai").should("be.visible") + cy.getByDataHook("ai-settings-provider-ollama").should("be.visible") + cy.getByDataHook("ai-settings-cancel").click() + }) + + it("should reset fields on cancel, preserve them on back, and add model via Enter key", () => { + cy.loadConsoleWithAuth() + + cy.getByDataHook("ai-assistant-settings-button").click() + cy.getByDataHook("ai-promo-continue").click() + cy.getByDataHook("ai-settings-provider-custom").click() + + cy.getByDataHook("custom-provider-name-input").type("Partial") + cy.getByDataHook("custom-provider-base-url-input").type( + "http://localhost:1234", + ) + + cy.getByDataHook("multi-step-modal-cancel-button").click() + + cy.getByDataHook("ai-settings-provider-custom").click() + cy.getByDataHook("custom-provider-name-input").should("have.value", "") + cy.getByDataHook("custom-provider-base-url-input").should("have.value", "") + + cy.getByDataHook("custom-provider-name-input").type("Test Provider") + cy.getByDataHook("custom-provider-base-url-input").type( + "http://localhost:5555", + ) + + cy.intercept("GET", "http://localhost:5555/models", { + statusCode: 500, + body: { error: "fail" }, + }) + + cy.getByDataHook("multi-step-modal-next-button").click() + + cy.getByDataHook("custom-provider-warning-banner").should("be.visible") + cy.getByDataHook("custom-provider-context-window-input").should( + "have.value", + "200000", + ) + cy.getByDataHook("custom-provider-schema-access").should("not.be.checked") + cy.getByDataHook("custom-provider-add-model-button").should("be.disabled") + + cy.getByDataHook("custom-provider-manual-model-input").type( + "enter-model{enter}", + ) + cy.getByDataHook("custom-provider-model-chip").should("have.length", 1) + cy.getByDataHook("custom-provider-manual-model-input").should( + "have.value", + "", + ) + + // Back button preserves step 1 fields + cy.getByDataHook("multi-step-modal-cancel-button").click() + cy.getByDataHook("custom-provider-name-input").should( + "have.value", + "Test Provider", + ) + cy.getByDataHook("custom-provider-base-url-input").should( + "have.value", + "http://localhost:5555", + ) + + cy.intercept("GET", "http://localhost:5555/models", { + statusCode: 500, + body: { error: "fail" }, + }) + cy.getByDataHook("multi-step-modal-next-button").click() + cy.getByDataHook("custom-provider-warning-banner").should("be.visible") + }) + + it("should preserve custom provider settings and chat after page reload", () => { + const customEndpoint = getCustomProviderEndpoint( + CUSTOM_PROVIDER_DEFAULTS.baseURL, + "openai-chat-completions", + ) + const settings = getCustomProviderConfiguredSettings() + + cy.loadConsoleWithAuth(false, settings) + cy.getByDataHook("ai-chat-button").should("be.visible") + + cy.loadConsoleWithAuth(false, settings) + cy.getByDataHook("ai-chat-button").should("be.visible") + + cy.intercept("POST", customEndpoint, (req) => { + if (isTitleRequest("openai-chat-completions", req.body)) { + req.reply( + createChatTitleResponse("openai-chat-completions", "Test Chat"), + ) + return + } + req.reply( + createResponse( + "openai-chat-completions", + createFinalResponseData( + "openai-chat-completions", + "Working after reload", + null, + ), + { streaming: true }, + ), + ) + }).as("chatRequest") + + cy.getByDataHook("ai-chat-button").click() + cy.getByDataHook("chat-input-textarea").should("be.visible") + cy.getByDataHook("chat-input-textarea").type("Test after reload") + cy.getByDataHook("chat-send-button").click() + + cy.wait("@chatRequest") + cy.waitForStreamingComplete() + cy.getByDataHook("chat-message-assistant") + .should("be.visible") + .should("contain", "Working after reload") + }) + + it("should omit auth token without API key and send Bearer token when API key is configured", () => { + const customEndpoint = getCustomProviderEndpoint( + CUSTOM_PROVIDER_DEFAULTS.baseURL, + "openai-chat-completions", + ) + + cy.loadConsoleWithAuth( + false, + getCustomProviderConfiguredSettings({ + apiKey: "", + }), + ) + + cy.intercept("POST", customEndpoint, (req) => { + if (isTitleRequest("openai-chat-completions", req.body)) { + req.reply( + createChatTitleResponse("openai-chat-completions", "Test Chat"), + ) + return + } + const auth = req.headers["authorization"] || "" + expect(auth).to.not.include("sk-") + req.reply( + createResponse( + "openai-chat-completions", + createFinalResponseData( + "openai-chat-completions", + "No auth response", + null, + ), + { streaming: true }, + ), + ) + }).as("noAuthRequest") + + cy.getByDataHook("ai-chat-button").click() + cy.getByDataHook("chat-input-textarea").should("be.visible") + cy.getByDataHook("chat-input-textarea").type("Test no auth") + cy.getByDataHook("chat-send-button").click() + + cy.wait("@noAuthRequest") + cy.waitForStreamingComplete() + cy.getByDataHook("chat-message-assistant") + .should("be.visible") + .should("contain", "No auth response") + + cy.loadConsoleWithAuth( + false, + getCustomProviderConfiguredSettings({ + apiKey: "sk-test-key-123", + }), + ) + + cy.intercept("POST", customEndpoint, (req) => { + if (isTitleRequest("openai-chat-completions", req.body)) { + req.reply( + createChatTitleResponse("openai-chat-completions", "Test Chat"), + ) + return + } + expect(req.headers["authorization"]).to.equal("Bearer sk-test-key-123") + req.reply( + createResponse( + "openai-chat-completions", + createFinalResponseData( + "openai-chat-completions", + "Auth response", + null, + ), + { streaming: true }, + ), + ) + }).as("authRequest") + + cy.getByDataHook("ai-chat-button").click() + cy.getByDataHook("chat-input-textarea").should("be.visible") + cy.getByDataHook("chat-input-textarea").type("Test with auth") + cy.getByDataHook("chat-send-button").click() + + cy.wait("@authRequest") + cy.waitForStreamingComplete() + cy.getByDataHook("chat-message-assistant") + .should("be.visible") + .should("contain", "Auth response") + }) +}) diff --git a/e2e/utils/aiAssistant.js b/e2e/utils/aiAssistant.js index 2113f03dc..30e02efbe 100644 --- a/e2e/utils/aiAssistant.js +++ b/e2e/utils/aiAssistant.js @@ -9,6 +9,16 @@ const PROVIDERS = { }, } +const CUSTOM_PROVIDER_DEFAULTS = { + providerId: "test-provider", + name: "Test Provider", + type: "openai-chat-completions", + baseURL: "http://localhost:11434/v1", + models: ["test-model-1"], + contextWindow: 200000, + grantSchemaAccess: true, +} + function getOpenAIConfiguredSettings(schemaAccess = true) { return { "ai.assistant.settings": JSON.stringify({ @@ -39,6 +49,76 @@ function getAnthropicConfiguredSettings(schemaAccess = true) { } } +/** + * Returns localStorage settings for a pre-configured custom provider. + * Can optionally merge with existing settings (e.g., a built-in provider). + */ +function getCustomProviderConfiguredSettings(config = {}, mergeWith = null) { + const { + providerId = CUSTOM_PROVIDER_DEFAULTS.providerId, + name = CUSTOM_PROVIDER_DEFAULTS.name, + type = CUSTOM_PROVIDER_DEFAULTS.type, + baseURL = CUSTOM_PROVIDER_DEFAULTS.baseURL, + apiKey = "", + models = CUSTOM_PROVIDER_DEFAULTS.models, + contextWindow = CUSTOM_PROVIDER_DEFAULTS.contextWindow, + grantSchemaAccess = CUSTOM_PROVIDER_DEFAULTS.grantSchemaAccess, + } = config + + const enabledModels = models.map((m) => `${providerId}:${m}`) + + const baseSettings = mergeWith + ? JSON.parse(mergeWith["ai.assistant.settings"]) + : {} + + const settings = { + ...baseSettings, + selectedModel: enabledModels[0], + customProviders: { + ...(baseSettings.customProviders || {}), + [providerId]: { + type, + name, + baseURL, + ...(apiKey ? { apiKey } : {}), + contextWindow, + models, + grantSchemaAccess, + }, + }, + providers: { + ...(baseSettings.providers || {}), + [providerId]: { + apiKey: apiKey || "", + enabledModels, + grantSchemaAccess, + }, + }, + } + + return { + "ai.assistant.settings": JSON.stringify(settings), + } +} + +/** + * Returns the API endpoint for a custom provider based on its type. + */ +function getCustomProviderEndpoint(baseURL, type) { + if (type === "openai-chat-completions") { + return `${baseURL}/chat/completions` + } + if (type === "openai") { + return `${baseURL}/responses` + } + // anthropic - SDK appends /v1/messages to baseURL + return `${baseURL}/v1/messages` +} + +// ============================================================================= +// RESPONSE DATA BUILDERS +// ============================================================================= + function createFinalResponseData(provider, explanation, sql = null) { const responseContent = { explanation, sql } @@ -62,6 +142,26 @@ function createFinalResponseData(provider, explanation, sql = null) { } } + if (provider === "openai-chat-completions") { + return { + id: "chatcmpl-mock-final", + object: "chat.completion", + created: Math.floor(Date.now() / 1000), + model: "test-model-1", + choices: [ + { + index: 0, + message: { + role: "assistant", + content: JSON.stringify(responseContent), + }, + finish_reason: "stop", + }, + ], + usage: { prompt_tokens: 200, completion_tokens: 100, total_tokens: 300 }, + } + } + // Anthropic return { id: "msg_mock_final", @@ -97,6 +197,36 @@ function createToolCallResponseData(provider, toolName, toolArguments = {}) { } } + if (provider === "openai-chat-completions") { + return { + id: `chatcmpl-mock-tool-${toolName}`, + object: "chat.completion", + created: Math.floor(Date.now() / 1000), + model: "test-model-1", + choices: [ + { + index: 0, + message: { + role: "assistant", + content: null, + tool_calls: [ + { + id: callId, + type: "function", + function: { + name: toolName, + arguments: JSON.stringify(toolArguments), + }, + }, + ], + }, + finish_reason: "tool_calls", + }, + ], + usage: { prompt_tokens: 100, completion_tokens: 50, total_tokens: 150 }, + } + } + // Anthropic return { id: `msg_mock_tool_${toolName}`, @@ -137,6 +267,23 @@ function createChatTitleResponseData(provider, title = "Test Chat") { } } + if (provider === "openai-chat-completions") { + return { + id: "chatcmpl-mock-title", + object: "chat.completion", + created: Math.floor(Date.now() / 1000), + model: "test-model-1", + choices: [ + { + index: 0, + message: { role: "assistant", content: content }, + finish_reason: "stop", + }, + ], + usage: { prompt_tokens: 50, completion_tokens: 20, total_tokens: 70 }, + } + } + // Anthropic return { id: "msg_mock_title", @@ -149,6 +296,10 @@ function createChatTitleResponseData(provider, title = "Test Chat") { } } +// ============================================================================= +// SSE RESPONSE BUILDERS +// ============================================================================= + function createOpenAISSEResponse(responseData, delay = 0) { const events = [] @@ -189,6 +340,121 @@ function createOpenAISSEResponse(responseData, delay = 0) { return response } +function createChatCompletionsSSEResponse(responseData, delay = 0) { + const events = [] + const choice = responseData.choices?.[0] + const content = choice?.message?.content || "" + const toolCalls = choice?.message?.tool_calls || [] + + // Stream content deltas + if (content) { + const chunkSize = 20 + for (let i = 0; i < content.length; i += chunkSize) { + const chunk = content.slice(i, i + chunkSize) + events.push( + `data: ${JSON.stringify({ + id: responseData.id, + object: "chat.completion.chunk", + choices: [ + { + index: 0, + delta: { content: chunk }, + finish_reason: null, + }, + ], + })}\n\n`, + ) + } + } + + // Stream tool call deltas + if (toolCalls.length > 0) { + for (const tc of toolCalls) { + // First chunk: tool call start + events.push( + `data: ${JSON.stringify({ + id: responseData.id, + object: "chat.completion.chunk", + choices: [ + { + index: 0, + delta: { + tool_calls: [ + { + index: 0, + id: tc.id, + type: "function", + function: { name: tc.function.name, arguments: "" }, + }, + ], + }, + finish_reason: null, + }, + ], + })}\n\n`, + ) + // Second chunk: tool call arguments + events.push( + `data: ${JSON.stringify({ + id: responseData.id, + object: "chat.completion.chunk", + choices: [ + { + index: 0, + delta: { + tool_calls: [ + { + index: 0, + function: { arguments: tc.function.arguments }, + }, + ], + }, + finish_reason: null, + }, + ], + })}\n\n`, + ) + } + } + + // Final chunk with finish_reason and usage + events.push( + `data: ${JSON.stringify({ + id: responseData.id, + object: "chat.completion.chunk", + choices: [ + { + index: 0, + delta: {}, + finish_reason: choice?.finish_reason || "stop", + }, + ], + usage: responseData.usage, + })}\n\n`, + ) + + // [DONE] marker + events.push("data: [DONE]\n\n") + + const sseBody = events.join("") + + const response = { + statusCode: 200, + headers: { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + Connection: "keep-alive", + }, + body: sseBody, + } + + if (delay > 0) { + response.delay = delay + } + + return response +} + function createAnthropicSSEResponse(responseData, delay = 0) { const events = [] @@ -327,6 +593,10 @@ function createAnthropicSSEResponse(responseData, delay = 0) { return response } +// ============================================================================= +// RESPONSE WRAPPERS +// ============================================================================= + function createResponse(provider, responseData, options = {}) { const { streaming = true, delay = 0 } = options @@ -347,6 +617,10 @@ function createResponse(provider, responseData, options = {}) { return createOpenAISSEResponse(responseData, delay) } + if (provider === "openai-chat-completions") { + return createChatCompletionsSSEResponse(responseData, delay) + } + return createAnthropicSSEResponse(responseData, delay) } @@ -378,6 +652,10 @@ function createChatTitleResponse(provider, title = "Test Chat") { }) } +// ============================================================================= +// REQUEST INSPECTION HELPERS +// ============================================================================= + function isTitleRequest(provider, body) { if (provider === "openai") { return ( @@ -385,7 +663,7 @@ function isTitleRequest(provider, body) { false ) } - // Anthropic + // Both openai-chat-completions and anthropic use messages array return ( body.messages?.[0]?.content?.includes?.("Generate a concise chat title") || false @@ -396,8 +674,7 @@ function requestMatchesQuestion(provider, body, question) { if (provider === "openai") { return body.input?.[0]?.content === question } - // Anthropic - find the first user message with string content (the original question) - // Subsequent messages may be tool results or assistant responses + // Both openai-chat-completions and anthropic use messages array const firstUserMessage = body.messages?.find( (msg) => msg.role === "user" && typeof msg.content === "string", ) @@ -413,8 +690,14 @@ function extractToolOutputContent(provider, body) { return latestOutput?.output || null } + if (provider === "openai-chat-completions") { + // Chat Completions: tool results are in messages with role "tool" + const toolMessages = body.messages?.filter((msg) => msg.role === "tool") + const latestToolMessage = toolMessages?.[toolMessages.length - 1] + return latestToolMessage?.content || null + } + // Anthropic - tool results are in user messages with content array containing tool_result objects - // Format: { role: "user", content: [{ type: "tool_result", tool_use_id: "...", content: "..." }] } const toolResultMessages = body.messages?.filter( (msg) => msg.role === "user" && @@ -425,7 +708,6 @@ function extractToolOutputContent(provider, body) { const latestToolResult = latestMessage?.content?.find( (c) => c.type === "tool_result", ) - // Anthropic tool result content can be a string directly return latestToolResult?.content || null } @@ -433,7 +715,7 @@ function extractAllInputContent(provider, body) { if (provider === "openai") { return body.input?.map((item) => item.content || "").join("\n") || "" } - // Anthropic + // Both openai-chat-completions and anthropic use messages return ( body.messages ?.map((msg) => { @@ -452,57 +734,22 @@ function extractAllInputContent(provider, body) { // ============================================================================= /** - * Creates a multi-turn tool call flow with automatic intercept handling - * - * @param {Object} config - Flow configuration - * @param {"openai" | "anthropic"} [config.provider="openai"] - The AI provider - * @param {boolean} [config.streaming=true] - Whether to use streaming responses - * @param {string} config.question - The user's question to match - * @param {Array} config.steps - Array of step definitions - * @param {Object} [config.steps[].toolCall] - Tool call definition { name, args } - * @param {Object} [config.steps[].expectToolResult] - Expected result { includes: string[] } - * @param {Object} [config.steps[].finalResponse] - Final response { explanation, sql } - * @returns {Object} Flow controller with intercept() and waitForCompletion() methods - * - * @example - * // OpenAI with streaming (default) - * const flow = createToolCallFlow({ - * question: "Describe the ecommerce_stats table", - * steps: [ - * { toolCall: { name: "get_tables", args: {} } }, - * { finalResponse: { explanation: "Table description...", sql: null } } - * ] - * }) + * Creates a multi-turn tool call flow with automatic intercept handling. + * Supports built-in providers (openai, anthropic) and custom providers. * - * @example - * // Anthropic with streaming - * const flow = createToolCallFlow({ - * provider: "anthropic", - * streaming: true, - * question: "What tables exist?", - * steps: [ - * { toolCall: { name: "get_tables", args: {} } }, - * { finalResponse: { explanation: "Found tables...", sql: null } } - * ] - * }) - * - * @example - * // OpenAI without streaming - * const flow = createToolCallFlow({ - * provider: "openai", - * streaming: false, - * question: "Quick test", - * steps: [ - * { finalResponse: { explanation: "Done", sql: null } } - * ] - * }) + * @param {Object} config + * @param {"openai" | "anthropic" | "openai-chat-completions"} [config.provider="openai"] + * @param {boolean} [config.streaming=true] + * @param {string} config.question + * @param {Array} config.steps + * @param {string} [config.endpoint] - Custom endpoint URL (overrides PROVIDERS lookup) */ function createToolCallFlow(config) { const { provider = "openai", streaming = true, question, steps } = config let requestCount = 0 const totalRequests = steps.length - const endpoint = PROVIDERS[provider].endpoint + const endpoint = config.endpoint || PROVIDERS[provider]?.endpoint const responseOptions = { streaming } return { @@ -510,9 +757,6 @@ function createToolCallFlow(config) { provider, streaming, - /** - * Sets up cy.intercept for both chat title and tool call flow - */ intercept() { // Handle chat title generation (never streamed) cy.intercept("POST", endpoint, (req) => { @@ -528,7 +772,6 @@ function createToolCallFlow(config) { return } - // Check if this request matches our question if (!requestMatchesQuestion(provider, req.body, question)) { return } @@ -571,14 +814,10 @@ function createToolCallFlow(config) { }).as("toolCallRequest") }, - /** - * Waits for all tool call requests to complete and streaming to finish - */ waitForCompletion() { for (let i = 0; i < totalRequests; i++) { cy.wait("@toolCallRequest") } - // Wait for streaming to finish if streaming is enabled if (streaming) { cy.waitForStreamingComplete() } @@ -591,61 +830,27 @@ function createToolCallFlow(config) { // ============================================================================= /** - * Creates a multi-turn conversation flow for testing multiple questions/responses - * in the same chat session. - * - * @param {Object} config - Flow configuration - * @param {"openai" | "anthropic"} [config.provider="openai"] - The AI provider - * @param {boolean} [config.streaming=true] - Whether to use streaming responses - * @param {Array} config.turns - Array of turn definitions - * @param {string} config.turns[].explanation - The AI's explanation text - * @param {string|null} config.turns[].sql - Optional SQL query suggestion - * @param {Object} [config.turns[].expectSystemMessage] - Expected content in system message - * @param {string[]} [config.turns[].expectSystemMessage.includes] - Strings that must appear - * @param {string[]} [config.turns[].expectSystemMessage.excludes] - Strings that must NOT appear - * @returns {Object} Flow controller with intercept(), waitForTurn(), and getRequestBody() methods - * - * @example - * // OpenAI streaming (default) - * const flow = createMultiTurnFlow({ - * turns: [ - * { explanation: "First query.", sql: "SELECT 1;" }, - * { explanation: "Second query.", sql: "SELECT 2;" } - * ] - * }) + * Creates a multi-turn conversation flow. + * Supports built-in providers and custom providers. * - * @example - * // Anthropic streaming - * const flow = createMultiTurnFlow({ - * provider: "anthropic", - * turns: [ - * { explanation: "First response.", sql: null }, - * { - * explanation: "Second response.", - * sql: "SELECT * FROM users;", - * expectSystemMessage: { - * includes: ["User accepted the suggested SQL"], - * excludes: ["User rejected"] - * } - * } - * ] - * }) + * @param {Object} config + * @param {"openai" | "anthropic" | "openai-chat-completions"} [config.provider="openai"] + * @param {boolean} [config.streaming=true] + * @param {Array} config.turns + * @param {string} [config.endpoint] - Custom endpoint URL */ function createMultiTurnFlow(config) { const { provider = "openai", streaming = true, turns } = config let requestCount = 0 const requestBodies = [] - const endpoint = PROVIDERS[provider].endpoint + const endpoint = config.endpoint || PROVIDERS[provider]?.endpoint const responseOptions = { streaming } return { provider, streaming, - /** - * Sets up cy.intercept for chat title and all conversation turns - */ intercept() { // Intercept for chat title generation cy.intercept("POST", endpoint, (req) => { @@ -656,22 +861,17 @@ function createMultiTurnFlow(config) { // Intercept for conversation turns cy.intercept("POST", endpoint, (req) => { - // Skip title requests if (isTitleRequest(provider, req.body)) { return } - // Handle conversation turns const turn = turns[requestCount] if (turn) { - // Store the request body for later assertions requestBodies[requestCount] = req.body - // Verify system message expectations if defined if (turn.expectSystemMessage) { const allInputContent = extractAllInputContent(provider, req.body) - // Check includes if (turn.expectSystemMessage.includes) { for (const expected of turn.expectSystemMessage.includes) { expect(allInputContent).to.include( @@ -681,7 +881,6 @@ function createMultiTurnFlow(config) { } } - // Check excludes if (turn.expectSystemMessage.excludes) { for (const excluded of turn.expectSystemMessage.excludes) { expect(allInputContent).to.not.include( @@ -705,11 +904,6 @@ function createMultiTurnFlow(config) { }).as("multiTurnRequest") }, - /** - * Waits for a specific turn to complete and streaming to finish - * @param {number} turnIndex - The turn index (0-based) - * @returns {Cypress.Chainable} Chainable that resolves when the turn is complete - */ waitForTurn(turnIndex) { return cy .wrap(null) @@ -729,10 +923,6 @@ function createMultiTurnFlow(config) { }) }, - /** - * Waits for all turns to complete - * @returns {Cypress.Chainable} Chainable that yields all request bodies - */ waitForAllTurns() { return cy .wrap(null) @@ -750,20 +940,10 @@ function createMultiTurnFlow(config) { }) }, - /** - * Gets the captured request body for a specific turn. - * Must be called inside cy.then() after waitForTurn() - * @param {number} turnIndex - The turn index (0-based) - * @returns {Object} The request body sent for that turn - */ getRequestBody(turnIndex) { return requestBodies[turnIndex] }, - /** - * Gets all captured request bodies - * @returns {Array} Array of request bodies - */ getAllRequestBodies() { return requestBodies }, @@ -772,11 +952,15 @@ function createMultiTurnFlow(config) { module.exports = { PROVIDERS, + CUSTOM_PROVIDER_DEFAULTS, getOpenAIConfiguredSettings, getAnthropicConfiguredSettings, + getCustomProviderConfiguredSettings, + getCustomProviderEndpoint, createFinalResponseData, createResponse, createFinalResponse, + createToolCallResponse, createChatTitleResponse, createToolCallFlow, createMultiTurnFlow, diff --git a/src/components/SetupAIAssistant/ConfigurationModal.tsx b/src/components/SetupAIAssistant/ConfigurationModal.tsx index 92a75ff0e..b380c7a31 100644 --- a/src/components/SetupAIAssistant/ConfigurationModal.tsx +++ b/src/components/SetupAIAssistant/ConfigurationModal.tsx @@ -516,7 +516,11 @@ const StepOneContent = ({ /> {getProviderName("anthropic")} - + Custom diff --git a/src/components/SetupAIAssistant/CustomProviderModal.tsx b/src/components/SetupAIAssistant/CustomProviderModal.tsx index 99395d958..d472d2105 100644 --- a/src/components/SetupAIAssistant/CustomProviderModal.tsx +++ b/src/components/SetupAIAssistant/CustomProviderModal.tsx @@ -379,6 +379,7 @@ const SchemaAccessToggle = ({ onChange(e.target.checked)} /> @@ -445,6 +446,7 @@ const StepOneContent = ({ Provider Name onNameChange(e.target.value)} @@ -454,6 +456,7 @@ const StepOneContent = ({ Provider Type @@ -482,6 +485,7 @@ const StepOneContent = ({ Base URL onBaseURLChange(e.target.value)} @@ -494,6 +498,7 @@ const StepOneContent = ({ API Key onApiKeyChange(e.target.value)} @@ -566,17 +571,25 @@ const StepTwoAutoContent = ({ > Select Models - + Select All - + Deselect All {fetchedModels.map((model) => ( - + onToggleModel(model)} @@ -590,6 +603,7 @@ const StepTwoAutoContent = ({ Don't see your model? Add it manually: onManualModelInputChange(e.target.value)} @@ -602,6 +616,7 @@ const StepTwoAutoContent = ({ }} /> !fetchedModels.includes(m)) .map((model) => ( - + {model} onToggleModel(model)} > @@ -634,6 +650,7 @@ const StepTwoAutoContent = ({ Context Window onContextWindowChange(Number(e.target.value))} @@ -698,7 +715,7 @@ const StepTwoManualContent = ({ - + Could not fetch models automatically from this provider. Please @@ -709,6 +726,7 @@ const StepTwoManualContent = ({ Add Models onManualModelInputChange(e.target.value)} @@ -721,6 +739,7 @@ const StepTwoManualContent = ({ }} /> 0 && ( {manualModels.map((model) => ( - + {model} onRemoveManualModel(model)} title={`Remove ${model}`} @@ -751,6 +771,7 @@ const StepTwoManualContent = ({ Context Window onContextWindowChange(Number(e.target.value))} @@ -799,7 +820,7 @@ export const CustomProviderModal = ({ const [baseURL, setBaseURL] = useState("") const [apiKey, setApiKey] = useState("") - const [contextWindow, setContextWindow] = useState(128_000) + const [contextWindow, setContextWindow] = useState(200_000) const [fetchedModels, setFetchedModels] = useState(null) const [selectedModels, setSelectedModels] = useState([]) const [manualModels, setManualModels] = useState([]) diff --git a/src/components/SetupAIAssistant/SettingsModal.tsx b/src/components/SetupAIAssistant/SettingsModal.tsx index af59f672c..6efda6484 100644 --- a/src/components/SetupAIAssistant/SettingsModal.tsx +++ b/src/components/SetupAIAssistant/SettingsModal.tsx @@ -1018,6 +1018,7 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { })} { setCustomProviderModalOpen(true) }} @@ -1270,6 +1271,7 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { skin="error" prefixIcon={} type="button" + data-hook="ai-settings-remove-provider" onClick={() => handleRemoveProvider(selectedProvider)} > {isCustomProvider ? "Remove Provider" : "Reset Provider"} From 868813c6f47d1bd1990ff5bbee0d457285b6e083 Mon Sep 17 00:00:00 2001 From: emrberk Date: Fri, 13 Mar 2026 17:20:19 +0300 Subject: [PATCH 21/25] manage models after setup --- src/components/Overlay/index.tsx | 17 +- .../SetupAIAssistant/CustomProviderModal.tsx | 782 ++---------------- .../SetupAIAssistant/ManageModelsModal.tsx | 161 ++++ .../SetupAIAssistant/ModelSettings.tsx | 668 +++++++++++++++ .../SetupAIAssistant/SettingsModal.tsx | 382 ++++++--- 5 files changed, 1144 insertions(+), 866 deletions(-) create mode 100644 src/components/SetupAIAssistant/ManageModelsModal.tsx create mode 100644 src/components/SetupAIAssistant/ModelSettings.tsx diff --git a/src/components/Overlay/index.tsx b/src/components/Overlay/index.tsx index b18cfc8d4..b54cd4076 100644 --- a/src/components/Overlay/index.tsx +++ b/src/components/Overlay/index.tsx @@ -43,10 +43,13 @@ const StyledOverlay = styled.div` } ` -export const Overlay = ({ - primitive, -}: { - primitive: typeof RadixDialogOverlay | typeof RadixAlertDialogOverlay -}) => { - return -} +export const Overlay = React.forwardRef< + HTMLDivElement, + { + primitive: typeof RadixDialogOverlay | typeof RadixAlertDialogOverlay + } +>(({ primitive }, ref) => { + return +}) + +Overlay.displayName = "Overlay" diff --git a/src/components/SetupAIAssistant/CustomProviderModal.tsx b/src/components/SetupAIAssistant/CustomProviderModal.tsx index d472d2105..182a65562 100644 --- a/src/components/SetupAIAssistant/CustomProviderModal.tsx +++ b/src/components/SetupAIAssistant/CustomProviderModal.tsx @@ -1,20 +1,24 @@ -import React, { useState, useMemo, useCallback, useEffect, useRef } from "react" -import styled, { useTheme } from "styled-components" +import React, { useState, useMemo, useCallback, useRef } from "react" +import styled from "styled-components" import { MultiStepModal } from "../MultiStepModal" import type { Step } from "../MultiStepModal" import { useModalNavigation } from "../MultiStepModal" import { Box } from "../Box" -import { Input } from "../Input" -import { Checkbox } from "../Checkbox" -import { Text } from "../Text" import { Dialog } from "../Dialog" -import { createProviderByType } from "../../utils/ai/registry" import type { ProviderType, CustomProviderDefinition, } from "../../utils/ai/settings" import { Select } from "../Select" -import { WarningIcon, XIcon } from "@phosphor-icons/react" +import { toast } from "../Toast" +import { + ModelSettings, + InputSection, + InputLabel, + StyledInput, + HelperText, +} from "./ModelSettings" +import type { ModelSettingsRef } from "./ModelSettings" const ModalContent = styled.div` display: flex; @@ -95,34 +99,6 @@ const ContentSection = styled(Box).attrs({ width: 100%; ` -const InputSection = styled(Box).attrs({ - flexDirection: "column", - gap: "1.2rem", -})` - width: 100%; -` - -const InputLabel = styled(Text)` - font-size: 1.6rem; - font-weight: 600; - color: ${({ theme }) => theme.color.gray2}; -` - -const StyledInput = styled(Input)<{ $hasError?: boolean }>` - width: 100%; - background: #262833; - border: 0.1rem solid - ${({ theme, $hasError }) => ($hasError ? theme.color.red : "#6b7280")}; - border-radius: 0.8rem; - font-size: 1.4rem; - min-height: 3rem; - - &::placeholder { - color: ${({ theme }) => theme.color.gray2}; - font-family: inherit; - } -` - const PasswordInput = styled(StyledInput)` text-security: disc; -webkit-text-security: disc; @@ -150,200 +126,6 @@ const StyledSelect = styled(Select)` } ` -const HelperText = styled(Text)` - font-size: 1.3rem; - font-weight: 300; - color: ${({ theme }) => theme.color.gray2}; -` - -const WarningBanner = styled(Box).attrs({ - flexDirection: "row", - gap: "0.6rem", - align: "center", -})` - width: 100%; - background: rgba(255, 165, 0, 0.08); - border: 0.1rem solid ${({ theme }) => theme.color.orange}; - border-radius: 0.8rem; - padding: 0.75rem; -` - -const WarningText = styled(Text)` - font-size: 1.3rem; - color: ${({ theme }) => theme.color.orange}; -` - -const ModelListContainer = styled.div` - max-height: 30rem; - overflow-y: auto; - display: flex; - flex-direction: column; - gap: 0.25rem; - border: 0.1rem solid #6b7280; - border-radius: 0.4rem; - width: 100%; -` - -const ModelRow = styled.label` - display: flex; - align-items: center; - gap: 0.8rem; - padding: 0.6rem 0.8rem; - cursor: pointer; - font-size: 1.4rem; - color: ${({ theme }) => theme.color.foreground}; - - &:hover { - background: ${({ theme }) => theme.color.selection}; - } -` - -const ModelChipsContainer = styled.div` - display: flex; - flex-wrap: wrap; - gap: 0.6rem; -` - -const ModelChip = styled.div` - display: inline-flex; - align-items: center; - gap: 0.5rem; - background: ${({ theme }) => theme.color.selection}; - border-radius: 0.4rem; - padding: 0.4rem 0.8rem; - font-size: 1.3rem; - color: ${({ theme }) => theme.color.foreground}; -` - -const ChipRemoveButton = styled.button` - background: none; - border: none; - cursor: pointer; - padding: 0; - display: flex; - justify-content: center; - align-items: center; - color: ${({ theme }) => theme.color.gray2}; - - &:hover { - color: ${({ theme }) => theme.color.foreground}; - } -` - -const AddModelRow = styled(Box).attrs({ - gap: "0.8rem", - align: "center", -})` - width: 100%; -` - -const AddModelButton = styled.button` - height: 3rem; - border: 0.1rem solid ${({ theme }) => theme.color.pinkDarker}; - background: ${({ theme }) => theme.color.background}; - color: ${({ theme }) => theme.color.foreground}; - border-radius: 0.4rem; - padding: 0 1.2rem; - font-size: 1.4rem; - font-weight: 500; - cursor: pointer; - white-space: nowrap; - - &:hover:not(:disabled) { - background: ${({ theme }) => theme.color.pinkDarker}; - } - - &:disabled { - opacity: 0.6; - cursor: not-allowed; - } -` - -const SelectAllRow = styled(Box).attrs({ - gap: "2rem", - align: "center", -})` - display: inline-flex; - margin-left: auto; -` - -const SelectAllLink = styled.button` - background: none; - border: none; - cursor: pointer; - color: ${({ theme }) => theme.color.cyan}; - font-size: 1.4rem; - padding: 0; - - &:hover { - text-decoration: underline; - } -` - -const SchemaAccessSection = styled(Box).attrs({ - flexDirection: "column", - gap: "1.6rem", - align: "flex-start", -})` - width: 100%; -` - -const SchemaAccessTitle = styled(Text)` - font-size: 1.6rem; - font-weight: 600; - color: ${({ theme }) => theme.color.gray2}; - flex: 1; -` - -const SchemaCheckboxContainer = styled(Box).attrs({ - gap: "1.5rem", - align: "flex-start", -})` - background: rgba(68, 71, 90, 0.56); - padding: 0.75rem; - border-radius: 0.4rem; - width: 100%; -` - -const SchemaCheckboxInner = styled(Box).attrs({ - gap: "1.5rem", - align: "center", -})` - flex: 1; - padding: 0.75rem; - border-radius: 0.5rem; -` - -const SchemaCheckboxWrapper = styled.div` - flex-shrink: 0; - display: flex; - align-items: center; -` - -const SchemaCheckboxContent = styled(Box).attrs({ - flexDirection: "column", - gap: "0.6rem", -})` - flex: 1; -` - -const SchemaCheckboxLabel = styled(Text)` - font-size: 1.4rem; - font-weight: 500; - color: ${({ theme }) => theme.color.foreground}; -` - -const SchemaCheckboxDescription = styled(Text)` - font-size: 1.3rem; - font-weight: 400; - color: ${({ theme }) => theme.color.gray2}; -` - -const SchemaCheckboxDescriptionBold = styled.span` - font-weight: 500; - color: ${({ theme }) => theme.color.foreground}; -` - const CloseButton = ({ onClick }: { onClick: () => void }) => ( void }) => ( ) -const SchemaAccessToggle = ({ - checked, - onChange, - providerName, -}: { - checked: boolean - onChange: (checked: boolean) => void - providerName: string -}) => ( - - Schema Access - - - - onChange(e.target.checked)} - /> - - - - Grant schema access to {providerName} - - - When enabled, the AI assistant can access your database schema - information to provide more accurate suggestions and explanations. - Schema information helps the AI understand your table structures, - column names, and relationships.{" "} - - The AI model will not have access to your data. - - - - - - -) - type StepOneProps = { name: string providerType: ProviderType @@ -514,283 +257,21 @@ const StepOneContent = ({ ) } -type StepTwoAutoProps = { - fetchedModels: string[] - selectedModels: string[] - contextWindow: number - grantSchemaAccess: boolean - providerName: string - manualModelInput: string - onToggleModel: (model: string) => void - onSelectAll: () => void - onDeselectAll: () => void - onContextWindowChange: (v: number) => void - onSchemaAccessChange: (v: boolean) => void - onManualModelInputChange: (v: string) => void - onAddManualModel: () => void -} - -const StepTwoAutoContent = ({ - fetchedModels, - selectedModels, - contextWindow, - grantSchemaAccess, - providerName, - manualModelInput, - onToggleModel, - onSelectAll, - onDeselectAll, - onContextWindowChange, - onSchemaAccessChange, - onManualModelInputChange, - onAddManualModel, -}: StepTwoAutoProps) => { - const navigation = useModalNavigation() - - return ( - - - - - Configure Settings - - Configure the settings for your custom provider. - - - - - - - - - - Select Models - - - Select All - - - Deselect All - - - - - {fetchedModels.map((model) => ( - - onToggleModel(model)} - /> - {model} - - ))} - - - - Don't see your model? Add it manually: - - onManualModelInputChange(e.target.value)} - placeholder="e.g., llama3" - onKeyDown={(e) => { - if (e.key === "Enter") { - e.preventDefault() - onAddManualModel() - } - }} - /> - - Add - - - {selectedModels.filter((m) => !fetchedModels.includes(m)).length > - 0 && ( - - {selectedModels - .filter((m) => !fetchedModels.includes(m)) - .map((model) => ( - - {model} - onToggleModel(model)} - > - - - - ))} - - )} - - - - - - Context Window - onContextWindowChange(Number(e.target.value))} - /> - - Maximum number of tokens the model can process. AI assistant - requires a minimum of 100,000 tokens. - - - - - - - - ) -} - -type StepTwoManualProps = { - manualModels: string[] - manualModelInput: string - contextWindow: number - grantSchemaAccess: boolean - providerName: string - onManualModelInputChange: (v: string) => void - onAddManualModel: () => void - onRemoveManualModel: (model: string) => void - onContextWindowChange: (v: number) => void - onSchemaAccessChange: (v: boolean) => void -} - -const StepTwoManualContent = ({ - manualModels, - manualModelInput, - contextWindow, - grantSchemaAccess, - providerName, - onManualModelInputChange, - onAddManualModel, - onRemoveManualModel, - onContextWindowChange, - onSchemaAccessChange, -}: StepTwoManualProps) => { - const theme = useTheme() +const StepTwoHeader = () => { const navigation = useModalNavigation() return ( - - - - - Models & Settings - - Configure the models and settings for your custom provider. - - - - - - - - - - - Could not fetch models automatically from this provider. Please - enter model IDs manually. - - - - Add Models - - onManualModelInputChange(e.target.value)} - placeholder="e.g., llama3, gpt-4o, claude-sonnet-4-20250514" - onKeyDown={(e) => { - if (e.key === "Enter") { - e.preventDefault() - onAddManualModel() - } - }} - /> - - Add - - - {manualModels.length > 0 && ( - - {manualModels.map((model) => ( - - {model} - onRemoveManualModel(model)} - title={`Remove ${model}`} - > - - - - ))} - - )} - - - - - - Context Window - onContextWindowChange(Number(e.target.value))} - /> - - Maximum number of tokens the model can process. AI assistant - requires a minimum of 100,000 tokens. - - - - - - - - + + + + Models & Settings + + Configure the models and settings for your custom provider. + + + + + ) } @@ -820,66 +301,9 @@ export const CustomProviderModal = ({ const [baseURL, setBaseURL] = useState("") const [apiKey, setApiKey] = useState("") - const [contextWindow, setContextWindow] = useState(200_000) - const [fetchedModels, setFetchedModels] = useState(null) - const [selectedModels, setSelectedModels] = useState([]) - const [manualModels, setManualModels] = useState([]) - const [manualModelInput, setManualModelInput] = useState("") - - const [flowPath, setFlowPath] = useState<"auto" | "manual">("manual") - - const [grantSchemaAccess, setGrantSchemaAccess] = useState(false) - - const abortControllerRef = useRef(null) + const modelSettingsRef = useRef(null) - useEffect(() => { - return () => { - abortControllerRef.current?.abort() - } - }, []) - - const handleToggleModel = useCallback((model: string) => { - setSelectedModels((prev) => - prev.includes(model) ? prev.filter((m) => m !== model) : [...prev, model], - ) - }, []) - - const handleSelectAll = useCallback(() => { - if (fetchedModels) { - setSelectedModels((prev) => { - const manual = prev.filter((m) => !fetchedModels.includes(m)) - return [...fetchedModels, ...manual] - }) - } - }, [fetchedModels]) - - const handleDeselectAll = useCallback(() => { - setSelectedModels((prev) => - fetchedModels ? prev.filter((m) => !fetchedModels.includes(m)) : [], - ) - }, [fetchedModels]) - - const handleAddManualModel = useCallback(() => { - const trimmed = manualModelInput.trim() - if (!trimmed) return - - if (flowPath === "auto") { - if (selectedModels.includes(trimmed)) return - setSelectedModels((prev) => [...prev, trimmed]) - } else { - if (manualModels.includes(trimmed)) return - setManualModels((prev) => [...prev, trimmed]) - } - setManualModelInput("") - }, [manualModelInput, flowPath, selectedModels, manualModels]) - - const handleRemoveManualModel = useCallback((model: string) => { - setManualModels((prev) => prev.filter((m) => m !== model)) - }, []) - - const connectionValidate = useCallback(async (): Promise< - string | boolean - > => { + const connectionValidate = useCallback((): string | boolean => { if (!name.trim()) return "Provider name is required" if (!baseURL.trim()) return "Base URL is required" if (!baseURL.startsWith("http://") && !baseURL.startsWith("https://")) @@ -887,87 +311,33 @@ export const CustomProviderModal = ({ const providerId = generateProviderId(name) if (existingProviderIds.includes(providerId)) - return `A provider with a similar name already exists` - - // Try to fetch models, fallback to manual entry - try { - const tempProvider = createProviderByType( - providerType, - "temp", - apiKey || "", - { baseURL, contextWindow, isCustom: true }, - ) - const models = await tempProvider.listModels() - if (models && models.length > 0) { - setFetchedModels(models) - setFlowPath("auto") - } else { - setFetchedModels(null) - setFlowPath("manual") - } - } catch { - setFetchedModels(null) - setFlowPath("manual") - } + return "A provider with a similar name already exists" return true - }, [name, baseURL, providerType, apiKey, contextWindow, existingProviderIds]) + }, [name, baseURL, existingProviderIds]) const modelsValidate = useCallback((): string | boolean => { - if (flowPath === "auto") { - if (selectedModels.length === 0) return "Select at least one model" - } else { - if (manualModels.length === 0 && !manualModelInput.trim()) - return "Add at least one model" - } - if (!contextWindow || contextWindow < 100_000) - return "Context window must be at least 100,000 tokens" - return true - }, [flowPath, selectedModels, manualModels, manualModelInput, contextWindow]) + return modelSettingsRef.current?.validate() ?? "Not ready" + }, []) const handleComplete = useCallback(() => { const providerId = generateProviderId(name) - - // Auto-add any pending manual model input - const pendingModel = manualModelInput.trim() - let models: string[] - - if (flowPath === "auto") { - models = - pendingModel && !selectedModels.includes(pendingModel) - ? [...selectedModels, pendingModel] - : selectedModels - } else { - models = - pendingModel && !manualModels.includes(pendingModel) - ? [...manualModels, pendingModel] - : manualModels - } + const values = modelSettingsRef.current?.getValues() + if (!values) return const definition: CustomProviderDefinition = { type: providerType, name: name.trim(), baseURL: baseURL.trim(), apiKey: apiKey || undefined, - contextWindow, - models, - grantSchemaAccess, + contextWindow: values.contextWindow, + models: values.models, + grantSchemaAccess: values.grantSchemaAccess, } onSave(providerId, definition) - }, [ - name, - flowPath, - selectedModels, - manualModels, - manualModelInput, - providerType, - baseURL, - apiKey, - contextWindow, - grantSchemaAccess, - onSave, - ]) + toast.success(`Added custom provider ${name.trim()}.`) + }, [name, providerType, baseURL, apiKey, onSave]) const steps: Step[] = useMemo(() => { const connectionStep: Step = { @@ -989,78 +359,34 @@ export const CustomProviderModal = ({ validate: connectionValidate, } - if (flowPath === "auto" && fetchedModels !== null) { - return [ - connectionStep, - { - id: "select-models", - title: "Add Custom Provider", - stepName: "Models & Settings", - content: ( - - ), - validate: modelsValidate, - }, - ] - } - return [ connectionStep, { - id: "manual-models", + id: "model-settings", title: "Add Custom Provider", stepName: "Models & Settings", content: ( - + + + + + ), validate: modelsValidate, }, ] - }, [ - name, - providerType, - baseURL, - apiKey, - connectionValidate, - flowPath, - fetchedModels, - selectedModels, - contextWindow, - grantSchemaAccess, - manualModelInput, - handleToggleModel, - handleSelectAll, - handleDeselectAll, - handleAddManualModel, - modelsValidate, - manualModels, - handleRemoveManualModel, - ]) + }, [name, providerType, baseURL, apiKey, connectionValidate, modelsValidate]) const canProceed = useCallback( (stepIndex: number): boolean => { diff --git a/src/components/SetupAIAssistant/ManageModelsModal.tsx b/src/components/SetupAIAssistant/ManageModelsModal.tsx new file mode 100644 index 000000000..b6de4da0a --- /dev/null +++ b/src/components/SetupAIAssistant/ManageModelsModal.tsx @@ -0,0 +1,161 @@ +import React, { useState, useCallback, useRef } from "react" +import styled from "styled-components" +import * as RadixDialog from "@radix-ui/react-dialog" +import { Dialog } from "../Dialog" +import { Box } from "../Box" +import { Text } from "../Text" +import { Button } from "../Button" +import { Overlay } from "../Overlay" +import type { CustomProviderDefinition } from "../../utils/ai/settings" +import { ModelSettings } from "./ModelSettings" +import type { ModelSettingsRef } from "./ModelSettings" + +const ModalContent = styled.div` + display: flex; + flex-direction: column; + width: 100%; + overflow-y: auto; +` + +const HeaderSection = styled(Box).attrs({ + flexDirection: "column", + gap: "1.2rem", + align: "flex-start", +})` + padding: 2rem 2.4rem; +` + +const ModalTitle = styled(Dialog.Title)` + font-size: 2.4rem; + font-weight: 600; + margin: 0; + padding: 0; + color: ${({ theme }) => theme.color.foreground}; + border: 0; +` + +const ModalSubtitle = styled(RadixDialog.Description)` + font-size: 1.4rem; + color: ${({ theme }) => theme.color.gray2}; + margin: 0; +` + +const Separator = styled.div` + height: 0.1rem; + width: 100%; + background: ${({ theme }) => theme.color.selection}; +` + +const FooterSection = styled(Box).attrs({ + justifyContent: "flex-end", + align: "center", + gap: "1.2rem", +})` + padding: 2rem 2.4rem; + width: 100%; +` + +const FooterButton = styled(Button)` + padding: 1.1rem 1.2rem; + font-size: 1.4rem; + font-weight: 500; + height: 4rem; + min-width: 12rem; +` + +const ErrorText = styled(Text)` + font-size: 1.3rem; + color: ${({ theme }) => theme.color.red}; +` + +type ManageModelsModalProps = { + open: boolean + onOpenChange: (open: boolean) => void + providerId: string + definition: CustomProviderDefinition + onSave: (providerId: string, definition: CustomProviderDefinition) => void +} + +export const ManageModelsModal = ({ + open, + onOpenChange, + providerId, + definition, + onSave, +}: ManageModelsModalProps) => { + const [error, setError] = useState(null) + const [modelsLoading, setModelsLoading] = useState(true) + const modelSettingsRef = useRef(null) + + const handleSave = useCallback(() => { + setError(null) + const result = modelSettingsRef.current?.validate() + if (typeof result === "string") { + setError(result) + return + } + const values = modelSettingsRef.current?.getValues() + if (!values) return + onSave(providerId, { + ...definition, + models: values.models, + contextWindow: values.contextWindow, + }) + onOpenChange(false) + }, [definition, providerId, onSave, onOpenChange]) + + return ( + + + + + + + Manage Models + + Add or remove models and update the context window for{" "} + {definition.name}. + + + + {open && ( + + )} + + + {error && {error}} + onOpenChange(false)} + > + Cancel + + + Save + + + + + + + ) +} diff --git a/src/components/SetupAIAssistant/ModelSettings.tsx b/src/components/SetupAIAssistant/ModelSettings.tsx new file mode 100644 index 000000000..f91ef8f17 --- /dev/null +++ b/src/components/SetupAIAssistant/ModelSettings.tsx @@ -0,0 +1,668 @@ +import React, { + useState, + useCallback, + useEffect, + useRef, + useImperativeHandle, + forwardRef, +} from "react" +import styled, { useTheme } from "styled-components" +import { Box } from "../Box" +import { Input } from "../Input" +import { Checkbox } from "../Checkbox" +import { Text } from "../Text" +import { LoadingSpinner } from "../LoadingSpinner" +import { WarningIcon, XIcon } from "@phosphor-icons/react" +import { createProviderByType } from "../../utils/ai/registry" +import type { ProviderType } from "../../utils/ai/settings" + +export const InputSection = styled(Box).attrs({ + flexDirection: "column", + gap: "1.2rem", +})` + width: 100%; +` + +export const InputLabel = styled(Text)` + font-size: 1.6rem; + font-weight: 600; + color: ${({ theme }) => theme.color.gray2}; +` + +export const StyledInput = styled(Input)<{ $hasError?: boolean }>` + width: 100%; + background: #262833; + border: 0.1rem solid + ${({ theme, $hasError }) => ($hasError ? theme.color.red : "#6b7280")}; + border-radius: 0.8rem; + font-size: 1.4rem; + min-height: 3rem; + + &::placeholder { + color: ${({ theme }) => theme.color.gray2}; + font-family: inherit; + } +` + +export const HelperText = styled(Text)` + font-size: 1.3rem; + font-weight: 300; + color: ${({ theme }) => theme.color.gray2}; +` + +const WarningBanner = styled(Box).attrs({ + flexDirection: "row", + gap: "0.6rem", + align: "center", +})` + width: 100%; + background: rgba(255, 165, 0, 0.08); + border: 0.1rem solid ${({ theme }) => theme.color.orange}; + border-radius: 0.8rem; + padding: 0.75rem; +` + +const WarningText = styled(Text)` + font-size: 1.3rem; + color: ${({ theme }) => theme.color.orange}; +` + +const ModelListContainer = styled.div` + max-height: 30rem; + overflow-y: auto; + display: flex; + flex-direction: column; + gap: 0.25rem; + border: 0.1rem solid #6b7280; + border-radius: 0.4rem; + width: 100%; +` + +const ModelRow = styled.label` + display: flex; + align-items: center; + gap: 0.8rem; + padding: 0.6rem 0.8rem; + cursor: pointer; + font-size: 1.4rem; + color: ${({ theme }) => theme.color.foreground}; + + &:hover { + background: ${({ theme }) => theme.color.selection}; + } +` + +const ModelChipsContainer = styled.div` + display: flex; + flex-wrap: wrap; + gap: 0.6rem; +` + +const ModelChip = styled.div` + display: inline-flex; + align-items: center; + gap: 0.5rem; + background: ${({ theme }) => theme.color.selection}; + border-radius: 0.4rem; + padding: 0.4rem 0.8rem; + font-size: 1.3rem; + color: ${({ theme }) => theme.color.foreground}; +` + +const ChipRemoveButton = styled.button` + background: none; + border: none; + cursor: pointer; + padding: 0; + display: flex; + justify-content: center; + align-items: center; + color: ${({ theme }) => theme.color.gray2}; + + &:hover { + color: ${({ theme }) => theme.color.foreground}; + } +` + +const AddModelRow = styled(Box).attrs({ + gap: "0.8rem", + align: "center", +})` + width: 100%; +` + +const AddModelButton = styled.button` + height: 3rem; + border: 0.1rem solid ${({ theme }) => theme.color.pinkDarker}; + background: ${({ theme }) => theme.color.background}; + color: ${({ theme }) => theme.color.foreground}; + border-radius: 0.4rem; + padding: 0 1.2rem; + font-size: 1.4rem; + font-weight: 500; + cursor: pointer; + white-space: nowrap; + + &:hover:not(:disabled) { + background: ${({ theme }) => theme.color.pinkDarker}; + } + + &:disabled { + opacity: 0.6; + cursor: not-allowed; + } +` + +const SelectAllRow = styled(Box).attrs({ + gap: "2rem", + align: "center", +})` + display: inline-flex; + margin-left: auto; +` + +const SelectAllLink = styled.button` + background: none; + border: none; + cursor: pointer; + color: ${({ theme }) => theme.color.cyan}; + font-size: 1.4rem; + padding: 0; + + &:hover { + text-decoration: underline; + } +` + +const SchemaAccessSection = styled(Box).attrs({ + flexDirection: "column", + gap: "1.6rem", + align: "flex-start", +})` + width: 100%; +` + +const SchemaAccessTitle = styled(Text)` + font-size: 1.6rem; + font-weight: 600; + color: ${({ theme }) => theme.color.gray2}; + flex: 1; +` + +const SchemaCheckboxContainer = styled(Box).attrs({ + gap: "1.5rem", + align: "flex-start", +})` + background: rgba(68, 71, 90, 0.56); + padding: 0.75rem; + border-radius: 0.4rem; + width: 100%; +` + +const SchemaCheckboxInner = styled(Box).attrs({ + gap: "1.5rem", + align: "center", +})` + flex: 1; + padding: 0.75rem; + border-radius: 0.5rem; +` + +const SchemaCheckboxWrapper = styled.div` + flex-shrink: 0; + display: flex; + align-items: center; +` + +const SchemaCheckboxContent = styled(Box).attrs({ + flexDirection: "column", + gap: "0.6rem", +})` + flex: 1; +` + +const SchemaCheckboxLabel = styled(Text)` + font-size: 1.4rem; + font-weight: 500; + color: ${({ theme }) => theme.color.foreground}; +` + +const SchemaCheckboxDescription = styled(Text)` + font-size: 1.3rem; + font-weight: 400; + color: ${({ theme }) => theme.color.gray2}; +` + +const SchemaCheckboxDescriptionBold = styled.span` + font-weight: 500; + color: ${({ theme }) => theme.color.foreground}; +` + +const ContentSection = styled(Box).attrs({ + flexDirection: "column", + gap: "2rem", +})` + padding: 2.4rem; + width: 100%; +` + +const Separator = styled.div` + height: 0.1rem; + width: 100%; + background: ${({ theme }) => theme.color.selection}; +` + +const LoadingContainer = styled(Box).attrs({ + align: "center", + justifyContent: "center", +})` + width: 100%; + padding: 4rem 0; +` + +// --- Types --- + +export type FetchConfig = { + providerType: ProviderType + providerId: string + apiKey: string + baseURL: string +} + +export type ModelSettingsInitialValues = { + models?: string[] + contextWindow?: number + grantSchemaAccess?: boolean +} + +export type ModelSettingsData = { + models: string[] + contextWindow: number + grantSchemaAccess: boolean +} + +export type ModelSettingsRef = { + getValues: () => ModelSettingsData + validate: () => string | true +} + +export type ModelSettingsProps = { + initialValues?: ModelSettingsInitialValues + fetchConfig: FetchConfig + renderSchemaAccess?: boolean + providerName?: string + onLoadingChange?: (loading: boolean) => void +} + +// --- Utility --- + +async function fetchProviderModels( + config: FetchConfig, + contextWindow: number, +): Promise { + try { + const provider = createProviderByType( + config.providerType, + config.providerId, + config.apiKey, + { baseURL: config.baseURL, contextWindow, isCustom: true }, + ) + const models = await provider.listModels() + return models && models.length > 0 ? models : null + } catch { + return null + } +} + +// --- Component --- + +export const ModelSettings = forwardRef( + ( + { + initialValues, + fetchConfig, + renderSchemaAccess, + providerName, + onLoadingChange, + }, + ref, + ) => { + const theme = useTheme() + + const [fetchedModels, setFetchedModels] = useState(null) + const [selectedModels, setSelectedModels] = useState([]) + const [manualModels, setManualModels] = useState( + () => initialValues?.models ?? [], + ) + const [manualModelInput, setManualModelInput] = useState("") + const [contextWindowInput, setContextWindowInput] = useState(() => + String(initialValues?.contextWindow ?? 128_000), + ) + const [grantSchemaAccess, setGrantSchemaAccess] = useState( + () => initialValues?.grantSchemaAccess ?? true, + ) + const [isLoading, setIsLoading] = useState(true) + + const fetchConfigRef = useRef(fetchConfig) + fetchConfigRef.current = fetchConfig + const initialValuesRef = useRef(initialValues) + initialValuesRef.current = initialValues + + // Fetch models on mount + useEffect(() => { + let cancelled = false + + const doFetch = async () => { + setIsLoading(true) + const config = fetchConfigRef.current + const initModels = initialValuesRef.current?.models ?? [] + const initContextWindow = + initialValuesRef.current?.contextWindow ?? 128_000 + + const models = await fetchProviderModels(config, initContextWindow) + + if (cancelled) return + + if (models) { + // Auto mode: reconcile initialValues.models against fetched list + setFetchedModels(models) + const selected = [ + ...initModels.filter((m) => models.includes(m)), + ...initModels.filter((m) => !models.includes(m)), + ] + setSelectedModels(selected.length > 0 ? selected : []) + setManualModels([]) + } else { + // Manual mode + setFetchedModels(null) + setSelectedModels([]) + setManualModels([...initModels]) + } + setIsLoading(false) + } + + void doFetch() + return () => { + cancelled = true + } + }, []) + + useEffect(() => { + onLoadingChange?.(isLoading) + }, [isLoading, onLoadingChange]) + + const isAutoMode = fetchedModels !== null + + // --- Handlers --- + + const handleToggleModel = useCallback((model: string) => { + setSelectedModels((prev) => + prev.includes(model) + ? prev.filter((m) => m !== model) + : [...prev, model], + ) + }, []) + + const handleSelectAll = useCallback(() => { + setSelectedModels((prev) => { + if (!fetchedModels) return prev + const manual = prev.filter((m) => !fetchedModels.includes(m)) + return [...fetchedModels, ...manual] + }) + }, [fetchedModels]) + + const handleDeselectAll = useCallback(() => { + setSelectedModels((prev) => + fetchedModels ? prev.filter((m) => !fetchedModels.includes(m)) : [], + ) + }, [fetchedModels]) + + const handleAddManualModel = useCallback(() => { + const trimmed = manualModelInput.trim() + if (!trimmed) return + + if (isAutoMode) { + setSelectedModels((prev) => + prev.includes(trimmed) ? prev : [...prev, trimmed], + ) + } else { + setManualModels((prev) => + prev.includes(trimmed) ? prev : [...prev, trimmed], + ) + } + setManualModelInput("") + }, [manualModelInput, isAutoMode]) + + const handleRemoveManualModel = useCallback((model: string) => { + setManualModels((prev) => prev.filter((m) => m !== model)) + }, []) + + // --- Imperative handle --- + + useImperativeHandle( + ref, + () => ({ + getValues: () => { + const pending = manualModelInput.trim() + let models: string[] + + if (isAutoMode) { + models = + pending && !selectedModels.includes(pending) + ? [...selectedModels, pending] + : [...selectedModels] + } else { + models = + pending && !manualModels.includes(pending) + ? [...manualModels, pending] + : [...manualModels] + } + + const contextWindow = Number(contextWindowInput) || 0 + return { models, contextWindow, grantSchemaAccess } + }, + validate: () => { + const pending = manualModelInput.trim() + const models = isAutoMode ? selectedModels : manualModels + const hasModels = models.length > 0 || !!pending + if (!hasModels) return "Select at least one model" + const trimmed = contextWindowInput.trim() + if (!trimmed) return "Context window is required" + const contextWindow = Number(trimmed) + if (isNaN(contextWindow) || !Number.isInteger(contextWindow)) + return "Context window must be a valid number" + if (contextWindow < 100_000) + return "Context window must be at least 100,000 tokens" + return true + }, + }), + [ + manualModelInput, + isAutoMode, + selectedModels, + manualModels, + contextWindowInput, + grantSchemaAccess, + ], + ) + + // --- Render --- + + if (isLoading) { + return ( + + + + + + ) + } + + return ( + <> + + {!isAutoMode && ( + + + + Could not fetch models automatically from this provider. Please + enter model IDs manually. + + + )} + {isAutoMode && ( + + + Select Models + + + Select All + + + Deselect All + + + + + {fetchedModels.map((model) => ( + + handleToggleModel(model)} + /> + {model} + + ))} + + + )} + + {!isAutoMode && Add Models} + {isAutoMode && ( + + Don't see your model? Add it manually: + + )} + + setManualModelInput(e.target.value)} + placeholder="e.g., llama3, gpt-4o, claude-sonnet-4-20250514" + onKeyDown={(e) => { + if (e.key === "Enter") { + e.preventDefault() + handleAddManualModel() + } + }} + /> + + Add + + + {isAutoMode && + selectedModels.filter((m) => !fetchedModels.includes(m)).length > + 0 && ( + + {selectedModels + .filter((m) => !fetchedModels.includes(m)) + .map((model) => ( + + {model} + handleToggleModel(model)} + > + + + + ))} + + )} + {!isAutoMode && manualModels.length > 0 && ( + + {manualModels.map((model) => ( + + {model} + handleRemoveManualModel(model)} + title={`Remove ${model}`} + > + + + + ))} + + )} + + + + + + Context Window + setContextWindowInput(e.target.value)} + /> + + Maximum number of tokens the model can process. AI assistant + requires a minimum of 100,000 tokens. + + + + {renderSchemaAccess && ( + <> + + + + Schema Access + + + + setGrantSchemaAccess(e.target.checked)} + /> + + + + Grant schema access to {providerName || "this provider"} + + + When enabled, the AI assistant can access your database + schema information to provide more accurate suggestions + and explanations. Schema information helps the AI + understand your table structures, column names, and + relationships.{" "} + + The AI model will not have access to your data. + + + + + + + + + )} + + ) + }, +) + +ModelSettings.displayName = "ModelSettings" diff --git a/src/components/SetupAIAssistant/SettingsModal.tsx b/src/components/SetupAIAssistant/SettingsModal.tsx index 6efda6484..515b9ff19 100644 --- a/src/components/SetupAIAssistant/SettingsModal.tsx +++ b/src/components/SetupAIAssistant/SettingsModal.tsx @@ -38,6 +38,7 @@ import { ForwardRef } from "../ForwardRef" import { Badge, BadgeType } from "../../components/Badge" import { CheckboxCircle } from "@styled-icons/remix-fill" import { CustomProviderModal } from "./CustomProviderModal" +import { ManageModelsModal } from "./ManageModelsModal" const ModalContent = styled.div` display: flex; @@ -179,6 +180,7 @@ const ProviderTabName = styled(Text)<{ $active: boolean }>` font-weight: ${({ $active }) => ($active ? 600 : 400)}; color: ${({ theme, $active }) => $active ? theme.color.foreground : theme.color.gray2}; + text-align: left; ` const StatusBadge = styled(Box).attrs({ @@ -399,6 +401,19 @@ const EnableModelsTitle = styled(Text)` color: ${({ theme }) => theme.color.foreground}; ` +const ManageModelsButton = styled.button` + background: none; + border: none; + cursor: pointer; + color: ${({ theme }) => theme.color.cyan}; + font-size: 1.3rem; + padding: 0; + + &:hover { + text-decoration: underline; + } +` + const SchemaAccessSection = styled(Box).attrs({ flexDirection: "column", gap: "1.6rem", @@ -636,6 +651,7 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { const inputRef = useRef(null) const [customProviderModalOpen, setCustomProviderModalOpen] = useState(false) + const [manageModelsModalOpen, setManageModelsModalOpen] = useState(false) const [localCustomProviders, setLocalCustomProviders] = useState< Record @@ -661,7 +677,7 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { (provider: ProviderId, value: string) => { setApiKeys((prev) => ({ ...prev, [provider]: value })) setValidationErrors((prev) => ({ ...prev, [provider]: null })) - // If API key changes, mark as not validated + if (validatedApiKeys[provider]) { setValidatedApiKeys((prev) => ({ ...prev, [provider]: false })) } @@ -910,12 +926,86 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { [aiAssistantSettings, updateSettings], ) + const handleManageModelsSave = useCallback( + (providerId: string, definition: CustomProviderDefinition) => { + const newModelValues = definition.models.map((m) => + makeCustomModelValue(providerId, m), + ) + + // Update local custom providers — only override models and contextWindow, + // preserve everything else (apiKey, grantSchemaAccess, etc.) from local state. + setLocalCustomProviders((prev) => ({ + ...prev, + [providerId]: { + ...prev[providerId], + models: definition.models, + contextWindow: definition.contextWindow, + }, + })) + + // Determine which models are truly new (not in the previous model list) + const oldModelValues = ( + localCustomProviders[providerId]?.models || [] + ).map((m) => makeCustomModelValue(providerId, m)) + const trulyNew = newModelValues.filter((m) => !oldModelValues.includes(m)) + + // Local state: respect unsaved checkbox toggles, add truly new as enabled + const localEnabled = enabledModels[providerId] || [] + const localStillEnabled = localEnabled.filter((m: string) => + newModelValues.includes(m), + ) + setEnabledModels((prev) => ({ + ...prev, + [providerId]: [...localStillEnabled, ...trulyNew], + })) + + // Storage: preserve stored enabled/disabled state, only add truly new models. + // Unsaved toggle changes (apiKey, grantSchemaAccess, enable/disable) are not + // persisted here — they require "Save Settings". + const storedProviderSettings = aiAssistantSettings.providers?.[providerId] + const storedEnabled = storedProviderSettings?.enabledModels || [] + const storedStillEnabled = storedEnabled.filter((m: string) => + newModelValues.includes(m), + ) + const storedCustomProvider = + aiAssistantSettings.customProviders?.[providerId] + updateSettings(StoreKey.AI_ASSISTANT_SETTINGS, { + ...aiAssistantSettings, + customProviders: { + ...(aiAssistantSettings.customProviders ?? {}), + ...(storedCustomProvider && { + [providerId]: { + ...storedCustomProvider, + models: definition.models, + contextWindow: definition.contextWindow, + }, + }), + }, + providers: { + ...aiAssistantSettings.providers, + ...(storedProviderSettings && { + [providerId]: { + ...storedProviderSettings, + enabledModels: [...storedStillEnabled, ...trulyNew], + }, + }), + }, + }) + + toast.success("Model preferences updated") + }, + [aiAssistantSettings, enabledModels, localCustomProviders, updateSettings], + ) + const currentProviderValidated = validatedApiKeys[selectedProvider] const currentProviderApiKey = apiKeys[selectedProvider] const currentProviderValidationState = validationState[selectedProvider] const currentProviderError = validationErrors[selectedProvider] const currentProviderIsFocused = isInputFocused[selectedProvider] const maskInput = !!(currentProviderApiKey && !currentProviderIsFocused) + const noApiKeyReadonly = + isCustomProvider && !currentProviderApiKey && !currentProviderIsFocused + const showEditButton = maskInput || noApiKeyReadonly const modelsForProvider = useMemo( () => getModelsForProvider(selectedProvider, localSettings), @@ -947,7 +1037,7 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { return ( <> @@ -1030,24 +1120,16 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { - {isCustomProvider && - !localCustomProviders[selectedProvider]?.apiKey ? ( - <> + <> + API Key - - This provider does not have an API key. - - - ) : ( - <> - - API Key - {validatedApiKeys[selectedProvider] && ( + {validatedApiKeys[selectedProvider] && + currentProviderApiKey && ( } data-hook="ai-settings-validated-badge" @@ -1055,126 +1137,147 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { Validated )} - {!isCustomProvider && ( - - Get your API key from{" "} - - {getProviderName( - selectedProvider, - localSettings, - )} - - . - - )} - - - { - handleApiKeyChange( - selectedProvider, - e.target.value, - ) - }} - placeholder={`Enter ${getProviderName(selectedProvider, localSettings)} API key`} - $hasError={!!currentProviderError} - $showEditButton={maskInput} - readOnly={maskInput} - onFocus={() => { - setIsInputFocused((prev) => ({ - ...prev, - [selectedProvider]: true, - })) - }} - onBlur={() => { - setIsInputFocused((prev) => ({ - ...prev, - [selectedProvider]: false, - })) - if (inputRef.current) { - inputRef.current.blur() - } - }} - onMouseDown={(e) => { - if (maskInput) { - e.preventDefault() + {!isCustomProvider && ( + + Get your API key from{" "} + - {maskInput && ( - { - inputRef.current?.focus() - }} - title="Edit API key" + target="_blank" + rel="noopener noreferrer" > - - - )} - - {currentProviderError && ( - {currentProviderError} + {getProviderName( + selectedProvider, + localSettings, + )} + + . + )} - {!currentProviderError && ( - - Stored locally in your browser and never sent to - QuestDB servers. This API key is used to - authenticate your requests to the model provider. - + + + { + handleApiKeyChange( + selectedProvider, + e.target.value, + ) + }} + placeholder={ + noApiKeyReadonly + ? "This provider does not have an API key" + : `Enter ${getProviderName(selectedProvider, localSettings)} API key` + } + $hasError={!!currentProviderError} + $showEditButton={showEditButton} + readOnly={maskInput || noApiKeyReadonly} + onFocus={() => { + setIsInputFocused((prev) => ({ + ...prev, + [selectedProvider]: true, + })) + }} + onBlur={() => { + setIsInputFocused((prev) => ({ + ...prev, + [selectedProvider]: false, + })) + if (inputRef.current) { + inputRef.current.blur() + } + }} + onMouseDown={(e) => { + if (maskInput || noApiKeyReadonly) { + e.preventDefault() + } + }} + tabIndex={maskInput || noApiKeyReadonly ? -1 : 0} + style={{ + cursor: + maskInput || noApiKeyReadonly + ? "default" + : "text", + }} + data-hook="ai-settings-api-key" + /> + {showEditButton && ( + { + inputRef.current?.focus() + }} + title="Edit API key" + > + + )} - {!currentProviderValidated && - currentProviderApiKey && ( - - handleValidateApiKey(selectedProvider) - } - disabled={ - currentProviderValidationState === - "validating" - } - data-hook="ai-settings-test-api" - > - {currentProviderValidationState === - "validating" ? ( - - - Validating... - - ) : ( - "Validate API Key" - )} - + + {currentProviderError && ( + {currentProviderError} + )} + {!currentProviderError && ( + + Stored locally in your browser and never sent to + QuestDB servers. This API key is used to + authenticate your requests to the model provider. + + )} + {!currentProviderValidated && currentProviderApiKey && ( + + handleValidateApiKey(selectedProvider) + } + disabled={ + currentProviderValidationState === "validating" + } + data-hook="ai-settings-test-api" + > + {currentProviderValidationState === "validating" ? ( + + + Validating... + + ) : ( + "Validate API Key" )} - - )} + + )} + - Enable Models - {currentProviderValidated ? ( + + Enable Models + {isCustomProvider && + (currentProviderValidated || + modelsForProvider.length > 0) && ( + setManageModelsModalOpen(true)} + > + Manage models + + )} + + {currentProviderValidated || + (isCustomProvider && modelsForProvider.length > 0) ? ( {modelsForProvider.map((model) => { const isEnabled = enabledModelsForProvider.includes( @@ -1242,7 +1345,13 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { e.target.checked, ) } - disabled={!currentProviderValidated} + disabled={ + !currentProviderValidated && + !( + isCustomProvider && + modelsForProvider.length > 0 + ) + } data-hook="ai-settings-schema-access" /> @@ -1310,6 +1419,17 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { existingProviderIds={allProviders} /> )} + {manageModelsModalOpen && + isCustomProvider && + localCustomProviders[selectedProvider] && ( + + )} ) } From 690ee1719e137ff61023521d346a871f5cf6ea73 Mon Sep 17 00:00:00 2001 From: emrberk Date: Fri, 13 Mar 2026 17:41:28 +0300 Subject: [PATCH 22/25] merge --- e2e/tests/console/aiAssistant.spec.js | 2 +- .../SetupAIAssistant/ModelSettings.tsx | 35 ++++++++++++++----- 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/e2e/tests/console/aiAssistant.spec.js b/e2e/tests/console/aiAssistant.spec.js index ef7b181e1..d4bef43f6 100644 --- a/e2e/tests/console/aiAssistant.spec.js +++ b/e2e/tests/console/aiAssistant.spec.js @@ -3240,7 +3240,7 @@ describe("custom providers", () => { "have.value", "200000", ) - cy.getByDataHook("custom-provider-schema-access").should("not.be.checked") + cy.getByDataHook("custom-provider-schema-access").should("be.checked") cy.getByDataHook("custom-provider-add-model-button").should("be.disabled") cy.getByDataHook("custom-provider-manual-model-input").type( diff --git a/src/components/SetupAIAssistant/ModelSettings.tsx b/src/components/SetupAIAssistant/ModelSettings.tsx index f91ef8f17..c976d89a4 100644 --- a/src/components/SetupAIAssistant/ModelSettings.tsx +++ b/src/components/SetupAIAssistant/ModelSettings.tsx @@ -336,7 +336,7 @@ export const ModelSettings = forwardRef( ) const [manualModelInput, setManualModelInput] = useState("") const [contextWindowInput, setContextWindowInput] = useState(() => - String(initialValues?.contextWindow ?? 128_000), + String(initialValues?.contextWindow ?? 200_000), ) const [grantSchemaAccess, setGrantSchemaAccess] = useState( () => initialValues?.grantSchemaAccess ?? true, @@ -357,7 +357,7 @@ export const ModelSettings = forwardRef( const config = fetchConfigRef.current const initModels = initialValuesRef.current?.models ?? [] const initContextWindow = - initialValuesRef.current?.contextWindow ?? 128_000 + initialValuesRef.current?.contextWindow ?? 200_000 const models = await fetchProviderModels(config, initContextWindow) @@ -465,7 +465,7 @@ export const ModelSettings = forwardRef( const pending = manualModelInput.trim() const models = isAutoMode ? selectedModels : manualModels const hasModels = models.length > 0 || !!pending - if (!hasModels) return "Select at least one model" + if (!hasModels) return "Add at least one model" const trimmed = contextWindowInput.trim() if (!trimmed) return "Context window is required" const contextWindow = Number(trimmed) @@ -502,7 +502,7 @@ export const ModelSettings = forwardRef( <> {!isAutoMode && ( - + ( > Select Models - + Select All - + Deselect All {fetchedModels.map((model) => ( - + handleToggleModel(model)} @@ -555,6 +563,7 @@ export const ModelSettings = forwardRef( setManualModelInput(e.target.value)} placeholder="e.g., llama3, gpt-4o, claude-sonnet-4-20250514" @@ -567,6 +576,7 @@ export const ModelSettings = forwardRef( /> @@ -580,9 +590,13 @@ export const ModelSettings = forwardRef( {selectedModels .filter((m) => !fetchedModels.includes(m)) .map((model) => ( - + {model} handleToggleModel(model)} > @@ -595,9 +609,10 @@ export const ModelSettings = forwardRef( {!isAutoMode && manualModels.length > 0 && ( {manualModels.map((model) => ( - + {model} handleRemoveManualModel(model)} title={`Remove ${model}`} @@ -615,6 +630,7 @@ export const ModelSettings = forwardRef( Context Window setContextWindowInput(e.target.value)} @@ -635,6 +651,7 @@ export const ModelSettings = forwardRef( setGrantSchemaAccess(e.target.checked)} /> From 28b9937316fc32968836aebd0e92290f9db5b0a1 Mon Sep 17 00:00:00 2001 From: emrberk Date: Fri, 13 Mar 2026 17:42:23 +0300 Subject: [PATCH 23/25] submodule --- e2e/questdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/e2e/questdb b/e2e/questdb index c5dc7a1eb..ae67e0bd4 160000 --- a/e2e/questdb +++ b/e2e/questdb @@ -1 +1 @@ -Subproject commit c5dc7a1eb1533a6971322e69515278e711ec4ec7 +Subproject commit ae67e0bd43773ca7de8386bef0c0b7bece6f4f88 From 8f1110fa00a8d15a3680d1f4b00b409c4e802687 Mon Sep 17 00:00:00 2001 From: emrberk Date: Fri, 13 Mar 2026 19:15:25 +0300 Subject: [PATCH 24/25] new test cases, uuid generation for providers --- e2e/tests/console/aiAssistant.spec.js | 351 ++++++++++++++---- .../SetupAIAssistant/ConfigurationModal.tsx | 4 +- .../SetupAIAssistant/CustomProviderModal.tsx | 22 +- .../SetupAIAssistant/ManageModelsModal.tsx | 2 + .../SetupAIAssistant/SettingsModal.tsx | 5 +- 5 files changed, 304 insertions(+), 80 deletions(-) diff --git a/e2e/tests/console/aiAssistant.spec.js b/e2e/tests/console/aiAssistant.spec.js index d4bef43f6..bf7991c67 100644 --- a/e2e/tests/console/aiAssistant.spec.js +++ b/e2e/tests/console/aiAssistant.spec.js @@ -2550,7 +2550,7 @@ describe("custom providers", () => { }).as("unhandledAnthropic") }) - it("should configure provider with auto-fetched models, select/deselect, and verify localStorage", () => { + it("should configure provider with auto-fetched models, select/deselect", () => { cy.loadConsoleWithAuth() cy.intercept("GET", "**/models*", { @@ -2627,32 +2627,6 @@ describe("custom providers", () => { cy.contains("AI Assistant activated successfully").should("be.visible") cy.getByDataHook("ai-chat-button").should("be.visible") - - cy.window().then((win) => { - const settings = JSON.parse( - win.localStorage.getItem("ai.assistant.settings"), - ) - expect(settings.customProviders.ollama).to.exist - expect(settings.customProviders.ollama.type).to.equal( - "openai-chat-completions", - ) - expect(settings.customProviders.ollama.name).to.equal("Ollama") - expect(settings.customProviders.ollama.baseURL).to.equal( - "http://localhost:11434/v1", - ) - expect(settings.customProviders.ollama.models).to.deep.equal([ - "llama3", - "mistral", - ]) - expect(settings.customProviders.ollama.contextWindow).to.equal(200000) - expect(settings.customProviders.ollama.grantSchemaAccess).to.be.true - - expect(settings.providers.ollama.enabledModels).to.deep.equal([ - "ollama:llama3", - "ollama:mistral", - ]) - expect(settings.selectedModel).to.equal("ollama:llama3") - }) }) it("should reject invalid URL, require models, enforce context window minimum, and prevent duplicates", () => { @@ -2729,19 +2703,6 @@ describe("custom providers", () => { cy.contains("AI Assistant activated successfully").should("be.visible") cy.getByDataHook("ai-chat-button").should("be.visible") - - cy.window().then((win) => { - const settings = JSON.parse( - win.localStorage.getItem("ai.assistant.settings"), - ) - expect(settings.customProviders.openrouter).to.exist - expect(settings.customProviders.openrouter.models).to.deep.equal([ - "gpt-4o", - "claude-3.5-sonnet", - ]) - expect(settings.customProviders.openrouter.apiKey).to.equal("sk-test") - expect(settings.customProviders.openrouter.contextWindow).to.equal(100000) - }) }) it("should send chat with tool call through custom endpoint and accept SQL suggestion", () => { @@ -2873,7 +2834,9 @@ describe("custom providers", () => { cy.getByDataHook("multi-step-modal-next-button").click() cy.getByDataHook("ai-settings-provider-ollama").should("be.visible") - cy.getByDataHook("ai-settings-provider-openrouter").should("be.visible") + cy.contains("[data-hook^='ai-settings-provider-']", "OpenRouter").should( + "be.visible", + ) cy.getByDataHook("ai-settings-save").click() cy.getByDataHook("ai-settings-model-dropdown").should("be.visible").click() @@ -2886,23 +2849,15 @@ describe("custom providers", () => { cy.getByDataHook("ai-settings-remove-provider").click() cy.getByDataHook("ai-settings-provider-ollama").should("not.exist") - cy.getByDataHook("ai-settings-provider-openrouter").should("be.visible") + cy.contains("[data-hook^='ai-settings-provider-']", "OpenRouter").should( + "be.visible", + ) cy.getByDataHook("ai-settings-save").click() cy.getByDataHook("ai-settings-model-dropdown").should("be.visible").click() cy.getByDataHook("ai-settings-model-item").should("have.length", 1) cy.contains("gpt-4o").should("be.visible") cy.getByDataHook("ai-settings-model-dropdown").click() - - cy.window().then((win) => { - const settings = JSON.parse( - win.localStorage.getItem("ai.assistant.settings"), - ) - expect(settings.customProviders.ollama).to.not.exist - expect(settings.providers.ollama).to.not.exist - expect(settings.customProviders.openrouter).to.exist - expect(settings.providers.openrouter).to.exist - }) }) it("should show error on 401, retry successfully, and show error on network failure", () => { @@ -2990,7 +2945,7 @@ describe("custom providers", () => { cy.getByDataHook("retry-button").should("be.visible") }) - it("should reject duplicate names against custom and built-in providers, and sanitize special characters in IDs", () => { + it("should reject duplicate names against custom and built-in providers, and allow unique names", () => { cy.loadConsoleWithAuth( false, getCustomProviderConfiguredSettings({ @@ -3003,6 +2958,7 @@ describe("custom providers", () => { cy.getByDataHook("ai-assistant-settings-button").click() cy.getByDataHook("ai-settings-add-custom-provider").click() + // Exact duplicate name cy.getByDataHook("custom-provider-name-input").type("My Provider") cy.getByDataHook("custom-provider-base-url-input").type( "http://localhost:1234", @@ -3014,12 +2970,12 @@ describe("custom providers", () => { }) cy.getByDataHook("multi-step-modal-next-button").click() - cy.contains("A provider with a similar name already exists").should( + cy.contains("A provider with the same name already exists").should( "be.visible", ) - // "OpenAI" collides with built-in provider ID "openai" - cy.getByDataHook("custom-provider-name-input").clear().type("OpenAI") + // Case-insensitive duplicate of built-in provider name + cy.getByDataHook("custom-provider-name-input").clear().type("openai") cy.intercept("GET", "http://localhost:1234/models", { statusCode: 500, @@ -3027,11 +2983,11 @@ describe("custom providers", () => { }) cy.getByDataHook("multi-step-modal-next-button").click() - cy.contains("A provider with a similar name already exists").should( + cy.contains("A provider with the same name already exists").should( "be.visible", ) - // Special characters should be stripped from the generated ID + // Unique name should proceed to step 2 cy.getByDataHook("custom-provider-name-input") .clear() .type("My Provider (v2.0)!") @@ -3048,18 +3004,22 @@ describe("custom providers", () => { cy.getByDataHook("custom-provider-add-model-button").click() cy.getByDataHook("multi-step-modal-next-button").click() - cy.getByDataHook("ai-settings-provider-my-provider").should("be.visible") - cy.getByDataHook("ai-settings-provider-my-provider-v2-0").should( - "be.visible", - ) + // Both providers visible in sidebar (find by name text) + cy.contains("My Provider").should("be.visible") + cy.contains("My Provider (v2.0)!").should("be.visible") cy.getByDataHook("ai-settings-save").click() cy.window().then((win) => { const settings = JSON.parse( win.localStorage.getItem("ai.assistant.settings"), ) - expect(settings.customProviders["my-provider-v2-0"]).to.exist - expect(settings.customProviders["my-provider-v2-0"].name).to.equal( + // Provider ID should be a UUID + const customIds = Object.keys(settings.customProviders) + const newProviderId = customIds.find((id) => id !== "my-provider") + expect(newProviderId).to.match( + /^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$/, + ) + expect(settings.customProviders[newProviderId].name).to.equal( "My Provider (v2.0)!", ) }) @@ -3401,4 +3361,267 @@ describe("custom providers", () => { .should("be.visible") .should("contain", "Auth response") }) + + it("should open manage models, add and remove models, update context window, and reflect in dropdown", () => { + const providerId = "ollama" + + cy.loadConsoleWithAuth( + false, + getCustomProviderConfiguredSettings({ + providerId, + name: "Ollama", + models: ["llama3", "mistral", "codellama"], + }), + ) + + // Intercept model fetch before opening modal + cy.intercept("GET", "**/models*", { + statusCode: 200, + body: { + object: "list", + data: [ + { id: "llama3", object: "model" }, + { id: "mistral", object: "model" }, + { id: "codellama", object: "model" }, + ], + }, + }).as("manageModelsFetch") + + cy.getByDataHook("ai-assistant-settings-button").click() + cy.getByDataHook("ai-settings-provider-ollama").should("be.visible").click() + + cy.getByDataHook("ai-settings-manage-models").should("be.visible").click() + cy.wait("@manageModelsFetch") + + // All 3 models should be checked + cy.getByDataHook("custom-provider-model-row").should("have.length", 3) + cy.getByDataHook("custom-provider-model-row") + .find('input[type="checkbox"]') + .each(($checkbox) => { + cy.wrap($checkbox).should("be.checked") + }) + + // Add a manual model + cy.getByDataHook("custom-provider-manual-model-input").type( + "custom-finetune", + ) + cy.getByDataHook("custom-provider-add-model-button").click() + cy.getByDataHook("custom-provider-model-chip").should("have.length", 1) + cy.getByDataHook("custom-provider-model-chip").should( + "contain", + "custom-finetune", + ) + + // Uncheck codellama + cy.getByDataHook("custom-provider-model-row").contains("codellama").click() + + // Validation: context window too low + cy.getByDataHook("custom-provider-context-window-input") + .clear() + .type("50000") + cy.getByDataHook("manage-models-save").click() + cy.contains("Context window must be at least 100,000 tokens").should( + "be.visible", + ) + + // Fix context window and save + cy.getByDataHook("custom-provider-context-window-input") + .clear() + .type("150000") + cy.getByDataHook("custom-provider-context-window-input").should( + "have.value", + "150000", + ) + cy.getByDataHook("manage-models-save").click() + + // Modal should close, settings modal visible again + cy.getByDataHook("ai-settings-manage-models").should("be.visible") + + // Save the outer settings modal (manage-models toast auto-dismisses) + cy.getByDataHook("ai-settings-save").click() + cy.get(".toast-success-container").should("be.visible").first().click() + + // Dropdown should show 3 models (llama3, mistral, custom-finetune) + cy.getByDataHook("ai-settings-model-dropdown").should("be.visible").click() + cy.getByDataHook("ai-settings-model-item").should("have.length", 3) + cy.contains("llama3").should("be.visible") + cy.contains("mistral").should("be.visible") + cy.contains("custom-finetune").should("be.visible") + cy.getByDataHook("ai-settings-model-dropdown").click() + }) + + it("should auto-enable new models from manage models and preserve unsaved toggle state", () => { + const providerId = "test-provider" + + cy.loadConsoleWithAuth( + false, + getCustomProviderConfiguredSettings({ + providerId, + name: "Test Provider", + models: ["model-a", "model-b", "model-c"], + }), + ) + + cy.getByDataHook("ai-assistant-settings-button").click() + cy.getByDataHook("ai-settings-provider-test-provider") + .should("be.visible") + .click() + + // All 3 models should be enabled + cy.get("[data-model='model-a'][data-enabled='true']").should("exist") + cy.get("[data-model='model-b'][data-enabled='true']").should("exist") + cy.get("[data-model='model-c'][data-enabled='true']").should("exist") + + // Disable model-b toggle (unsaved state) + cy.get("[data-model='model-b']").find("button[role='switch']").click() + cy.get("[data-model='model-b'][data-enabled='false']").should("exist") + + // Intercept model fetch → fail to get manual mode + cy.intercept("GET", "**/models*", { + statusCode: 500, + body: { error: "Server error" }, + }).as("manageModelsFetchFail") + + // Open manage models + cy.getByDataHook("ai-settings-manage-models").click() + cy.wait("@manageModelsFetchFail") + + // Manual mode: warning banner + existing models as chips + cy.getByDataHook("custom-provider-warning-banner").should("be.visible") + cy.getByDataHook("custom-provider-model-chip").should("have.length", 3) + + // Add model-d + cy.getByDataHook("custom-provider-manual-model-input").type("model-d") + cy.getByDataHook("custom-provider-add-model-button").click() + cy.getByDataHook("custom-provider-model-chip").should("have.length", 4) + + // Remove model-b + cy.getByDataHook("custom-provider-model-chip") + .filter(":contains('model-b')") + .find("[data-hook='custom-provider-remove-model']") + .click() + cy.getByDataHook("custom-provider-model-chip").should("have.length", 3) + + // Save manage models + cy.getByDataHook("manage-models-save").click() + + // Back in SettingsModal: model-b gone, model-d auto-enabled + cy.get("[data-model='model-a'][data-enabled='true']").should("exist") + cy.get("[data-model='model-b']").should("not.exist") + cy.get("[data-model='model-c'][data-enabled='true']").should("exist") + cy.get("[data-model='model-d'][data-enabled='true']").should("exist") + + // Save settings + cy.getByDataHook("ai-settings-save").click() + cy.get(".toast-success-container").should("be.visible").first().click() + + // Dropdown should show 3 models (model-a, model-c, model-d) + cy.getByDataHook("ai-settings-model-dropdown").should("be.visible").click() + cy.getByDataHook("ai-settings-model-item").should("have.length", 3) + cy.contains("model-a").should("be.visible") + cy.contains("model-c").should("be.visible") + cy.contains("model-d").should("be.visible") + cy.getByDataHook("ai-settings-model-dropdown").click() + }) + + it("should handle no-API-key custom provider: models visible, no validated badge, schema toggle enabled, and allow adding an API key", () => { + const providerId = "ollama" + const customEndpoint = getCustomProviderEndpoint( + CUSTOM_PROVIDER_DEFAULTS.baseURL, + "openai-chat-completions", + ) + + cy.loadConsoleWithAuth( + false, + getCustomProviderConfiguredSettings({ + providerId, + name: "Ollama", + baseURL: CUSTOM_PROVIDER_DEFAULTS.baseURL, + apiKey: "", + models: ["llama3", "mistral"], + }), + ) + + cy.getByDataHook("ai-assistant-settings-button").click() + cy.getByDataHook("ai-settings-provider-ollama").should("be.visible").click() + + // Part A: No-API-key state + + // No validated badge + cy.getByDataHook("ai-settings-validated-badge").should("not.exist") + + // API key input shows placeholder about no key + cy.getByDataHook("ai-settings-api-key").should( + "have.attr", + "placeholder", + "This provider does not have an API key", + ) + + // Model list visible with both models + cy.get("[data-model='llama3']").should("exist") + cy.get("[data-model='mistral']").should("exist") + + // Schema access toggle is not disabled + cy.getByDataHook("ai-settings-schema-access").should("not.be.disabled") + + // Manage models button visible + cy.getByDataHook("ai-settings-manage-models").should("be.visible") + + // Toggle mistral off + cy.get("[data-model='mistral']").find("button[role='switch']").click() + cy.get("[data-model='mistral'][data-enabled='false']").should("exist") + + // Built-in provider should NOT have manage models button + cy.getByDataHook("ai-settings-provider-openai").click() + cy.getByDataHook("ai-settings-manage-models").should("not.exist") + + // Part B: Add API key + + // Switch back to custom provider + cy.getByDataHook("ai-settings-provider-ollama").click() + + // Click Edit button to make input editable, then type API key + cy.get('button[title="Edit API key"]').click() + cy.getByDataHook("ai-settings-api-key").type("sk-custom-key-123") + + // Intercept validation request to custom endpoint + cy.intercept("POST", customEndpoint, { + statusCode: 200, + delay: 200, + body: { + id: "chatcmpl-mock", + object: "chat.completion", + choices: [ + { + index: 0, + message: { role: "assistant", content: "" }, + finish_reason: "stop", + }, + ], + usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 }, + }, + }).as("customValidation") + + // Click validate + cy.getByDataHook("ai-settings-test-api").should("be.visible").click() + cy.wait("@customValidation") + + // Validated badge should now appear + cy.getByDataHook("ai-settings-validated-badge").should("be.visible") + + // Models still visible, mistral toggle preserved + cy.get("[data-model='llama3'][data-enabled='true']").should("exist") + cy.get("[data-model='mistral'][data-enabled='false']").should("exist") + + // Part C: Save and verify + + cy.getByDataHook("ai-settings-save").click() + cy.get(".toast-success-container").should("be.visible").click() + + // Dropdown should show only llama3 (mistral was disabled) + cy.getByDataHook("ai-settings-model-dropdown").should("be.visible").click() + cy.getByDataHook("ai-settings-model-item").should("have.length", 1) + cy.contains("llama3").should("be.visible") + cy.getByDataHook("ai-settings-model-dropdown").click() + }) }) diff --git a/src/components/SetupAIAssistant/ConfigurationModal.tsx b/src/components/SetupAIAssistant/ConfigurationModal.tsx index b380c7a31..09ea7c807 100644 --- a/src/components/SetupAIAssistant/ConfigurationModal.tsx +++ b/src/components/SetupAIAssistant/ConfigurationModal.tsx @@ -967,7 +967,9 @@ export const ConfigurationModal = ({ open={customProviderModalOpen} onOpenChange={setCustomProviderModalOpen} onSave={handleCustomProviderSave} - existingProviderIds={getAllProviders(aiAssistantSettings)} + existingProviderNames={getAllProviders(aiAssistantSettings).map((p) => + getProviderName(p, aiAssistantSettings), + )} /> )} diff --git a/src/components/SetupAIAssistant/CustomProviderModal.tsx b/src/components/SetupAIAssistant/CustomProviderModal.tsx index 182a65562..84760a982 100644 --- a/src/components/SetupAIAssistant/CustomProviderModal.tsx +++ b/src/components/SetupAIAssistant/CustomProviderModal.tsx @@ -279,20 +279,14 @@ export type CustomProviderModalProps = { open: boolean onOpenChange: (open: boolean) => void onSave: (providerId: string, provider: CustomProviderDefinition) => void - existingProviderIds: string[] + existingProviderNames: string[] } -const generateProviderId = (name: string): string => - name - .toLowerCase() - .replace(/[^a-z0-9]+/g, "-") - .replace(/^-|-$/g, "") || "custom-provider" - export const CustomProviderModal = ({ open, onOpenChange, onSave, - existingProviderIds, + existingProviderNames, }: CustomProviderModalProps) => { const [name, setName] = useState("") const [providerType, setProviderType] = useState( @@ -309,19 +303,19 @@ export const CustomProviderModal = ({ if (!baseURL.startsWith("http://") && !baseURL.startsWith("https://")) return "Base URL must start with http:// or https://" - const providerId = generateProviderId(name) - if (existingProviderIds.includes(providerId)) - return "A provider with a similar name already exists" + const normalizedName = name.trim().toLowerCase() + if (existingProviderNames.some((n) => n.toLowerCase() === normalizedName)) + return "A provider with the same name already exists" return true - }, [name, baseURL, existingProviderIds]) + }, [name, baseURL, existingProviderNames]) const modelsValidate = useCallback((): string | boolean => { return modelSettingsRef.current?.validate() ?? "Not ready" }, []) const handleComplete = useCallback(() => { - const providerId = generateProviderId(name) + const providerId = crypto.randomUUID() const values = modelSettingsRef.current?.getValues() if (!values) return @@ -374,7 +368,7 @@ export const CustomProviderModal = ({ ref={modelSettingsRef} fetchConfig={{ providerType, - providerId: generateProviderId(name), + providerId: "custom-provider-setup", apiKey: apiKey || "", baseURL, }} diff --git a/src/components/SetupAIAssistant/ManageModelsModal.tsx b/src/components/SetupAIAssistant/ManageModelsModal.tsx index b6de4da0a..c4cf606c9 100644 --- a/src/components/SetupAIAssistant/ManageModelsModal.tsx +++ b/src/components/SetupAIAssistant/ManageModelsModal.tsx @@ -141,12 +141,14 @@ export const ManageModelsModal = ({ {error && {error}} onOpenChange(false)} > Cancel diff --git a/src/components/SetupAIAssistant/SettingsModal.tsx b/src/components/SetupAIAssistant/SettingsModal.tsx index 515b9ff19..5e10fb316 100644 --- a/src/components/SetupAIAssistant/SettingsModal.tsx +++ b/src/components/SetupAIAssistant/SettingsModal.tsx @@ -1270,6 +1270,7 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { modelsForProvider.length > 0) && ( setManageModelsModalOpen(true)} > Manage models @@ -1416,7 +1417,9 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { open={customProviderModalOpen} onOpenChange={setCustomProviderModalOpen} onSave={handleCustomProviderSave} - existingProviderIds={allProviders} + existingProviderNames={allProviders.map((p) => + getProviderName(p, localSettings), + )} /> )} {manageModelsModalOpen && From b11b8a01873b26105951bbbd6a16314900286be1 Mon Sep 17 00:00:00 2001 From: emrberk Date: Fri, 13 Mar 2026 19:15:47 +0300 Subject: [PATCH 25/25] submodule --- e2e/questdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/e2e/questdb b/e2e/questdb index ae67e0bd4..c30491b9d 160000 --- a/e2e/questdb +++ b/e2e/questdb @@ -1 +1 @@ -Subproject commit ae67e0bd43773ca7de8386bef0c0b7bece6f4f88 +Subproject commit c30491b9dc70125d12f5890120e06aad321c2637