diff --git a/e2e/commands.js b/e2e/commands.js index 6e2559af0..5dd0de036 100644 --- a/e2e/commands.js +++ b/e2e/commands.js @@ -628,3 +628,9 @@ Cypress.Commands.add("waitForAIResponse", (alias) => { cy.wait(alias) cy.waitForStreamingComplete() }) + +Cypress.Commands.add("setupCustomProvider", (config = {}) => { + const { getCustomProviderConfiguredSettings } = require("./utils/aiAssistant") + const settings = getCustomProviderConfiguredSettings(config) + cy.loadConsoleWithAuth(false, settings) +}) diff --git a/e2e/questdb b/e2e/questdb index 8ab362f2d..c30491b9d 160000 --- a/e2e/questdb +++ b/e2e/questdb @@ -1 +1 @@ -Subproject commit 8ab362f2d269f699bdd855870144468aa6e7e5d2 +Subproject commit c30491b9dc70125d12f5890120e06aad321c2637 diff --git a/e2e/tests/console/aiAssistant.spec.js b/e2e/tests/console/aiAssistant.spec.js index 72d8e733e..bf7991c67 100644 --- a/e2e/tests/console/aiAssistant.spec.js +++ b/e2e/tests/console/aiAssistant.spec.js @@ -2,6 +2,7 @@ const { PROVIDERS, + CUSTOM_PROVIDER_DEFAULTS, getOpenAIConfiguredSettings, getAnthropicConfiguredSettings, createToolCallFlow, @@ -10,6 +11,8 @@ const { createFinalResponseData, createChatTitleResponse, isTitleRequest, + getCustomProviderConfiguredSettings, + getCustomProviderEndpoint, } = require("../../utils/aiAssistant") /** @@ -266,28 +269,24 @@ describe("ai assistant", () => { // Then cy.getByDataHook("ai-settings-modal-step-one").should("be.visible") - cy.getByDataHook("ai-settings-api-key") - .should("be.visible") - .should("have.attr", "placeholder", "Enter API key") - .should("be.disabled") + // API key input is hidden until a provider is selected + cy.getByDataHook("ai-settings-api-key").should("not.exist") - // When + // When - select Anthropic cy.getByDataHook("ai-settings-provider-anthropic").click() - // Then + // Then - API key input appears cy.getByDataHook("ai-settings-api-key") .should("be.visible") .should("have.attr", "placeholder", "Enter Anthropic API key") - .should("not.be.disabled") - // When + // When - switch to OpenAI cy.getByDataHook("ai-settings-provider-openai").click() // Then cy.getByDataHook("ai-settings-api-key") .should("be.visible") .should("have.attr", "placeholder", "Enter OpenAI API key") - .should("not.be.disabled") ;["anthropic", "openai"].forEach((provider) => { // Given interceptTokenValidation(provider, false) @@ -303,7 +302,6 @@ describe("ai assistant", () => { "placeholder", `Enter ${provider === "anthropic" ? "Anthropic" : "OpenAI"} API key`, ) - .should("not.be.disabled") .should("be.empty") // When @@ -415,9 +413,9 @@ describe("ai assistant", () => { .should("contain", "Inactive") // When - cy.getByDataHook("ai-settings-test-api") + cy.getByDataHook("ai-settings-remove-provider").scrollIntoView() + cy.getByDataHook("ai-settings-remove-provider") .should("be.visible") - .should("contain", "Remove API Key") .click() // Then @@ -2536,3 +2534,1094 @@ Syntax: \`avg(column)\` }) }) }) + +describe("custom providers", () => { + beforeEach(() => { + cy.intercept("POST", PROVIDERS.openai.endpoint, (req) => { + throw new Error( + `Unhandled OpenAI request detected! Request body: ${JSON.stringify(req.body).slice(0, 200)}...`, + ) + }).as("unhandledOpenAI") + + cy.intercept("POST", PROVIDERS.anthropic.endpoint, (req) => { + throw new Error( + `Unhandled Anthropic request detected! Request body: ${JSON.stringify(req.body).slice(0, 200)}...`, + ) + }).as("unhandledAnthropic") + }) + + it("should configure provider with auto-fetched models, select/deselect", () => { + cy.loadConsoleWithAuth() + + cy.intercept("GET", "**/models*", { + statusCode: 200, + body: { + object: "list", + data: [ + { id: "llama3", object: "model" }, + { id: "mistral", object: "model" }, + { id: "codellama", object: "model" }, + ], + }, + }).as("modelListRequest") + + cy.getByDataHook("ai-assistant-settings-button") + .should("be.visible") + .click() + cy.getByDataHook("ai-promo-continue").should("be.visible").click() + cy.getByDataHook("ai-settings-modal-step-one").should("be.visible") + cy.getByDataHook("ai-settings-provider-custom").should("be.visible").click() + + cy.getByDataHook("custom-provider-name-input") + .should("be.visible") + .type("Ollama") + cy.getByDataHook("custom-provider-type-select").should( + "have.value", + "openai-chat-completions", + ) + cy.getByDataHook("custom-provider-base-url-input").type( + "http://localhost:11434/v1", + ) + + cy.getByDataHook("multi-step-modal-next-button").click() + cy.wait("@modelListRequest") + + cy.getByDataHook("custom-provider-model-row").should("have.length", 3) + cy.getByDataHook("custom-provider-context-window-input").should( + "have.value", + "200000", + ) + + cy.getByDataHook("custom-provider-select-all").click() + cy.getByDataHook("custom-provider-model-row") + .find('input[type="checkbox"]') + .each(($checkbox) => { + cy.wrap($checkbox).should("be.checked") + }) + + cy.getByDataHook("custom-provider-deselect-all").click() + cy.getByDataHook("custom-provider-model-row") + .find('input[type="checkbox"]') + .each(($checkbox) => { + cy.wrap($checkbox).should("not.be.checked") + }) + + cy.getByDataHook("custom-provider-model-row").contains("llama3").click() + cy.getByDataHook("custom-provider-model-row").contains("mistral").click() + + cy.getByDataHook("custom-provider-manual-model-input").type( + "custom-finetune", + ) + cy.getByDataHook("custom-provider-add-model-button").click() + cy.getByDataHook("custom-provider-model-chip").should("have.length", 1) + cy.getByDataHook("custom-provider-model-chip").should( + "contain", + "custom-finetune", + ) + + cy.getByDataHook("custom-provider-remove-model").click() + cy.getByDataHook("custom-provider-model-chip").should("not.exist") + + cy.getByDataHook("custom-provider-schema-access").check() + cy.getByDataHook("multi-step-modal-next-button").click() + + cy.contains("AI Assistant activated successfully").should("be.visible") + cy.getByDataHook("ai-chat-button").should("be.visible") + }) + + it("should reject invalid URL, require models, enforce context window minimum, and prevent duplicates", () => { + cy.loadConsoleWithAuth() + + cy.getByDataHook("ai-assistant-settings-button") + .should("be.visible") + .click() + cy.getByDataHook("ai-promo-continue").should("be.visible").click() + cy.getByDataHook("ai-settings-provider-custom").should("be.visible").click() + + cy.getByDataHook("multi-step-modal-next-button").should("be.disabled") + + cy.getByDataHook("custom-provider-name-input").type("OpenRouter") + cy.getByDataHook("custom-provider-base-url-input").type("ftp://invalid") + cy.getByDataHook("multi-step-modal-next-button").click() + cy.contains("Base URL must start with http:// or https://").should( + "be.visible", + ) + + cy.getByDataHook("custom-provider-base-url-input") + .clear() + .type("https://openrouter.ai/api/v1") + cy.getByDataHook("custom-provider-api-key-input").type("sk-test") + + cy.intercept("GET", "**/models*", { + statusCode: 500, + body: { error: "Internal Server Error" }, + }).as("modelListFail") + + cy.getByDataHook("multi-step-modal-next-button").click() + cy.wait("@modelListFail") + + cy.getByDataHook("custom-provider-warning-banner").should("be.visible") + + cy.getByDataHook("multi-step-modal-next-button").click() + cy.contains("Add at least one model").should("be.visible") + + cy.getByDataHook("custom-provider-manual-model-input").type("gpt-4o{enter}") + cy.getByDataHook("custom-provider-model-chip") + .should("have.length", 1) + .should("contain", "gpt-4o") + + // Duplicate model should not create a second chip + cy.getByDataHook("custom-provider-manual-model-input").type("gpt-4o") + cy.getByDataHook("custom-provider-add-model-button").click() + cy.getByDataHook("custom-provider-model-chip").should("have.length", 1) + + // Input not cleared on duplicate, clear manually + cy.getByDataHook("custom-provider-manual-model-input") + .clear() + .type("claude-3.5-sonnet") + cy.getByDataHook("custom-provider-add-model-button").click() + cy.getByDataHook("custom-provider-model-chip").should("have.length", 2) + + cy.getByDataHook("custom-provider-add-model-button").should("be.disabled") + + cy.getByDataHook("custom-provider-context-window-input").type( + "{selectall}50000", + ) + cy.getByDataHook("custom-provider-context-window-input").should( + "have.value", + "50000", + ) + cy.getByDataHook("multi-step-modal-next-button").click() + cy.contains("Context window must be at least 100,000 tokens").should( + "be.visible", + ) + + cy.getByDataHook("custom-provider-context-window-input").type( + "{selectall}100000", + ) + cy.getByDataHook("multi-step-modal-next-button").click() + + cy.contains("AI Assistant activated successfully").should("be.visible") + cy.getByDataHook("ai-chat-button").should("be.visible") + }) + + it("should send chat with tool call through custom endpoint and accept SQL suggestion", () => { + const customBaseURL = CUSTOM_PROVIDER_DEFAULTS.baseURL + const customEndpoint = getCustomProviderEndpoint( + customBaseURL, + "openai-chat-completions", + ) + + cy.setupCustomProvider() + cy.createTable("btc_trades") + cy.refreshSchema() + + const assistantResponse = + "Here are the tables in your database. Let me write a query for btc_trades." + const sql = "SELECT * FROM btc_trades LIMIT 10;" + + const flow = createToolCallFlow({ + provider: "openai-chat-completions", + streaming: true, + question: "What tables are in the database?", + endpoint: customEndpoint, + steps: [ + { toolCall: { name: "get_tables", args: {} } }, + { + finalResponse: { + explanation: assistantResponse, + sql: sql, + }, + expectToolResult: { includes: ["btc_trades"] }, + }, + ], + }) + + flow.intercept() + + cy.getByDataHook("ai-chat-button").click() + cy.getByDataHook("chat-input-textarea").should("be.visible") + cy.getByDataHook("chat-input-textarea").type(flow.question) + cy.getByDataHook("chat-send-button").click() + + flow.waitForCompletion() + + cy.getByDataHook("chat-message-assistant") + .should("be.visible") + .should("contain", assistantResponse) + + cy.getByDataHook("assistant-mode-processing-collapsed").click() + cy.getByDataHook("assistant-mode-reviewing-tables").should("exist") + + cy.getByDataHook("message-action-accept").should("be.visible") + cy.getByDataHook("message-action-accept").click() + + cy.getByDataHook("diff-status-accepted").should("contain", "Accepted") + cy.getByDataHook("chat-context-badge").should( + "contain", + "SELECT * FROM btc_trades", + ) + + cy.dropTableIfExists("btc_trades") + }) + + it("should toggle models, add second provider, remove first, and update model dropdown", () => { + cy.loadConsoleWithAuth( + false, + getCustomProviderConfiguredSettings({ + providerId: "ollama", + name: "Ollama", + models: ["llama3", "mistral", "codellama"], + }), + ) + + cy.getByDataHook("ai-settings-model-dropdown").should("be.visible").click() + cy.getByDataHook("ai-settings-model-item").should("have.length", 3) + cy.contains("llama3").should("be.visible") + cy.contains("mistral").should("be.visible") + cy.contains("codellama").should("be.visible") + cy.getByDataHook("ai-settings-model-dropdown").click() + + cy.getByDataHook("ai-assistant-settings-button").click() + cy.getByDataHook("ai-settings-provider-ollama").should("be.visible").click() + + cy.get("[data-model='llama3']").should("exist") + cy.get("[data-model='mistral']").should("exist") + cy.get("[data-model='codellama']").should("exist") + + cy.get("[data-model='mistral']").find("button[role='switch']").click() + cy.get("[data-model='mistral'][data-enabled='false']").should("exist") + cy.getByDataHook("ai-settings-save").click() + + cy.getByDataHook("ai-settings-model-dropdown").should("be.visible").click() + cy.getByDataHook("ai-settings-model-item").should("have.length", 2) + cy.contains("llama3").should("be.visible") + cy.contains("codellama").should("be.visible") + cy.getByDataHook("ai-settings-model-dropdown").click() + + cy.getByDataHook("ai-assistant-settings-button").click() + cy.getByDataHook("ai-settings-provider-ollama").click() + cy.get("[data-model='mistral'][data-enabled='false']").should("exist") + + cy.get("[data-model='mistral']").find("button[role='switch']").click() + cy.get("[data-model='mistral'][data-enabled='true']").should("exist") + cy.getByDataHook("ai-settings-save").click() + + cy.getByDataHook("ai-settings-model-dropdown").should("be.visible").click() + cy.getByDataHook("ai-settings-model-item").should("have.length", 3) + cy.getByDataHook("ai-settings-model-dropdown").click() + + cy.getByDataHook("ai-assistant-settings-button").click() + cy.getByDataHook("ai-settings-add-custom-provider") + .should("be.visible") + .click() + + cy.getByDataHook("custom-provider-name-input").type("OpenRouter") + cy.getByDataHook("custom-provider-base-url-input").type( + "https://openrouter.ai/api/v1", + ) + + cy.intercept("GET", "https://openrouter.ai/api/v1/models", { + statusCode: 500, + body: { error: "Server error" }, + }) + + cy.getByDataHook("multi-step-modal-next-button").click() + + cy.getByDataHook("custom-provider-warning-banner").should("be.visible") + cy.getByDataHook("custom-provider-manual-model-input").type("gpt-4o") + cy.getByDataHook("custom-provider-add-model-button").click() + cy.getByDataHook("multi-step-modal-next-button").click() + + cy.getByDataHook("ai-settings-provider-ollama").should("be.visible") + cy.contains("[data-hook^='ai-settings-provider-']", "OpenRouter").should( + "be.visible", + ) + cy.getByDataHook("ai-settings-save").click() + + cy.getByDataHook("ai-settings-model-dropdown").should("be.visible").click() + cy.getByDataHook("ai-settings-model-item").should("have.length", 4) + cy.contains("gpt-4o").should("be.visible") + cy.getByDataHook("ai-settings-model-dropdown").click() + + cy.getByDataHook("ai-assistant-settings-button").click() + cy.getByDataHook("ai-settings-provider-ollama").click() + cy.getByDataHook("ai-settings-remove-provider").click() + + cy.getByDataHook("ai-settings-provider-ollama").should("not.exist") + cy.contains("[data-hook^='ai-settings-provider-']", "OpenRouter").should( + "be.visible", + ) + cy.getByDataHook("ai-settings-save").click() + + cy.getByDataHook("ai-settings-model-dropdown").should("be.visible").click() + cy.getByDataHook("ai-settings-model-item").should("have.length", 1) + cy.contains("gpt-4o").should("be.visible") + cy.getByDataHook("ai-settings-model-dropdown").click() + }) + + it("should show error on 401, retry successfully, and show error on network failure", () => { + const customEndpoint = getCustomProviderEndpoint( + CUSTOM_PROVIDER_DEFAULTS.baseURL, + "openai-chat-completions", + ) + + cy.setupCustomProvider() + + cy.intercept("POST", customEndpoint, (req) => { + if (isTitleRequest("openai-chat-completions", req.body)) { + req.reply( + createChatTitleResponse("openai-chat-completions", "Test Chat"), + ) + return + } + req.reply({ + statusCode: 401, + body: { + error: { + type: "authentication_error", + message: "Invalid API key", + }, + }, + }) + }).as("errorRequest") + + cy.getByDataHook("ai-chat-button").click() + cy.getByDataHook("chat-input-textarea").should("be.visible") + cy.getByDataHook("chat-input-textarea").type("Test error handling") + cy.getByDataHook("chat-send-button").click() + + cy.wait("@errorRequest") + cy.getByDataHook("chat-message-error").should("be.visible") + cy.getByDataHook("retry-button").should("be.visible") + + cy.intercept("POST", customEndpoint, (req) => { + if (isTitleRequest("openai-chat-completions", req.body)) { + req.reply( + createChatTitleResponse("openai-chat-completions", "Test Chat"), + ) + return + } + req.reply( + createResponse( + "openai-chat-completions", + createFinalResponseData( + "openai-chat-completions", + "Successful response after retry", + null, + ), + { streaming: true }, + ), + ) + }).as("successRequest") + + cy.getByDataHook("retry-button").click() + + cy.wait("@successRequest") + cy.waitForStreamingComplete() + cy.getByDataHook("chat-message-error").should("not.exist") + cy.getByDataHook("chat-message-assistant") + .should("be.visible") + .should("contain", "Successful response after retry") + + cy.getByDataHook("chat-window-new").click() + + cy.intercept("POST", customEndpoint, (req) => { + if (isTitleRequest("openai-chat-completions", req.body)) { + req.reply( + createChatTitleResponse("openai-chat-completions", "Test Chat"), + ) + return + } + req.destroy() + }).as("networkError") + + cy.getByDataHook("chat-input-textarea").should("be.visible") + cy.getByDataHook("chat-input-textarea").type("Test network error") + cy.getByDataHook("chat-send-button").click() + + cy.wait("@networkError") + cy.getByDataHook("chat-message-error").should("be.visible") + cy.getByDataHook("retry-button").should("be.visible") + }) + + it("should reject duplicate names against custom and built-in providers, and allow unique names", () => { + cy.loadConsoleWithAuth( + false, + getCustomProviderConfiguredSettings({ + providerId: "my-provider", + name: "My Provider", + models: ["test-model"], + }), + ) + + cy.getByDataHook("ai-assistant-settings-button").click() + cy.getByDataHook("ai-settings-add-custom-provider").click() + + // Exact duplicate name + cy.getByDataHook("custom-provider-name-input").type("My Provider") + cy.getByDataHook("custom-provider-base-url-input").type( + "http://localhost:1234", + ) + + cy.intercept("GET", "http://localhost:1234/models", { + statusCode: 500, + body: { error: "not needed" }, + }) + + cy.getByDataHook("multi-step-modal-next-button").click() + cy.contains("A provider with the same name already exists").should( + "be.visible", + ) + + // Case-insensitive duplicate of built-in provider name + cy.getByDataHook("custom-provider-name-input").clear().type("openai") + + cy.intercept("GET", "http://localhost:1234/models", { + statusCode: 500, + body: { error: "not needed" }, + }) + + cy.getByDataHook("multi-step-modal-next-button").click() + cy.contains("A provider with the same name already exists").should( + "be.visible", + ) + + // Unique name should proceed to step 2 + cy.getByDataHook("custom-provider-name-input") + .clear() + .type("My Provider (v2.0)!") + + cy.intercept("GET", "http://localhost:1234/models", { + statusCode: 500, + body: { error: "not needed" }, + }) + + cy.getByDataHook("multi-step-modal-next-button").click() + cy.getByDataHook("custom-provider-warning-banner").should("be.visible") + + cy.getByDataHook("custom-provider-manual-model-input").type("test-model-2") + cy.getByDataHook("custom-provider-add-model-button").click() + cy.getByDataHook("multi-step-modal-next-button").click() + + // Both providers visible in sidebar (find by name text) + cy.contains("My Provider").should("be.visible") + cy.contains("My Provider (v2.0)!").should("be.visible") + + cy.getByDataHook("ai-settings-save").click() + cy.window().then((win) => { + const settings = JSON.parse( + win.localStorage.getItem("ai.assistant.settings"), + ) + // Provider ID should be a UUID + const customIds = Object.keys(settings.customProviders) + const newProviderId = customIds.find((id) => id !== "my-provider") + expect(newProviderId).to.match( + /^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$/, + ) + expect(settings.customProviders[newProviderId].name).to.equal( + "My Provider (v2.0)!", + ) + }) + }) + + it("should route Anthropic-type provider requests to custom base URL", () => { + const anthropicBaseURL = "http://localhost:8080" + + cy.loadConsoleWithAuth( + false, + getCustomProviderConfiguredSettings({ + providerId: "custom-anthropic", + name: "Custom Anthropic", + type: "anthropic", + baseURL: anthropicBaseURL, + apiKey: "test-anthropic-key", + models: ["claude-custom"], + }), + ) + + // Anthropic SDK appends /v1/messages to baseURL + cy.intercept("POST", "http://localhost:8080/v1/messages", (req) => { + if (isTitleRequest("anthropic", req.body)) { + req.reply(createChatTitleResponse("anthropic", "Test Chat")) + return + } + req.reply( + createResponse( + "anthropic", + createFinalResponseData( + "anthropic", + "Response from custom Anthropic provider", + null, + ), + { streaming: true }, + ), + ) + }).as("anthropicRequest") + + cy.intercept("POST", "https://api.anthropic.com/**", () => { + throw new Error( + "Request should not go to api.anthropic.com for custom provider", + ) + }) + + cy.getByDataHook("ai-chat-button").click() + cy.getByDataHook("chat-input-textarea").should("be.visible") + cy.getByDataHook("chat-input-textarea").type("Test Anthropic custom") + cy.getByDataHook("chat-send-button").click() + + cy.wait("@anthropicRequest") + cy.waitForStreamingComplete() + + cy.getByDataHook("chat-message-assistant") + .should("be.visible") + .should("contain", "Response from custom Anthropic provider") + }) + + it("should route requests to correct endpoint when switching between built-in and custom models", () => { + const customBaseURL = CUSTOM_PROVIDER_DEFAULTS.baseURL + const customEndpoint = getCustomProviderEndpoint( + customBaseURL, + "openai-chat-completions", + ) + + const openaiSettings = getOpenAIConfiguredSettings() + cy.loadConsoleWithAuth( + false, + getCustomProviderConfiguredSettings( + { + providerId: "ollama", + name: "Ollama", + models: ["llama3"], + }, + openaiSettings, + ), + ) + + cy.intercept("POST", customEndpoint, (req) => { + if (isTitleRequest("openai-chat-completions", req.body)) { + req.reply( + createChatTitleResponse("openai-chat-completions", "Ollama Chat"), + ) + return + } + req.reply( + createResponse( + "openai-chat-completions", + createFinalResponseData( + "openai-chat-completions", + "Response from Ollama", + null, + ), + { streaming: true }, + ), + ) + }).as("ollamaRequest") + + cy.intercept("POST", PROVIDERS.openai.endpoint, (req) => { + if (isTitleRequest("openai", req.body)) { + req.reply(createChatTitleResponse("openai", "OpenAI Chat")) + return + } + req.reply( + createResponse( + "openai", + createFinalResponseData("openai", "Response from OpenAI", null), + { streaming: true }, + ), + ) + }).as("openaiRequest") + + cy.getByDataHook("ai-chat-button").click() + cy.getByDataHook("chat-input-textarea").should("be.visible") + cy.getByDataHook("chat-input-textarea").type("Test with Ollama") + cy.getByDataHook("chat-send-button").click() + + cy.wait("@ollamaRequest") + cy.waitForStreamingComplete() + cy.getByDataHook("chat-message-assistant") + .should("be.visible") + .should("contain", "Response from Ollama") + + cy.getByDataHook("ai-settings-model-dropdown").click() + cy.getByDataHook("ai-settings-model-item-label") + .contains("GPT-5 mini") + .click() + + cy.getByDataHook("chat-window-new").click() + cy.getByDataHook("chat-input-textarea").should("be.visible") + cy.getByDataHook("chat-input-textarea").type("Test with OpenAI") + cy.getByDataHook("chat-send-button").click() + + cy.wait("@openaiRequest") + cy.waitForStreamingComplete() + cy.getByDataHook("chat-message-assistant") + .should("be.visible") + .should("contain", "Response from OpenAI") + + cy.getByDataHook("ai-assistant-settings-button").click() + cy.getByDataHook("ai-settings-provider-openai").should("be.visible") + cy.getByDataHook("ai-settings-provider-ollama").should("be.visible") + cy.getByDataHook("ai-settings-cancel").click() + }) + + it("should reset fields on cancel, preserve them on back, and add model via Enter key", () => { + cy.loadConsoleWithAuth() + + cy.getByDataHook("ai-assistant-settings-button").click() + cy.getByDataHook("ai-promo-continue").click() + cy.getByDataHook("ai-settings-provider-custom").click() + + cy.getByDataHook("custom-provider-name-input").type("Partial") + cy.getByDataHook("custom-provider-base-url-input").type( + "http://localhost:1234", + ) + + cy.getByDataHook("multi-step-modal-cancel-button").click() + + cy.getByDataHook("ai-settings-provider-custom").click() + cy.getByDataHook("custom-provider-name-input").should("have.value", "") + cy.getByDataHook("custom-provider-base-url-input").should("have.value", "") + + cy.getByDataHook("custom-provider-name-input").type("Test Provider") + cy.getByDataHook("custom-provider-base-url-input").type( + "http://localhost:5555", + ) + + cy.intercept("GET", "http://localhost:5555/models", { + statusCode: 500, + body: { error: "fail" }, + }) + + cy.getByDataHook("multi-step-modal-next-button").click() + + cy.getByDataHook("custom-provider-warning-banner").should("be.visible") + cy.getByDataHook("custom-provider-context-window-input").should( + "have.value", + "200000", + ) + cy.getByDataHook("custom-provider-schema-access").should("be.checked") + cy.getByDataHook("custom-provider-add-model-button").should("be.disabled") + + cy.getByDataHook("custom-provider-manual-model-input").type( + "enter-model{enter}", + ) + cy.getByDataHook("custom-provider-model-chip").should("have.length", 1) + cy.getByDataHook("custom-provider-manual-model-input").should( + "have.value", + "", + ) + + // Back button preserves step 1 fields + cy.getByDataHook("multi-step-modal-cancel-button").click() + cy.getByDataHook("custom-provider-name-input").should( + "have.value", + "Test Provider", + ) + cy.getByDataHook("custom-provider-base-url-input").should( + "have.value", + "http://localhost:5555", + ) + + cy.intercept("GET", "http://localhost:5555/models", { + statusCode: 500, + body: { error: "fail" }, + }) + cy.getByDataHook("multi-step-modal-next-button").click() + cy.getByDataHook("custom-provider-warning-banner").should("be.visible") + }) + + it("should preserve custom provider settings and chat after page reload", () => { + const customEndpoint = getCustomProviderEndpoint( + CUSTOM_PROVIDER_DEFAULTS.baseURL, + "openai-chat-completions", + ) + const settings = getCustomProviderConfiguredSettings() + + cy.loadConsoleWithAuth(false, settings) + cy.getByDataHook("ai-chat-button").should("be.visible") + + cy.loadConsoleWithAuth(false, settings) + cy.getByDataHook("ai-chat-button").should("be.visible") + + cy.intercept("POST", customEndpoint, (req) => { + if (isTitleRequest("openai-chat-completions", req.body)) { + req.reply( + createChatTitleResponse("openai-chat-completions", "Test Chat"), + ) + return + } + req.reply( + createResponse( + "openai-chat-completions", + createFinalResponseData( + "openai-chat-completions", + "Working after reload", + null, + ), + { streaming: true }, + ), + ) + }).as("chatRequest") + + cy.getByDataHook("ai-chat-button").click() + cy.getByDataHook("chat-input-textarea").should("be.visible") + cy.getByDataHook("chat-input-textarea").type("Test after reload") + cy.getByDataHook("chat-send-button").click() + + cy.wait("@chatRequest") + cy.waitForStreamingComplete() + cy.getByDataHook("chat-message-assistant") + .should("be.visible") + .should("contain", "Working after reload") + }) + + it("should omit auth token without API key and send Bearer token when API key is configured", () => { + const customEndpoint = getCustomProviderEndpoint( + CUSTOM_PROVIDER_DEFAULTS.baseURL, + "openai-chat-completions", + ) + + cy.loadConsoleWithAuth( + false, + getCustomProviderConfiguredSettings({ + apiKey: "", + }), + ) + + cy.intercept("POST", customEndpoint, (req) => { + if (isTitleRequest("openai-chat-completions", req.body)) { + req.reply( + createChatTitleResponse("openai-chat-completions", "Test Chat"), + ) + return + } + const auth = req.headers["authorization"] || "" + expect(auth).to.not.include("sk-") + req.reply( + createResponse( + "openai-chat-completions", + createFinalResponseData( + "openai-chat-completions", + "No auth response", + null, + ), + { streaming: true }, + ), + ) + }).as("noAuthRequest") + + cy.getByDataHook("ai-chat-button").click() + cy.getByDataHook("chat-input-textarea").should("be.visible") + cy.getByDataHook("chat-input-textarea").type("Test no auth") + cy.getByDataHook("chat-send-button").click() + + cy.wait("@noAuthRequest") + cy.waitForStreamingComplete() + cy.getByDataHook("chat-message-assistant") + .should("be.visible") + .should("contain", "No auth response") + + cy.loadConsoleWithAuth( + false, + getCustomProviderConfiguredSettings({ + apiKey: "sk-test-key-123", + }), + ) + + cy.intercept("POST", customEndpoint, (req) => { + if (isTitleRequest("openai-chat-completions", req.body)) { + req.reply( + createChatTitleResponse("openai-chat-completions", "Test Chat"), + ) + return + } + expect(req.headers["authorization"]).to.equal("Bearer sk-test-key-123") + req.reply( + createResponse( + "openai-chat-completions", + createFinalResponseData( + "openai-chat-completions", + "Auth response", + null, + ), + { streaming: true }, + ), + ) + }).as("authRequest") + + cy.getByDataHook("ai-chat-button").click() + cy.getByDataHook("chat-input-textarea").should("be.visible") + cy.getByDataHook("chat-input-textarea").type("Test with auth") + cy.getByDataHook("chat-send-button").click() + + cy.wait("@authRequest") + cy.waitForStreamingComplete() + cy.getByDataHook("chat-message-assistant") + .should("be.visible") + .should("contain", "Auth response") + }) + + it("should open manage models, add and remove models, update context window, and reflect in dropdown", () => { + const providerId = "ollama" + + cy.loadConsoleWithAuth( + false, + getCustomProviderConfiguredSettings({ + providerId, + name: "Ollama", + models: ["llama3", "mistral", "codellama"], + }), + ) + + // Intercept model fetch before opening modal + cy.intercept("GET", "**/models*", { + statusCode: 200, + body: { + object: "list", + data: [ + { id: "llama3", object: "model" }, + { id: "mistral", object: "model" }, + { id: "codellama", object: "model" }, + ], + }, + }).as("manageModelsFetch") + + cy.getByDataHook("ai-assistant-settings-button").click() + cy.getByDataHook("ai-settings-provider-ollama").should("be.visible").click() + + cy.getByDataHook("ai-settings-manage-models").should("be.visible").click() + cy.wait("@manageModelsFetch") + + // All 3 models should be checked + cy.getByDataHook("custom-provider-model-row").should("have.length", 3) + cy.getByDataHook("custom-provider-model-row") + .find('input[type="checkbox"]') + .each(($checkbox) => { + cy.wrap($checkbox).should("be.checked") + }) + + // Add a manual model + cy.getByDataHook("custom-provider-manual-model-input").type( + "custom-finetune", + ) + cy.getByDataHook("custom-provider-add-model-button").click() + cy.getByDataHook("custom-provider-model-chip").should("have.length", 1) + cy.getByDataHook("custom-provider-model-chip").should( + "contain", + "custom-finetune", + ) + + // Uncheck codellama + cy.getByDataHook("custom-provider-model-row").contains("codellama").click() + + // Validation: context window too low + cy.getByDataHook("custom-provider-context-window-input") + .clear() + .type("50000") + cy.getByDataHook("manage-models-save").click() + cy.contains("Context window must be at least 100,000 tokens").should( + "be.visible", + ) + + // Fix context window and save + cy.getByDataHook("custom-provider-context-window-input") + .clear() + .type("150000") + cy.getByDataHook("custom-provider-context-window-input").should( + "have.value", + "150000", + ) + cy.getByDataHook("manage-models-save").click() + + // Modal should close, settings modal visible again + cy.getByDataHook("ai-settings-manage-models").should("be.visible") + + // Save the outer settings modal (manage-models toast auto-dismisses) + cy.getByDataHook("ai-settings-save").click() + cy.get(".toast-success-container").should("be.visible").first().click() + + // Dropdown should show 3 models (llama3, mistral, custom-finetune) + cy.getByDataHook("ai-settings-model-dropdown").should("be.visible").click() + cy.getByDataHook("ai-settings-model-item").should("have.length", 3) + cy.contains("llama3").should("be.visible") + cy.contains("mistral").should("be.visible") + cy.contains("custom-finetune").should("be.visible") + cy.getByDataHook("ai-settings-model-dropdown").click() + }) + + it("should auto-enable new models from manage models and preserve unsaved toggle state", () => { + const providerId = "test-provider" + + cy.loadConsoleWithAuth( + false, + getCustomProviderConfiguredSettings({ + providerId, + name: "Test Provider", + models: ["model-a", "model-b", "model-c"], + }), + ) + + cy.getByDataHook("ai-assistant-settings-button").click() + cy.getByDataHook("ai-settings-provider-test-provider") + .should("be.visible") + .click() + + // All 3 models should be enabled + cy.get("[data-model='model-a'][data-enabled='true']").should("exist") + cy.get("[data-model='model-b'][data-enabled='true']").should("exist") + cy.get("[data-model='model-c'][data-enabled='true']").should("exist") + + // Disable model-b toggle (unsaved state) + cy.get("[data-model='model-b']").find("button[role='switch']").click() + cy.get("[data-model='model-b'][data-enabled='false']").should("exist") + + // Intercept model fetch → fail to get manual mode + cy.intercept("GET", "**/models*", { + statusCode: 500, + body: { error: "Server error" }, + }).as("manageModelsFetchFail") + + // Open manage models + cy.getByDataHook("ai-settings-manage-models").click() + cy.wait("@manageModelsFetchFail") + + // Manual mode: warning banner + existing models as chips + cy.getByDataHook("custom-provider-warning-banner").should("be.visible") + cy.getByDataHook("custom-provider-model-chip").should("have.length", 3) + + // Add model-d + cy.getByDataHook("custom-provider-manual-model-input").type("model-d") + cy.getByDataHook("custom-provider-add-model-button").click() + cy.getByDataHook("custom-provider-model-chip").should("have.length", 4) + + // Remove model-b + cy.getByDataHook("custom-provider-model-chip") + .filter(":contains('model-b')") + .find("[data-hook='custom-provider-remove-model']") + .click() + cy.getByDataHook("custom-provider-model-chip").should("have.length", 3) + + // Save manage models + cy.getByDataHook("manage-models-save").click() + + // Back in SettingsModal: model-b gone, model-d auto-enabled + cy.get("[data-model='model-a'][data-enabled='true']").should("exist") + cy.get("[data-model='model-b']").should("not.exist") + cy.get("[data-model='model-c'][data-enabled='true']").should("exist") + cy.get("[data-model='model-d'][data-enabled='true']").should("exist") + + // Save settings + cy.getByDataHook("ai-settings-save").click() + cy.get(".toast-success-container").should("be.visible").first().click() + + // Dropdown should show 3 models (model-a, model-c, model-d) + cy.getByDataHook("ai-settings-model-dropdown").should("be.visible").click() + cy.getByDataHook("ai-settings-model-item").should("have.length", 3) + cy.contains("model-a").should("be.visible") + cy.contains("model-c").should("be.visible") + cy.contains("model-d").should("be.visible") + cy.getByDataHook("ai-settings-model-dropdown").click() + }) + + it("should handle no-API-key custom provider: models visible, no validated badge, schema toggle enabled, and allow adding an API key", () => { + const providerId = "ollama" + const customEndpoint = getCustomProviderEndpoint( + CUSTOM_PROVIDER_DEFAULTS.baseURL, + "openai-chat-completions", + ) + + cy.loadConsoleWithAuth( + false, + getCustomProviderConfiguredSettings({ + providerId, + name: "Ollama", + baseURL: CUSTOM_PROVIDER_DEFAULTS.baseURL, + apiKey: "", + models: ["llama3", "mistral"], + }), + ) + + cy.getByDataHook("ai-assistant-settings-button").click() + cy.getByDataHook("ai-settings-provider-ollama").should("be.visible").click() + + // Part A: No-API-key state + + // No validated badge + cy.getByDataHook("ai-settings-validated-badge").should("not.exist") + + // API key input shows placeholder about no key + cy.getByDataHook("ai-settings-api-key").should( + "have.attr", + "placeholder", + "This provider does not have an API key", + ) + + // Model list visible with both models + cy.get("[data-model='llama3']").should("exist") + cy.get("[data-model='mistral']").should("exist") + + // Schema access toggle is not disabled + cy.getByDataHook("ai-settings-schema-access").should("not.be.disabled") + + // Manage models button visible + cy.getByDataHook("ai-settings-manage-models").should("be.visible") + + // Toggle mistral off + cy.get("[data-model='mistral']").find("button[role='switch']").click() + cy.get("[data-model='mistral'][data-enabled='false']").should("exist") + + // Built-in provider should NOT have manage models button + cy.getByDataHook("ai-settings-provider-openai").click() + cy.getByDataHook("ai-settings-manage-models").should("not.exist") + + // Part B: Add API key + + // Switch back to custom provider + cy.getByDataHook("ai-settings-provider-ollama").click() + + // Click Edit button to make input editable, then type API key + cy.get('button[title="Edit API key"]').click() + cy.getByDataHook("ai-settings-api-key").type("sk-custom-key-123") + + // Intercept validation request to custom endpoint + cy.intercept("POST", customEndpoint, { + statusCode: 200, + delay: 200, + body: { + id: "chatcmpl-mock", + object: "chat.completion", + choices: [ + { + index: 0, + message: { role: "assistant", content: "" }, + finish_reason: "stop", + }, + ], + usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 }, + }, + }).as("customValidation") + + // Click validate + cy.getByDataHook("ai-settings-test-api").should("be.visible").click() + cy.wait("@customValidation") + + // Validated badge should now appear + cy.getByDataHook("ai-settings-validated-badge").should("be.visible") + + // Models still visible, mistral toggle preserved + cy.get("[data-model='llama3'][data-enabled='true']").should("exist") + cy.get("[data-model='mistral'][data-enabled='false']").should("exist") + + // Part C: Save and verify + + cy.getByDataHook("ai-settings-save").click() + cy.get(".toast-success-container").should("be.visible").click() + + // Dropdown should show only llama3 (mistral was disabled) + cy.getByDataHook("ai-settings-model-dropdown").should("be.visible").click() + cy.getByDataHook("ai-settings-model-item").should("have.length", 1) + cy.contains("llama3").should("be.visible") + cy.getByDataHook("ai-settings-model-dropdown").click() + }) +}) diff --git a/e2e/utils/aiAssistant.js b/e2e/utils/aiAssistant.js index 2113f03dc..30e02efbe 100644 --- a/e2e/utils/aiAssistant.js +++ b/e2e/utils/aiAssistant.js @@ -9,6 +9,16 @@ const PROVIDERS = { }, } +const CUSTOM_PROVIDER_DEFAULTS = { + providerId: "test-provider", + name: "Test Provider", + type: "openai-chat-completions", + baseURL: "http://localhost:11434/v1", + models: ["test-model-1"], + contextWindow: 200000, + grantSchemaAccess: true, +} + function getOpenAIConfiguredSettings(schemaAccess = true) { return { "ai.assistant.settings": JSON.stringify({ @@ -39,6 +49,76 @@ function getAnthropicConfiguredSettings(schemaAccess = true) { } } +/** + * Returns localStorage settings for a pre-configured custom provider. + * Can optionally merge with existing settings (e.g., a built-in provider). + */ +function getCustomProviderConfiguredSettings(config = {}, mergeWith = null) { + const { + providerId = CUSTOM_PROVIDER_DEFAULTS.providerId, + name = CUSTOM_PROVIDER_DEFAULTS.name, + type = CUSTOM_PROVIDER_DEFAULTS.type, + baseURL = CUSTOM_PROVIDER_DEFAULTS.baseURL, + apiKey = "", + models = CUSTOM_PROVIDER_DEFAULTS.models, + contextWindow = CUSTOM_PROVIDER_DEFAULTS.contextWindow, + grantSchemaAccess = CUSTOM_PROVIDER_DEFAULTS.grantSchemaAccess, + } = config + + const enabledModels = models.map((m) => `${providerId}:${m}`) + + const baseSettings = mergeWith + ? JSON.parse(mergeWith["ai.assistant.settings"]) + : {} + + const settings = { + ...baseSettings, + selectedModel: enabledModels[0], + customProviders: { + ...(baseSettings.customProviders || {}), + [providerId]: { + type, + name, + baseURL, + ...(apiKey ? { apiKey } : {}), + contextWindow, + models, + grantSchemaAccess, + }, + }, + providers: { + ...(baseSettings.providers || {}), + [providerId]: { + apiKey: apiKey || "", + enabledModels, + grantSchemaAccess, + }, + }, + } + + return { + "ai.assistant.settings": JSON.stringify(settings), + } +} + +/** + * Returns the API endpoint for a custom provider based on its type. + */ +function getCustomProviderEndpoint(baseURL, type) { + if (type === "openai-chat-completions") { + return `${baseURL}/chat/completions` + } + if (type === "openai") { + return `${baseURL}/responses` + } + // anthropic - SDK appends /v1/messages to baseURL + return `${baseURL}/v1/messages` +} + +// ============================================================================= +// RESPONSE DATA BUILDERS +// ============================================================================= + function createFinalResponseData(provider, explanation, sql = null) { const responseContent = { explanation, sql } @@ -62,6 +142,26 @@ function createFinalResponseData(provider, explanation, sql = null) { } } + if (provider === "openai-chat-completions") { + return { + id: "chatcmpl-mock-final", + object: "chat.completion", + created: Math.floor(Date.now() / 1000), + model: "test-model-1", + choices: [ + { + index: 0, + message: { + role: "assistant", + content: JSON.stringify(responseContent), + }, + finish_reason: "stop", + }, + ], + usage: { prompt_tokens: 200, completion_tokens: 100, total_tokens: 300 }, + } + } + // Anthropic return { id: "msg_mock_final", @@ -97,6 +197,36 @@ function createToolCallResponseData(provider, toolName, toolArguments = {}) { } } + if (provider === "openai-chat-completions") { + return { + id: `chatcmpl-mock-tool-${toolName}`, + object: "chat.completion", + created: Math.floor(Date.now() / 1000), + model: "test-model-1", + choices: [ + { + index: 0, + message: { + role: "assistant", + content: null, + tool_calls: [ + { + id: callId, + type: "function", + function: { + name: toolName, + arguments: JSON.stringify(toolArguments), + }, + }, + ], + }, + finish_reason: "tool_calls", + }, + ], + usage: { prompt_tokens: 100, completion_tokens: 50, total_tokens: 150 }, + } + } + // Anthropic return { id: `msg_mock_tool_${toolName}`, @@ -137,6 +267,23 @@ function createChatTitleResponseData(provider, title = "Test Chat") { } } + if (provider === "openai-chat-completions") { + return { + id: "chatcmpl-mock-title", + object: "chat.completion", + created: Math.floor(Date.now() / 1000), + model: "test-model-1", + choices: [ + { + index: 0, + message: { role: "assistant", content: content }, + finish_reason: "stop", + }, + ], + usage: { prompt_tokens: 50, completion_tokens: 20, total_tokens: 70 }, + } + } + // Anthropic return { id: "msg_mock_title", @@ -149,6 +296,10 @@ function createChatTitleResponseData(provider, title = "Test Chat") { } } +// ============================================================================= +// SSE RESPONSE BUILDERS +// ============================================================================= + function createOpenAISSEResponse(responseData, delay = 0) { const events = [] @@ -189,6 +340,121 @@ function createOpenAISSEResponse(responseData, delay = 0) { return response } +function createChatCompletionsSSEResponse(responseData, delay = 0) { + const events = [] + const choice = responseData.choices?.[0] + const content = choice?.message?.content || "" + const toolCalls = choice?.message?.tool_calls || [] + + // Stream content deltas + if (content) { + const chunkSize = 20 + for (let i = 0; i < content.length; i += chunkSize) { + const chunk = content.slice(i, i + chunkSize) + events.push( + `data: ${JSON.stringify({ + id: responseData.id, + object: "chat.completion.chunk", + choices: [ + { + index: 0, + delta: { content: chunk }, + finish_reason: null, + }, + ], + })}\n\n`, + ) + } + } + + // Stream tool call deltas + if (toolCalls.length > 0) { + for (const tc of toolCalls) { + // First chunk: tool call start + events.push( + `data: ${JSON.stringify({ + id: responseData.id, + object: "chat.completion.chunk", + choices: [ + { + index: 0, + delta: { + tool_calls: [ + { + index: 0, + id: tc.id, + type: "function", + function: { name: tc.function.name, arguments: "" }, + }, + ], + }, + finish_reason: null, + }, + ], + })}\n\n`, + ) + // Second chunk: tool call arguments + events.push( + `data: ${JSON.stringify({ + id: responseData.id, + object: "chat.completion.chunk", + choices: [ + { + index: 0, + delta: { + tool_calls: [ + { + index: 0, + function: { arguments: tc.function.arguments }, + }, + ], + }, + finish_reason: null, + }, + ], + })}\n\n`, + ) + } + } + + // Final chunk with finish_reason and usage + events.push( + `data: ${JSON.stringify({ + id: responseData.id, + object: "chat.completion.chunk", + choices: [ + { + index: 0, + delta: {}, + finish_reason: choice?.finish_reason || "stop", + }, + ], + usage: responseData.usage, + })}\n\n`, + ) + + // [DONE] marker + events.push("data: [DONE]\n\n") + + const sseBody = events.join("") + + const response = { + statusCode: 200, + headers: { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + Connection: "keep-alive", + }, + body: sseBody, + } + + if (delay > 0) { + response.delay = delay + } + + return response +} + function createAnthropicSSEResponse(responseData, delay = 0) { const events = [] @@ -327,6 +593,10 @@ function createAnthropicSSEResponse(responseData, delay = 0) { return response } +// ============================================================================= +// RESPONSE WRAPPERS +// ============================================================================= + function createResponse(provider, responseData, options = {}) { const { streaming = true, delay = 0 } = options @@ -347,6 +617,10 @@ function createResponse(provider, responseData, options = {}) { return createOpenAISSEResponse(responseData, delay) } + if (provider === "openai-chat-completions") { + return createChatCompletionsSSEResponse(responseData, delay) + } + return createAnthropicSSEResponse(responseData, delay) } @@ -378,6 +652,10 @@ function createChatTitleResponse(provider, title = "Test Chat") { }) } +// ============================================================================= +// REQUEST INSPECTION HELPERS +// ============================================================================= + function isTitleRequest(provider, body) { if (provider === "openai") { return ( @@ -385,7 +663,7 @@ function isTitleRequest(provider, body) { false ) } - // Anthropic + // Both openai-chat-completions and anthropic use messages array return ( body.messages?.[0]?.content?.includes?.("Generate a concise chat title") || false @@ -396,8 +674,7 @@ function requestMatchesQuestion(provider, body, question) { if (provider === "openai") { return body.input?.[0]?.content === question } - // Anthropic - find the first user message with string content (the original question) - // Subsequent messages may be tool results or assistant responses + // Both openai-chat-completions and anthropic use messages array const firstUserMessage = body.messages?.find( (msg) => msg.role === "user" && typeof msg.content === "string", ) @@ -413,8 +690,14 @@ function extractToolOutputContent(provider, body) { return latestOutput?.output || null } + if (provider === "openai-chat-completions") { + // Chat Completions: tool results are in messages with role "tool" + const toolMessages = body.messages?.filter((msg) => msg.role === "tool") + const latestToolMessage = toolMessages?.[toolMessages.length - 1] + return latestToolMessage?.content || null + } + // Anthropic - tool results are in user messages with content array containing tool_result objects - // Format: { role: "user", content: [{ type: "tool_result", tool_use_id: "...", content: "..." }] } const toolResultMessages = body.messages?.filter( (msg) => msg.role === "user" && @@ -425,7 +708,6 @@ function extractToolOutputContent(provider, body) { const latestToolResult = latestMessage?.content?.find( (c) => c.type === "tool_result", ) - // Anthropic tool result content can be a string directly return latestToolResult?.content || null } @@ -433,7 +715,7 @@ function extractAllInputContent(provider, body) { if (provider === "openai") { return body.input?.map((item) => item.content || "").join("\n") || "" } - // Anthropic + // Both openai-chat-completions and anthropic use messages return ( body.messages ?.map((msg) => { @@ -452,57 +734,22 @@ function extractAllInputContent(provider, body) { // ============================================================================= /** - * Creates a multi-turn tool call flow with automatic intercept handling - * - * @param {Object} config - Flow configuration - * @param {"openai" | "anthropic"} [config.provider="openai"] - The AI provider - * @param {boolean} [config.streaming=true] - Whether to use streaming responses - * @param {string} config.question - The user's question to match - * @param {Array} config.steps - Array of step definitions - * @param {Object} [config.steps[].toolCall] - Tool call definition { name, args } - * @param {Object} [config.steps[].expectToolResult] - Expected result { includes: string[] } - * @param {Object} [config.steps[].finalResponse] - Final response { explanation, sql } - * @returns {Object} Flow controller with intercept() and waitForCompletion() methods - * - * @example - * // OpenAI with streaming (default) - * const flow = createToolCallFlow({ - * question: "Describe the ecommerce_stats table", - * steps: [ - * { toolCall: { name: "get_tables", args: {} } }, - * { finalResponse: { explanation: "Table description...", sql: null } } - * ] - * }) + * Creates a multi-turn tool call flow with automatic intercept handling. + * Supports built-in providers (openai, anthropic) and custom providers. * - * @example - * // Anthropic with streaming - * const flow = createToolCallFlow({ - * provider: "anthropic", - * streaming: true, - * question: "What tables exist?", - * steps: [ - * { toolCall: { name: "get_tables", args: {} } }, - * { finalResponse: { explanation: "Found tables...", sql: null } } - * ] - * }) - * - * @example - * // OpenAI without streaming - * const flow = createToolCallFlow({ - * provider: "openai", - * streaming: false, - * question: "Quick test", - * steps: [ - * { finalResponse: { explanation: "Done", sql: null } } - * ] - * }) + * @param {Object} config + * @param {"openai" | "anthropic" | "openai-chat-completions"} [config.provider="openai"] + * @param {boolean} [config.streaming=true] + * @param {string} config.question + * @param {Array} config.steps + * @param {string} [config.endpoint] - Custom endpoint URL (overrides PROVIDERS lookup) */ function createToolCallFlow(config) { const { provider = "openai", streaming = true, question, steps } = config let requestCount = 0 const totalRequests = steps.length - const endpoint = PROVIDERS[provider].endpoint + const endpoint = config.endpoint || PROVIDERS[provider]?.endpoint const responseOptions = { streaming } return { @@ -510,9 +757,6 @@ function createToolCallFlow(config) { provider, streaming, - /** - * Sets up cy.intercept for both chat title and tool call flow - */ intercept() { // Handle chat title generation (never streamed) cy.intercept("POST", endpoint, (req) => { @@ -528,7 +772,6 @@ function createToolCallFlow(config) { return } - // Check if this request matches our question if (!requestMatchesQuestion(provider, req.body, question)) { return } @@ -571,14 +814,10 @@ function createToolCallFlow(config) { }).as("toolCallRequest") }, - /** - * Waits for all tool call requests to complete and streaming to finish - */ waitForCompletion() { for (let i = 0; i < totalRequests; i++) { cy.wait("@toolCallRequest") } - // Wait for streaming to finish if streaming is enabled if (streaming) { cy.waitForStreamingComplete() } @@ -591,61 +830,27 @@ function createToolCallFlow(config) { // ============================================================================= /** - * Creates a multi-turn conversation flow for testing multiple questions/responses - * in the same chat session. - * - * @param {Object} config - Flow configuration - * @param {"openai" | "anthropic"} [config.provider="openai"] - The AI provider - * @param {boolean} [config.streaming=true] - Whether to use streaming responses - * @param {Array} config.turns - Array of turn definitions - * @param {string} config.turns[].explanation - The AI's explanation text - * @param {string|null} config.turns[].sql - Optional SQL query suggestion - * @param {Object} [config.turns[].expectSystemMessage] - Expected content in system message - * @param {string[]} [config.turns[].expectSystemMessage.includes] - Strings that must appear - * @param {string[]} [config.turns[].expectSystemMessage.excludes] - Strings that must NOT appear - * @returns {Object} Flow controller with intercept(), waitForTurn(), and getRequestBody() methods - * - * @example - * // OpenAI streaming (default) - * const flow = createMultiTurnFlow({ - * turns: [ - * { explanation: "First query.", sql: "SELECT 1;" }, - * { explanation: "Second query.", sql: "SELECT 2;" } - * ] - * }) + * Creates a multi-turn conversation flow. + * Supports built-in providers and custom providers. * - * @example - * // Anthropic streaming - * const flow = createMultiTurnFlow({ - * provider: "anthropic", - * turns: [ - * { explanation: "First response.", sql: null }, - * { - * explanation: "Second response.", - * sql: "SELECT * FROM users;", - * expectSystemMessage: { - * includes: ["User accepted the suggested SQL"], - * excludes: ["User rejected"] - * } - * } - * ] - * }) + * @param {Object} config + * @param {"openai" | "anthropic" | "openai-chat-completions"} [config.provider="openai"] + * @param {boolean} [config.streaming=true] + * @param {Array} config.turns + * @param {string} [config.endpoint] - Custom endpoint URL */ function createMultiTurnFlow(config) { const { provider = "openai", streaming = true, turns } = config let requestCount = 0 const requestBodies = [] - const endpoint = PROVIDERS[provider].endpoint + const endpoint = config.endpoint || PROVIDERS[provider]?.endpoint const responseOptions = { streaming } return { provider, streaming, - /** - * Sets up cy.intercept for chat title and all conversation turns - */ intercept() { // Intercept for chat title generation cy.intercept("POST", endpoint, (req) => { @@ -656,22 +861,17 @@ function createMultiTurnFlow(config) { // Intercept for conversation turns cy.intercept("POST", endpoint, (req) => { - // Skip title requests if (isTitleRequest(provider, req.body)) { return } - // Handle conversation turns const turn = turns[requestCount] if (turn) { - // Store the request body for later assertions requestBodies[requestCount] = req.body - // Verify system message expectations if defined if (turn.expectSystemMessage) { const allInputContent = extractAllInputContent(provider, req.body) - // Check includes if (turn.expectSystemMessage.includes) { for (const expected of turn.expectSystemMessage.includes) { expect(allInputContent).to.include( @@ -681,7 +881,6 @@ function createMultiTurnFlow(config) { } } - // Check excludes if (turn.expectSystemMessage.excludes) { for (const excluded of turn.expectSystemMessage.excludes) { expect(allInputContent).to.not.include( @@ -705,11 +904,6 @@ function createMultiTurnFlow(config) { }).as("multiTurnRequest") }, - /** - * Waits for a specific turn to complete and streaming to finish - * @param {number} turnIndex - The turn index (0-based) - * @returns {Cypress.Chainable} Chainable that resolves when the turn is complete - */ waitForTurn(turnIndex) { return cy .wrap(null) @@ -729,10 +923,6 @@ function createMultiTurnFlow(config) { }) }, - /** - * Waits for all turns to complete - * @returns {Cypress.Chainable} Chainable that yields all request bodies - */ waitForAllTurns() { return cy .wrap(null) @@ -750,20 +940,10 @@ function createMultiTurnFlow(config) { }) }, - /** - * Gets the captured request body for a specific turn. - * Must be called inside cy.then() after waitForTurn() - * @param {number} turnIndex - The turn index (0-based) - * @returns {Object} The request body sent for that turn - */ getRequestBody(turnIndex) { return requestBodies[turnIndex] }, - /** - * Gets all captured request bodies - * @returns {Array} Array of request bodies - */ getAllRequestBodies() { return requestBodies }, @@ -772,11 +952,15 @@ function createMultiTurnFlow(config) { module.exports = { PROVIDERS, + CUSTOM_PROVIDER_DEFAULTS, getOpenAIConfiguredSettings, getAnthropicConfiguredSettings, + getCustomProviderConfiguredSettings, + getCustomProviderEndpoint, createFinalResponseData, createResponse, createFinalResponse, + createToolCallResponse, createChatTitleResponse, createToolCallFlow, createMultiTurnFlow, diff --git a/package.json b/package.json index 778044fb1..6e2187973 100644 --- a/package.json +++ b/package.json @@ -31,7 +31,7 @@ "prepare": "husky" }, "dependencies": { - "@anthropic-ai/sdk": "^0.71.2", + "@anthropic-ai/sdk": "^0.78.0", "@date-fns/tz": "^1.2.0", "@docsearch/css": "^3.5.2", "@docsearch/react": "^3.5.2", @@ -79,6 +79,7 @@ "js-base64": "^3.7.7", "js-sha256": "^0.11.0", "js-tiktoken": "^1.0.21", + "jsonrepair": "^3.13.3", "lodash.isequal": "^4.5.0", "lodash.merge": "^4.6.2", "monaco-editor": "^0.52.2", diff --git a/src/components/AIStatusIndicator/index.tsx b/src/components/AIStatusIndicator/index.tsx index 646680160..bd26a25c3 100644 --- a/src/components/AIStatusIndicator/index.tsx +++ b/src/components/AIStatusIndicator/index.tsx @@ -15,7 +15,7 @@ import { color } from "../../utils" import { slideAnimation } from "../Animation" import { AISparkle } from "../AISparkle" import { pinkLinearGradientHorizontal } from "../../theme" -import { MODEL_OPTIONS } from "../../utils/aiAssistantSettings" +import { getAllModelOptions } from "../../utils/ai" import { useAIConversation } from "../../providers/AIConversationProvider" import { Button } from "../../components/Button" import { BrainIcon } from "../SetupAIAssistant/BrainIcon" @@ -310,6 +310,7 @@ export const AIStatusIndicator: React.FC = () => { currentModel, abortOperation, clearOperation, + aiAssistantSettings, } = useAIStatus() const { chatWindowState, openChatWindow } = useAIConversation() const [expanded, setExpanded] = useState(false) @@ -320,8 +321,10 @@ export const AIStatusIndicator: React.FC = () => { const activeSidebar = useSelector(selectors.console.getActiveSidebar) const statusRef = useRef(null) const hasExtendedThinking = useMemo(() => { - return MODEL_OPTIONS.find((model) => model.value === currentModel)?.isSlow - }, [currentModel]) + return getAllModelOptions(aiAssistantSettings).find( + (model) => model.value === currentModel, + )?.isSlow + }, [currentModel, aiAssistantSettings]) const operationSections = useMemo( () => buildOperationSections(currentOperation, status, true), diff --git a/src/components/ExplainQueryButton/index.tsx b/src/components/ExplainQueryButton/index.tsx index 74ffc5d25..ada476df1 100644 --- a/src/components/ExplainQueryButton/index.tsx +++ b/src/components/ExplainQueryButton/index.tsx @@ -44,6 +44,7 @@ export const ExplainQueryButton = ({ hasSchemaAccess, currentModel: currentModelValue, apiKey: apiKeyValue, + aiAssistantSettings, } = useAIStatus() const { addMessage, @@ -62,6 +63,7 @@ export const ExplainQueryButton = ({ conversationId, queryText, settings: { model: currentModel, apiKey }, + aiAssistantSettings, questClient: quest, tables, hasSchemaAccess, diff --git a/src/components/FixQueryButton/index.tsx b/src/components/FixQueryButton/index.tsx index b79b552e9..37ebe80c9 100644 --- a/src/components/FixQueryButton/index.tsx +++ b/src/components/FixQueryButton/index.tsx @@ -21,8 +21,14 @@ export const FixQueryButton = () => { const { quest } = useContext(QuestContext) const { editorRef, executionRefs } = useEditor() const tables = useSelector(selectors.query.getTables) - const { setStatus, abortController, hasSchemaAccess, currentModel, apiKey } = - useAIStatus() + const { + setStatus, + abortController, + hasSchemaAccess, + currentModel, + apiKey, + aiAssistantSettings, + } = useAIStatus() const { chatWindowState, getConversationMeta, @@ -53,6 +59,7 @@ export const FixQueryButton = () => { errorMessage, errorWord: word ?? undefined, settings: { model: currentModel!, apiKey: apiKey! }, + aiAssistantSettings, questClient: quest, tables, hasSchemaAccess, diff --git a/src/components/Overlay/index.tsx b/src/components/Overlay/index.tsx index b18cfc8d4..b54cd4076 100644 --- a/src/components/Overlay/index.tsx +++ b/src/components/Overlay/index.tsx @@ -43,10 +43,13 @@ const StyledOverlay = styled.div` } ` -export const Overlay = ({ - primitive, -}: { - primitive: typeof RadixDialogOverlay | typeof RadixAlertDialogOverlay -}) => { - return -} +export const Overlay = React.forwardRef< + HTMLDivElement, + { + primitive: typeof RadixDialogOverlay | typeof RadixAlertDialogOverlay + } +>(({ primitive }, ref) => { + return +}) + +Overlay.displayName = "Overlay" diff --git a/src/components/SetupAIAssistant/ConfigurationModal.tsx b/src/components/SetupAIAssistant/ConfigurationModal.tsx index d48f22ab5..09ea7c807 100644 --- a/src/components/SetupAIAssistant/ConfigurationModal.tsx +++ b/src/components/SetupAIAssistant/ConfigurationModal.tsx @@ -1,4 +1,4 @@ -import React, { useState, useMemo, useCallback } from "react" +import React, { useState, useMemo, useCallback, useEffect } from "react" import styled, { css } from "styled-components" import { Dialog } from "../Dialog" import { MultiStepModal, Step } from "../MultiStepModal" @@ -10,17 +10,23 @@ import { Text } from "../Text" import { useLocalStorage } from "../../providers/LocalStorageProvider" import { testApiKey } from "../../utils/aiAssistant" import { StoreKey } from "../../utils/localStorage/types" +import type { CustomProviderDefinition } from "../../providers/LocalStorageProvider/types" import { toast } from "../Toast" import { - MODEL_OPTIONS, + getAllModelOptions, + getAllProviders, + makeCustomModelValue, type ModelOption, - type Provider, -} from "../../utils/aiAssistantSettings" + type ProviderId, + getProviderName, +} from "../../utils/ai" import { useModalNavigation } from "../MultiStepModal" import { OpenAIIcon } from "./OpenAIIcon" import { AnthropicIcon } from "./AnthropicIcon" import { BrainIcon } from "./BrainIcon" +import { PlusIcon, Plugs as PlugsIcon } from "@phosphor-icons/react" import { theme } from "../../theme" +import { CustomProviderModal } from "./CustomProviderModal" const ModalContent = styled.div` display: flex; @@ -113,17 +119,12 @@ const SectionDescription = styled(Text)` color: ${({ theme }) => theme.color.gray2}; ` -const ProviderSelectionContainer = styled(Box).attrs({ - gap: "4rem", - align: "center", -})` - width: 100%; -` - const ProviderCardsContainer = styled(Box).attrs({ gap: "2rem", + alignItems: "flex-start", })` height: 8.5rem; + width: 100%; ` const ProviderCard = styled.button<{ $selected: boolean }>` @@ -165,34 +166,6 @@ const ProviderName = styled(Text)` text-align: center; ` -const ComingSoonContainer = styled(Box).attrs({ - flexDirection: "column", - gap: "0.6rem", - align: "flex-start", -})` - width: 13.2rem; -` - -const ComingSoonIcons = styled(Box).attrs({ - align: "center", -})` - width: 100%; - padding-left: 0; - padding-right: 1.2rem; -` - -const ComingSoonIcon = styled.img` - width: 100%; - height: auto; - object-fit: contain; -` - -const ComingSoonText = styled(Text)` - font-size: 1.3rem; - font-weight: 300; - color: ${({ theme }) => theme.color.gray2}; -` - const InputSection = styled(Box).attrs({ flexDirection: "column", gap: "1.2rem", @@ -406,30 +379,53 @@ const WarningText = styled(Text)` text-align: left; ` +const AddCustomProviderCard = styled.button` + background: transparent; + border: 0.1rem dashed ${({ theme }) => theme.color.gray2}; + border-radius: 0.8rem; + cursor: pointer; + display: flex; + flex-direction: column; + align-items: center; + justify-content: center; + gap: 0.6rem; + padding: 1.2rem 2rem; + width: 10rem; + height: 8.5rem; + transition: all 0.2s; + color: ${({ theme }) => theme.color.gray2}; + + &:hover { + border-color: ${({ theme }) => theme.color.foreground}; + color: ${({ theme }) => theme.color.foreground}; + } + + &:focus-visible { + outline: 0.2rem solid ${({ theme }) => theme.color.foreground}; + outline-offset: 0.2rem; + } +` + type ConfigurationModalProps = { open?: boolean onOpenChange?: (open: boolean) => void } -const getProviderName = (provider: Provider | null) => { - if (!provider) return "" - return provider === "openai" ? "OpenAI" : "Anthropic" -} - type StepOneContentProps = { - selectedProvider: Provider | null + selectedProvider: ProviderId | null apiKey: string error: string | null providerName: string - onProviderSelect: (provider: Provider) => void + onProviderSelect: (provider: ProviderId) => void onApiKeyChange: (value: string) => void + onAddCustomProvider: () => void } type StepTwoContentProps = { - selectedProvider: Provider | null + selectedProvider: ProviderId | null enabledModels: string[] grantSchemaAccess: boolean - modelsByProvider: { anthropic: ModelOption[]; openai: ModelOption[] } + modelsByProvider: Record onModelToggle: (modelValue: string) => void onSchemaAccessChange: (checked: boolean) => void } @@ -463,6 +459,7 @@ const StepOneContent = ({ providerName, onProviderSelect, onApiKeyChange, + onAddCustomProvider, }: StepOneContentProps) => { const navigation = useModalNavigation() const handleClose: () => void = navigation.handleClose @@ -487,74 +484,77 @@ const StepOneContent = ({ Select Provider - We currently only support two model providers, with support for - more coming soon. + Choose a built-in provider or add your own custom provider. + You'll be able to configure and switch between multiple + providers later. - - - onProviderSelect("openai")} - type="button" - data-hook="ai-settings-provider-openai" - > - - OpenAI - - onProviderSelect("anthropic")} - type="button" - data-hook="ai-settings-provider-anthropic" - > - - Anthropic - - - - - - - Coming soon... - - + + onProviderSelect("openai")} + type="button" + data-hook="ai-settings-provider-openai" + > + + {getProviderName("openai")} + + onProviderSelect("anthropic")} + type="button" + data-hook="ai-settings-provider-anthropic" + > + + {getProviderName("anthropic")} + + + + Custom + + - - - - API Key - onApiKeyChange(e.target.value)} - placeholder={`Enter${providerName ? ` ${providerName}` : ""} API key`} - $hasError={!!error} - disabled={!selectedProvider} - data-hook="ai-settings-api-key" - /> - {error && ( - {error} - )} - - Stored locally in your browser and never sent to QuestDB servers. - This API key is used to authenticate your requests to the model - provider. - - - + {selectedProvider && ( + <> + + + + API Key + onApiKeyChange(e.target.value)} + placeholder={`Enter${providerName ? ` ${providerName}` : ""} API key`} + $hasError={!!error} + data-hook="ai-settings-api-key" + /> + {error && ( + + {error} + + )} + + Stored locally in your browser and never sent to QuestDB + servers. This API key is used to authenticate your requests to + the model provider. + + + + + )} ) } @@ -571,10 +571,8 @@ const StepTwoContent = ({ const handleClose: () => void = navigation.handleClose const currentProvider = selectedProvider - const getModelsForProvider = (provider: Provider) => { - return provider === "openai" - ? modelsByProvider.openai - : modelsByProvider.anthropic + const getModelsForProvider = (provider: ProviderId) => { + return modelsByProvider[provider] || [] } return ( @@ -600,10 +598,12 @@ const StepTwoContent = ({ Enable Models - {currentProvider === "openai" ? ( + {currentProvider === "anthropic" ? ( + + ) : currentProvider === "openai" ? ( ) : ( - + )} {getProviderName(currentProvider)} @@ -700,33 +700,38 @@ export const ConfigurationModal = ({ onOpenChange, }: ConfigurationModalProps) => { const { aiAssistantSettings, updateSettings } = useLocalStorage() - const [selectedProvider, setSelectedProvider] = useState( + const [selectedProvider, setSelectedProvider] = useState( null, ) const providerName = useMemo( - () => getProviderName(selectedProvider), - [selectedProvider], + () => getProviderName(selectedProvider, aiAssistantSettings), + [selectedProvider, aiAssistantSettings], ) const [apiKey, setApiKey] = useState("") const [error, setError] = useState(null) + const [customProviderModalOpen, setCustomProviderModalOpen] = useState(false) + + useEffect(() => { + if (!open) { + setCustomProviderModalOpen(false) + } + }, [open]) const [enabledModels, setEnabledModels] = useState([]) const [grantSchemaAccess, setGrantSchemaAccess] = useState(true) const modelsByProvider = useMemo(() => { - const anthropic: ModelOption[] = [] - const openai: ModelOption[] = [] - MODEL_OPTIONS.forEach((model) => { - if (model.provider === "anthropic") { - anthropic.push(model) - } else { - openai.push(model) + const result: Record = {} + getAllModelOptions(aiAssistantSettings).forEach((model) => { + if (!result[model.provider]) { + result[model.provider] = [] } + result[model.provider].push(model) }) - return { anthropic, openai } - }, []) + return result + }, [aiAssistantSettings]) - const handleProviderSelect = useCallback((provider: Provider) => { + const handleProviderSelect = useCallback((provider: ProviderId) => { setSelectedProvider(provider) setError(null) setApiKey("") @@ -755,7 +760,9 @@ export const ConfigurationModal = ({ const selectedModel = enabledModels.find( - (m) => MODEL_OPTIONS.find((mo) => mo.value === m)?.default, + (m) => + getAllModelOptions(aiAssistantSettings).find((mo) => mo.value === m) + ?.default, ) ?? enabledModels[0] const newSettings = { @@ -794,20 +801,25 @@ export const ConfigurationModal = ({ } const testModel = - MODEL_OPTIONS.find( + getAllModelOptions(aiAssistantSettings).find( (m) => m.isTestModel && m.provider === selectedProvider, )?.value ?? modelsByProvider[selectedProvider][0].value try { - const result = await testApiKey(apiKey, testModel) + const result = await testApiKey( + apiKey, + testModel, + selectedProvider, + aiAssistantSettings, + ) if (!result.valid) { const errorMsg = result.error || "Invalid API key" setError(errorMsg) return errorMsg } - const defaultModels = MODEL_OPTIONS.filter( - (m) => m.defaultEnabled && m.provider === selectedProvider, - ).map((m) => m.value) + const defaultModels = getAllModelOptions(aiAssistantSettings) + .filter((m) => m.defaultEnabled && m.provider === selectedProvider) + .map((m) => m.value) if (defaultModels.length > 0) { setEnabledModels(defaultModels) } @@ -848,6 +860,37 @@ export const ConfigurationModal = ({ setGrantSchemaAccess(true) }, []) + const handleCustomProviderSave = useCallback( + (providerId: string, definition: CustomProviderDefinition) => { + const newEnabledModels = definition.models.map((m) => + makeCustomModelValue(providerId, m), + ) + + const newSettings = { + ...aiAssistantSettings, + selectedModel: newEnabledModels[0], + customProviders: { + ...(aiAssistantSettings.customProviders ?? {}), + [providerId]: definition, + }, + providers: { + ...aiAssistantSettings.providers, + [providerId]: { + apiKey: definition.apiKey ?? "", + enabledModels: newEnabledModels, + grantSchemaAccess: definition.grantSchemaAccess ?? false, + }, + }, + } + + updateSettings(StoreKey.AI_ASSISTANT_SETTINGS, newSettings) + toast.success("AI Assistant activated successfully") + setCustomProviderModalOpen(false) + onOpenChange?.(false) + }, + [aiAssistantSettings, updateSettings, onOpenChange], + ) + const steps: Step[] = useMemo( () => [ { @@ -862,6 +905,7 @@ export const ConfigurationModal = ({ providerName={providerName} onProviderSelect={handleProviderSelect} onApiKeyChange={handleApiKeyChange} + onAddCustomProvider={() => setCustomProviderModalOpen(true)} /> ), validate: validateStepOne, @@ -901,21 +945,33 @@ export const ConfigurationModal = ({ ) return ( - { - if (!isOpen) { - handleModalClose() - } - onOpenChange?.(isOpen) - }} - onStepChange={handleStepChange} - steps={steps} - maxWidth="64rem" - onComplete={handleComplete} - canProceed={canProceed} - completeButtonText="Activate Assistant" - showValidationError={false} - /> + <> + { + if (!isOpen) { + handleModalClose() + } + onOpenChange?.(isOpen) + }} + onStepChange={handleStepChange} + steps={steps} + maxWidth="64rem" + onComplete={handleComplete} + canProceed={canProceed} + completeButtonText="Activate Assistant" + showValidationError={false} + /> + {customProviderModalOpen && ( + + getProviderName(p, aiAssistantSettings), + )} + /> + )} + ) } diff --git a/src/components/SetupAIAssistant/CustomProviderModal.tsx b/src/components/SetupAIAssistant/CustomProviderModal.tsx new file mode 100644 index 000000000..84760a982 --- /dev/null +++ b/src/components/SetupAIAssistant/CustomProviderModal.tsx @@ -0,0 +1,406 @@ +import React, { useState, useMemo, useCallback, useRef } from "react" +import styled from "styled-components" +import { MultiStepModal } from "../MultiStepModal" +import type { Step } from "../MultiStepModal" +import { useModalNavigation } from "../MultiStepModal" +import { Box } from "../Box" +import { Dialog } from "../Dialog" +import type { + ProviderType, + CustomProviderDefinition, +} from "../../utils/ai/settings" +import { Select } from "../Select" +import { toast } from "../Toast" +import { + ModelSettings, + InputSection, + InputLabel, + StyledInput, + HelperText, +} from "./ModelSettings" +import type { ModelSettingsRef } from "./ModelSettings" + +const ModalContent = styled.div` + display: flex; + flex-direction: column; + width: 100%; +` + +const HeaderSection = styled(Box).attrs({ + flexDirection: "column", + gap: "1.6rem", +})` + padding: 2.4rem; + padding-top: 0; + width: 100%; +` + +const HeaderTitleRow = styled(Box).attrs({ + justifyContent: "space-between", + align: "flex-start", + gap: "1rem", +})` + width: 100%; +` + +const HeaderText = styled(Box).attrs({ + flexDirection: "column", + gap: "1.2rem", + align: "flex-start", +})` + flex: 1; +` + +const ModalTitle = styled(Dialog.Title)` + font-size: 2.4rem; + font-weight: 600; + margin: 0; + padding: 0; + color: ${({ theme }) => theme.color.foreground}; + border: 0; +` + +const ModalSubtitle = styled(Dialog.Description)` + color: ${({ theme }) => theme.color.gray2}; + margin: 0; + padding: 0; +` + +const StyledCloseButton = styled.button` + background: transparent; + border: none; + cursor: pointer; + padding: 0; + display: flex; + align-items: center; + justify-content: center; + color: ${({ theme }) => theme.color.gray1}; + border-radius: 0.4rem; + flex-shrink: 0; + width: 2.2rem; + height: 2.2rem; + + &:hover { + color: ${({ theme }) => theme.color.foreground}; + } +` + +const Separator = styled.div` + height: 0.1rem; + width: 100%; + background: ${({ theme }) => theme.color.selection}; +` + +const ContentSection = styled(Box).attrs({ + flexDirection: "column", + gap: "2rem", +})` + padding: 2.4rem; + width: 100%; +` + +const PasswordInput = styled(StyledInput)` + text-security: disc; + -webkit-text-security: disc; + -moz-text-security: disc; +` + +const StyledSelect = styled(Select)` + width: 100%; + background: #262833; + color: ${({ theme }) => theme.color.foreground}; + border: 0.1rem solid #6b7280; + border-radius: 0.8rem; + min-height: 3.2rem; + padding: 0 0.75rem; + cursor: pointer; + + &:focus { + border-color: ${({ theme }) => theme.color.pink}; + outline: none; + } + + option { + background: ${({ theme }) => theme.color.backgroundDarker}; + color: ${({ theme }) => theme.color.foreground}; + } +` + +const CloseButton = ({ onClick }: { onClick: () => void }) => ( + + + + + +) + +type StepOneProps = { + name: string + providerType: ProviderType + baseURL: string + apiKey: string + onNameChange: (v: string) => void + onProviderTypeChange: (v: ProviderType) => void + onBaseURLChange: (v: string) => void + onApiKeyChange: (v: string) => void +} + +const StepOneContent = ({ + name, + providerType, + baseURL, + apiKey, + onNameChange, + onProviderTypeChange, + onBaseURLChange, + onApiKeyChange, +}: StepOneProps) => { + const navigation = useModalNavigation() + + return ( + + + + + Add Custom Provider + + Configure a custom AI provider endpoint. Supports + OpenAI-compatible, Anthropic-compatible, and local providers like + Ollama. + + + + + + + + + Provider Name + onNameChange(e.target.value)} + placeholder="e.g., OpenRouter, Ollama" + /> + + + Provider Type + + onProviderTypeChange(e.target.value as ProviderType) + } + options={[ + { + label: "OpenAI Chat Completions API", + value: "openai-chat-completions", + }, + { + label: "OpenAI Responses API", + value: "openai", + }, + { + label: "Anthropic Messages API", + value: "anthropic", + }, + ]} + /> + + Most third-party providers and local models use the OpenAI Chat + Completions format. + + + + Base URL + onBaseURLChange(e.target.value)} + placeholder="e.g., http://localhost:11434/v1" + /> + + The base URL of your provider's API endpoint. + + + + API Key + onApiKeyChange(e.target.value)} + placeholder="Optional for local providers" + /> + + Stored locally in your browser. Optional for local providers like + Ollama. + + + + + ) +} + +const StepTwoHeader = () => { + const navigation = useModalNavigation() + + return ( + + + + Models & Settings + + Configure the models and settings for your custom provider. + + + + + + ) +} + +export type CustomProviderModalProps = { + open: boolean + onOpenChange: (open: boolean) => void + onSave: (providerId: string, provider: CustomProviderDefinition) => void + existingProviderNames: string[] +} + +export const CustomProviderModal = ({ + open, + onOpenChange, + onSave, + existingProviderNames, +}: CustomProviderModalProps) => { + const [name, setName] = useState("") + const [providerType, setProviderType] = useState( + "openai-chat-completions", + ) + const [baseURL, setBaseURL] = useState("") + const [apiKey, setApiKey] = useState("") + + const modelSettingsRef = useRef(null) + + const connectionValidate = useCallback((): string | boolean => { + if (!name.trim()) return "Provider name is required" + if (!baseURL.trim()) return "Base URL is required" + if (!baseURL.startsWith("http://") && !baseURL.startsWith("https://")) + return "Base URL must start with http:// or https://" + + const normalizedName = name.trim().toLowerCase() + if (existingProviderNames.some((n) => n.toLowerCase() === normalizedName)) + return "A provider with the same name already exists" + + return true + }, [name, baseURL, existingProviderNames]) + + const modelsValidate = useCallback((): string | boolean => { + return modelSettingsRef.current?.validate() ?? "Not ready" + }, []) + + const handleComplete = useCallback(() => { + const providerId = crypto.randomUUID() + const values = modelSettingsRef.current?.getValues() + if (!values) return + + const definition: CustomProviderDefinition = { + type: providerType, + name: name.trim(), + baseURL: baseURL.trim(), + apiKey: apiKey || undefined, + contextWindow: values.contextWindow, + models: values.models, + grantSchemaAccess: values.grantSchemaAccess, + } + + onSave(providerId, definition) + toast.success(`Added custom provider ${name.trim()}.`) + }, [name, providerType, baseURL, apiKey, onSave]) + + const steps: Step[] = useMemo(() => { + const connectionStep: Step = { + id: "connection", + title: "Add Custom Provider", + stepName: "Connection", + content: ( + + ), + validate: connectionValidate, + } + + return [ + connectionStep, + { + id: "model-settings", + title: "Add Custom Provider", + stepName: "Models & Settings", + content: ( + + + + + + ), + validate: modelsValidate, + }, + ] + }, [name, providerType, baseURL, apiKey, connectionValidate, modelsValidate]) + + const canProceed = useCallback( + (stepIndex: number): boolean => { + if (stepIndex === 0) { + return !!name.trim() && !!baseURL.trim() + } + return true + }, + [name, baseURL], + ) + + return ( + + ) +} diff --git a/src/components/SetupAIAssistant/ManageModelsModal.tsx b/src/components/SetupAIAssistant/ManageModelsModal.tsx new file mode 100644 index 000000000..c4cf606c9 --- /dev/null +++ b/src/components/SetupAIAssistant/ManageModelsModal.tsx @@ -0,0 +1,163 @@ +import React, { useState, useCallback, useRef } from "react" +import styled from "styled-components" +import * as RadixDialog from "@radix-ui/react-dialog" +import { Dialog } from "../Dialog" +import { Box } from "../Box" +import { Text } from "../Text" +import { Button } from "../Button" +import { Overlay } from "../Overlay" +import type { CustomProviderDefinition } from "../../utils/ai/settings" +import { ModelSettings } from "./ModelSettings" +import type { ModelSettingsRef } from "./ModelSettings" + +const ModalContent = styled.div` + display: flex; + flex-direction: column; + width: 100%; + overflow-y: auto; +` + +const HeaderSection = styled(Box).attrs({ + flexDirection: "column", + gap: "1.2rem", + align: "flex-start", +})` + padding: 2rem 2.4rem; +` + +const ModalTitle = styled(Dialog.Title)` + font-size: 2.4rem; + font-weight: 600; + margin: 0; + padding: 0; + color: ${({ theme }) => theme.color.foreground}; + border: 0; +` + +const ModalSubtitle = styled(RadixDialog.Description)` + font-size: 1.4rem; + color: ${({ theme }) => theme.color.gray2}; + margin: 0; +` + +const Separator = styled.div` + height: 0.1rem; + width: 100%; + background: ${({ theme }) => theme.color.selection}; +` + +const FooterSection = styled(Box).attrs({ + justifyContent: "flex-end", + align: "center", + gap: "1.2rem", +})` + padding: 2rem 2.4rem; + width: 100%; +` + +const FooterButton = styled(Button)` + padding: 1.1rem 1.2rem; + font-size: 1.4rem; + font-weight: 500; + height: 4rem; + min-width: 12rem; +` + +const ErrorText = styled(Text)` + font-size: 1.3rem; + color: ${({ theme }) => theme.color.red}; +` + +type ManageModelsModalProps = { + open: boolean + onOpenChange: (open: boolean) => void + providerId: string + definition: CustomProviderDefinition + onSave: (providerId: string, definition: CustomProviderDefinition) => void +} + +export const ManageModelsModal = ({ + open, + onOpenChange, + providerId, + definition, + onSave, +}: ManageModelsModalProps) => { + const [error, setError] = useState(null) + const [modelsLoading, setModelsLoading] = useState(true) + const modelSettingsRef = useRef(null) + + const handleSave = useCallback(() => { + setError(null) + const result = modelSettingsRef.current?.validate() + if (typeof result === "string") { + setError(result) + return + } + const values = modelSettingsRef.current?.getValues() + if (!values) return + onSave(providerId, { + ...definition, + models: values.models, + contextWindow: values.contextWindow, + }) + onOpenChange(false) + }, [definition, providerId, onSave, onOpenChange]) + + return ( + + + + + + + Manage Models + + Add or remove models and update the context window for{" "} + {definition.name}. + + + + {open && ( + + )} + + + {error && {error}} + onOpenChange(false)} + > + Cancel + + + Save + + + + + + + ) +} diff --git a/src/components/SetupAIAssistant/ModelDropdown.tsx b/src/components/SetupAIAssistant/ModelDropdown.tsx index 063d53845..3487fb689 100644 --- a/src/components/SetupAIAssistant/ModelDropdown.tsx +++ b/src/components/SetupAIAssistant/ModelDropdown.tsx @@ -6,12 +6,13 @@ import { PopperToggle } from "../PopperToggle" import { Box } from "../Box" import { Text } from "../Text" import { useLocalStorage } from "../../providers/LocalStorageProvider" -import { MODEL_OPTIONS } from "../../utils/aiAssistantSettings" +import { getAllModelOptions } from "../../utils/ai" import { useAIStatus } from "../../providers/AIStatusProvider" import { StoreKey } from "../../utils/localStorage/types" import { OpenAIIcon } from "./OpenAIIcon" import { AnthropicIcon } from "./AnthropicIcon" import { BrainIcon } from "./BrainIcon" +import { PlugsIcon } from "@phosphor-icons/react" import { Tooltip } from "../Tooltip" const ExpandUpDown = () => ( @@ -97,6 +98,8 @@ const DropdownContent = styled.div` min-width: 22.8rem; gap: 0.4rem; z-index: 9999; + max-height: 50vh; + overflow-y: auto; ` const Title = styled(Text)` @@ -164,10 +167,10 @@ export const ModelDropdown = () => { const [dropdownActive, setDropdownActive] = useState(false) const enabledModels = useMemo(() => { - return MODEL_OPTIONS.filter((model) => + return getAllModelOptions(aiAssistantSettings).filter((model) => enabledModelValues.includes(model.value), ) - }, [enabledModelValues]) + }, [enabledModelValues, aiAssistantSettings]) const handleModelSelect = (modelValue: string) => { updateSettings(StoreKey.AI_ASSISTANT_SETTINGS, { @@ -221,10 +224,12 @@ export const ModelDropdown = () => { ]} trigger={ - {displayModel.provider === "openai" ? ( + {displayModel.provider === "anthropic" ? ( + + ) : displayModel.provider === "openai" ? ( ) : ( - + )} {displayModel.label} @@ -246,10 +251,12 @@ export const ModelDropdown = () => { $selected={isSelected} > - {model.provider === "openai" ? ( + {model.provider === "anthropic" ? ( + + ) : model.provider === "openai" ? ( ) : ( - + )} {model.label} diff --git a/src/components/SetupAIAssistant/ModelSettings.tsx b/src/components/SetupAIAssistant/ModelSettings.tsx new file mode 100644 index 000000000..c976d89a4 --- /dev/null +++ b/src/components/SetupAIAssistant/ModelSettings.tsx @@ -0,0 +1,685 @@ +import React, { + useState, + useCallback, + useEffect, + useRef, + useImperativeHandle, + forwardRef, +} from "react" +import styled, { useTheme } from "styled-components" +import { Box } from "../Box" +import { Input } from "../Input" +import { Checkbox } from "../Checkbox" +import { Text } from "../Text" +import { LoadingSpinner } from "../LoadingSpinner" +import { WarningIcon, XIcon } from "@phosphor-icons/react" +import { createProviderByType } from "../../utils/ai/registry" +import type { ProviderType } from "../../utils/ai/settings" + +export const InputSection = styled(Box).attrs({ + flexDirection: "column", + gap: "1.2rem", +})` + width: 100%; +` + +export const InputLabel = styled(Text)` + font-size: 1.6rem; + font-weight: 600; + color: ${({ theme }) => theme.color.gray2}; +` + +export const StyledInput = styled(Input)<{ $hasError?: boolean }>` + width: 100%; + background: #262833; + border: 0.1rem solid + ${({ theme, $hasError }) => ($hasError ? theme.color.red : "#6b7280")}; + border-radius: 0.8rem; + font-size: 1.4rem; + min-height: 3rem; + + &::placeholder { + color: ${({ theme }) => theme.color.gray2}; + font-family: inherit; + } +` + +export const HelperText = styled(Text)` + font-size: 1.3rem; + font-weight: 300; + color: ${({ theme }) => theme.color.gray2}; +` + +const WarningBanner = styled(Box).attrs({ + flexDirection: "row", + gap: "0.6rem", + align: "center", +})` + width: 100%; + background: rgba(255, 165, 0, 0.08); + border: 0.1rem solid ${({ theme }) => theme.color.orange}; + border-radius: 0.8rem; + padding: 0.75rem; +` + +const WarningText = styled(Text)` + font-size: 1.3rem; + color: ${({ theme }) => theme.color.orange}; +` + +const ModelListContainer = styled.div` + max-height: 30rem; + overflow-y: auto; + display: flex; + flex-direction: column; + gap: 0.25rem; + border: 0.1rem solid #6b7280; + border-radius: 0.4rem; + width: 100%; +` + +const ModelRow = styled.label` + display: flex; + align-items: center; + gap: 0.8rem; + padding: 0.6rem 0.8rem; + cursor: pointer; + font-size: 1.4rem; + color: ${({ theme }) => theme.color.foreground}; + + &:hover { + background: ${({ theme }) => theme.color.selection}; + } +` + +const ModelChipsContainer = styled.div` + display: flex; + flex-wrap: wrap; + gap: 0.6rem; +` + +const ModelChip = styled.div` + display: inline-flex; + align-items: center; + gap: 0.5rem; + background: ${({ theme }) => theme.color.selection}; + border-radius: 0.4rem; + padding: 0.4rem 0.8rem; + font-size: 1.3rem; + color: ${({ theme }) => theme.color.foreground}; +` + +const ChipRemoveButton = styled.button` + background: none; + border: none; + cursor: pointer; + padding: 0; + display: flex; + justify-content: center; + align-items: center; + color: ${({ theme }) => theme.color.gray2}; + + &:hover { + color: ${({ theme }) => theme.color.foreground}; + } +` + +const AddModelRow = styled(Box).attrs({ + gap: "0.8rem", + align: "center", +})` + width: 100%; +` + +const AddModelButton = styled.button` + height: 3rem; + border: 0.1rem solid ${({ theme }) => theme.color.pinkDarker}; + background: ${({ theme }) => theme.color.background}; + color: ${({ theme }) => theme.color.foreground}; + border-radius: 0.4rem; + padding: 0 1.2rem; + font-size: 1.4rem; + font-weight: 500; + cursor: pointer; + white-space: nowrap; + + &:hover:not(:disabled) { + background: ${({ theme }) => theme.color.pinkDarker}; + } + + &:disabled { + opacity: 0.6; + cursor: not-allowed; + } +` + +const SelectAllRow = styled(Box).attrs({ + gap: "2rem", + align: "center", +})` + display: inline-flex; + margin-left: auto; +` + +const SelectAllLink = styled.button` + background: none; + border: none; + cursor: pointer; + color: ${({ theme }) => theme.color.cyan}; + font-size: 1.4rem; + padding: 0; + + &:hover { + text-decoration: underline; + } +` + +const SchemaAccessSection = styled(Box).attrs({ + flexDirection: "column", + gap: "1.6rem", + align: "flex-start", +})` + width: 100%; +` + +const SchemaAccessTitle = styled(Text)` + font-size: 1.6rem; + font-weight: 600; + color: ${({ theme }) => theme.color.gray2}; + flex: 1; +` + +const SchemaCheckboxContainer = styled(Box).attrs({ + gap: "1.5rem", + align: "flex-start", +})` + background: rgba(68, 71, 90, 0.56); + padding: 0.75rem; + border-radius: 0.4rem; + width: 100%; +` + +const SchemaCheckboxInner = styled(Box).attrs({ + gap: "1.5rem", + align: "center", +})` + flex: 1; + padding: 0.75rem; + border-radius: 0.5rem; +` + +const SchemaCheckboxWrapper = styled.div` + flex-shrink: 0; + display: flex; + align-items: center; +` + +const SchemaCheckboxContent = styled(Box).attrs({ + flexDirection: "column", + gap: "0.6rem", +})` + flex: 1; +` + +const SchemaCheckboxLabel = styled(Text)` + font-size: 1.4rem; + font-weight: 500; + color: ${({ theme }) => theme.color.foreground}; +` + +const SchemaCheckboxDescription = styled(Text)` + font-size: 1.3rem; + font-weight: 400; + color: ${({ theme }) => theme.color.gray2}; +` + +const SchemaCheckboxDescriptionBold = styled.span` + font-weight: 500; + color: ${({ theme }) => theme.color.foreground}; +` + +const ContentSection = styled(Box).attrs({ + flexDirection: "column", + gap: "2rem", +})` + padding: 2.4rem; + width: 100%; +` + +const Separator = styled.div` + height: 0.1rem; + width: 100%; + background: ${({ theme }) => theme.color.selection}; +` + +const LoadingContainer = styled(Box).attrs({ + align: "center", + justifyContent: "center", +})` + width: 100%; + padding: 4rem 0; +` + +// --- Types --- + +export type FetchConfig = { + providerType: ProviderType + providerId: string + apiKey: string + baseURL: string +} + +export type ModelSettingsInitialValues = { + models?: string[] + contextWindow?: number + grantSchemaAccess?: boolean +} + +export type ModelSettingsData = { + models: string[] + contextWindow: number + grantSchemaAccess: boolean +} + +export type ModelSettingsRef = { + getValues: () => ModelSettingsData + validate: () => string | true +} + +export type ModelSettingsProps = { + initialValues?: ModelSettingsInitialValues + fetchConfig: FetchConfig + renderSchemaAccess?: boolean + providerName?: string + onLoadingChange?: (loading: boolean) => void +} + +// --- Utility --- + +async function fetchProviderModels( + config: FetchConfig, + contextWindow: number, +): Promise { + try { + const provider = createProviderByType( + config.providerType, + config.providerId, + config.apiKey, + { baseURL: config.baseURL, contextWindow, isCustom: true }, + ) + const models = await provider.listModels() + return models && models.length > 0 ? models : null + } catch { + return null + } +} + +// --- Component --- + +export const ModelSettings = forwardRef( + ( + { + initialValues, + fetchConfig, + renderSchemaAccess, + providerName, + onLoadingChange, + }, + ref, + ) => { + const theme = useTheme() + + const [fetchedModels, setFetchedModels] = useState(null) + const [selectedModels, setSelectedModels] = useState([]) + const [manualModels, setManualModels] = useState( + () => initialValues?.models ?? [], + ) + const [manualModelInput, setManualModelInput] = useState("") + const [contextWindowInput, setContextWindowInput] = useState(() => + String(initialValues?.contextWindow ?? 200_000), + ) + const [grantSchemaAccess, setGrantSchemaAccess] = useState( + () => initialValues?.grantSchemaAccess ?? true, + ) + const [isLoading, setIsLoading] = useState(true) + + const fetchConfigRef = useRef(fetchConfig) + fetchConfigRef.current = fetchConfig + const initialValuesRef = useRef(initialValues) + initialValuesRef.current = initialValues + + // Fetch models on mount + useEffect(() => { + let cancelled = false + + const doFetch = async () => { + setIsLoading(true) + const config = fetchConfigRef.current + const initModels = initialValuesRef.current?.models ?? [] + const initContextWindow = + initialValuesRef.current?.contextWindow ?? 200_000 + + const models = await fetchProviderModels(config, initContextWindow) + + if (cancelled) return + + if (models) { + // Auto mode: reconcile initialValues.models against fetched list + setFetchedModels(models) + const selected = [ + ...initModels.filter((m) => models.includes(m)), + ...initModels.filter((m) => !models.includes(m)), + ] + setSelectedModels(selected.length > 0 ? selected : []) + setManualModels([]) + } else { + // Manual mode + setFetchedModels(null) + setSelectedModels([]) + setManualModels([...initModels]) + } + setIsLoading(false) + } + + void doFetch() + return () => { + cancelled = true + } + }, []) + + useEffect(() => { + onLoadingChange?.(isLoading) + }, [isLoading, onLoadingChange]) + + const isAutoMode = fetchedModels !== null + + // --- Handlers --- + + const handleToggleModel = useCallback((model: string) => { + setSelectedModels((prev) => + prev.includes(model) + ? prev.filter((m) => m !== model) + : [...prev, model], + ) + }, []) + + const handleSelectAll = useCallback(() => { + setSelectedModels((prev) => { + if (!fetchedModels) return prev + const manual = prev.filter((m) => !fetchedModels.includes(m)) + return [...fetchedModels, ...manual] + }) + }, [fetchedModels]) + + const handleDeselectAll = useCallback(() => { + setSelectedModels((prev) => + fetchedModels ? prev.filter((m) => !fetchedModels.includes(m)) : [], + ) + }, [fetchedModels]) + + const handleAddManualModel = useCallback(() => { + const trimmed = manualModelInput.trim() + if (!trimmed) return + + if (isAutoMode) { + setSelectedModels((prev) => + prev.includes(trimmed) ? prev : [...prev, trimmed], + ) + } else { + setManualModels((prev) => + prev.includes(trimmed) ? prev : [...prev, trimmed], + ) + } + setManualModelInput("") + }, [manualModelInput, isAutoMode]) + + const handleRemoveManualModel = useCallback((model: string) => { + setManualModels((prev) => prev.filter((m) => m !== model)) + }, []) + + // --- Imperative handle --- + + useImperativeHandle( + ref, + () => ({ + getValues: () => { + const pending = manualModelInput.trim() + let models: string[] + + if (isAutoMode) { + models = + pending && !selectedModels.includes(pending) + ? [...selectedModels, pending] + : [...selectedModels] + } else { + models = + pending && !manualModels.includes(pending) + ? [...manualModels, pending] + : [...manualModels] + } + + const contextWindow = Number(contextWindowInput) || 0 + return { models, contextWindow, grantSchemaAccess } + }, + validate: () => { + const pending = manualModelInput.trim() + const models = isAutoMode ? selectedModels : manualModels + const hasModels = models.length > 0 || !!pending + if (!hasModels) return "Add at least one model" + const trimmed = contextWindowInput.trim() + if (!trimmed) return "Context window is required" + const contextWindow = Number(trimmed) + if (isNaN(contextWindow) || !Number.isInteger(contextWindow)) + return "Context window must be a valid number" + if (contextWindow < 100_000) + return "Context window must be at least 100,000 tokens" + return true + }, + }), + [ + manualModelInput, + isAutoMode, + selectedModels, + manualModels, + contextWindowInput, + grantSchemaAccess, + ], + ) + + // --- Render --- + + if (isLoading) { + return ( + + + + + + ) + } + + return ( + <> + + {!isAutoMode && ( + + + + Could not fetch models automatically from this provider. Please + enter model IDs manually. + + + )} + {isAutoMode && ( + + + Select Models + + + Select All + + + Deselect All + + + + + {fetchedModels.map((model) => ( + + handleToggleModel(model)} + /> + {model} + + ))} + + + )} + + {!isAutoMode && Add Models} + {isAutoMode && ( + + Don't see your model? Add it manually: + + )} + + setManualModelInput(e.target.value)} + placeholder="e.g., llama3, gpt-4o, claude-sonnet-4-20250514" + onKeyDown={(e) => { + if (e.key === "Enter") { + e.preventDefault() + handleAddManualModel() + } + }} + /> + + Add + + + {isAutoMode && + selectedModels.filter((m) => !fetchedModels.includes(m)).length > + 0 && ( + + {selectedModels + .filter((m) => !fetchedModels.includes(m)) + .map((model) => ( + + {model} + handleToggleModel(model)} + > + + + + ))} + + )} + {!isAutoMode && manualModels.length > 0 && ( + + {manualModels.map((model) => ( + + {model} + handleRemoveManualModel(model)} + title={`Remove ${model}`} + > + + + + ))} + + )} + + + + + + Context Window + setContextWindowInput(e.target.value)} + /> + + Maximum number of tokens the model can process. AI assistant + requires a minimum of 100,000 tokens. + + + + {renderSchemaAccess && ( + <> + + + + Schema Access + + + + setGrantSchemaAccess(e.target.checked)} + /> + + + + Grant schema access to {providerName || "this provider"} + + + When enabled, the AI assistant can access your database + schema information to provide more accurate suggestions + and explanations. Schema information helps the AI + understand your table structures, column names, and + relationships.{" "} + + The AI model will not have access to your data. + + + + + + + + + )} + + ) + }, +) + +ModelSettings.displayName = "ModelSettings" diff --git a/src/components/SetupAIAssistant/SettingsModal.tsx b/src/components/SetupAIAssistant/SettingsModal.tsx index be389a38f..5e10fb316 100644 --- a/src/components/SetupAIAssistant/SettingsModal.tsx +++ b/src/components/SetupAIAssistant/SettingsModal.tsx @@ -13,6 +13,7 @@ import { testApiKey } from "../../utils/aiAssistant" import { StoreKey } from "../../utils/localStorage/types" import { toast } from "../Toast" import { Edit } from "@styled-icons/remix-line" +import { TrashIcon, PlugsIcon, PlusIcon } from "@phosphor-icons/react" import { OpenAIIcon } from "./OpenAIIcon" import { AnthropicIcon } from "./AnthropicIcon" import { BrainIcon } from "./BrainIcon" @@ -20,15 +21,24 @@ import { LoadingSpinner } from "../LoadingSpinner" import { Overlay } from "../Overlay" import { getAllProviders, - MODEL_OPTIONS, + getAllModelOptions, + getApiKey, + makeCustomModelValue, + BUILTIN_PROVIDERS, type ModelOption, - type Provider, + type ProviderId, getNextModel, -} from "../../utils/aiAssistantSettings" -import type { AiAssistantSettings } from "../../providers/LocalStorageProvider/types" + getProviderName, +} from "../../utils/ai" +import type { + AiAssistantSettings, + CustomProviderDefinition, +} from "../../providers/LocalStorageProvider/types" import { ForwardRef } from "../ForwardRef" import { Badge, BadgeType } from "../../components/Badge" import { CheckboxCircle } from "@styled-icons/remix-fill" +import { CustomProviderModal } from "./CustomProviderModal" +import { ManageModelsModal } from "./ManageModelsModal" const ModalContent = styled.div` display: flex; @@ -159,6 +169,10 @@ const ProviderTabTitle = styled(Box).attrs({ align: "center", })` width: 100%; + + svg { + flex-shrink: 0; + } ` const ProviderTabName = styled(Text)<{ $active: boolean }>` @@ -166,6 +180,7 @@ const ProviderTabName = styled(Text)<{ $active: boolean }>` font-weight: ${({ $active }) => ($active ? 600 : 400)}; color: ${({ theme, $active }) => $active ? theme.color.foreground : theme.color.gray2}; + text-align: left; ` const StatusBadge = styled(Box).attrs({ @@ -386,6 +401,19 @@ const EnableModelsTitle = styled(Text)` color: ${({ theme }) => theme.color.foreground}; ` +const ManageModelsButton = styled.button` + background: none; + border: none; + cursor: pointer; + color: ${({ theme }) => theme.color.cyan}; + font-size: 1.3rem; + padding: 0; + + &:hover { + text-decoration: underline; + } +` + const SchemaAccessSection = styled(Box).attrs({ flexDirection: "column", gap: "1.6rem", @@ -457,6 +485,18 @@ const SchemaCheckboxDescriptionBold = styled.span` color: ${({ theme }) => theme.color.foreground}; ` +const RemoveProviderButton = styled(Button)` + border: 0.1rem solid ${({ theme }) => theme.color.red}; + background: ${({ theme }) => theme.color.backgroundDarker}; + color: ${({ theme }) => theme.color.foreground}; + + &:hover:not(:disabled) { + background: ${({ theme }) => theme.color.background}; + border: 0.1rem solid ${({ theme }) => theme.color.red}; + color: ${({ theme }) => theme.color.foreground}; + } +` + const FooterSection = styled(Box).attrs({ flexDirection: "column", gap: "2rem", @@ -495,24 +535,46 @@ const SaveButton = styled(Button)` width: 100%; ` +const AddProviderButton = styled.button` + display: flex; + flex-direction: column; + align-items: center; + gap: 0.5rem; + padding: 0.8rem 1.6rem; + background: none; + border: 0.1rem dashed ${({ theme }) => theme.color.gray2}; + border-radius: 0.4rem; + color: ${({ theme }) => theme.color.gray2}; + cursor: pointer; + font-size: 1.3rem; + justify-content: center; + margin: 0 1rem; + + &:hover { + border-color: ${({ theme }) => theme.color.foreground}; + color: ${({ theme }) => theme.color.foreground}; + } +` + type SettingsModalProps = { open?: boolean onOpenChange?: (open: boolean) => void } -const getProviderName = (provider: Provider) => { - return provider === "openai" ? "OpenAI" : "Anthropic" +const getModelsForProvider = ( + provider: ProviderId, + settings?: AiAssistantSettings, +): ModelOption[] => { + return getAllModelOptions(settings).filter((m) => m.provider === provider) } -const getModelsForProvider = (provider: Provider): ModelOption[] => { - return MODEL_OPTIONS.filter((m) => m.provider === provider) -} - -const getProvidersWithApiKeys = (settings: AiAssistantSettings): Provider[] => { - const providers: Provider[] = [] - const allProviders = getAllProviders() +const getProvidersWithApiKeys = ( + settings: AiAssistantSettings, +): ProviderId[] => { + const providers: ProviderId[] = [] + const allProviders = getAllProviders(settings) for (const provider of allProviders) { - if (settings.providers?.[provider]?.apiKey) { + if (getApiKey(provider, settings)) { providers.push(provider) } } @@ -523,11 +585,11 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { const { aiAssistantSettings, updateSettings } = useLocalStorage() const initializeProviderState = useCallback( ( - getValue: (provider: Provider) => T, + getValue: (provider: ProviderId) => T, defaultValue: T, - ): Record => { - const allProviders = getAllProviders() - const state = {} as Record + ): Record => { + const allProviders = getAllProviders(aiAssistantSettings) + const state = {} as Record for (const provider of allProviders) { state[provider] = getValue(provider) ?? defaultValue } @@ -536,18 +598,19 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { [], ) - const [selectedProvider, setSelectedProvider] = useState(() => { + const [selectedProvider, setSelectedProvider] = useState(() => { const providersWithKeys = getProvidersWithApiKeys(aiAssistantSettings) - return providersWithKeys[0] || getAllProviders()[0] + return providersWithKeys[0] || getAllProviders(aiAssistantSettings)[0] }) - const [apiKeys, setApiKeys] = useState>(() => + const isCustomProvider = !BUILTIN_PROVIDERS[selectedProvider] + const [apiKeys, setApiKeys] = useState>(() => initializeProviderState( - (provider) => aiAssistantSettings.providers?.[provider]?.apiKey || "", + (provider) => getApiKey(provider, aiAssistantSettings) || "", "", ), ) const [enabledModels, setEnabledModels] = useState< - Record + Record >(() => initializeProviderState( (provider) => @@ -556,43 +619,65 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { ), ) const [grantSchemaAccess, setGrantSchemaAccess] = useState< - Record + Record >(() => - initializeProviderState( - (provider) => - aiAssistantSettings.providers?.[provider]?.grantSchemaAccess !== false, - true, - ), + initializeProviderState((provider) => { + const providerSettings = aiAssistantSettings.providers?.[provider] + if (providerSettings) return providerSettings.grantSchemaAccess !== false + const custom = aiAssistantSettings.customProviders?.[provider] + if (custom) return custom.grantSchemaAccess !== false + return true + }, true), ) const [validatedApiKeys, setValidatedApiKeys] = useState< - Record + Record >(() => initializeProviderState( - (provider) => !!aiAssistantSettings.providers?.[provider]?.apiKey, + (provider) => + !BUILTIN_PROVIDERS[provider] || + !!getApiKey(provider, aiAssistantSettings), false, ), ) const [validationState, setValidationState] = useState< - Record + Record >(() => initializeProviderState(() => "idle" as const, "idle" as const)) const [validationErrors, setValidationErrors] = useState< - Record + Record >(() => initializeProviderState(() => null, null)) const [isInputFocused, setIsInputFocused] = useState< - Record + Record >(() => initializeProviderState(() => false, false)) const inputRef = useRef(null) - const handleProviderSelect = useCallback((provider: Provider) => { + const [customProviderModalOpen, setCustomProviderModalOpen] = useState(false) + const [manageModelsModalOpen, setManageModelsModalOpen] = useState(false) + + const [localCustomProviders, setLocalCustomProviders] = useState< + Record + >(() => ({ ...(aiAssistantSettings.customProviders ?? {}) })) + + const localSettings = useMemo( + () => ({ + ...aiAssistantSettings, + customProviders: + Object.keys(localCustomProviders).length > 0 + ? localCustomProviders + : undefined, + }), + [aiAssistantSettings, localCustomProviders], + ) + + const handleProviderSelect = useCallback((provider: ProviderId) => { setSelectedProvider(provider) setValidationErrors((prev) => ({ ...prev, [provider]: null })) }, []) const handleApiKeyChange = useCallback( - (provider: Provider, value: string) => { + (provider: ProviderId, value: string) => { setApiKeys((prev) => ({ ...prev, [provider]: value })) setValidationErrors((prev) => ({ ...prev, [provider]: null })) - // If API key changes, mark as not validated + if (validatedApiKeys[provider]) { setValidatedApiKeys((prev) => ({ ...prev, [provider]: false })) } @@ -601,7 +686,7 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { ) const handleValidateApiKey = useCallback( - async (provider: Provider) => { + async (provider: ProviderId) => { const apiKey = apiKeys[provider] if (!apiKey) { setValidationErrors((prev) => ({ @@ -614,7 +699,7 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { setValidationState((prev) => ({ ...prev, [provider]: "validating" })) setValidationErrors((prev) => ({ ...prev, [provider]: null })) - const providerModels = getModelsForProvider(provider) + const providerModels = getModelsForProvider(provider, localSettings) if (providerModels.length === 0) { setValidationState((prev) => ({ ...prev, [provider]: "error" })) setValidationErrors((prev) => ({ @@ -628,7 +713,12 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { providerModels.find((m) => m.isTestModel) ?? providerModels[0] ).value try { - const result = await testApiKey(apiKey, testModel) + const result = await testApiKey( + apiKey, + testModel, + provider, + localSettings, + ) if (!result.valid) { setValidationState((prev) => ({ ...prev, [provider]: "error" })) setValidationErrors((prev) => ({ @@ -636,9 +726,9 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { [provider]: result.error || "Invalid API key", })) } else { - const defaultModels = MODEL_OPTIONS.filter( - (m) => m.defaultEnabled && m.provider === provider, - ).map((m) => m.value) + const defaultModels = getAllModelOptions(localSettings) + .filter((m) => m.defaultEnabled && m.provider === provider) + .map((m) => m.value) if (defaultModels.length > 0) { setEnabledModels((prev) => ({ ...prev, [provider]: defaultModels })) } @@ -656,18 +746,8 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { [apiKeys], ) - const handleRemoveApiKey = useCallback((provider: Provider) => { - // Remove API key from local state only - // Settings will be persisted when Save Settings is clicked - setApiKeys((prev) => ({ ...prev, [provider]: "" })) - setValidatedApiKeys((prev) => ({ ...prev, [provider]: false })) - setValidationState((prev) => ({ ...prev, [provider]: "idle" })) - setValidationErrors((prev) => ({ ...prev, [provider]: null })) - setIsInputFocused((prev) => ({ ...prev, [provider]: false })) - }, []) - const handleModelToggle = useCallback( - (provider: Provider, modelValue: string) => { + (provider: ProviderId, modelValue: string) => { setEnabledModels((prev) => { const current = prev[provider] const isEnabled = current.includes(modelValue) @@ -683,7 +763,7 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { ) const handleSchemaAccessChange = useCallback( - (provider: Provider, checked: boolean) => { + (provider: ProviderId, checked: boolean) => { setGrantSchemaAccess((prev) => ({ ...prev, [provider]: checked })) }, [], @@ -691,28 +771,54 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { const handleSave = useCallback(() => { const updatedProviders = { ...aiAssistantSettings.providers } - const allProviders = getAllProviders() + const allProviderIds = getAllProviders(localSettings) - for (const provider of allProviders) { - if (validatedApiKeys[provider]) { - // Only save providers with validated API keys + for (const provider of allProviderIds) { + const isCustom = !BUILTIN_PROVIDERS[provider] + if (validatedApiKeys[provider] || isCustom) { updatedProviders[provider] = { - apiKey: apiKeys[provider], + apiKey: apiKeys[provider] ?? "", enabledModels: enabledModels[provider], grantSchemaAccess: grantSchemaAccess[provider], } } else { - // Remove provider entry if no validated API key delete updatedProviders[provider] } } + // Remove provider entries for deleted custom providers + for (const provider of Object.keys(updatedProviders)) { + if (!BUILTIN_PROVIDERS[provider] && !localCustomProviders[provider]) { + delete updatedProviders[provider] + } + } + + // Sync API keys and schema access into custom provider definitions + const updatedCustomProviders = + Object.keys(localCustomProviders).length > 0 + ? { ...localCustomProviders } + : undefined + if (updatedCustomProviders) { + for (const provider of Object.keys(updatedCustomProviders)) { + updatedCustomProviders[provider] = { + ...updatedCustomProviders[provider], + apiKey: apiKeys[provider] || undefined, + grantSchemaAccess: grantSchemaAccess[provider], + } + } + } + const updatedSettings: AiAssistantSettings = { ...aiAssistantSettings, providers: updatedProviders, + customProviders: updatedCustomProviders, } - const nextModel = getNextModel(updatedSettings.selectedModel, enabledModels) + const nextModel = getNextModel( + updatedSettings.selectedModel, + enabledModels, + updatedSettings, + ) updatedSettings.selectedModel = nextModel || undefined updateSettings(StoreKey.AI_ASSISTANT_SETTINGS, updatedSettings) @@ -720,6 +826,8 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { onOpenChange?.(false) }, [ aiAssistantSettings, + localSettings, + localCustomProviders, apiKeys, enabledModels, grantSchemaAccess, @@ -732,16 +840,176 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { onOpenChange?.(false) }, [onOpenChange]) + const handleRemoveProvider = useCallback( + (providerId: ProviderId) => { + const isCustom = !BUILTIN_PROVIDERS[providerId] + + if (isCustom) { + setLocalCustomProviders((prev) => { + const { [providerId]: _, ...rest } = prev + return rest + }) + } + + setApiKeys((prev) => ({ ...prev, [providerId]: "" })) + setGrantSchemaAccess((prev) => ({ ...prev, [providerId]: false })) + setValidatedApiKeys((prev) => ({ ...prev, [providerId]: false })) + setValidationState((prev) => ({ ...prev, [providerId]: "idle" })) + setValidationErrors((prev) => ({ ...prev, [providerId]: null })) + setEnabledModels((prev) => ({ ...prev, [providerId]: [] })) + setIsInputFocused((prev) => ({ ...prev, [providerId]: false })) + + // Switch to first remaining active provider + const updatedCustomProviders = isCustom + ? (() => { + const { [providerId]: _, ...rest } = localCustomProviders + return Object.keys(rest).length > 0 ? rest : undefined + })() + : localSettings.customProviders + const remaining = getAllProviders({ + ...localSettings, + customProviders: updatedCustomProviders, + }).filter((p) => p !== providerId || BUILTIN_PROVIDERS[p]) + setSelectedProvider(remaining[0] ?? "openai") + }, + [localSettings, localCustomProviders], + ) + + const handleCustomProviderSave = useCallback( + (providerId: string, definition: CustomProviderDefinition) => { + const newEnabledModels = definition.models.map((m) => + makeCustomModelValue(providerId, m), + ) + + setLocalCustomProviders((prev) => ({ + ...prev, + [providerId]: definition, + })) + setApiKeys((prev) => ({ + ...prev, + [providerId]: definition.apiKey ?? "", + })) + setGrantSchemaAccess((prev) => ({ + ...prev, + [providerId]: definition.grantSchemaAccess ?? false, + })) + setValidatedApiKeys((prev) => ({ + ...prev, + [providerId]: true, + })) + setEnabledModels((prev) => ({ + ...prev, + [providerId]: newEnabledModels, + })) + + const updatedCustomProviders = { + ...(aiAssistantSettings.customProviders ?? {}), + [providerId]: definition, + } + const updatedProviders = { + ...aiAssistantSettings.providers, + [providerId]: { + apiKey: definition.apiKey ?? "", + enabledModels: newEnabledModels, + grantSchemaAccess: definition.grantSchemaAccess ?? false, + }, + } + updateSettings(StoreKey.AI_ASSISTANT_SETTINGS, { + ...aiAssistantSettings, + customProviders: updatedCustomProviders, + providers: updatedProviders, + }) + + setSelectedProvider(providerId) + setCustomProviderModalOpen(false) + }, + [aiAssistantSettings, updateSettings], + ) + + const handleManageModelsSave = useCallback( + (providerId: string, definition: CustomProviderDefinition) => { + const newModelValues = definition.models.map((m) => + makeCustomModelValue(providerId, m), + ) + + // Update local custom providers — only override models and contextWindow, + // preserve everything else (apiKey, grantSchemaAccess, etc.) from local state. + setLocalCustomProviders((prev) => ({ + ...prev, + [providerId]: { + ...prev[providerId], + models: definition.models, + contextWindow: definition.contextWindow, + }, + })) + + // Determine which models are truly new (not in the previous model list) + const oldModelValues = ( + localCustomProviders[providerId]?.models || [] + ).map((m) => makeCustomModelValue(providerId, m)) + const trulyNew = newModelValues.filter((m) => !oldModelValues.includes(m)) + + // Local state: respect unsaved checkbox toggles, add truly new as enabled + const localEnabled = enabledModels[providerId] || [] + const localStillEnabled = localEnabled.filter((m: string) => + newModelValues.includes(m), + ) + setEnabledModels((prev) => ({ + ...prev, + [providerId]: [...localStillEnabled, ...trulyNew], + })) + + // Storage: preserve stored enabled/disabled state, only add truly new models. + // Unsaved toggle changes (apiKey, grantSchemaAccess, enable/disable) are not + // persisted here — they require "Save Settings". + const storedProviderSettings = aiAssistantSettings.providers?.[providerId] + const storedEnabled = storedProviderSettings?.enabledModels || [] + const storedStillEnabled = storedEnabled.filter((m: string) => + newModelValues.includes(m), + ) + const storedCustomProvider = + aiAssistantSettings.customProviders?.[providerId] + updateSettings(StoreKey.AI_ASSISTANT_SETTINGS, { + ...aiAssistantSettings, + customProviders: { + ...(aiAssistantSettings.customProviders ?? {}), + ...(storedCustomProvider && { + [providerId]: { + ...storedCustomProvider, + models: definition.models, + contextWindow: definition.contextWindow, + }, + }), + }, + providers: { + ...aiAssistantSettings.providers, + ...(storedProviderSettings && { + [providerId]: { + ...storedProviderSettings, + enabledModels: [...storedStillEnabled, ...trulyNew], + }, + }), + }, + }) + + toast.success("Model preferences updated") + }, + [aiAssistantSettings, enabledModels, localCustomProviders, updateSettings], + ) + const currentProviderValidated = validatedApiKeys[selectedProvider] const currentProviderApiKey = apiKeys[selectedProvider] const currentProviderValidationState = validationState[selectedProvider] const currentProviderError = validationErrors[selectedProvider] const currentProviderIsFocused = isInputFocused[selectedProvider] const maskInput = !!(currentProviderApiKey && !currentProviderIsFocused) + const noApiKeyReadonly = + isCustomProvider && !currentProviderApiKey && !currentProviderIsFocused + const showEditButton = maskInput || noApiKeyReadonly const modelsForProvider = useMemo( - () => getModelsForProvider(selectedProvider), - [selectedProvider], + () => getModelsForProvider(selectedProvider, localSettings), + [selectedProvider, localSettings], ) const enabledModelsForProvider = useMemo( @@ -749,329 +1017,422 @@ export const SettingsModal = ({ open, onOpenChange }: SettingsModalProps) => { [enabledModels, selectedProvider], ) - const allProviders = useMemo(() => getAllProviders(), []) + const allProviders = useMemo( + () => getAllProviders(localSettings), + [localSettings], + ) - const renderProviderIcon = (provider: Provider, isActive: boolean) => { + const renderProviderIcon = (provider: ProviderId, isActive: boolean) => { const color = isActive ? "#f8f8f2" : "#9ca3af" - if (provider === "openai") { - return + switch (provider) { + case "openai": + return + case "anthropic": + return + default: + return } - return } return ( - - - - - - - - - - - Assistant Settings - - Modify settings for your AI assistant, set up new providers, - and review access. - - - - - - - - - - - - - {allProviders.map((provider) => { - const isActive = selectedProvider === provider - return ( - handleProviderSelect(provider)} - data-hook={`ai-settings-provider-${provider}`} - > - - {renderProviderIcon(provider, isActive)} - - {getProviderName(provider)} - - - - - - {validatedApiKeys[provider] ? "Enabled" : "Inactive"} - - - - ) - })} - - - - - - + + + + + + + + + + + Assistant Settings + + Modify settings for your AI assistant, set up new + providers, and review access. + + + + - API Key - {validatedApiKeys[selectedProvider] && ( - } - data-hook="ai-settings-validated-badge" - > - Validated - - )} - - Get your API key from{" "} - - {getProviderName(selectedProvider)} - - . - - - - { - handleApiKeyChange(selectedProvider, e.target.value) - }} - placeholder={`Enter ${getProviderName(selectedProvider)} API key`} - $hasError={!!currentProviderError} - $showEditButton={maskInput} - readOnly={maskInput} - onFocus={() => { - setIsInputFocused((prev) => ({ - ...prev, - [selectedProvider]: true, - })) - }} - onBlur={() => { - setIsInputFocused((prev) => ({ - ...prev, - [selectedProvider]: false, - })) - if (inputRef.current) { - inputRef.current.blur() - } - }} - onMouseDown={(e) => { - if (maskInput) { - e.preventDefault() - } - }} - tabIndex={maskInput ? -1 : 0} - style={{ - cursor: maskInput ? "default" : "text", - }} - data-hook="ai-settings-api-key" + - {maskInput && ( - { - inputRef.current?.focus() - }} - title="Edit API key" + + + + + + + + {allProviders.map((provider) => { + const isActive = selectedProvider === provider + return ( + handleProviderSelect(provider)} + data-hook={`ai-settings-provider-${provider}`} + > + + {renderProviderIcon(provider, isActive)} + + {getProviderName(provider, localSettings)} + + + + + + {validatedApiKeys[provider] + ? "Enabled" + : "Inactive"} + + + + ) + })} + { + setCustomProviderModalOpen(true) + }} + > + Add custom provider + + + + + + + <> + - - - )} - - {currentProviderError && ( - {currentProviderError} - )} - {!currentProviderError && ( - - Stored locally in your browser and never sent to QuestDB - servers. This API key is used to authenticate your - requests to the model provider. - - )} - - currentProviderValidated - ? handleRemoveApiKey(selectedProvider) - : handleValidateApiKey(selectedProvider) - } - disabled={ - currentProviderValidationState === "validating" || - (!currentProviderValidated && !currentProviderApiKey) - } - data-hook="ai-settings-test-api" - > - {currentProviderValidationState === "validating" ? ( - - - Validating... - - ) : currentProviderValidated ? ( - "Remove API Key" - ) : ( - "Validate API Key" - )} - - - - - - Enable Models - {currentProviderValidated ? ( - - {modelsForProvider.map((model) => { - const isEnabled = enabledModelsForProvider.includes( - model.value, - ) - return ( - - - {model.label} - {model.isSlow && ( - - - - Due to advanced reasoning & thinking - capabilities, responses using this model - can be slow. - - - )} - - - handleModelToggle( - selectedProvider, - model.value, - ) + API Key + {validatedApiKeys[selectedProvider] && + currentProviderApiKey && ( + } + data-hook="ai-settings-validated-badge" + > + Validated + + )} + {!isCustomProvider && ( + + Get your API key from{" "} + - - ) - })} - - ) : ( - - - When you've entered and validated your API key, - you'll be able to select and enable available - models. - - - )} - - - - - - Schema Access - - - - - + {getProviderName( + selectedProvider, + localSettings, + )} + + . + + )} + + + - handleSchemaAccessChange( + autoComplete="off" + onChange={(e) => { + handleApiKeyChange( selectedProvider, - e.target.checked, + e.target.value, ) + }} + placeholder={ + noApiKeyReadonly + ? "This provider does not have an API key" + : `Enter ${getProviderName(selectedProvider, localSettings)} API key` } - disabled={!currentProviderValidated} - data-hook="ai-settings-schema-access" + $hasError={!!currentProviderError} + $showEditButton={showEditButton} + readOnly={maskInput || noApiKeyReadonly} + onFocus={() => { + setIsInputFocused((prev) => ({ + ...prev, + [selectedProvider]: true, + })) + }} + onBlur={() => { + setIsInputFocused((prev) => ({ + ...prev, + [selectedProvider]: false, + })) + if (inputRef.current) { + inputRef.current.blur() + } + }} + onMouseDown={(e) => { + if (maskInput || noApiKeyReadonly) { + e.preventDefault() + } + }} + tabIndex={maskInput || noApiKeyReadonly ? -1 : 0} + style={{ + cursor: + maskInput || noApiKeyReadonly + ? "default" + : "text", + }} + data-hook="ai-settings-api-key" /> - - - - Grant schema access to{" "} - {getProviderName(selectedProvider)} - - - When enabled, the AI assistant can access your - database schema information to provide more accurate - suggestions and explanations. Schema information - helps the AI understand your table structures, - column names, and relationships.{" "} - - The AI model will not have access to your data. - - - - - - - - - - - - - - Cancel - - - Save Settings - - - - - - - + {showEditButton && ( + { + inputRef.current?.focus() + }} + title="Edit API key" + > + + + )} + + {currentProviderError && ( + {currentProviderError} + )} + {!currentProviderError && ( + + Stored locally in your browser and never sent to + QuestDB servers. This API key is used to + authenticate your requests to the model provider. + + )} + {!currentProviderValidated && currentProviderApiKey && ( + + handleValidateApiKey(selectedProvider) + } + disabled={ + currentProviderValidationState === "validating" + } + data-hook="ai-settings-test-api" + > + {currentProviderValidationState === "validating" ? ( + + + Validating... + + ) : ( + "Validate API Key" + )} + + )} + + + + + + + Enable Models + {isCustomProvider && + (currentProviderValidated || + modelsForProvider.length > 0) && ( + setManageModelsModalOpen(true)} + > + Manage models + + )} + + {currentProviderValidated || + (isCustomProvider && modelsForProvider.length > 0) ? ( + + {modelsForProvider.map((model) => { + const isEnabled = enabledModelsForProvider.includes( + model.value, + ) + return ( + + + {model.label} + {model.isSlow && ( + + + + Due to advanced reasoning & thinking + capabilities, responses using this model + can be slow. + + + )} + + + handleModelToggle( + selectedProvider, + model.value, + ) + } + /> + + ) + })} + + ) : ( + + + When you've entered and validated your API key, + you'll be able to select and enable available + models. + + + )} + + + + + + Schema Access + + + + + + handleSchemaAccessChange( + selectedProvider, + e.target.checked, + ) + } + disabled={ + !currentProviderValidated && + !( + isCustomProvider && + modelsForProvider.length > 0 + ) + } + data-hook="ai-settings-schema-access" + /> + + + + Grant schema access to{" "} + {getProviderName(selectedProvider, localSettings)} + + + When enabled, the AI assistant can access your + database schema information to provide more + accurate suggestions and explanations. Schema + information helps the AI understand your table + structures, column names, and relationships.{" "} + + The AI model will not have access to your data. + + + + + + + + + } + type="button" + data-hook="ai-settings-remove-provider" + onClick={() => handleRemoveProvider(selectedProvider)} + > + {isCustomProvider ? "Remove Provider" : "Reset Provider"} + + + + + + + + + Cancel + + + Save Settings + + + + + + + + {customProviderModalOpen && ( + + getProviderName(p, localSettings), + )} + /> + )} + {manageModelsModalOpen && + isCustomProvider && + localCustomProviders[selectedProvider] && ( + + )} + ) } diff --git a/src/hooks/useAIQuickActions.ts b/src/hooks/useAIQuickActions.ts index 1fdd31f29..7c2c1aab1 100644 --- a/src/hooks/useAIQuickActions.ts +++ b/src/hooks/useAIQuickActions.ts @@ -49,6 +49,7 @@ export const useAIQuickActions = () => { hasSchemaAccess, currentModel, apiKey, + aiAssistantSettings, } = useAIStatus() const { @@ -130,6 +131,7 @@ export const useAIQuickActions = () => { ...schemaDisplayData, }, settings: { model: currentModel, apiKey }, + aiAssistantSettings, questClient: quest, tables, hasSchemaAccess, @@ -171,6 +173,7 @@ export const useAIQuickActions = () => { ...schemaDisplayData, }, settings: { model: currentModel, apiKey }, + aiAssistantSettings, questClient: quest, tables, hasSchemaAccess, @@ -248,6 +251,7 @@ export const useAIQuickActions = () => { monitoringDocs, trendSamples, settings: { model: currentModel, apiKey }, + aiAssistantSettings, questClient: quest, tables, hasSchemaAccess, @@ -295,6 +299,7 @@ export const useAIQuickActions = () => { monitoringDocs, trendSamples, settings: { model: currentModel, apiKey }, + aiAssistantSettings, questClient: quest, tables, hasSchemaAccess, diff --git a/src/providers/AIStatusProvider/index.tsx b/src/providers/AIStatusProvider/index.tsx index e9dd5a43f..fdb479632 100644 --- a/src/providers/AIStatusProvider/index.tsx +++ b/src/providers/AIStatusProvider/index.tsx @@ -14,7 +14,10 @@ import { hasSchemaAccess, providerForModel, canUseAiAssistant, -} from "../../utils/aiAssistantSettings" + getAllEnabledModels, + getApiKey, +} from "../../utils/ai" +import type { AiAssistantSettings } from "../LocalStorageProvider/types" import { useAIConversation } from "../AIConversationProvider" export const useAIStatus = () => { @@ -78,6 +81,7 @@ type BaseAIStatusContextType = { models: string[] currentOperation: OperationHistory clearOperation: () => void + aiAssistantSettings: AiAssistantSettings } export type AIStatusContextType = @@ -135,19 +139,15 @@ export const AIStatusProvider: React.FC = ({ const apiKey = useMemo(() => { if (!currentModel) return null - const provider = providerForModel(currentModel) - return aiAssistantSettings.providers?.[provider]?.apiKey || null + const provider = providerForModel(currentModel, aiAssistantSettings) + if (!provider) return null + return getApiKey(provider, aiAssistantSettings) }, [currentModel, aiAssistantSettings]) - const models = useMemo(() => { - const allModels: string[] = [] - const anthropicModels = - aiAssistantSettings.providers?.anthropic?.enabledModels || [] - const openaiModels = - aiAssistantSettings.providers?.openai?.enabledModels || [] - allModels.push(...anthropicModels, ...openaiModels) - return allModels - }, [aiAssistantSettings]) + const models = useMemo( + () => getAllEnabledModels(aiAssistantSettings), + [aiAssistantSettings], + ) const setStatus = useCallback( ( @@ -251,6 +251,7 @@ export const AIStatusProvider: React.FC = ({ apiKey: apiKey!, models, currentOperation, + aiAssistantSettings, } : { status, @@ -265,6 +266,7 @@ export const AIStatusProvider: React.FC = ({ apiKey, models, currentOperation, + aiAssistantSettings, } return ( diff --git a/src/providers/LocalStorageProvider/index.tsx b/src/providers/LocalStorageProvider/index.tsx index ec3ecf258..7f53aa549 100644 --- a/src/providers/LocalStorageProvider/index.tsx +++ b/src/providers/LocalStorageProvider/index.tsx @@ -33,6 +33,7 @@ import { LeftPanelState, LeftPanelType, } from "./types" +import { reconcileSettings } from "../../utils/ai/settings" export const DEFAULT_AI_ASSISTANT_SETTINGS: AiAssistantSettings = { providers: {}, @@ -139,10 +140,17 @@ export const LocalStorageProvider = ({ if (stored) { try { const parsed = JSON.parse(stored) as AiAssistantSettings - return { + const reconciled = reconcileSettings({ selectedModel: parsed.selectedModel, providers: parsed.providers || {}, + ...(parsed.customProviders && { + customProviders: parsed.customProviders, + }), + }) + if (JSON.stringify(reconciled) !== stored) { + setValue(StoreKey.AI_ASSISTANT_SETTINGS, JSON.stringify(reconciled)) } + return reconciled } catch (e) { return defaultConfig.aiAssistantSettings } diff --git a/src/providers/LocalStorageProvider/types.ts b/src/providers/LocalStorageProvider/types.ts index 7a7335d09..48e55986d 100644 --- a/src/providers/LocalStorageProvider/types.ts +++ b/src/providers/LocalStorageProvider/types.ts @@ -4,12 +4,20 @@ export type ProviderSettings = { grantSchemaAccess: boolean } +export type CustomProviderDefinition = { + type: "anthropic" | "openai" | "openai-chat-completions" + name: string + baseURL: string + apiKey?: string + contextWindow: number + models: string[] + grantSchemaAccess?: boolean +} + export type AiAssistantSettings = { selectedModel?: string - providers: { - anthropic?: ProviderSettings - openai?: ProviderSettings - } + providers: Partial> + customProviders?: Record } export type SettingsType = string | boolean | number | AiAssistantSettings diff --git a/src/scenes/Editor/AIChatWindow/index.tsx b/src/scenes/Editor/AIChatWindow/index.tsx index c4a202b3d..b0a28e0ad 100644 --- a/src/scenes/Editor/AIChatWindow/index.tsx +++ b/src/scenes/Editor/AIChatWindow/index.tsx @@ -215,6 +215,7 @@ const AIChatWindow: React.FC = () => { hasSchemaAccess, currentModel, apiKey, + aiAssistantSettings, } = useAIStatus() const tables = useSelector(selectors.query.getTables) const running = useSelector(selectors.query.getRunning) @@ -384,6 +385,7 @@ const AIChatWindow: React.FC = () => { conversationHistory: conversation.messages, isFirstMessage: !hasAssistantMessages, settings: { model: currentModel, apiKey }, + aiAssistantSettings, questClient: quest, tables, hasSchemaAccess, @@ -560,6 +562,7 @@ const AIChatWindow: React.FC = () => { const settings = { model: currentModel, apiKey } const commonConfig = { settings, + aiAssistantSettings, questClient: quest, tables, hasSchemaAccess, diff --git a/src/scenes/Schema/VirtualTables/index.tsx b/src/scenes/Schema/VirtualTables/index.tsx index c0edfd492..e5b32affb 100644 --- a/src/scenes/Schema/VirtualTables/index.tsx +++ b/src/scenes/Schema/VirtualTables/index.tsx @@ -135,7 +135,7 @@ const TableRow = styled(Row)<{ $contextMenuOpen: boolean }>` $contextMenuOpen && ` background: ${theme.color.tableSelection}; - border: 1px solid ${theme.color.cyan}; + box-shadow: inset 0 0 0 1px ${theme.color.cyan}; `} ` diff --git a/src/utils/ai/anthropicProvider.ts b/src/utils/ai/anthropicProvider.ts new file mode 100644 index 000000000..d57a09ee2 --- /dev/null +++ b/src/utils/ai/anthropicProvider.ts @@ -0,0 +1,742 @@ +import Anthropic from "@anthropic-ai/sdk" +import type { MessageParam } from "@anthropic-ai/sdk/resources/messages" +import type { OutputConfig } from "@anthropic-ai/sdk/resources/messages" +import type { Tool as AnthropicTool } from "@anthropic-ai/sdk/resources/messages" +import type { + AiAssistantAPIError, + ModelToolsClient, + StatusCallback, + StreamingCallback, + TokenUsage, +} from "../aiAssistant" +import { AIOperationStatus } from "../../providers/AIStatusProvider" +import { getModelProps } from "./settings" +import type { ProviderId } from "./settings" +import type { + AIProvider, + FlowConfig, + ResponseFormatSchema, + ToolDefinition, +} from "./types" +import { + StreamingError, + RefusalError, + MaxTokensError, + extractPartialExplanation, + executeTool, + parseCustomProviderResponse, + responseFormatToPromptInstruction, +} from "./shared" +import { + createHeaderFilteredFetch, + ANTHROPIC_ALLOWED_HEADERS, +} from "./fetchWithFilteredHeaders" + +function toAnthropicTools(tools: ToolDefinition[]): AnthropicTool[] { + return tools.map((t) => ({ + name: t.name, + description: t.description, + input_schema: { + type: "object" as const, + properties: t.inputSchema.properties, + ...(t.inputSchema.required ? { required: t.inputSchema.required } : {}), + }, + })) +} + +function toAnthropicModel(model: string): string { + return getModelProps(model).model +} + +function toAnthropicOutputConfig(format: ResponseFormatSchema): OutputConfig { + return { + format: { + type: "json_schema", + schema: format.schema, + }, + } +} + +async function createAnthropicMessage( + anthropic: Anthropic, + params: Omit & { + max_tokens?: number + }, + signal?: AbortSignal, +): Promise { + const message = await anthropic.messages.create( + { + ...params, + stream: false, + max_tokens: params.max_tokens ?? 8192, + }, + { + signal, + }, + ) + + if (message.stop_reason === "refusal") { + throw new RefusalError( + "The model refused to generate a response for this request.", + ) + } + if (message.stop_reason === "max_tokens") { + throw new MaxTokensError( + "The response exceeded the maximum token limit. Please try again with a different prompt or model.", + ) + } + + return message +} + +async function createAnthropicMessageStreaming( + anthropic: Anthropic, + params: Omit & { + max_tokens?: number + }, + streamCallback: StreamingCallback, + abortSignal?: AbortSignal, +): Promise { + let accumulatedText = "" + let lastExplanation = "" + + const stream = anthropic.messages.stream( + { + ...params, + max_tokens: params.max_tokens ?? 8192, + }, + { + signal: abortSignal, + }, + ) + + try { + for await (const event of stream) { + if (abortSignal?.aborted) { + throw new StreamingError("Operation aborted", "interrupted") + } + + const eventWithType = event as { type: string } + if (eventWithType.type === "error") { + const errorEvent = event as { + error?: { type?: string; message?: string } + } + const errorType = errorEvent.error?.type + const errorMessage = errorEvent.error?.message || "Stream error" + + if (errorType === "overloaded_error") { + throw new StreamingError( + "Service is temporarily overloaded. Please try again.", + "network", + event, + ) + } + throw new StreamingError(errorMessage, "failed", event) + } + + if ( + event.type === "content_block_delta" && + event.delta.type === "text_delta" + ) { + accumulatedText += event.delta.text + const explanation = extractPartialExplanation(accumulatedText) + if (explanation !== lastExplanation) { + const chunk = explanation.slice(lastExplanation.length) + lastExplanation = explanation + streamCallback.onTextChunk(chunk, explanation) + } + } + } + } catch (error) { + if (error instanceof StreamingError) { + throw error + } + if (abortSignal?.aborted) { + throw new StreamingError("Operation aborted", "interrupted") + } + if (error instanceof Anthropic.APIError) { + throw error + } + throw new StreamingError( + error instanceof Error ? error.message : "Stream interrupted", + "network", + error, + ) + } + + let finalMessage: Anthropic.Messages.Message + try { + finalMessage = await stream.finalMessage() + } catch (error) { + if (abortSignal?.aborted || error instanceof Anthropic.APIUserAbortError) { + throw new StreamingError("Operation aborted", "interrupted") + } + if (error instanceof Anthropic.APIError) { + throw error + } + throw new StreamingError( + "Failed to get final message from the provider", + "network", + error, + ) + } + + if (finalMessage.stop_reason === "refusal") { + throw new RefusalError( + "The model refused to generate a response for this request.", + ) + } + if (finalMessage.stop_reason === "max_tokens") { + throw new MaxTokensError( + "The response exceeded the maximum token limit. Please try again with a different prompt or model.", + ) + } + + return finalMessage +} + +interface AnthropicToolCallResult { + message: Anthropic.Messages.Message + accumulatedTokens: TokenUsage +} + +async function handleToolCalls( + message: Anthropic.Messages.Message, + anthropic: Anthropic, + modelToolsClient: ModelToolsClient, + conversationHistory: Array, + model: string, + systemPrompt: string, + setStatus: StatusCallback, + outputConfig: OutputConfig | undefined, + tools: AnthropicTool[], + contextWindow: number, + abortSignal?: AbortSignal, + accumulatedTokens: TokenUsage = { inputTokens: 0, outputTokens: 0 }, + streaming?: StreamingCallback, +): Promise { + const toolUseBlocks = message.content.filter( + (block) => block.type === "tool_use", + ) + const toolResults = [] + + if (abortSignal?.aborted) { + return { + type: "aborted", + message: "Operation was cancelled", + } as AiAssistantAPIError + } + + for (const toolUse of toolUseBlocks) { + if ("name" in toolUse) { + const exec = await executeTool( + toolUse.name, + toolUse.input, + modelToolsClient, + setStatus, + ) + toolResults.push({ + type: "tool_result" as const, + tool_use_id: toolUse.id, + content: exec.content, + is_error: exec.is_error, + }) + } + } + + const updatedHistory = [ + ...conversationHistory, + { + role: "assistant" as const, + content: message.content, + }, + { + role: "user" as const, + content: toolResults, + }, + ] + + const criticalTokenUsage = + message.usage.input_tokens >= contextWindow - 50_000 && + toolResults.length > 0 + if (criticalTokenUsage) { + updatedHistory.push({ + role: "user" as const, + content: + "**CRITICAL TOKEN USAGE: The conversation is getting too long to fit the context window. If you are planning to use more tools, summarize your findings to the user first, and wait for user confirmation to continue working on the task.**", + }) + } + + const followUpParams: Parameters[1] = { + model, + system: systemPrompt, + tools, + messages: updatedHistory, + temperature: 0.3, + ...(outputConfig ? { output_config: outputConfig } : {}), + } + + const followUpMessage = streaming + ? await createAnthropicMessageStreaming( + anthropic, + followUpParams, + streaming, + abortSignal, + ) + : await createAnthropicMessage(anthropic, followUpParams, abortSignal) + + const newAccumulatedTokens: TokenUsage = { + inputTokens: + accumulatedTokens.inputTokens + + (followUpMessage.usage?.input_tokens || 0), + outputTokens: + accumulatedTokens.outputTokens + + (followUpMessage.usage?.output_tokens || 0), + } + + if (followUpMessage.stop_reason === "tool_use") { + return handleToolCalls( + followUpMessage, + anthropic, + modelToolsClient, + updatedHistory, + model, + systemPrompt, + setStatus, + outputConfig, + tools, + contextWindow, + abortSignal, + newAccumulatedTokens, + streaming, + ) + } + + return { + message: followUpMessage, + accumulatedTokens: newAccumulatedTokens, + } +} + +export function createAnthropicProvider( + apiKey: string, + providerId: ProviderId = "anthropic", + options?: { baseURL?: string; contextWindow?: number; isCustom?: boolean }, +): AIProvider { + const isCustom = options?.isCustom ?? false + const anthropic = new Anthropic({ + apiKey, + dangerouslyAllowBrowser: true, + ...(options?.baseURL ? { baseURL: options.baseURL } : {}), + ...(isCustom + ? { + fetch: createHeaderFilteredFetch(ANTHROPIC_ALLOWED_HEADERS), + } + : {}), + }) + + const contextWindow = options?.contextWindow ?? 200_000 + + return { + id: providerId, + contextWindow, + + async executeFlow({ + model, + config, + modelToolsClient, + tools, + setStatus, + abortSignal, + streaming, + }: { + model: string + config: FlowConfig + modelToolsClient: ModelToolsClient + tools: ToolDefinition[] + setStatus: StatusCallback + abortSignal?: AbortSignal + streaming?: StreamingCallback + }): Promise { + const initialMessages: MessageParam[] = [] + if (config.conversationHistory && config.conversationHistory.length > 0) { + const validMessages = config.conversationHistory.filter( + (msg) => msg.content && msg.content.trim() !== "", + ) + for (const msg of validMessages) { + initialMessages.push({ + role: msg.role, + content: msg.content, + }) + } + } + + initialMessages.push({ + role: "user" as const, + content: config.initialUserContent, + }) + + const anthropicTools = toAnthropicTools(tools) + const outputConfig = isCustom + ? undefined + : toAnthropicOutputConfig(config.responseFormat) + + const systemPrompt = isCustom + ? config.systemInstructions + + responseFormatToPromptInstruction(config.responseFormat) + : config.systemInstructions + + const resolvedModel = toAnthropicModel(model) + + const messageParams: Parameters[1] = { + model: resolvedModel, + system: systemPrompt, + tools: anthropicTools, + messages: initialMessages, + temperature: 0.3, + ...(outputConfig ? { output_config: outputConfig } : {}), + } + + const message = streaming + ? await createAnthropicMessageStreaming( + anthropic, + messageParams, + streaming, + abortSignal, + ) + : await createAnthropicMessage(anthropic, messageParams, abortSignal) + + let totalInputTokens = message.usage?.input_tokens || 0 + let totalOutputTokens = message.usage?.output_tokens || 0 + + let responseMessage: Anthropic.Messages.Message + + if (message.stop_reason === "tool_use") { + const toolCallResult = await handleToolCalls( + message, + anthropic, + modelToolsClient, + initialMessages, + resolvedModel, + systemPrompt, + setStatus, + outputConfig, + anthropicTools, + contextWindow, + abortSignal, + { inputTokens: 0, outputTokens: 0 }, + streaming, + ) + + if ("type" in toolCallResult && "message" in toolCallResult) { + return toolCallResult + } + + const result = toolCallResult + responseMessage = result.message + totalInputTokens += result.accumulatedTokens.inputTokens + totalOutputTokens += result.accumulatedTokens.outputTokens + } else { + responseMessage = message + } + + if (abortSignal?.aborted) { + return { + type: "aborted", + message: "Operation was cancelled", + } as AiAssistantAPIError + } + + const textBlock = responseMessage.content.find( + (block) => block.type === "text", + ) + if (!textBlock || !("text" in textBlock)) { + setStatus(null) + return { + type: "unknown", + message: "No text response received from assistant.", + } as AiAssistantAPIError + } + + if (isCustom) { + const json = parseCustomProviderResponse( + textBlock.text, + (config.responseFormat.schema.required as string[]) || [], + (raw) => ({ explanation: raw }) as unknown as T, + ) + setStatus(null) + + const tokenUsage = { + inputTokens: totalInputTokens, + outputTokens: totalOutputTokens, + } + + if (config.postProcess) { + const processed = config.postProcess(json) + return { ...processed, tokenUsage } as T & { tokenUsage: TokenUsage } + } + return { ...json, tokenUsage } as T & { tokenUsage: TokenUsage } + } + + try { + const json = JSON.parse(textBlock.text) as T + setStatus(null) + + const resultWithTokens = { + ...json, + tokenUsage: { + inputTokens: totalInputTokens, + outputTokens: totalOutputTokens, + }, + } as T & { tokenUsage: TokenUsage } + + if (config.postProcess) { + const processed = config.postProcess(json) + return { + ...processed, + tokenUsage: { + inputTokens: totalInputTokens, + outputTokens: totalOutputTokens, + }, + } as T & { tokenUsage: TokenUsage } + } + return resultWithTokens + } catch { + setStatus(null) + return { + type: "unknown", + message: "Failed to parse assistant response.", + } as AiAssistantAPIError + } + }, + + async generateTitle({ model, prompt, responseFormat }) { + try { + const userContent = isCustom + ? prompt + responseFormatToPromptInstruction(responseFormat) + : prompt + + const titleOutputConfig = isCustom + ? undefined + : toAnthropicOutputConfig(responseFormat) + + const messageParams: Parameters[1] = { + model: toAnthropicModel(model), + messages: [{ role: "user", content: userContent }], + max_tokens: 100, + temperature: 0.3, + ...(titleOutputConfig ? { output_config: titleOutputConfig } : {}), + } + const message = await createAnthropicMessage(anthropic, messageParams) + + const textBlock = message.content.find((block) => block.type === "text") + if (textBlock && "text" in textBlock) { + if (isCustom) { + const parsed = parseCustomProviderResponse<{ title: string }>( + textBlock.text, + (responseFormat.schema.required as string[]) || [], + (raw) => ({ title: raw.trim().slice(0, 40) }), + ) + return parsed.title || null + } + + const parsed = JSON.parse(textBlock.text) as { title: string } + return parsed.title?.slice(0, 40) || null + } + return null + } catch { + return null + } + }, + + async generateSummary({ model, systemPrompt, userMessage }) { + const response = await anthropic.messages.create({ + ...getModelProps(model), + max_tokens: 8192, + messages: [{ role: "user", content: userMessage }], + system: systemPrompt, + }) + + const textBlock = response.content.find((block) => block.type === "text") + return textBlock?.type === "text" ? textBlock.text : "" + }, + + async testConnection({ apiKey: testApiKey, model }) { + try { + const testClient = new Anthropic({ + apiKey: testApiKey, + dangerouslyAllowBrowser: true, + ...(options?.baseURL ? { baseURL: options.baseURL } : {}), + ...(isCustom + ? { + fetch: createHeaderFilteredFetch(ANTHROPIC_ALLOWED_HEADERS), + } + : {}), + }) + + await createAnthropicMessage(testClient, { + model: toAnthropicModel(model), + messages: [{ role: "user", content: "ping" }], + }) + return { valid: true } + } catch (error: unknown) { + if (error instanceof Anthropic.AuthenticationError) { + return { valid: false, error: "Invalid API key" } + } + if (error instanceof Anthropic.RateLimitError) { + return { valid: true } + } + const status = + (error as { status?: number })?.status || + (error as { error?: { status?: number } })?.error?.status + if (status === 401) { + return { valid: false, error: "Invalid API key" } + } + if (status === 429) { + return { valid: true } + } + return { + valid: false, + error: + error instanceof Error + ? error.message + : "Failed to validate API key", + } + } + }, + + async countTokens({ messages, systemPrompt, model }) { + const anthropicMessages: Anthropic.MessageParam[] = messages.map((m) => ({ + role: m.role, + content: m.content, + })) + + const response = await anthropic.messages.countTokens({ + model: toAnthropicModel(model), + system: systemPrompt, + messages: anthropicMessages, + }) + + return response.input_tokens + }, + + async listModels(): Promise { + const models: string[] = [] + for await (const model of anthropic.models.list()) { + models.push(model.id) + } + return models.sort((a, b) => a.localeCompare(b)) + }, + + classifyError( + error: unknown, + setStatus: StatusCallback, + ): AiAssistantAPIError { + if ( + error instanceof Anthropic.APIUserAbortError || + (error instanceof StreamingError && error.errorType === "interrupted") + ) { + setStatus(AIOperationStatus.Aborted) + return { type: "aborted", message: "Operation was cancelled" } + } + setStatus(null) + + if (error instanceof RefusalError) { + return { + type: "unknown", + message: "The model refused to generate a response for this request.", + details: error.message, + } + } + + if (error instanceof MaxTokensError) { + return { + type: "unknown", + message: + "The response exceeded the maximum token limit for the selected model. Please try again with a different prompt or model.", + details: error.message, + } + } + + if (error instanceof StreamingError) { + switch (error.errorType) { + case "network": + return { + type: "network", + message: + "Network error during streaming. Please check your connection.", + details: error.message, + } + case "failed": + default: + return { + type: "unknown", + message: error.message || "Stream failed unexpectedly.", + details: + error.originalError instanceof Error + ? error.originalError.message + : undefined, + } + } + } + + if (error instanceof Anthropic.AuthenticationError) { + return { + type: "invalid_key", + message: "Invalid API key. Please check your Anthropic API key.", + details: error.message, + } + } + + if (error instanceof Anthropic.RateLimitError) { + return { + type: "rate_limit", + message: "Rate limit exceeded. Please try again later.", + details: error.message, + } + } + + if (error instanceof Anthropic.APIConnectionError) { + return { + type: "network", + message: "Network error. Please check your connection.", + details: error.message, + } + } + + if (error instanceof Anthropic.APIError) { + return { + type: "unknown", + message: error.message, + details: `Status ${error.status}`, + } + } + + return { + type: "unknown", + message: "An unexpected error occurred. Please try again.", + details: error instanceof Error ? error.message : String(error), + } + }, + + isNonRetryableError(error: unknown): boolean { + if (error instanceof StreamingError) { + return error.errorType === "interrupted" || error.errorType === "failed" + } + if ( + error instanceof Anthropic.APIError && + error.status != null && + error.status >= 400 && + error.status < 500 && + error.status !== 429 + ) { + return true + } + return ( + error instanceof RefusalError || + error instanceof MaxTokensError || + error instanceof Anthropic.APIUserAbortError + ) + }, + } +} diff --git a/src/utils/ai/fetchWithFilteredHeaders.ts b/src/utils/ai/fetchWithFilteredHeaders.ts new file mode 100644 index 000000000..5851fde98 --- /dev/null +++ b/src/utils/ai/fetchWithFilteredHeaders.ts @@ -0,0 +1,53 @@ +function filterHeaders(raw: HeadersInit, allowSet: Set): Headers { + const source = new Headers(raw) + const filtered = new Headers() + source.forEach((value, key) => { + if (allowSet.has(key.toLowerCase())) { + filtered.set(key, value) + } + }) + return filtered +} + +export function createHeaderFilteredFetch( + allowedHeaders: string[], +): typeof globalThis.fetch { + const allowSet = new Set(allowedHeaders.map((h) => h.toLowerCase())) + return (input, init) => { + // Collect headers from both the Request object and init + const merged = new Headers() + if (input instanceof Request) { + input.headers.forEach((v, k) => merged.set(k, v)) + } + if (init?.headers) { + new Headers(init.headers).forEach((v, k) => merged.set(k, v)) + } + + const filtered = filterHeaders(merged, allowSet) + + // Normalize to plain URL + init to avoid Request carrying extra headers + const url = input instanceof Request ? input.url : input + const method = + init?.method ?? (input instanceof Request ? input.method : undefined) + const body = + init?.body ?? (input instanceof Request ? input.body : undefined) + const signal = + init?.signal ?? (input instanceof Request ? input.signal : undefined) + + return globalThis.fetch(url, { + ...init, + method, + headers: filtered, + body, + signal, + }) + } +} + +export const OPENAI_ALLOWED_HEADERS = ["content-type", "authorization"] + +export const ANTHROPIC_ALLOWED_HEADERS = [ + "content-type", + "x-api-key", + "anthropic-version", +] diff --git a/src/utils/ai/index.ts b/src/utils/ai/index.ts new file mode 100644 index 000000000..a7eeb3d37 --- /dev/null +++ b/src/utils/ai/index.ts @@ -0,0 +1,57 @@ +export type { + AIProvider, + ToolDefinition, + ResponseFormatSchema, + FlowConfig, +} from "./types" +export { createProvider } from "./registry" +export { SCHEMA_TOOLS, REFERENCE_TOOLS, ALL_TOOLS } from "./tools" +export { + ExplainFormat, + FixSQLFormat, + ConversationResponseFormat, + ChatTitleFormat, +} from "./responseFormats" +export { + RefusalError, + MaxTokensError, + StreamingError, + safeJsonParse, + extractPartialExplanation, + executeTool, + parseCustomProviderResponse, + responseFormatToPromptInstruction, +} from "./shared" +export { + DOCS_INSTRUCTION, + getUnifiedPrompt, + getExplainSchemaPrompt, + getHealthIssuePrompt, +} from "./prompts" +export type { HealthIssuePromptData } from "./prompts" +export { + MODEL_OPTIONS, + BUILTIN_PROVIDERS, + providerForModel, + getModelProps, + getProviderName, + getAllProviders, + getAllModelOptions, + getAllEnabledModels, + getSelectedModel, + getNextModel, + getTestModel, + getProviderContextWindow, + getApiKey, + makeCustomModelValue, + parseModelValue, + isAiAssistantConfigured, + canUseAiAssistant, + hasSchemaAccess, +} from "./settings" +export type { + ProviderId, + ProviderType, + ModelOption, + CustomProviderDefinition, +} from "./settings" diff --git a/src/utils/ai/openaiChatCompletionsProvider.ts b/src/utils/ai/openaiChatCompletionsProvider.ts new file mode 100644 index 000000000..a673945a2 --- /dev/null +++ b/src/utils/ai/openaiChatCompletionsProvider.ts @@ -0,0 +1,715 @@ +import OpenAI from "openai" +import type { + ChatCompletionMessageParam, + ChatCompletionMessageToolCall, +} from "openai/resources/chat/completions" +import type { + AiAssistantAPIError, + ModelToolsClient, + StatusCallback, + StreamingCallback, + TokenUsage, +} from "../aiAssistant" +import { AIOperationStatus } from "../../providers/AIStatusProvider" +import { getModelProps } from "./settings" +import type { ProviderId } from "./settings" +import type { + AIProvider, + FlowConfig, + ResponseFormatSchema, + ToolDefinition, +} from "./types" +import { + StreamingError, + RefusalError, + MaxTokensError, + safeJsonParse, + extractPartialExplanation, + executeTool, + parseCustomProviderResponse, + responseFormatToPromptInstruction, +} from "./shared" +import type { Tiktoken, TiktokenBPE } from "js-tiktoken/lite" +import { + createHeaderFilteredFetch, + OPENAI_ALLOWED_HEADERS, +} from "./fetchWithFilteredHeaders" + +function toResponseFormat(format: ResponseFormatSchema) { + return { + type: "json_schema" as const, + json_schema: { + name: format.name, + schema: format.schema, + strict: format.strict, + }, + } +} + +function toOpenAITools( + tools: ToolDefinition[], +): OpenAI.Chat.Completions.ChatCompletionTool[] { + return tools.map((t) => ({ + type: "function" as const, + function: { + name: t.name, + description: t.description, + parameters: { ...t.inputSchema, additionalProperties: false }, + strict: true, + }, + })) +} + +interface RequestResult { + content: string + toolCalls: { id: string; name: string; arguments: unknown }[] + promptTokens: number + completionTokens: number + assistantMessage: ChatCompletionMessageParam +} + +async function createChatCompletionStreaming( + openai: OpenAI, + params: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming, + streamCallback: StreamingCallback, + abortSignal?: AbortSignal, +): Promise<{ + content: string + refusal: string | null + finishReason: string | null + toolCalls: ChatCompletionMessageToolCall[] + usage: { prompt_tokens: number; completion_tokens: number } | null +}> { + let accumulatedText = "" + let accumulatedRefusal = "" + let lastExplanation = "" + let finishReason: string | null = null + const toolCallAccumulator: Map< + number, + { id: string; name: string; arguments: string } + > = new Map() + let usage: { prompt_tokens: number; completion_tokens: number } | null = null + + try { + const stream = await openai.chat.completions.create({ + ...params, + stream: true, + stream_options: { include_usage: true }, + }) + + for await (const chunk of stream) { + if (abortSignal?.aborted) { + throw new StreamingError("Operation aborted", "interrupted") + } + + const choice = chunk.choices?.[0] + + if (choice?.delta?.content) { + accumulatedText += choice.delta.content + const explanation = extractPartialExplanation(accumulatedText) + if (explanation !== lastExplanation) { + const delta = explanation.slice(lastExplanation.length) + lastExplanation = explanation + streamCallback.onTextChunk(delta, explanation) + } + } + + if (choice?.delta?.refusal) { + accumulatedRefusal += choice.delta.refusal + } + + if (choice?.finish_reason) { + finishReason = choice.finish_reason + } + + if (choice?.delta?.tool_calls) { + for (const tc of choice.delta.tool_calls) { + const existing = toolCallAccumulator.get(tc.index) + if (existing) { + if (tc.id) existing.id = tc.id + if (tc.function?.name) existing.name = tc.function.name + existing.arguments += tc.function?.arguments ?? "" + } else { + toolCallAccumulator.set(tc.index, { + id: tc.id ?? "", + name: tc.function?.name ?? "", + arguments: tc.function?.arguments ?? "", + }) + } + } + } + + if (chunk.usage) { + usage = { + prompt_tokens: chunk.usage.prompt_tokens, + completion_tokens: chunk.usage.completion_tokens, + } + } + } + } catch (error) { + if (error instanceof StreamingError) { + throw error + } + if (abortSignal?.aborted || error instanceof OpenAI.APIUserAbortError) { + throw new StreamingError("Operation aborted", "interrupted") + } + if (error instanceof OpenAI.APIError) { + throw error + } + throw new StreamingError( + error instanceof Error ? error.message : "Stream interrupted", + "network", + error, + ) + } + + const toolCalls: ChatCompletionMessageToolCall[] = Array.from( + toolCallAccumulator.values(), + ).map((tc) => ({ + id: tc.id, + type: "function" as const, + function: { name: tc.name, arguments: tc.arguments }, + })) + + return { + content: accumulatedText, + refusal: accumulatedRefusal || null, + finishReason, + toolCalls, + usage, + } +} + +function extractToolCallsFromMessage( + toolCalls: ChatCompletionMessageToolCall[], +): { id: string; name: string; arguments: unknown }[] { + return toolCalls + .filter((tc) => tc.type === "function") + .map((tc) => ({ + id: tc.id, + name: tc.function.name, + arguments: safeJsonParse(tc.function.arguments), + })) +} + +function buildAssistantMessage( + content: string | null, + toolCalls: ChatCompletionMessageToolCall[], +): ChatCompletionMessageParam { + return { + role: "assistant" as const, + content, + ...(toolCalls.length > 0 ? { tool_calls: toolCalls } : {}), + } +} + +async function executeRequest( + openai: OpenAI, + params: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming, + streaming?: StreamingCallback, + abortSignal?: AbortSignal, +): Promise { + if (streaming) { + const accumulated = await createChatCompletionStreaming( + openai, + params, + streaming, + abortSignal, + ) + + if (accumulated.refusal) { + throw new RefusalError(accumulated.refusal) + } + if (accumulated.finishReason === "length") { + throw new MaxTokensError( + "Response truncated: the model ran out of tokens.", + ) + } + + return { + content: accumulated.content, + toolCalls: extractToolCallsFromMessage(accumulated.toolCalls), + promptTokens: accumulated.usage?.prompt_tokens ?? 0, + completionTokens: accumulated.usage?.completion_tokens ?? 0, + assistantMessage: buildAssistantMessage( + accumulated.content || null, + accumulated.toolCalls, + ), + } + } + + const response = await openai.chat.completions.create(params) + const message = response.choices[0]?.message + + if (message?.refusal) { + throw new RefusalError(message.refusal) + } + if (response.choices[0]?.finish_reason === "length") { + throw new MaxTokensError("Response truncated: the model ran out of tokens.") + } + + const rawToolCalls = message?.tool_calls ?? [] + + return { + content: message?.content ?? "", + toolCalls: extractToolCallsFromMessage(rawToolCalls), + promptTokens: response.usage?.prompt_tokens ?? 0, + completionTokens: response.usage?.completion_tokens ?? 0, + assistantMessage: buildAssistantMessage( + message?.content ?? null, + rawToolCalls.filter( + (tc): tc is Extract => + tc.type === "function", + ), + ), + } +} + +let tiktokenEncoder: Tiktoken | null = null + +function toChatCompletionsAPIProps(model: string): { + model: string + reasoning_effort?: OpenAI.ReasoningEffort +} { + const props = getModelProps(model) + return { + model: props.model, + ...(props.reasoningEffort + ? { reasoning_effort: props.reasoningEffort as OpenAI.ReasoningEffort } + : {}), + } +} + +export function createOpenAIChatCompletionsProvider( + apiKey: string, + providerId: ProviderId = "openai", + options?: { baseURL?: string; contextWindow?: number; isCustom?: boolean }, +): AIProvider { + const isCustom = options?.isCustom ?? false + const openai = new OpenAI({ + apiKey, + dangerouslyAllowBrowser: true, + ...(options?.baseURL ? { baseURL: options.baseURL } : {}), + ...(isCustom + ? { + fetch: createHeaderFilteredFetch(OPENAI_ALLOWED_HEADERS), + } + : {}), + }) + + const contextWindow = options?.contextWindow ?? 400_000 + + return { + id: providerId, + contextWindow, + + async executeFlow({ + model, + config, + modelToolsClient, + tools, + setStatus, + abortSignal, + streaming, + }: { + model: string + config: FlowConfig + modelToolsClient: ModelToolsClient + tools: ToolDefinition[] + setStatus: StatusCallback + abortSignal?: AbortSignal + streaming?: StreamingCallback + }): Promise { + const systemContent = isCustom + ? config.systemInstructions + + responseFormatToPromptInstruction(config.responseFormat) + : config.systemInstructions + + const messages: ChatCompletionMessageParam[] = [ + { role: "system", content: systemContent }, + ] + + if (config.conversationHistory && config.conversationHistory.length > 0) { + const validMessages = config.conversationHistory.filter( + (msg) => msg.content && msg.content.trim() !== "", + ) + for (const msg of validMessages) { + messages.push({ role: msg.role, content: msg.content }) + } + } + + messages.push({ role: "user", content: config.initialUserContent }) + + const openaiTools = toOpenAITools(tools) + let totalInputTokens = 0 + let totalOutputTokens = 0 + let lastPromptTokens = 0 + + const response_format = isCustom + ? undefined + : toResponseFormat(config.responseFormat) + + const baseParams = { + ...toChatCompletionsAPIProps(model), + tools: openaiTools, + response_format, + } + + let result = await executeRequest( + openai, + { + ...baseParams, + messages, + } as OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming, + streaming, + abortSignal, + ) + totalInputTokens += result.promptTokens + totalOutputTokens += result.completionTokens + lastPromptTokens = result.promptTokens + messages.push(result.assistantMessage) + + while (true) { + if (abortSignal?.aborted) { + return { + type: "aborted", + message: "Operation was cancelled", + } as AiAssistantAPIError + } + + if (!result.toolCalls.length) break + + for (const tc of result.toolCalls) { + const exec = await executeTool( + tc.name, + tc.arguments, + modelToolsClient, + setStatus, + ) + messages.push({ + role: "tool", + tool_call_id: tc.id, + content: exec.content, + }) + } + + if ( + lastPromptTokens >= contextWindow - 50_000 && + result.toolCalls.length > 0 + ) { + messages.push({ + role: "user" as const, + content: + "**CRITICAL TOKEN USAGE: The conversation is getting too long to fit the context window. If you are planning to use more tools, summarize your findings to the user first, and wait for user confirmation to continue working on the task.**", + }) + } + + result = await executeRequest( + openai, + { + ...baseParams, + messages, + } as OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming, + streaming, + abortSignal, + ) + totalInputTokens += result.promptTokens + totalOutputTokens += result.completionTokens + lastPromptTokens = result.promptTokens + messages.push(result.assistantMessage) + } + + if (abortSignal?.aborted) { + return { + type: "aborted", + message: "Operation was cancelled", + } as AiAssistantAPIError + } + + if (isCustom) { + const json = parseCustomProviderResponse( + result.content, + (config.responseFormat.schema.required as string[]) || [], + (raw) => ({ explanation: raw }) as unknown as T, + ) + setStatus(null) + + const tokenUsage = { + inputTokens: totalInputTokens, + outputTokens: totalOutputTokens, + } + + if (config.postProcess) { + const processed = config.postProcess(json) + return { ...processed, tokenUsage } as T & { tokenUsage: TokenUsage } + } + return { ...json, tokenUsage } as T & { tokenUsage: TokenUsage } + } + + try { + const json = JSON.parse(result.content) as T + setStatus(null) + + const resultWithTokens = { + ...json, + tokenUsage: { + inputTokens: totalInputTokens, + outputTokens: totalOutputTokens, + }, + } as T & { tokenUsage: TokenUsage } + + if (config.postProcess) { + const processed = config.postProcess(json) + return { + ...processed, + tokenUsage: { + inputTokens: totalInputTokens, + outputTokens: totalOutputTokens, + }, + } as T & { tokenUsage: TokenUsage } + } + return resultWithTokens + } catch { + setStatus(null) + return { + type: "unknown", + message: "Failed to parse assistant response.", + } as AiAssistantAPIError + } + }, + + async generateTitle({ model, prompt, responseFormat }) { + try { + const userContent = isCustom + ? prompt + responseFormatToPromptInstruction(responseFormat) + : prompt + + const response_format = isCustom + ? undefined + : toResponseFormat(responseFormat) + + const response = await openai.chat.completions.create({ + model: toChatCompletionsAPIProps(model).model, + messages: [{ role: "user", content: userContent }], + response_format, + ...(isCustom ? {} : { max_completion_tokens: 100 }), + }) + const content = response.choices[0]?.message?.content || "" + + if (isCustom) { + const parsed = parseCustomProviderResponse<{ title: string }>( + content, + (responseFormat.schema.required as string[]) || [], + (raw) => ({ title: raw.trim().slice(0, 40) }), + ) + return parsed.title || null + } + + const parsed = JSON.parse(content) as { title: string } + return parsed.title || null + } catch { + return null + } + }, + + async generateSummary({ model, systemPrompt, userMessage }) { + const response = await openai.chat.completions.create({ + ...toChatCompletionsAPIProps(model), + messages: [ + { role: "system", content: systemPrompt }, + { role: "user", content: userMessage }, + ], + }) + return response.choices[0]?.message?.content || "" + }, + + async testConnection({ apiKey: testApiKey, model }) { + try { + const testClient = new OpenAI({ + apiKey: testApiKey, + dangerouslyAllowBrowser: true, + ...(options?.baseURL ? { baseURL: options.baseURL } : {}), + ...(isCustom + ? { + fetch: createHeaderFilteredFetch(OPENAI_ALLOWED_HEADERS), + } + : {}), + }) + await testClient.chat.completions.create({ + model: getModelProps(model).model, + messages: [{ role: "user", content: "ping" }], + max_completion_tokens: 16, + }) + return { valid: true } + } catch (error: unknown) { + const status = + (error as { status?: number })?.status || + (error as { error?: { status?: number } })?.error?.status + if (status === 401) { + return { valid: false, error: "Invalid API key" } + } + if (status === 429) { + return { valid: true } + } + return { + valid: false, + error: + error instanceof Error + ? error.message + : "Failed to validate API key", + } + } + }, + + async countTokens({ messages, systemPrompt }) { + // Custom providers (non-default baseURL) use chars/3.5 estimation + // because the actual tokenizer is unknown and tiktoken underestimates + // Claude tokens by 15-25% (dangerous for compaction). + if (options?.baseURL) { + const totalChars = + systemPrompt.length + + messages.reduce((sum, m) => sum + m.content.length, 0) + return Math.ceil(totalChars / 3.5) + } + + if (!tiktokenEncoder) { + const { Tiktoken } = await import("js-tiktoken/lite") + const o200k_base = await import("js-tiktoken/ranks/o200k_base").then( + (module: { default: TiktokenBPE }) => module.default, + ) + tiktokenEncoder = new Tiktoken(o200k_base) + } + + let totalTokens = 0 + totalTokens += tiktokenEncoder.encode(systemPrompt).length + totalTokens += 4 // system message formatting overhead + + for (const message of messages) { + totalTokens += 4 // role markers overhead + totalTokens += tiktokenEncoder.encode(message.content).length + } + + totalTokens += 2 // assistant reply priming + return totalTokens + }, + + async listModels(): Promise { + const models: string[] = [] + for await (const model of openai.models.list()) { + models.push(model.id) + } + return models.sort((a, b) => a.localeCompare(b)) + }, + + classifyError( + error: unknown, + setStatus: StatusCallback, + ): AiAssistantAPIError { + if ( + error instanceof OpenAI.APIUserAbortError || + (error instanceof StreamingError && error.errorType === "interrupted") + ) { + setStatus(AIOperationStatus.Aborted) + return { type: "aborted", message: "Operation was cancelled" } + } + setStatus(null) + + if (error instanceof RefusalError) { + return { + type: "unknown", + message: "The model refused to generate a response for this request.", + details: error.message, + } + } + + if (error instanceof MaxTokensError) { + return { + type: "unknown", + message: + "The response exceeded the maximum token limit for the selected model. Please try again with a different prompt or model.", + details: error.message, + } + } + + if (error instanceof StreamingError) { + switch (error.errorType) { + case "network": + return { + type: "network", + message: + "Network error during streaming. Please check your connection.", + details: error.message, + } + case "failed": + default: + return { + type: "unknown", + message: error.message || "Stream failed unexpectedly.", + details: + error.originalError instanceof Error + ? error.originalError.message + : undefined, + } + } + } + + if (error instanceof OpenAI.AuthenticationError) { + return { + type: "invalid_key", + message: "Invalid API key. Please check your OpenAI API key.", + details: error.message, + } + } + + if (error instanceof OpenAI.RateLimitError) { + return { + type: "rate_limit", + message: "Rate limit exceeded. Please try again later.", + details: error.message, + } + } + + if (error instanceof OpenAI.APIConnectionError) { + return { + type: "network", + message: "Network error. Please check your connection.", + details: error.message, + } + } + + if (error instanceof OpenAI.APIError) { + return { + type: "unknown", + message: error.message, + details: `Status ${error.status}`, + } + } + + return { + type: "unknown", + message: "An unexpected error occurred. Please try again.", + details: error instanceof Error ? error.message : String(error), + } + }, + + isNonRetryableError(error: unknown): boolean { + if (error instanceof StreamingError) { + return error.errorType === "interrupted" || error.errorType === "failed" + } + if ( + error instanceof OpenAI.APIError && + error.status != null && + error.status >= 400 && + error.status < 500 && + error.status !== 429 + ) { + return true + } + return ( + error instanceof RefusalError || + error instanceof MaxTokensError || + error instanceof OpenAI.APIUserAbortError + ) + }, + } +} diff --git a/src/utils/ai/openaiProvider.ts b/src/utils/ai/openaiProvider.ts new file mode 100644 index 000000000..bd349a668 --- /dev/null +++ b/src/utils/ai/openaiProvider.ts @@ -0,0 +1,655 @@ +import OpenAI from "openai" +import type { + ResponseOutputItem, + ResponseTextConfig, +} from "openai/resources/responses/responses" +import type { + AiAssistantAPIError, + ModelToolsClient, + StatusCallback, + StreamingCallback, + TokenUsage, +} from "../aiAssistant" +import { AIOperationStatus } from "../../providers/AIStatusProvider" +import { getModelProps } from "./settings" +import type { ProviderId } from "./settings" +import type { + AIProvider, + FlowConfig, + ResponseFormatSchema, + ToolDefinition, +} from "./types" +import { + StreamingError, + RefusalError, + MaxTokensError, + safeJsonParse, + extractPartialExplanation, + executeTool, + parseCustomProviderResponse, + responseFormatToPromptInstruction, +} from "./shared" +import type { Tiktoken, TiktokenBPE } from "js-tiktoken/lite" +import { + createHeaderFilteredFetch, + OPENAI_ALLOWED_HEADERS, +} from "./fetchWithFilteredHeaders" + +function toResponseTextConfig( + format: ResponseFormatSchema, +): ResponseTextConfig { + return { + format: { + type: "json_schema" as const, + name: format.name, + schema: format.schema, + strict: format.strict, + }, + } +} + +function toOpenAIFunctions(tools: ToolDefinition[]): OpenAI.Responses.Tool[] { + return tools.map((t) => ({ + type: "function" as const, + name: t.name, + description: t.description, + parameters: { ...t.inputSchema, additionalProperties: false }, + strict: true, + })) as OpenAI.Responses.Tool[] +} + +async function createOpenAIResponseStreaming( + openai: OpenAI, + params: OpenAI.Responses.ResponseCreateParamsNonStreaming, + streamCallback: StreamingCallback, + abortSignal?: AbortSignal, +): Promise { + let accumulatedText = "" + let lastExplanation = "" + let finalResponse: OpenAI.Responses.Response | null = null + + try { + const stream = await openai.responses.create({ + ...params, + stream: true, + } as OpenAI.Responses.ResponseCreateParamsStreaming) + + for await (const event of stream) { + if (abortSignal?.aborted) { + throw new StreamingError("Operation aborted", "interrupted") + } + + if (event.type === "error") { + const errorEvent = event as { error?: { message?: string } } + throw new StreamingError( + errorEvent.error?.message || "Stream error occurred", + "failed", + event, + ) + } + + if (event.type === "response.failed") { + const failedEvent = event as { + response?: { error?: { message?: string } } + } + throw new StreamingError( + failedEvent.response?.error?.message || + "Provider failed to return a response", + "failed", + event, + ) + } + + if (event.type === "response.output_text.delta") { + accumulatedText += event.delta + const explanation = extractPartialExplanation(accumulatedText) + if (explanation !== lastExplanation) { + const chunk = explanation.slice(lastExplanation.length) + lastExplanation = explanation + streamCallback.onTextChunk(chunk, explanation) + } + } + + if (event.type === "response.completed") { + finalResponse = event.response + } + } + } catch (error) { + if (error instanceof StreamingError) { + throw error + } + if (abortSignal?.aborted || error instanceof OpenAI.APIUserAbortError) { + throw new StreamingError("Operation aborted", "interrupted") + } + if (error instanceof OpenAI.APIError) { + throw error + } + throw new StreamingError( + error instanceof Error ? error.message : "Stream interrupted", + "network", + error, + ) + } + + if (!finalResponse) { + throw new StreamingError("Provider failed to return a response", "failed") + } + + return finalResponse +} + +function extractOpenAIToolCalls( + response: OpenAI.Responses.Response, +): { id?: string; name: string; arguments: unknown; call_id: string }[] { + const calls = [] + for (const item of response.output) { + if (item?.type === "function_call") { + const args = + typeof item.arguments === "string" + ? safeJsonParse(item.arguments) + : item.arguments || {} + calls.push({ + id: item.id, + name: item.name, + arguments: args, + call_id: item.call_id, + }) + } + } + return calls +} + +function getOpenAIText(response: OpenAI.Responses.Response): { + type: "refusal" | "text" + message: string +} { + const out = response.output || [] + if ( + out.find( + (item: ResponseOutputItem) => + item.type === "message" && + item.content.some((c) => c.type === "refusal"), + ) + ) { + return { + type: "refusal", + message: "The model refused to generate a response for this request.", + } + } + + for (const item of out) { + if (item.type === "message" && item.content) { + for (const content of item.content) { + if (content.type === "output_text" && "text" in content) { + return { type: "text", message: content.text } + } + } + } + } + + return { type: "text", message: "" } +} + +let tiktokenEncoder: Tiktoken | null = null + +function toResponsesAPIProps(model: string): { + model: string + reasoning?: OpenAI.Reasoning +} { + const props = getModelProps(model) + return { + model: props.model, + ...(props.reasoningEffort + ? { + reasoning: { + effort: props.reasoningEffort as OpenAI.ReasoningEffort, + }, + } + : {}), + } +} + +export function createOpenAIProvider( + apiKey: string, + providerId: ProviderId = "openai", + options?: { baseURL?: string; contextWindow?: number; isCustom?: boolean }, +): AIProvider { + const isCustom = options?.isCustom ?? false + const openai = new OpenAI({ + apiKey, + dangerouslyAllowBrowser: true, + ...(options?.baseURL ? { baseURL: options.baseURL } : {}), + ...(isCustom + ? { + fetch: createHeaderFilteredFetch(OPENAI_ALLOWED_HEADERS), + } + : {}), + }) + + const contextWindow = options?.contextWindow ?? 400_000 + + return { + id: providerId, + contextWindow, + + async executeFlow({ + model, + config, + modelToolsClient, + tools, + setStatus, + abortSignal, + streaming, + }: { + model: string + config: FlowConfig + modelToolsClient: ModelToolsClient + tools: ToolDefinition[] + setStatus: StatusCallback + abortSignal?: AbortSignal + streaming?: StreamingCallback + }): Promise { + let input: OpenAI.Responses.ResponseInput = [] + if (config.conversationHistory && config.conversationHistory.length > 0) { + const validMessages = config.conversationHistory.filter( + (msg) => msg.content && msg.content.trim() !== "", + ) + for (const msg of validMessages) { + input.push({ + role: msg.role, + content: msg.content, + }) + } + } + + input.push({ + role: "user", + content: config.initialUserContent, + }) + + const openaiTools = toOpenAIFunctions(tools) + + let totalInputTokens = 0 + let totalOutputTokens = 0 + + const systemInstructions = isCustom + ? config.systemInstructions + + responseFormatToPromptInstruction(config.responseFormat) + : config.systemInstructions + + const textConfig = isCustom + ? undefined + : toResponseTextConfig(config.responseFormat) + + const requestParams = { + ...toResponsesAPIProps(model), + instructions: systemInstructions, + input, + tools: openaiTools, + ...(textConfig ? { text: textConfig } : {}), + } as OpenAI.Responses.ResponseCreateParamsNonStreaming + + let lastResponse = streaming + ? await createOpenAIResponseStreaming( + openai, + requestParams, + streaming, + abortSignal, + ) + : await openai.responses.create(requestParams) + input = [...input, ...lastResponse.output] + + totalInputTokens += lastResponse.usage?.input_tokens ?? 0 + totalOutputTokens += lastResponse.usage?.output_tokens ?? 0 + + while (true) { + if (abortSignal?.aborted) { + return { + type: "aborted", + message: "Operation was cancelled", + } as AiAssistantAPIError + } + + const toolCalls = extractOpenAIToolCalls(lastResponse) + if (!toolCalls.length) break + const tool_outputs: OpenAI.Responses.ResponseFunctionToolCallOutputItem[] = + [] + for (const tc of toolCalls) { + const exec = await executeTool( + tc.name, + tc.arguments, + modelToolsClient, + setStatus, + ) + tool_outputs.push({ + type: "function_call_output", + call_id: tc.call_id, + output: exec.content, + } as OpenAI.Responses.ResponseFunctionToolCallOutputItem) + } + input = [...input, ...tool_outputs] + + if ( + (lastResponse.usage?.input_tokens ?? 0) >= contextWindow - 50_000 && + tool_outputs.length > 0 + ) { + input.push({ + role: "user" as const, + content: + "**CRITICAL TOKEN USAGE: The conversation is getting too long to fit the context window. If you are planning to use more tools, summarize your findings to the user first, and wait for user confirmation to continue working on the task.**", + }) + } + const loopRequestParams = { + ...toResponsesAPIProps(model), + instructions: systemInstructions, + input, + tools: openaiTools, + ...(textConfig ? { text: textConfig } : {}), + } as OpenAI.Responses.ResponseCreateParamsNonStreaming + + lastResponse = streaming + ? await createOpenAIResponseStreaming( + openai, + loopRequestParams, + streaming, + abortSignal, + ) + : await openai.responses.create(loopRequestParams) + input = [...input, ...lastResponse.output] + + totalInputTokens += lastResponse.usage?.input_tokens ?? 0 + totalOutputTokens += lastResponse.usage?.output_tokens ?? 0 + } + + if (abortSignal?.aborted) { + return { + type: "aborted", + message: "Operation was cancelled", + } as AiAssistantAPIError + } + + const text = getOpenAIText(lastResponse) + if (text.type === "refusal") { + return { + type: "unknown", + message: text.message, + } as AiAssistantAPIError + } + + const rawOutput = text.message + + if (isCustom) { + const json = parseCustomProviderResponse( + rawOutput, + (config.responseFormat.schema.required as string[]) || [], + (raw) => ({ explanation: raw }) as unknown as T, + ) + setStatus(null) + + const tokenUsage = { + inputTokens: totalInputTokens, + outputTokens: totalOutputTokens, + } + + if (config.postProcess) { + const processed = config.postProcess(json) + return { ...processed, tokenUsage } as T & { tokenUsage: TokenUsage } + } + return { ...json, tokenUsage } as T & { tokenUsage: TokenUsage } + } + + try { + const json = JSON.parse(rawOutput) as T + setStatus(null) + + const resultWithTokens = { + ...json, + tokenUsage: { + inputTokens: totalInputTokens, + outputTokens: totalOutputTokens, + }, + } as T & { tokenUsage: TokenUsage } + + if (config.postProcess) { + const processed = config.postProcess(json) + return { + ...processed, + tokenUsage: { + inputTokens: totalInputTokens, + outputTokens: totalOutputTokens, + }, + } as T & { tokenUsage: TokenUsage } + } + return resultWithTokens + } catch { + setStatus(null) + return { + type: "unknown", + message: "Failed to parse assistant response.", + } as AiAssistantAPIError + } + }, + + async generateTitle({ model, prompt, responseFormat }) { + try { + const userContent = isCustom + ? prompt + responseFormatToPromptInstruction(responseFormat) + : prompt + + const titleTextConfig = isCustom + ? undefined + : toResponseTextConfig(responseFormat) + + const response = await openai.responses.create({ + model: toResponsesAPIProps(model).model, + input: [{ role: "user", content: userContent }], + ...(titleTextConfig ? { text: titleTextConfig } : {}), + max_output_tokens: 100, + }) + const rawText = response.output_text + + if (isCustom) { + const parsed = parseCustomProviderResponse<{ title: string }>( + rawText, + (responseFormat.schema.required as string[]) || [], + (raw) => ({ title: raw.trim().slice(0, 40) }), + ) + return parsed.title || null + } + + const parsed = JSON.parse(rawText) as { title: string } + return parsed.title || null + } catch { + return null + } + }, + + async generateSummary({ model, systemPrompt, userMessage }) { + const response = await openai.responses.create({ + ...toResponsesAPIProps(model), + instructions: systemPrompt, + input: userMessage, + }) + return response.output_text || "" + }, + + async testConnection({ apiKey: testApiKey, model }) { + try { + const testClient = new OpenAI({ + apiKey: testApiKey, + dangerouslyAllowBrowser: true, + ...(options?.baseURL ? { baseURL: options.baseURL } : {}), + ...(isCustom + ? { + fetch: createHeaderFilteredFetch(OPENAI_ALLOWED_HEADERS), + } + : {}), + }) + await testClient.responses.create({ + model: getModelProps(model).model, // testConnection only needs model name + input: [{ role: "user", content: "ping" }], + max_output_tokens: 16, + }) + return { valid: true } + } catch (error: unknown) { + const status = + (error as { status?: number })?.status || + (error as { error?: { status?: number } })?.error?.status + if (status === 401) { + return { valid: false, error: "Invalid API key" } + } + if (status === 429) { + return { valid: true } + } + return { + valid: false, + error: + error instanceof Error + ? error.message + : "Failed to validate API key", + } + } + }, + + async countTokens({ messages, systemPrompt }) { + if (!tiktokenEncoder) { + const { Tiktoken } = await import("js-tiktoken/lite") + const o200k_base = await import("js-tiktoken/ranks/o200k_base").then( + (module: { default: TiktokenBPE }) => module.default, + ) + tiktokenEncoder = new Tiktoken(o200k_base) + } + + let totalTokens = 0 + totalTokens += tiktokenEncoder.encode(systemPrompt).length + totalTokens += 4 // system message formatting overhead + + for (const message of messages) { + totalTokens += 4 // role markers overhead + totalTokens += tiktokenEncoder.encode(message.content).length + } + + totalTokens += 2 // assistant reply priming + return totalTokens + }, + + async listModels(): Promise { + const models: string[] = [] + for await (const model of openai.models.list()) { + models.push(model.id) + } + return models.sort((a, b) => a.localeCompare(b)) + }, + + classifyError( + error: unknown, + setStatus: StatusCallback, + ): AiAssistantAPIError { + if ( + error instanceof OpenAI.APIUserAbortError || + (error instanceof StreamingError && error.errorType === "interrupted") + ) { + setStatus(AIOperationStatus.Aborted) + return { type: "aborted", message: "Operation was cancelled" } + } + setStatus(null) + + if (error instanceof RefusalError) { + return { + type: "unknown", + message: "The model refused to generate a response for this request.", + details: error.message, + } + } + + if (error instanceof MaxTokensError) { + return { + type: "unknown", + message: + "The response exceeded the maximum token limit for the selected model. Please try again with a different prompt or model.", + details: error.message, + } + } + + if (error instanceof StreamingError) { + switch (error.errorType) { + case "network": + return { + type: "network", + message: + "Network error during streaming. Please check your connection.", + details: error.message, + } + case "failed": + default: + return { + type: "unknown", + message: error.message || "Stream failed unexpectedly.", + details: + error.originalError instanceof Error + ? error.originalError.message + : undefined, + } + } + } + + if (error instanceof OpenAI.AuthenticationError) { + return { + type: "invalid_key", + message: "Invalid API key. Please check your OpenAI API key.", + details: error.message, + } + } + + if (error instanceof OpenAI.RateLimitError) { + return { + type: "rate_limit", + message: "Rate limit exceeded. Please try again later.", + details: error.message, + } + } + + if (error instanceof OpenAI.APIConnectionError) { + return { + type: "network", + message: "Network error. Please check your connection.", + details: error.message, + } + } + + if (error instanceof OpenAI.APIError) { + return { + type: "unknown", + message: error.message, + details: `Status ${error.status}`, + } + } + + return { + type: "unknown", + message: "An unexpected error occurred. Please try again.", + details: error instanceof Error ? error.message : String(error), + } + }, + + isNonRetryableError(error: unknown): boolean { + if (error instanceof StreamingError) { + return error.errorType === "interrupted" || error.errorType === "failed" + } + if ( + error instanceof OpenAI.APIError && + error.status != null && + error.status >= 400 && + error.status < 500 && + error.status !== 429 + ) { + return true + } + return ( + error instanceof RefusalError || + error instanceof MaxTokensError || + error instanceof OpenAI.APIUserAbortError + ) + }, + } +} diff --git a/src/utils/ai/prompts.ts b/src/utils/ai/prompts.ts new file mode 100644 index 000000000..c7c6f9a21 --- /dev/null +++ b/src/utils/ai/prompts.ts @@ -0,0 +1,172 @@ +export const DOCS_INSTRUCTION = ` +CRITICAL: Always follow this documentation approach: +1. Use get_questdb_toc to see available functions, operators, SQL syntax, AND cookbook recipes +2. If user's request matches a cookbook recipe description, fetch it FIRST - recipes provide complete, tested SQL patterns +3. Use get_questdb_documentation for specific function/syntax details + +When a cookbook recipe matches the user's intent, ALWAYS use it as the foundation and adapt column/table names and use case to their schema.` + +export const getUnifiedPrompt = (grantSchemaAccess?: boolean) => { + const base = `You are a SQL expert coding assistant specializing in QuestDB, a high-performance time-series database. You help users with: +- Generating QuestDB SQL queries from natural language descriptions +- Explaining what QuestDB SQL queries do +- Fixing errors in QuestDB SQL queries +- Refining and modifying existing queries based on user requests + +## CRITICAL: Tool and Response Sequencing +Follow this EXACT sequence for every query generation request: + +**PHASE 1 - INFORMATION GATHERING (NO TEXT OUTPUT)** +1. Call available tools to gather information if you need, including documentation, schema, and validation tools. +2. Complete ALL information gathering before Phase 2. DO NOT CALL any tool after Phase 2. + +**PHASE 2 - FINAL RESPONSE (NO MORE TOOL CALLS)** +3. Return your JSON response with "sql" and "explanation" fields. Always return sql field first, then explanation field. + +NEVER interleave phases. NEVER use any tool after starting to return a response. + +## When Explaining Queries +- Focus on the business logic and what the query achieves, not the SQL syntax itself +- Pay special attention to QuestDB-specific features: + - Time-series operations (SAMPLE BY, LATEST ON, designated timestamp columns) + - Time-based filtering and aggregations + - Real-time data ingestion patterns + - Performance optimizations specific to time-series data + +## When Generating SQL +- DO NOT return any content before completing your tool calls including documentation and validation tools. You should NOT CALL any tool after starting to return a response. +- Always validate the query in "sql" field using the validate_query tool before returning an explanation or a generated SQL query +- Generate only valid QuestDB SQL syntax referring to the documentation about functions, operators, and SQL keywords +- Use appropriate time-series functions (SAMPLE BY, LATEST ON, etc.) and common table expressions when relevant +- Use \`IN\` with \`today()\`, \`tomorrow()\`, \`yesterday()\` interval functions when relevant +- Follow QuestDB best practices for performance referring to the documentation +- Use proper timestamp handling for time-series data +- Use correct data types and functions specific to QuestDB referring to the documentation. Do not use any word that is not in the documentation. + +## When Fixing Queries +- DO NOT return any content before completing your tool calls including documentation and validation tools. You should NOT CALL any tool after starting to return a response. +- Always validate the query in "sql" field using the validate_query tool before returning an explanation or a fixed SQL query +- Analyze the error message carefully to understand what went wrong +- Generate only valid QuestDB SQL syntax by always referring to the documentation about functions, operators, and SQL keywords +- Preserve the original intent of the query while fixing the error +- Follow QuestDB best practices and syntax rules referring to the documentation +- Consider common issues like: + - Missing or incorrect column names + - Invalid syntax for time-series operations + - Data type mismatches + - Incorrect function usage + +## Response Guidelines +- You are working as a coding assistant inside an IDE. Every time you return a query in "sql" field, you provide a suggestion to the user to accept or reject. When the user accepts the suggestion, you are informed and the query in the editor is updated with your suggestion. +- Modify a query by returning "sql" field only if the user asks you to generate, fix, or make changes to the query. If the user does not ask for fixing/changing/generating a query, return null in the "sql" field. Every time you provide a SQL query, the current SQL is updated. +- Provide the "explanation" field if you haven't provided it yet. Explanation should be in GFM (GitHub Flavored Markdown) format. Explanation field is cumulative, every time you provide an explanation, it is added to the previous explanations. + +## Tools +- Use the validate_query tool to validate the query in "sql" field before returning a response only if the user asks you to generate, fix, or make changes to the query. +` + const schemaAccess = grantSchemaAccess + ? `- Use the get_tables tool to retrieve all tables and materialized views in the database instance +- Use the get_table_schema tool to get detailed schema information for a specific table or a materialized view +- Use the get_table_details tool to get detailed information for a specific table or a materialized view. Each property is described in meta functions docs. +` + : "" + return base + schemaAccess + DOCS_INSTRUCTION +} + +export const getExplainSchemaPrompt = ( + tableName: string, + schema: string, + kindLabel: string, +) => `You are a SQL expert assistant specializing in QuestDB, a high-performance time-series database. +Explain the following ${kindLabel} schema. Include: +- The purpose of the ${kindLabel} +- What each column represents and its data type +- Any important properties like WAL enablement, partitioning strategy, designated timestamps +- Any performance or storage considerations + +${kindLabel} Name: ${tableName} + +Schema: +\`\`\`sql +${schema} +\`\`\` + +**IMPORTANT: Format your response in markdown exactly as follows:** + +1. Start with a brief paragraph explaining the purpose and general characteristics of this ${kindLabel}. + +2. Add a "## Columns" section with a markdown table: +| Column | Type | Description | +|--------|------|-------------| +| column_name | \`data_type\` | Brief description | + +3. If this is a table or materialized view (not a view), add a "## Storage Details" section with bullet points about: +- WAL enablement +- Partitioning strategy +- Designated timestamp column +- Any other storage considerations + +For views, skip the Storage Details section.` + +export type HealthIssuePromptData = { + tableName: string + issue: { + id: string + field: string + message: string + currentValue?: string + } + tableDetails: string + monitoringDocs: string + trendSamples?: Array<{ value: number; timestamp: number }> +} + +export const getHealthIssuePrompt = (data: HealthIssuePromptData): string => { + const { tableName, issue, tableDetails, monitoringDocs, trendSamples } = data + + let trendSection = "" + if (trendSamples && trendSamples.length > 0) { + const recentSamples = trendSamples.slice(-30) + trendSection = ` + +### Trend Data (Recent Samples) +| Timestamp | Value | +|-----------|-------| +${recentSamples.map((s) => `| ${new Date(s.timestamp).toISOString()} | ${s.value.toLocaleString()} |`).join("\n")} +` + } + + return `You are a QuestDB expert assistant helping diagnose and resolve table health issues. + +A user is viewing the health monitoring panel for their table and has asked for help with a detected issue. + +## Table: ${tableName} + +## Health Issue Detected +- **Issue ID**: ${issue.id} +- **Field**: ${issue.field} +- **Message**: ${issue.message} +${issue.currentValue ? `- **Current Value**: ${issue.currentValue}` : ""}${trendSection} + +## Table Details (from tables() function) +\`\`\`json +${tableDetails} +\`\`\` + +## QuestDB Monitoring Documentation +${monitoringDocs} + +--- + +**Your Task:** +1. Explain what this health issue means in the context of this specific table +2. Analyze the table details to identify potential root causes +3. Provide specific, actionable recommendations to resolve or mitigate the issue +4. If there is a clear SQL command that can help fix the issue (like \`ALTER TABLE ... RESUME WAL\`), include it in the "sql" field of your response. **Only provide SQL if it directly addresses the root cause** - do not provide SQL just to inspect the problem. + +**IMPORTANT: Be concise and thorough, format your response in markdown with clear sections:** +- Use ## headings for main sections +- Use bullet points for lists +- Use \`code\` for configuration values and SQL +` +} diff --git a/src/utils/ai/registry.ts b/src/utils/ai/registry.ts new file mode 100644 index 000000000..947b42e37 --- /dev/null +++ b/src/utils/ai/registry.ts @@ -0,0 +1,55 @@ +import type { AIProvider } from "./types" +import type { AiAssistantSettings } from "../../providers/LocalStorageProvider/types" +import { createOpenAIProvider } from "./openaiProvider" +import { createOpenAIChatCompletionsProvider } from "./openaiChatCompletionsProvider" +import { createAnthropicProvider } from "./anthropicProvider" +import { BUILTIN_PROVIDERS } from "./settings" +import type { ProviderId, ProviderType } from "./settings" + +type ProviderOptions = { + baseURL?: string + contextWindow?: number + isCustom?: boolean +} + +export function createProvider( + providerId: ProviderId, + apiKey: string, + settings?: AiAssistantSettings, +): AIProvider { + // Check built-in providers first + const builtin = BUILTIN_PROVIDERS[providerId] + if (builtin) { + return createProviderByType(builtin.type, providerId, apiKey) + } + + // Check custom providers + const custom = settings?.customProviders?.[providerId] + if (custom) { + return createProviderByType(custom.type, providerId, apiKey, { + baseURL: custom.baseURL, + contextWindow: custom.contextWindow, + isCustom: true, + }) + } + + throw new Error(`Unknown provider: ${providerId}`) +} + +export function createProviderByType( + providerType: ProviderType, + providerId: ProviderId, + apiKey: string, + options?: ProviderOptions, +): AIProvider { + switch (providerType) { + case "openai": + return createOpenAIProvider(apiKey, providerId, options) + case "openai-chat-completions": + return createOpenAIChatCompletionsProvider(apiKey, providerId, options) + case "anthropic": + return createAnthropicProvider(apiKey, providerId, options) + default: + throw new Error(`Unknown provider type: ${providerType}`) + } +} diff --git a/src/utils/ai/responseFormats.ts b/src/utils/ai/responseFormats.ts new file mode 100644 index 000000000..60d92636f --- /dev/null +++ b/src/utils/ai/responseFormats.ts @@ -0,0 +1,55 @@ +import type { ResponseFormatSchema } from "./types" + +export const ExplainFormat: ResponseFormatSchema = { + name: "explain_format", + schema: { + type: "object", + properties: { + explanation: { type: "string" }, + }, + required: ["explanation"], + additionalProperties: false, + }, + strict: true, +} + +export const FixSQLFormat: ResponseFormatSchema = { + name: "fix_sql_format", + schema: { + type: "object", + properties: { + sql: { type: ["string", "null"] }, + explanation: { type: "string" }, + }, + required: ["explanation", "sql"], + additionalProperties: false, + }, + strict: true, +} + +export const ConversationResponseFormat: ResponseFormatSchema = { + name: "conversation_response_format", + schema: { + type: "object", + properties: { + sql: { type: ["string", "null"] }, + explanation: { type: "string" }, + }, + required: ["sql", "explanation"], + additionalProperties: false, + }, + strict: true, +} + +export const ChatTitleFormat: ResponseFormatSchema = { + name: "chat_title_format", + schema: { + type: "object", + properties: { + title: { type: "string" }, + }, + required: ["title"], + additionalProperties: false, + }, + strict: true, +} diff --git a/src/utils/ai/settings.test.ts b/src/utils/ai/settings.test.ts new file mode 100644 index 000000000..fdbe5b7dc --- /dev/null +++ b/src/utils/ai/settings.test.ts @@ -0,0 +1,375 @@ +import { describe, it, expect, afterEach } from "vitest" +import { reconcileSettings, getSelectedModel, MODEL_OPTIONS } from "./settings" +import type { ModelOption } from "./settings" + +import type { AiAssistantSettings } from "../../providers/LocalStorageProvider/types" + +const makeSettings = ( + overrides: Partial = {}, +): AiAssistantSettings => ({ + providers: {}, + ...overrides, +}) + +describe("reconcileSettings", () => { + it("removes stale model IDs from enabledModels", () => { + const settings = makeSettings({ + providers: { + openai: { + apiKey: "sk-test", + enabledModels: ["gpt-5-mini", "removed-model", "also-removed"], + grantSchemaAccess: false, + }, + }, + }) + const result = reconcileSettings(settings) + expect(result.providers.openai!.enabledModels).toEqual(["gpt-5-mini"]) + }) + + it("does not add defaultEnabled models when user has valid models", () => { + const settings = makeSettings({ + providers: { + anthropic: { + apiKey: "sk-test", + enabledModels: ["claude-sonnet-4-5"], + grantSchemaAccess: false, + }, + }, + }) + const result = reconcileSettings(settings) + expect(result.providers.anthropic!.enabledModels).toEqual([ + "claude-sonnet-4-5", + ]) + }) + + it("leaves enabledModels empty when all previous models were removed", () => { + const settings = makeSettings({ + providers: { + anthropic: { + apiKey: "sk-test", + enabledModels: ["removed-model-1", "removed-model-2"], + grantSchemaAccess: false, + }, + }, + }) + const result = reconcileSettings(settings) + expect(result.providers.anthropic!.enabledModels).toEqual([]) + }) + + it("does not add defaults for unconfigured providers", () => { + const settings = makeSettings({ + providers: { + anthropic: { + apiKey: "sk-test", + enabledModels: ["claude-sonnet-4-5"], + grantSchemaAccess: false, + }, + }, + }) + const result = reconcileSettings(settings) + expect(result.providers.openai).toBeUndefined() + }) + + it("is idempotent", () => { + const settings = makeSettings({ + selectedModel: "claude-sonnet-4-5", + providers: { + anthropic: { + apiKey: "sk-test", + enabledModels: ["claude-sonnet-4-5", "stale-model"], + grantSchemaAccess: true, + }, + openai: { + apiKey: "sk-test", + enabledModels: ["gpt-5-mini"], + grantSchemaAccess: false, + }, + }, + }) + const once = reconcileSettings(settings) + const twice = reconcileSettings(once) + expect(twice).toEqual(once) + }) + + it("preserves unknown fields (forward compat)", () => { + const settings = makeSettings({ + providers: { + anthropic: { + apiKey: "sk-test", + enabledModels: ["claude-sonnet-4-5"], + grantSchemaAccess: false, + }, + }, + }) + const settingsWithFutureField = settings as unknown as Record< + string, + string + > + settingsWithFutureField.futureField = "preserved" + const result = reconcileSettings(settings) + expect((result as unknown as Record).futureField).toBe( + "preserved", + ) + }) + + it("clears selectedModel if not in any enabledModels", () => { + const settings = makeSettings({ + selectedModel: "removed-model", + providers: { + openai: { + apiKey: "sk-test", + enabledModels: ["gpt-5-mini"], + grantSchemaAccess: false, + }, + }, + }) + const result = reconcileSettings(settings) + expect(result.selectedModel).toEqual("gpt-5-mini") + }) + + it("preserves selectedModel if it is in enabledModels", () => { + const settings = makeSettings({ + selectedModel: "gpt-5-mini", + providers: { + openai: { + apiKey: "sk-test", + enabledModels: ["gpt-5-mini"], + grantSchemaAccess: false, + }, + }, + }) + const result = reconcileSettings(settings) + expect(result.selectedModel).toBe("gpt-5-mini") + }) + + it("handles empty providers gracefully", () => { + const settings = makeSettings({ providers: {} }) + const result = reconcileSettings(settings) + expect(result.providers).toEqual({}) + }) + + it("does not mutate the input settings", () => { + const settings = makeSettings({ + providers: { + openai: { + apiKey: "sk-test", + enabledModels: ["gpt-5-mini", "stale-model"], + grantSchemaAccess: false, + }, + }, + }) + const originalModels = [...settings.providers.openai!.enabledModels] + reconcileSettings(settings) + expect(settings.providers.openai!.enabledModels).toEqual(originalModels) + }) +}) + +describe("getSelectedModel", () => { + it("returns selectedModel when it is in enabledModels", () => { + const settings = makeSettings({ + selectedModel: "gpt-5-mini", + providers: { + openai: { + apiKey: "sk-test", + enabledModels: ["gpt-5-mini", "gpt-5"], + grantSchemaAccess: false, + }, + }, + }) + expect(getSelectedModel(settings)).toBe("gpt-5-mini") + }) + + it("does not return selectedModel if not in enabledModels", () => { + const settings = makeSettings({ + selectedModel: "claude-sonnet-4-5", + providers: { + openai: { + apiKey: "sk-test", + enabledModels: ["gpt-5-mini"], + grantSchemaAccess: false, + }, + }, + }) + expect(getSelectedModel(settings)).not.toBe("claude-sonnet-4-5") + expect(getSelectedModel(settings)).toBe("gpt-5-mini") + }) + + it("returns null when no models are enabled", () => { + const settings = makeSettings({ providers: {} }) + expect(getSelectedModel(settings)).toBeNull() + }) +}) + +/** + * Simulates version upgrades by temporarily replacing MODEL_OPTIONS contents. + * Tests verify that user settings from a previous version are handled correctly + * when the app is updated with a different model list. + */ +describe("version compatibility scenarios", () => { + let originalOptions: ModelOption[] + + function setModelOptions(options: ModelOption[]) { + originalOptions = [...MODEL_OPTIONS] + MODEL_OPTIONS.length = 0 + MODEL_OPTIONS.push(...options) + } + + afterEach(() => { + MODEL_OPTIONS.length = 0 + MODEL_OPTIONS.push(...originalOptions) + }) + + it("upgrade: model removed, selectedModel was that model", () => { + // v1: user had model-A and model-B, selected model-A + setModelOptions([ + { label: "A", value: "model-a", provider: "openai" }, + { label: "B", value: "model-b", provider: "openai" }, + ]) + + const v1Settings = makeSettings({ + selectedModel: "model-a", + providers: { + openai: { + apiKey: "sk-test", + enabledModels: ["model-a", "model-b"], + grantSchemaAccess: false, + }, + }, + }) + + // v2: model-A removed, model-C added + setModelOptions([ + { label: "B", value: "model-b", provider: "openai" }, + { + label: "C", + value: "model-c", + provider: "openai", + defaultEnabled: true, + }, + ]) + + const reconciled = reconcileSettings(v1Settings) + expect(reconciled.providers.openai!.enabledModels).toEqual(["model-b"]) + expect(reconciled.selectedModel).toBe("model-b") + }) + + it("upgrade: all models removed for a provider", () => { + setModelOptions([{ label: "A", value: "model-a", provider: "openai" }]) + + const v1Settings = makeSettings({ + selectedModel: "model-a", + providers: { + openai: { + apiKey: "sk-test", + enabledModels: ["model-a"], + grantSchemaAccess: false, + }, + }, + }) + + // v2: provider's models completely replaced + setModelOptions([ + { + label: "X", + value: "model-x", + provider: "openai", + defaultEnabled: true, + }, + { label: "Y", value: "model-y", provider: "openai" }, + ]) + + const reconciled = reconcileSettings(v1Settings) + // all old models gone, empty list — user must re-enable in settings + expect(reconciled.providers.openai!.enabledModels).toEqual([]) + expect(reconciled.selectedModel).toBeUndefined() + expect(getSelectedModel(reconciled)).toBeNull() + }) + + it("upgrade: new models added, user keeps their selection", () => { + setModelOptions([ + { label: "A", value: "model-a", provider: "anthropic", default: true }, + { label: "B", value: "model-b", provider: "anthropic" }, + ]) + + const v1Settings = makeSettings({ + selectedModel: "model-b", + providers: { + anthropic: { + apiKey: "sk-test", + enabledModels: ["model-a", "model-b"], + grantSchemaAccess: true, + }, + }, + }) + + // v2: model-C added + setModelOptions([ + { label: "A", value: "model-a", provider: "anthropic", default: true }, + { label: "B", value: "model-b", provider: "anthropic" }, + { + label: "C", + value: "model-c", + provider: "anthropic", + defaultEnabled: true, + }, + ]) + + const reconciled = reconcileSettings(v1Settings) + // existing models preserved, new model NOT auto-added + expect(reconciled.providers.anthropic!.enabledModels).toEqual([ + "model-a", + "model-b", + ]) + expect(reconciled.selectedModel).toBe("model-b") + expect(getSelectedModel(reconciled)).toBe("model-b") + }) + + it("upgrade: selected model survives but some enabled models removed", () => { + setModelOptions([ + { label: "A", value: "model-a", provider: "openai" }, + { label: "B", value: "model-b", provider: "openai" }, + { label: "C", value: "model-c", provider: "openai" }, + ]) + + const v1Settings = makeSettings({ + selectedModel: "model-b", + providers: { + openai: { + apiKey: "sk-test", + enabledModels: ["model-a", "model-b", "model-c"], + grantSchemaAccess: false, + }, + }, + }) + + // v2: model-A and model-C removed + setModelOptions([ + { label: "B", value: "model-b", provider: "openai" }, + { label: "D", value: "model-d", provider: "openai" }, + ]) + + const reconciled = reconcileSettings(v1Settings) + expect(reconciled.providers.openai!.enabledModels).toEqual(["model-b"]) + expect(reconciled.selectedModel).toBe("model-b") + expect(getSelectedModel(reconciled)).toBe("model-b") + }) + + it("downgrade: user has models from a newer version", () => { + setModelOptions([{ label: "A", value: "model-a", provider: "openai" }]) + + const futureSettings = makeSettings({ + selectedModel: "model-future", + providers: { + openai: { + apiKey: "sk-test", + enabledModels: ["model-a", "model-future"], + grantSchemaAccess: false, + }, + }, + }) + + const reconciled = reconcileSettings(futureSettings) + expect(reconciled.providers.openai!.enabledModels).toEqual(["model-a"]) + expect(reconciled.selectedModel).toBe("model-a") + }) +}) diff --git a/src/utils/ai/settings.ts b/src/utils/ai/settings.ts new file mode 100644 index 000000000..55fc42bb3 --- /dev/null +++ b/src/utils/ai/settings.ts @@ -0,0 +1,371 @@ +import type { + AiAssistantSettings, + CustomProviderDefinition, +} from "../../providers/LocalStorageProvider/types" + +export type ProviderType = "anthropic" | "openai" | "openai-chat-completions" + +/** Provider ID — built-in ("anthropic", "openai") or user-defined string for custom providers. */ +export type ProviderId = string + +export type ProviderDefinition = { + type: ProviderType + name: string +} + +export { type CustomProviderDefinition } + +export const BUILTIN_PROVIDERS: Record = { + anthropic: { type: "anthropic", name: "Anthropic" }, + openai: { type: "openai", name: "OpenAI" }, +} + +export const getProviderName = ( + providerId: ProviderId | null, + settings?: AiAssistantSettings, +): string => { + if (!providerId) return "" + if (BUILTIN_PROVIDERS[providerId]) return BUILTIN_PROVIDERS[providerId].name + const custom = settings?.customProviders?.[providerId] + if (custom) return custom.name + return providerId +} + +export type ModelOption = { + label: string + value: string + provider: ProviderId + isSlow?: boolean + isTestModel?: boolean + default?: boolean + defaultEnabled?: boolean +} + +export const MODEL_OPTIONS: ModelOption[] = [ + { + label: "Claude Opus 4.6", + value: "claude-opus-4-6", + provider: "anthropic", + isSlow: true, + defaultEnabled: true, + }, + { + label: "Claude Sonnet 4.6", + value: "claude-sonnet-4-6", + provider: "anthropic", + default: true, + defaultEnabled: true, + }, + { + label: "Claude Sonnet 4.5", + value: "claude-sonnet-4-5", + provider: "anthropic", + }, + { + label: "Claude Haiku 4.5", + value: "claude-haiku-4-5", + provider: "anthropic", + isTestModel: true, + }, + { + label: "GPT-5.4 (High Reasoning)", + value: "gpt-5.4@reasoning=high", + provider: "openai", + }, + { + label: "GPT-5.4 (Medium Reasoning)", + value: "gpt-5.4@reasoning=medium", + provider: "openai", + defaultEnabled: true, + }, + { + label: "GPT-5.4 (Low Reasoning)", + value: "gpt-5.4@reasoning=low", + provider: "openai", + defaultEnabled: true, + default: true, + }, + { + label: "GPT-5 mini", + value: "gpt-5-mini", + provider: "openai", + defaultEnabled: true, + }, + { + label: "GPT-5 nano", + value: "gpt-5-nano", + provider: "openai", + defaultEnabled: true, + isTestModel: true, + }, +] + +export type ReasoningEffort = "high" | "medium" | "low" + +export type ModelProps = { + model: string + reasoningEffort?: ReasoningEffort +} + +const CUSTOM_MODEL_SEP = ":" + +export const makeCustomModelValue = ( + providerId: ProviderId, + modelId: string, +): string => `${providerId}${CUSTOM_MODEL_SEP}${modelId}` + +export const parseModelValue = ( + value: string, +): { customProviderId: string; rawModel: string } | { rawModel: string } => { + const sepIndex = value.indexOf(CUSTOM_MODEL_SEP) + if (sepIndex === -1) return { rawModel: value } + const candidateProvider = value.slice(0, sepIndex) + // Only treat as namespaced if the prefix is NOT a built-in provider. + if (BUILTIN_PROVIDERS[candidateProvider]) return { rawModel: value } + return { + customProviderId: candidateProvider, + rawModel: value.slice(sepIndex + 1), + } +} + +export const getAllModelOptions = ( + settings?: AiAssistantSettings, +): ModelOption[] => { + if (!settings?.customProviders) return MODEL_OPTIONS + const customModels: ModelOption[] = [] + for (const [providerId, def] of Object.entries(settings.customProviders)) { + for (const modelId of def.models) { + customModels.push({ + label: modelId, + value: makeCustomModelValue(providerId, modelId), + provider: providerId, + }) + } + } + return [...MODEL_OPTIONS, ...customModels] +} + +export const providerForModel = ( + model: ModelOption["value"], + _settings?: AiAssistantSettings, +): ProviderId | null => { + // Check for namespaced custom model value (providerId:modelId) + const parsed = parseModelValue(model) + if ("customProviderId" in parsed) return parsed.customProviderId + // Fall back to built-in model lookup + return MODEL_OPTIONS.find((m) => m.value === model)?.provider ?? null +} + +export const getModelProps = (model: ModelOption["value"]): ModelProps => { + const { rawModel } = parseModelValue(model) + const parts = rawModel.split("@") + const modelName = parts[0] + const extraParams = parts[1] + ?.split(",") + ?.map((p) => ({ key: p.split("=")[0], value: p.split("=")[1] })) + if (extraParams) { + const reasoningParam = extraParams.find((p) => p.key === "reasoning") + if (reasoningParam && reasoningParam.value) { + return { + model: modelName, + reasoningEffort: reasoningParam.value as ReasoningEffort, + } + } + } + return { model: modelName } +} + +export const getAllProviders = ( + settings?: AiAssistantSettings, +): ProviderId[] => { + const providers = new Set() + MODEL_OPTIONS.forEach((model) => { + providers.add(model.provider) + }) + if (settings?.customProviders) { + for (const id of Object.keys(settings.customProviders)) { + providers.add(id) + } + } + return Array.from(providers) +} + +export const getSelectedModel = ( + settings: AiAssistantSettings, +): string | null => { + const enabledModels = getAllEnabledModels(settings) + const selectedModel = settings.selectedModel + if ( + selectedModel && + typeof selectedModel === "string" && + enabledModels.includes(selectedModel) + ) { + return selectedModel + } + + const allModels = getAllModelOptions(settings) + // Fall back to first enabled default model, then first enabled model + return ( + enabledModels.find( + (id) => allModels.find((m) => m.value === id)?.default, + ) ?? + enabledModels[0] ?? + null + ) +} + +export const getAllEnabledModels = ( + settings: AiAssistantSettings, +): string[] => { + const models: string[] = [] + for (const provider of getAllProviders(settings)) { + const providerModels = settings.providers?.[provider]?.enabledModels + if (providerModels) { + models.push(...providerModels) + } else if (settings.customProviders?.[provider]) { + models.push( + ...settings.customProviders[provider].models.map((m) => + makeCustomModelValue(provider, m), + ), + ) + } + } + return models +} + +export const getNextModel = ( + currentModel: string | undefined, + enabledModels: Record, + settings?: AiAssistantSettings, +): string | null => { + let nextModel: string | null | undefined = currentModel + + const allModels = getAllModelOptions(settings) + const modelProvider = currentModel + ? providerForModel(currentModel, settings) + : null + if (modelProvider && enabledModels[modelProvider]?.length > 0) { + // Current model is still enabled, so we can use it + if (currentModel && enabledModels[modelProvider].includes(currentModel)) { + return currentModel + } + // Take the default model of this provider, otherwise the first enabled model of this provider + nextModel = + enabledModels[modelProvider].find( + (m) => allModels.find((mo) => mo.value === m)?.default, + ) ?? enabledModels[modelProvider][0] + } else { + // No other enabled models for this provider, we have to choose from another provider if exists + const otherProviderWithEnabledModel = getAllProviders(settings).find( + (p) => enabledModels[p]?.length > 0, + ) + if (otherProviderWithEnabledModel) { + nextModel = + enabledModels[otherProviderWithEnabledModel].find( + (m) => allModels.find((mo) => mo.value === m)?.default, + ) ?? enabledModels[otherProviderWithEnabledModel][0] + } else { + nextModel = null + } + } + return nextModel ?? null +} + +export const isAiAssistantConfigured = ( + settings: AiAssistantSettings, +): boolean => { + const builtinConfigured = Object.keys(BUILTIN_PROVIDERS).some( + (provider) => !!settings.providers?.[provider]?.apiKey, + ) + if (builtinConfigured) return true + return Object.keys(settings.customProviders ?? {}).length > 0 +} + +export const canUseAiAssistant = (settings: AiAssistantSettings): boolean => { + return isAiAssistantConfigured(settings) && !!settings.selectedModel +} + +export const getTestModel = ( + providerId: ProviderId, + settings?: AiAssistantSettings, +): string | null => { + if (settings?.customProviders?.[providerId]) { + return settings.selectedModel ?? null + } + return ( + MODEL_OPTIONS.find((m) => m.provider === providerId && m.isTestModel) + ?.value ?? null + ) +} + +/** + * Returns the context window for a given provider. + * For custom providers, returns the configured value. + * For built-in providers, returns null (factory uses its own default). + */ +export const getProviderContextWindow = ( + providerId: ProviderId, + settings?: AiAssistantSettings, +): number | null => { + const custom = settings?.customProviders?.[providerId] + return custom?.contextWindow ?? null +} + +/** + * Reconciles persisted AI assistant settings against current model options. + * Removes stale model IDs from built-in providers' enabledModels. + * Preserves custom provider models (validated against customProviders definitions). + * + * Pure function — does not write to localStorage. + * Idempotent: applying it multiple times produces the same result. + */ +export const reconcileSettings = ( + settings: AiAssistantSettings, +): AiAssistantSettings => { + const allValidIds = new Set(getAllModelOptions(settings).map((m) => m.value)) + const result = { + ...settings, + providers: { ...settings.providers }, + } + + for (const providerKey of Object.keys(result.providers)) { + const providerSettings = result.providers[providerKey] + if (!providerSettings?.enabledModels) continue + + const models = providerSettings.enabledModels.filter((id) => + allValidIds.has(id), + ) + result.providers[providerKey] = { + ...providerSettings, + enabledModels: models, + } + } + + result.selectedModel = getSelectedModel(result) ?? undefined + + return result +} + +export const getApiKey = ( + providerId: ProviderId, + settings: AiAssistantSettings, +): string | null => { + const builtinKey = settings.providers?.[providerId]?.apiKey + if (builtinKey) return builtinKey + const custom = settings.customProviders?.[providerId] + if (custom) return custom.apiKey || "" + return null +} + +export const hasSchemaAccess = (settings: AiAssistantSettings): boolean => { + const selectedModel = getSelectedModel(settings) + if (!selectedModel) return false + + const provider = providerForModel(selectedModel, settings) + if (!provider) return false + + return ( + settings.providers?.[provider]?.grantSchemaAccess === true || + settings.customProviders?.[provider]?.grantSchemaAccess === true + ) +} diff --git a/src/utils/ai/shared.test.ts b/src/utils/ai/shared.test.ts new file mode 100644 index 000000000..b73d923d8 --- /dev/null +++ b/src/utils/ai/shared.test.ts @@ -0,0 +1,969 @@ +import { describe, it, expect } from "vitest" +import { + extractJsonWithExpectedFields, + parseCustomProviderResponse, + safeJsonParse, +} from "./shared" + +type SqlResponse = { sql: string | null; explanation: string } +type TitleResponse = { title: string } +type ExplainResponse = { explanation: string } + +const sqlFields = ["sql", "explanation"] +const titleFields = ["title"] +const explainFields = ["explanation"] + +const sqlFallback = (raw: string): SqlResponse => ({ + explanation: raw, + sql: null, +}) +const titleFallback = (raw: string): TitleResponse => ({ + title: raw.trim().slice(0, 40), +}) +const explainFallback = (raw: string): ExplainResponse => ({ + explanation: raw, +}) + +describe("parseCustomProviderResponse", () => { + // ─── Step 1: Direct JSON.parse ─────────────────────────────────── + + describe("step 1: valid JSON string", () => { + it("parses valid JSON with sql and explanation", () => { + const text = JSON.stringify({ + sql: "SELECT * FROM trades", + explanation: "Fetches all trades", + }) + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result).toEqual({ + sql: "SELECT * FROM trades", + explanation: "Fetches all trades", + }) + }) + + it("parses valid JSON with title", () => { + const text = JSON.stringify({ title: "My Chat" }) + const result = parseCustomProviderResponse( + text, + titleFields, + titleFallback, + ) + expect(result).toEqual({ title: "My Chat" }) + }) + + it("parses valid JSON with explanation only", () => { + const text = JSON.stringify({ explanation: "This is an explanation" }) + const result = parseCustomProviderResponse( + text, + explainFields, + explainFallback, + ) + expect(result).toEqual({ explanation: "This is an explanation" }) + }) + + it("parses valid JSON with null sql", () => { + const text = JSON.stringify({ + sql: null, + explanation: "No SQL needed", + }) + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result).toEqual({ sql: null, explanation: "No SQL needed" }) + }) + + it("returns full parsed object even with extra fields in step 1", () => { + const text = JSON.stringify({ + sql: "SELECT 1", + explanation: "test", + extra: "ignored by caller but present", + }) + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + // Step 1 returns JSON.parse(text) as-is, including extra fields + expect(result.sql).toBe("SELECT 1") + expect(result.explanation).toBe("test") + }) + }) + + // ─── Step 2: JSON in ```json ``` block ─────────────────────────── + + describe("step 2: JSON in ```json block", () => { + it("extracts JSON from ```json block", () => { + const text = + 'Here is the result:\n\n```json\n{"sql": "SELECT 1", "explanation": "Returns 1"}\n```' + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result).toEqual({ sql: "SELECT 1", explanation: "Returns 1" }) + }) + + it("extracts JSON from ```json block with pretty-printed JSON", () => { + const text = `Some preamble text. + +\`\`\`json +{ + "sql": "SELECT * FROM t", + "explanation": "Gets all rows" +} +\`\`\`` + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result).toEqual({ + sql: "SELECT * FROM t", + explanation: "Gets all rows", + }) + }) + + it("handles ```json block with nested markdown code blocks in explanation", () => { + const explanation = + "# ASOF JOIN\n\n```sql\nSELECT * FROM t1 ASOF JOIN t2\n```\n\nMore text\n\n```sql\nSELECT 1\n```" + const json = JSON.stringify({ + sql: "SELECT * FROM t1 ASOF JOIN t2", + explanation, + }) + const text = `Here is the response:\n\n\`\`\`json\n${json}\n\`\`\`` + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result.sql).toBe("SELECT * FROM t1 ASOF JOIN t2") + expect(result.explanation).toBe(explanation) + }) + + it("extracts title from ```json block", () => { + const text = + 'Generated title:\n\n```json\n{"title": "Trade Analysis"}\n```' + const result = parseCustomProviderResponse( + text, + titleFields, + titleFallback, + ) + expect(result).toEqual({ title: "Trade Analysis" }) + }) + + it("only includes expected fields from ```json block", () => { + const text = + '```json\n{"sql": "SELECT 1", "explanation": "test", "confidence": 0.9}\n```' + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result).toEqual({ sql: "SELECT 1", explanation: "test" }) + expect((result as Record).confidence).toBeUndefined() + }) + }) + + // ─── Step 2: Bare JSON with preamble ───────────────────────────── + + describe("step 2: bare JSON without ```json wrapper", () => { + it("extracts bare JSON after preamble text", () => { + const text = + 'Excellent! Here is the response:\n\n{"sql": "SELECT 1", "explanation": "Returns one"}' + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result).toEqual({ sql: "SELECT 1", explanation: "Returns one" }) + }) + + it("extracts bare JSON with preamble and epilogue", () => { + const text = + 'Here:\n{"sql": null, "explanation": "No query needed"}\nHope this helps!' + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result).toEqual({ sql: null, explanation: "No query needed" }) + }) + + it("handles preamble text that contains curly braces", () => { + const text = + 'Using {ASOF JOIN} syntax:\n\n{"sql": "SELECT 1", "explanation": "test"}' + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result).toEqual({ sql: "SELECT 1", explanation: "test" }) + }) + + it("extracts bare pretty-printed JSON after preamble", () => { + const text = `Let me provide the final response: + +{ + "sql": "SELECT * FROM trades LIMIT 10", + "explanation": "Fetches recent trades" +}` + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result).toEqual({ + sql: "SELECT * FROM trades LIMIT 10", + explanation: "Fetches recent trades", + }) + }) + }) + + // ─── Step 2: JSON with complex content ─────────────────────────── + + describe("step 2: complex content in JSON values", () => { + it("handles explanation with curly braces inside strings", () => { + const text = + 'Result:\n\n{"sql": "SELECT 1", "explanation": "Use {curly braces} in templates"}' + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result.explanation).toBe("Use {curly braces} in templates") + }) + + it("handles explanation with escaped quotes", () => { + const json = JSON.stringify({ + sql: 'SELECT * FROM "my table"', + explanation: 'Use "double quotes" for identifiers', + }) + const text = `Here:\n${json}` + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result.sql).toBe('SELECT * FROM "my table"') + expect(result.explanation).toBe('Use "double quotes" for identifiers') + }) + + it("handles SQL with complex nested queries", () => { + const sql = + "SELECT t.*, m.bids[1,1] FROM trades t ASOF JOIN market_data m ON (t.symbol = m.symbol) WHERE t.timestamp IN yesterday()" + const json = JSON.stringify({ sql, explanation: "Complex join query" }) + const text = `\`\`\`json\n${json}\n\`\`\`` + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result.sql).toBe(sql) + }) + + it("handles real-world DeepSeek response with nested markdown code blocks", () => { + // Simulates the actual failing case from DeepSeek via OpenRouter + const explanation = + "# ASOF JOIN\n\n## Basic Syntax\n\n```sql\nSELECT columns\nFROM left_table\nASOF JOIN right_table ON (matching_columns)\n```\n\n## Example\n\n```sql\nSELECT t.*, m.bids[1,1]\nFROM trades t\nASOF JOIN market_data m ON (symbol)\n```\n\nMore details here." + const innerJson = JSON.stringify({ + sql: "SELECT t.*, m.bids[1,1]\nFROM trades t\nASOF JOIN market_data m ON (symbol)", + explanation, + }) + const text = `Perfect! Now I'll provide you with a comprehensive response.\n\n\`\`\`json\n${innerJson}\n\`\`\`` + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result.sql).toBe( + "SELECT t.*, m.bids[1,1]\nFROM trades t\nASOF JOIN market_data m ON (symbol)", + ) + expect(result.explanation).toBe(explanation) + }) + + it("handles real-world DeepSeek response with bare JSON (no ```json wrapper)", () => { + // Simulates the actual failing case from DeepSeek via DeepInfra + const explanation = + "# ASOF JOIN in QuestDB\n\n## Example\n\n```sql\nSELECT t.*, m.bids[1,1]\nFROM trades t\nASOF JOIN market_data m ON (symbol)\n```" + const innerJson = JSON.stringify({ + sql: "SELECT t.*, m.bids[1,1]\nFROM trades t\nASOF JOIN market_data m ON (symbol)", + explanation, + }) + const text = `Excellent! Now let me provide the final response with examples:\n\n${innerJson}` + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result.sql).toBe( + "SELECT t.*, m.bids[1,1]\nFROM trades t\nASOF JOIN market_data m ON (symbol)", + ) + expect(result.explanation).toBe(explanation) + }) + }) + + // ─── Step 2: Missing expected fields ───────────────────────────── + + describe("step 2: missing expected fields", () => { + it("falls back when JSON in ```json block has wrong fields", () => { + const text = '```json\n{"query": "SELECT 1", "description": "test"}\n```' + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + // Falls to fallback because "sql" and "explanation" are missing + expect(result).toEqual({ + explanation: text, + sql: null, + }) + }) + + it("falls back when bare JSON has wrong fields", () => { + const text = 'Here:\n{"answer": "SELECT 1", "reasoning": "because"}' + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result).toEqual({ + explanation: text, + sql: null, + }) + }) + + it("falls back when JSON has only some expected fields", () => { + const text = '{"explanation": "test"}' + // Step 1 parses it successfully — this is valid JSON + // But wait, step 1 returns JSON.parse as-is without field checking + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + // Step 1 returns it as-is (no field validation in step 1) + expect(result.explanation).toBe("test") + expect(result.sql).toBeUndefined() + }) + }) + + // ─── Step 2: Invalid JSON repaired by jsonrepair ──────────────── + + describe("step 2: malformed JSON repaired by jsonrepair", () => { + it("repairs trailing commas in ```json block", () => { + const text = '```json\n{"sql": "SELECT 1", "explanation": "test",}\n```' + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result).toEqual({ sql: "SELECT 1", explanation: "test" }) + }) + + it("repairs single-quoted strings in bare JSON", () => { + const text = "Here:\n{'sql': 'SELECT 1', 'explanation': 'test'}" + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result).toEqual({ sql: "SELECT 1", explanation: "test" }) + }) + + it("repairs unquoted keys", () => { + const text = '{sql: "SELECT 1", explanation: "test"}' + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result).toEqual({ sql: "SELECT 1", explanation: "test" }) + }) + + it("repairs Python boolean/null constants (True, False, None)", () => { + const text = '{"sql": None, "explanation": "No query needed"}' + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result.sql).toBeNull() + expect(result.explanation).toBe("No query needed") + }) + + it("repairs trailing comma + unquoted keys combined", () => { + const text = "Preamble:\n{sql: 'SELECT 1', explanation: 'works',}" + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result).toEqual({ sql: "SELECT 1", explanation: "works" }) + }) + + it("repairs missing closing brace (truncated JSON)", () => { + // jsonrepair can add the missing brace + const text = '{"sql": "SELECT 1", "explanation": "truncated' + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result.sql).toBe("SELECT 1") + expect(result.explanation).toBe("truncated") + }) + }) + + // ─── Step 2: Unrepairable invalid JSON ───────────────────────── + + describe("step 2: unrepairable invalid JSON", () => { + it("falls back when ```json block contains gibberish", () => { + const text = "```json\n{invalid json here}\n```" + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result).toEqual({ explanation: text, sql: null }) + }) + }) + + // ─── Step 3: jsonrepair on full text ─────────────────────────── + + describe("step 3: jsonrepair on full text", () => { + it("repairs full-text malformed JSON (no preamble)", () => { + const text = "{'title': 'My Chat',}" + const result = parseCustomProviderResponse( + text, + titleFields, + titleFallback, + ) + expect(result).toEqual({ title: "My Chat" }) + }) + + it("repairs JSON with comments", () => { + // jsonrepair strips JS comments + const text = + '{\n "sql": "SELECT 1", // the query\n "explanation": "test"\n}' + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result.sql).toBe("SELECT 1") + expect(result.explanation).toBe("test") + }) + }) + + // ─── Step 4: Fallback to raw text ──────────────────────────────── + + describe("step 4: fallback", () => { + it("returns fallback for plain text response", () => { + const text = "I can help you write a query for that." + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result).toEqual({ explanation: text, sql: null }) + }) + + it("returns fallback for empty string", () => { + const result = parseCustomProviderResponse( + "", + sqlFields, + sqlFallback, + ) + expect(result).toEqual({ explanation: "", sql: null }) + }) + + it("returns title fallback for plain text", () => { + const text = "Trade Analysis Overview" + const result = parseCustomProviderResponse( + text, + titleFields, + titleFallback, + ) + expect(result).toEqual({ title: "Trade Analysis Overview" }) + }) + + it("truncates long title in fallback", () => { + const text = + "This is a very long title that should be truncated to forty characters maximum" + const result = parseCustomProviderResponse( + text, + titleFields, + titleFallback, + ) + expect(result.title.length).toBeLessThanOrEqual(40) + }) + + it("returns fallback for markdown without JSON", () => { + const text = + "# Query Help\n\nYou can use `SELECT * FROM trades` to get all trades." + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result).toEqual({ explanation: text, sql: null }) + }) + + it("returns fallback when text has braces but no valid JSON", () => { + const text = + "Use the following syntax: if (x > 0) { return x; } else { return -x; }" + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result).toEqual({ explanation: text, sql: null }) + }) + }) + + // ─── Edge cases ────────────────────────────────────────────────── + + describe("edge cases", () => { + it("handles empty expectedFields (matches any valid JSON object)", () => { + const text = 'Preamble\n{"foo": "bar"}' + const result = parseCustomProviderResponse>( + text, + [], + (raw) => ({ raw }), + ) + // Empty expectedFields means every() is vacuously true, + // but only expected fields are extracted → empty object + expect(result).toEqual({}) + }) + + it("handles JSON array (not object) — falls back", () => { + const text = '[{"sql": "SELECT 1"}]' + // Step 1: JSON.parse succeeds and returns the array + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + // Arrays are returned as-is from step 1 + expect(Array.isArray(result)).toBe(true) + }) + + it("handles multiple JSON objects in text — picks first with expected fields", () => { + const text = + '{"wrong": true}\n\n{"sql": "SELECT 1", "explanation": "test"}' + // Step 1: JSON.parse fails (two objects aren't valid single JSON) + // Step 2: first { → {"wrong": true} → valid but wrong fields → next { + // second { → finds the correct one + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result).toEqual({ sql: "SELECT 1", explanation: "test" }) + }) + + it("handles deeply nested JSON braces", () => { + const sql = "SELECT CASE WHEN x > 0 THEN 'pos' ELSE 'neg' END FROM t" + const text = `Result:\n${JSON.stringify({ sql, explanation: "CASE expression" })}` + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result.sql).toBe(sql) + }) + + it("handles unicode content", () => { + const text = JSON.stringify({ + sql: "SELECT * FROM données", + explanation: "Récupère toutes les données 日本語テスト", + }) + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result.sql).toBe("SELECT * FROM données") + expect(result.explanation).toContain("日本語テスト") + }) + + it("handles whitespace-only content", () => { + const result = parseCustomProviderResponse( + " \n\n ", + sqlFields, + sqlFallback, + ) + expect(result).toEqual({ explanation: " \n\n ", sql: null }) + }) + + it("handles ``` block without json language tag", () => { + const text = + 'Here:\n\n```\n{"sql": "SELECT 1", "explanation": "test"}\n```' + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result).toEqual({ sql: "SELECT 1", explanation: "test" }) + }) + + it("handles JSON with very long explanation containing multiple code blocks", () => { + const explanation = [ + "# Complex Query", + "", + "First, create the table:", + "```sql", + "CREATE TABLE trades (symbol STRING, price DOUBLE, ts TIMESTAMP) timestamp(ts);", + "```", + "", + "Then insert data:", + "```sql", + "INSERT INTO trades VALUES('BTC', 50000, now());", + "```", + "", + "Finally, query it:", + "```sql", + "SELECT * FROM trades WHERE symbol = 'BTC';", + "```", + ].join("\n") + const json = JSON.stringify({ + sql: "SELECT * FROM trades WHERE symbol = 'BTC'", + explanation, + }) + const text = `\`\`\`json\n${json}\n\`\`\`` + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result.sql).toBe("SELECT * FROM trades WHERE symbol = 'BTC'") + expect(result.explanation).toBe(explanation) + }) + + it("handles JSON where explanation contains JSON-like text", () => { + const json = JSON.stringify({ + sql: "SELECT 1", + explanation: + 'The response format is {"key": "value"} and you can nest {objects: {inside}} as needed.', + }) + const text = `Response:\n${json}` + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result.sql).toBe("SELECT 1") + expect(result.explanation).toContain('{"key": "value"}') + }) + + it("handles JSON with escaped backslashes and special chars", () => { + const text = JSON.stringify({ + sql: "SELECT '\\n' FROM t", + explanation: "Selects a backslash-n string\ttab", + }) + const result = parseCustomProviderResponse( + text, + sqlFields, + sqlFallback, + ) + expect(result.sql).toBe("SELECT '\\n' FROM t") + expect(result.explanation).toBe("Selects a backslash-n string\ttab") + }) + }) +}) + +describe("extractJsonWithExpectedFields", () => { + // ─── Basic extraction ─────────────────────────────────────────── + + it("extracts valid JSON with all expected fields", () => { + const text = '{"sql": "SELECT 1", "explanation": "test"}' + const result = extractJsonWithExpectedFields(text, ["sql", "explanation"]) + expect(result).toEqual({ sql: "SELECT 1", explanation: "test" }) + }) + + it("returns only expected fields, stripping extras", () => { + const text = '{"sql": "SELECT 1", "explanation": "test", "confidence": 0.9}' + const result = extractJsonWithExpectedFields(text, ["sql", "explanation"]) + expect(result).toEqual({ sql: "SELECT 1", explanation: "test" }) + expect(result).not.toHaveProperty("confidence") + }) + + it("returns null when no JSON found", () => { + const result = extractJsonWithExpectedFields("plain text", ["sql"]) + expect(result).toBeNull() + }) + + it("returns null for empty string", () => { + const result = extractJsonWithExpectedFields("", ["sql"]) + expect(result).toBeNull() + }) + + it("returns null when no opening brace exists", () => { + const result = extractJsonWithExpectedFields("no braces here", ["sql"]) + expect(result).toBeNull() + }) + + // ─── Field matching ───────────────────────────────────────────── + + it("returns null when JSON is valid but missing expected fields", () => { + const text = '{"query": "SELECT 1", "description": "test"}' + const result = extractJsonWithExpectedFields(text, ["sql", "explanation"]) + expect(result).toBeNull() + }) + + it("returns null when only some expected fields present", () => { + const text = '{"sql": "SELECT 1"}' + const result = extractJsonWithExpectedFields(text, ["sql", "explanation"]) + expect(result).toBeNull() + }) + + it("matches with empty expectedFields (vacuously true)", () => { + const text = '{"anything": "works"}' + const result = extractJsonWithExpectedFields(text, []) + expect(result).toEqual({}) + }) + + it("includes fields with null values", () => { + const text = '{"sql": null, "explanation": "No query needed"}' + const result = extractJsonWithExpectedFields(text, ["sql", "explanation"]) + expect(result).toEqual({ sql: null, explanation: "No query needed" }) + }) + + it("includes fields with falsy values (0, false, empty string)", () => { + const text = '{"count": 0, "active": false, "name": ""}' + const result = extractJsonWithExpectedFields(text, [ + "count", + "active", + "name", + ]) + expect(result).toEqual({ count: 0, active: false, name: "" }) + }) + + // ─── Preamble and epilogue ────────────────────────────────────── + + it("extracts JSON after preamble text", () => { + const text = + 'Here is the result:\n\n{"sql": "SELECT 1", "explanation": "test"}' + const result = extractJsonWithExpectedFields(text, ["sql", "explanation"]) + expect(result).toEqual({ sql: "SELECT 1", explanation: "test" }) + }) + + it("extracts JSON with both preamble and epilogue", () => { + const text = 'Result:\n{"title": "My Chat"}\nHope that helps!' + const result = extractJsonWithExpectedFields(text, ["title"]) + expect(result).toEqual({ title: "My Chat" }) + }) + + it("extracts JSON from ```json code block", () => { + const text = '```json\n{"sql": "SELECT 1", "explanation": "test"}\n```' + const result = extractJsonWithExpectedFields(text, ["sql", "explanation"]) + expect(result).toEqual({ sql: "SELECT 1", explanation: "test" }) + }) + + // ─── Multiple JSON objects ────────────────────────────────────── + + it("skips first JSON object if it lacks expected fields", () => { + const text = '{"wrong": true}\n\n{"sql": "SELECT 1", "explanation": "test"}' + const result = extractJsonWithExpectedFields(text, ["sql", "explanation"]) + expect(result).toEqual({ sql: "SELECT 1", explanation: "test" }) + }) + + it("when multiple matching objects exist, jsonrepair merges them (last wins)", () => { + const text = + '{"sql": "SELECT 1", "explanation": "first"}\n{"sql": "SELECT 2", "explanation": "second"}' + const result = extractJsonWithExpectedFields(text, ["sql", "explanation"]) + // jsonrepair merges concatenated objects, later keys overwrite earlier ones + expect(result).toEqual({ sql: "SELECT 2", explanation: "second" }) + }) + + it("skips non-JSON brace text to find real JSON", () => { + const text = + 'Use {curly braces} for templates\n{"sql": "SELECT 1", "explanation": "found"}' + const result = extractJsonWithExpectedFields(text, ["sql", "explanation"]) + expect(result).toEqual({ sql: "SELECT 1", explanation: "found" }) + }) + + // ─── jsonrepair fallback ──────────────────────────────────────── + + it("repairs trailing commas via jsonrepair", () => { + const text = '{"sql": "SELECT 1", "explanation": "test",}' + const result = extractJsonWithExpectedFields(text, ["sql", "explanation"]) + expect(result).toEqual({ sql: "SELECT 1", explanation: "test" }) + }) + + it("repairs single-quoted strings via jsonrepair", () => { + const text = "{'sql': 'SELECT 1', 'explanation': 'test'}" + const result = extractJsonWithExpectedFields(text, ["sql", "explanation"]) + expect(result).toEqual({ sql: "SELECT 1", explanation: "test" }) + }) + + it("repairs unquoted keys via jsonrepair", () => { + const text = '{sql: "SELECT 1", explanation: "test"}' + const result = extractJsonWithExpectedFields(text, ["sql", "explanation"]) + expect(result).toEqual({ sql: "SELECT 1", explanation: "test" }) + }) + + it("repairs Python None to null via jsonrepair", () => { + const text = '{"sql": None, "explanation": "No query"}' + const result = extractJsonWithExpectedFields(text, ["sql", "explanation"]) + expect(result).toEqual({ sql: null, explanation: "No query" }) + }) + + // ─── Complex content ──────────────────────────────────────────── + + it("handles nested braces inside string values", () => { + const text = JSON.stringify({ + sql: "SELECT 1", + explanation: "Use {curly braces} in {templates}", + }) + const result = extractJsonWithExpectedFields(text, ["sql", "explanation"]) + expect(result!.explanation).toBe("Use {curly braces} in {templates}") + }) + + it("handles escaped quotes inside string values", () => { + const text = JSON.stringify({ + sql: 'SELECT * FROM "my table"', + explanation: 'Use "double quotes" for identifiers', + }) + const result = extractJsonWithExpectedFields(text, ["sql", "explanation"]) + expect(result!.sql).toBe('SELECT * FROM "my table"') + }) + + it("handles newlines inside string values", () => { + const text = JSON.stringify({ + sql: "SELECT *\nFROM trades\nLIMIT 10", + explanation: "Multi-line query", + }) + const result = extractJsonWithExpectedFields(text, ["sql", "explanation"]) + expect(result!.sql).toBe("SELECT *\nFROM trades\nLIMIT 10") + }) + + it("handles pretty-printed JSON", () => { + const text = `{ + "sql": "SELECT 1", + "explanation": "test" +}` + const result = extractJsonWithExpectedFields(text, ["sql", "explanation"]) + expect(result).toEqual({ sql: "SELECT 1", explanation: "test" }) + }) + + it("handles unicode content", () => { + const text = JSON.stringify({ + sql: "SELECT * FROM données", + explanation: "日本語テスト", + }) + const result = extractJsonWithExpectedFields(text, ["sql", "explanation"]) + expect(result!.sql).toBe("SELECT * FROM données") + expect(result!.explanation).toBe("日本語テスト") + }) + + it("returns null when text has braces but no valid JSON", () => { + const text = "if (x > 0) { return x; } else { return -x; }" + const result = extractJsonWithExpectedFields(text, ["sql", "explanation"]) + expect(result).toBeNull() + }) + + it("returns null for nested non-JSON braces", () => { + const text = "function() { if (true) { console.log('hi') } }" + const result = extractJsonWithExpectedFields(text, ["sql"]) + expect(result).toBeNull() + }) + + it("handles single expected field", () => { + const text = '{"title": "My Chat"}' + const result = extractJsonWithExpectedFields(text, ["title"]) + expect(result).toEqual({ title: "My Chat" }) + }) + + it("handles many expected fields", () => { + const text = JSON.stringify({ a: 1, b: 2, c: 3, d: 4, e: 5 }) + const result = extractJsonWithExpectedFields(text, [ + "a", + "b", + "c", + "d", + "e", + ]) + expect(result).toEqual({ a: 1, b: 2, c: 3, d: 4, e: 5 }) + }) + + it("handles JSON embedded deep in markdown", () => { + const text = `# Response + +Here is the analysis: + +Some preamble with {random} braces. + +\`\`\`json +{"sql": "SELECT count() FROM trades", "explanation": "Counts all trades"} +\`\`\` + +And some epilogue text.` + const result = extractJsonWithExpectedFields(text, ["sql", "explanation"]) + expect(result).toEqual({ + sql: "SELECT count() FROM trades", + explanation: "Counts all trades", + }) + }) +}) + +describe("safeJsonParse", () => { + it("parses valid JSON", () => { + const result = safeJsonParse<{ a: number }>('{"a": 1}') + expect(result).toEqual({ a: 1 }) + }) + + it("returns jsonrepair result for non-JSON text", () => { + // jsonrepair turns plain text into a JSON string + const result = safeJsonParse("not json at all") + expect(result).toBe("not json at all") + }) + + it("returns empty object for empty input", () => { + const result = safeJsonParse("{") + expect(result).toEqual({}) + }) + + it("repairs truncated JSON (missing closing brace)", () => { + // Real-world Qwen case: tool call arguments with missing } + const text = + '{"category": "functions", "items": ["today", "tomorrow", "yesterday"]' + const result = safeJsonParse<{ + category: string + items: string[] + }>(text) + expect(result).toEqual({ + category: "functions", + items: ["today", "tomorrow", "yesterday"], + }) + }) + + it("repairs trailing commas", () => { + const result = safeJsonParse<{ table_name: string }>( + '{"table_name": "trades",}', + ) + expect(result).toEqual({ table_name: "trades" }) + }) + + it("repairs single-quoted strings", () => { + const result = safeJsonParse<{ query: string }>("{'query': 'SELECT 1'}") + expect(result).toEqual({ query: "SELECT 1" }) + }) + + it("repairs unquoted keys", () => { + const result = safeJsonParse<{ table_name: string }>( + '{table_name: "trades"}', + ) + expect(result).toEqual({ table_name: "trades" }) + }) + + it("handles empty string arguments", () => { + const result = safeJsonParse("") + expect(result).toEqual({}) + }) +}) diff --git a/src/utils/ai/shared.ts b/src/utils/ai/shared.ts new file mode 100644 index 000000000..fbda68515 --- /dev/null +++ b/src/utils/ai/shared.ts @@ -0,0 +1,305 @@ +import type { ModelToolsClient, StatusCallback } from "../aiAssistant" +import { AIOperationStatus } from "../../providers/AIStatusProvider" +import { + getQuestDBTableOfContents, + getSpecificDocumentation, + parseDocItems, + DocCategory, +} from "../questdbDocsRetrieval" +import type { ResponseFormatSchema } from "./types" +import { jsonrepair } from "jsonrepair" + +export class RefusalError extends Error { + constructor(message: string) { + super(message) + this.name = "RefusalError" + } +} + +export class MaxTokensError extends Error { + constructor(message: string) { + super(message) + this.name = "MaxTokensError" + } +} + +export class StreamingError extends Error { + constructor( + message: string, + public readonly errorType: "failed" | "network" | "interrupted" | "unknown", + public readonly originalError?: unknown, + ) { + super(message) + this.name = "StreamingError" + } +} + +export const safeJsonParse = (text: string): T | object => { + try { + return JSON.parse(text) as T + } catch { + try { + return JSON.parse(jsonrepair(text)) as T + } catch { + return {} + } + } +} + +export function extractPartialExplanation(partialJson: string): string { + const explanationMatch = partialJson.match( + /"explanation"\s*:\s*"((?:[^"\\]|\\.)*)/, + ) + if (!explanationMatch) { + return "" + } + + return explanationMatch[1] + .replace(/\\n/g, "\n") + .replace(/\\r/g, "\r") + .replace(/\\t/g, "\t") + .replace(/\\"/g, '"') + .replace(/\\\\/g, "\\") +} + +export const executeTool = async ( + toolName: string, + input: unknown, + modelToolsClient: ModelToolsClient, + setStatus: StatusCallback, +): Promise<{ content: string; is_error?: boolean }> => { + try { + switch (toolName) { + case "get_tables": { + setStatus(AIOperationStatus.RetrievingTables) + if (!modelToolsClient.getTables) { + return { + content: + "Error: Schema access is not granted. This tool is not available.", + is_error: true, + } + } + const result = await modelToolsClient.getTables() + const MAX_TABLES = 1000 + if (result.length > MAX_TABLES) { + const truncated = result.slice(0, MAX_TABLES) + return { + content: JSON.stringify( + { + tables: truncated, + total_count: result.length, + truncated: true, + message: `Showing ${MAX_TABLES} of ${result.length} tables. Use get_table_schema with a specific table name to get details if you are interested in a specific table.`, + }, + null, + 2, + ), + } + } + return { content: JSON.stringify(result, null, 2) } + } + case "get_table_schema": { + const tableName = (input as { table_name: string })?.table_name + if (!modelToolsClient.getTableSchema) { + return { + content: + "Error: Schema access is not granted. This tool is not available.", + is_error: true, + } + } + if (!tableName) { + return { + content: "Error: table_name parameter is required", + is_error: true, + } + } + setStatus(AIOperationStatus.InvestigatingTable, { + name: tableName, + tableOpType: "schema", + }) + const result = await modelToolsClient.getTableSchema(tableName) + return { + content: + result || `Table '${tableName}' not found or schema unavailable`, + } + } + case "get_table_details": { + const tableName = (input as { table_name: string })?.table_name + if (!modelToolsClient.getTableDetails) { + return { + content: + "Error: Schema access is not granted. This tool is not available.", + is_error: true, + } + } + if (!tableName) { + return { + content: "Error: table_name parameter is required", + is_error: true, + } + } + setStatus(AIOperationStatus.InvestigatingTable, { + name: tableName, + tableOpType: "details", + }) + const result = await modelToolsClient.getTableDetails(tableName) + return { + content: result + ? JSON.stringify(result, null, 2) + : "Table details not found", + is_error: !result, + } + } + case "validate_query": { + setStatus(AIOperationStatus.ValidatingQuery) + const query = (input as { query: string })?.query + if (!query) { + return { + content: "Error: query parameter is required", + is_error: true, + } + } + const result = await modelToolsClient.validateQuery(query) + const content = { + valid: result.valid, + error: result.valid ? undefined : result.error, + position: result.valid ? undefined : result.position, + } + return { content: JSON.stringify(content, null, 2) } + } + case "get_questdb_toc": { + setStatus(AIOperationStatus.RetrievingDocumentation) + const tocContent = await getQuestDBTableOfContents() + return { content: tocContent } + } + case "get_questdb_documentation": { + const { category, items } = + (input as { category: string; items: string[] }) || {} + if (!category || !items || !Array.isArray(items)) { + return { + content: "Error: category and items parameters are required", + is_error: true, + } + } + const parsedItems = parseDocItems(items) + + if (parsedItems.length > 0) { + setStatus(AIOperationStatus.InvestigatingDocs, { items: parsedItems }) + } else { + setStatus(AIOperationStatus.InvestigatingDocs) + } + const documentation = await getSpecificDocumentation( + category as DocCategory, + items, + ) + return { content: documentation } + } + default: + return { content: `Unknown tool: ${toolName}`, is_error: true } + } + } catch (error) { + return { + content: `Tool execution error: ${error instanceof Error ? error.message : "Unknown error"}`, + is_error: true, + } + } +} + +export function extractJsonWithExpectedFields( + text: string, + expectedFields: string[], +): Record | null { + let searchStart = 0 + while (true) { + const braceStart = text.indexOf("{", searchStart) + if (braceStart === -1) break + + const textFromBrace = text.slice(braceStart) + let endIdx = textFromBrace.lastIndexOf("}") + while (endIdx > 0) { + const candidate = textFromBrace.slice(0, endIdx + 1) + // Try direct JSON.parse first, then jsonrepair as fallback + let parsed: Record | null = null + try { + parsed = JSON.parse(candidate) as Record + } catch { + try { + parsed = JSON.parse(jsonrepair(candidate)) as Record + } catch { + // jsonrepair couldn't fix it either + } + } + + if (parsed !== null) { + if (expectedFields.every((field) => field in parsed)) { + const result: Record = {} + for (const field of expectedFields) { + result[field] = parsed[field] + } + return result + } + break // Valid JSON but missing expected fields — try next { + } + endIdx = textFromBrace.lastIndexOf("}", endIdx - 1) + } + searchStart = braceStart + 1 + } + return null +} + +export function parseCustomProviderResponse( + text: string, + expectedFields: string[], + fallback: (rawText: string) => T, +): T { + try { + return JSON.parse(text) as T + } catch { + // not valid JSON as-is + } + + const extracted = extractJsonWithExpectedFields(text, expectedFields) + if (extracted) { + return extracted as T + } + + try { + const repaired = JSON.parse(jsonrepair(text)) as Record + if ( + repaired !== null && + typeof repaired === "object" && + !Array.isArray(repaired) && + (expectedFields.length === 0 || + expectedFields.every((field) => field in repaired)) + ) { + return repaired as T + } + } catch { + // jsonrepair couldn't salvage it + } + + // Fallback — caller decides the shape + return fallback(text) +} + +export function responseFormatToPromptInstruction( + format: ResponseFormatSchema, +): string { + const properties = format.schema.properties as Record< + string, + { type: unknown } + > + const required = (format.schema.required as string[]) || [] + + const fields = Object.entries(properties) + .map(([key, value]) => { + const typeStr = Array.isArray(value.type) + ? value.type.join(" | ") + : String(value.type) + const isRequired = required.includes(key) + return ` "${key}": ${typeStr}${isRequired ? " (required)" : " (optional)"}` + }) + .join(",\n") + + return `\nAlways respond with a valid JSON object with the following fields:\n{\n${fields}\n}` +} diff --git a/src/utils/ai/tools.ts b/src/utils/ai/tools.ts new file mode 100644 index 000000000..557a64e2b --- /dev/null +++ b/src/utils/ai/tools.ts @@ -0,0 +1,105 @@ +import type { ToolDefinition } from "./types" + +export const SCHEMA_TOOLS: ToolDefinition[] = [ + { + name: "get_tables", + description: + "Get a list of all tables and materialized views in the QuestDB database", + inputSchema: { + type: "object", + properties: {}, + }, + }, + { + name: "get_table_schema", + description: + "Get the full schema definition (DDL) for a specific table or materialized view", + inputSchema: { + type: "object", + properties: { + table_name: { + type: "string", + description: + "The name of the table or materialized view to get schema for", + }, + }, + required: ["table_name"], + }, + }, + { + name: "get_table_details", + description: + "Get the runtime details/statistics of a specific table or materialized view", + inputSchema: { + type: "object", + properties: { + table_name: { + type: "string", + description: + "The name of the table or materialized view to get details for", + }, + }, + required: ["table_name"], + }, + }, +] + +export const REFERENCE_TOOLS: ToolDefinition[] = [ + { + name: "validate_query", + description: + "Validate the syntax correctness of a SQL query using QuestDB's SQL syntax validator. All generated SQL queries should be validated using this tool before responding to the user.", + inputSchema: { + type: "object", + properties: { + query: { + type: "string", + description: "The SQL query to validate", + }, + }, + required: ["query"], + }, + }, + { + name: "get_questdb_toc", + description: + "Get a table of contents listing all available QuestDB functions, operators, and SQL keywords. Use this first to see what documentation is available before requesting specific items.", + inputSchema: { + type: "object", + properties: {}, + }, + }, + { + name: "get_questdb_documentation", + description: + "Get documentation for specific QuestDB functions, operators, or SQL keywords. This is much more efficient than loading all documentation.", + inputSchema: { + type: "object", + properties: { + category: { + type: "string", + enum: [ + "functions", + "operators", + "sql", + "concepts", + "schema", + "cookbook", + ], + description: "The category of documentation to retrieve", + }, + items: { + type: "array", + items: { + type: "string", + }, + description: + "List of specific docs items in the category. IMPORTANT: Category of these items must match the category parameter. Name of these items should exactly match the entry in the table of contents you get with get_questdb_toc.", + }, + }, + required: ["category", "items"], + }, + }, +] + +export const ALL_TOOLS: ToolDefinition[] = [...SCHEMA_TOOLS, ...REFERENCE_TOOLS] diff --git a/src/utils/ai/types.ts b/src/utils/ai/types.ts new file mode 100644 index 000000000..9667e0704 --- /dev/null +++ b/src/utils/ai/types.ts @@ -0,0 +1,74 @@ +import type { + AiAssistantAPIError, + ModelToolsClient, + StatusCallback, + StreamingCallback, +} from "../aiAssistant" +import type { ProviderId } from "./settings" + +export interface ToolDefinition { + name: string + description?: string + inputSchema: { + type: "object" + properties: Record + required?: string[] + } +} + +export interface ResponseFormatSchema { + name: string + schema: Record + strict: boolean +} + +export interface FlowConfig { + systemInstructions: string + initialUserContent: string + conversationHistory?: Array<{ role: "user" | "assistant"; content: string }> + responseFormat: ResponseFormatSchema + postProcess?: (formatted: T) => T +} + +export interface AIProvider { + readonly id: ProviderId + readonly contextWindow: number + + executeFlow(params: { + model: string + config: FlowConfig + modelToolsClient: ModelToolsClient + tools: ToolDefinition[] + setStatus: StatusCallback + abortSignal?: AbortSignal + streaming?: StreamingCallback + }): Promise + + generateTitle(params: { + model: string + prompt: string + responseFormat: ResponseFormatSchema + }): Promise + + generateSummary(params: { + model: string + systemPrompt: string + userMessage: string + }): Promise + + testConnection(params: { + apiKey: string + model: string + }): Promise<{ valid: boolean; error?: string }> + + countTokens(params: { + messages: Array<{ role: "user" | "assistant"; content: string }> + systemPrompt: string + model: string + }): Promise + + listModels(): Promise + + classifyError(error: unknown, setStatus: StatusCallback): AiAssistantAPIError + isNonRetryableError(error: unknown): boolean +} diff --git a/src/utils/aiAssistant.ts b/src/utils/aiAssistant.ts index 0ed9e6567..514cf169b 100644 --- a/src/utils/aiAssistant.ts +++ b/src/utils/aiAssistant.ts @@ -1,34 +1,32 @@ -import Anthropic from "@anthropic-ai/sdk" -import OpenAI from "openai" import { Client } from "./questdb/client" import { Type, Table } from "./questdb/types" -import { getModelProps, MODEL_OPTIONS } from "./aiAssistantSettings" -import type { ModelOption, Provider } from "./aiAssistantSettings" +import type { ProviderId } from "./ai" +import type { AiAssistantSettings } from "../providers/LocalStorageProvider/types" import { formatSql } from "./formatSql" import { AIOperationStatus, StatusArgs } from "../providers/AIStatusProvider" -import { - getQuestDBTableOfContents, - getSpecificDocumentation, - parseDocItems, - DocCategory, -} from "./questdbDocsRetrieval" -import { MessageParam } from "@anthropic-ai/sdk/resources/messages" -import type { - ResponseOutputItem, - ResponseTextConfig, -} from "openai/resources/responses/responses" -import type { Tool as AnthropicTool } from "@anthropic-ai/sdk/resources/messages" import type { ConversationId, ConversationMessage, } from "../providers/AIConversationProvider/types" import { compactConversationIfNeeded } from "./contextCompaction" -import { COMPACTION_THRESHOLDS } from "./tokenCounting" +import { + createProvider, + ExplainFormat, + FixSQLFormat, + ConversationResponseFormat, + ChatTitleFormat, + ALL_TOOLS, + REFERENCE_TOOLS, + getUnifiedPrompt, + BUILTIN_PROVIDERS, +} from "./ai" +import type { AIProvider } from "./ai" export type ActiveProviderSettings = { model: string - provider: Provider + provider: ProviderId apiKey: string + aiAssistantSettings?: AiAssistantSettings } export interface AiAssistantAPIError { @@ -74,218 +72,6 @@ export type StreamingCallback = { cleanup?: () => void } -type ProviderClients = - | { - provider: "anthropic" - anthropic: Anthropic - } - | { - provider: "openai" - openai: OpenAI - } - -const ExplainFormat: ResponseTextConfig = { - format: { - type: "json_schema" as const, - name: "explain_format", - schema: { - type: "object", - properties: { - explanation: { type: "string" }, - }, - required: ["explanation"], - additionalProperties: false, - }, - strict: true, - }, -} - -const FixSQLFormat: ResponseTextConfig = { - format: { - type: "json_schema" as const, - name: "fix_sql_format", - schema: { - type: "object", - properties: { - sql: { type: ["string", "null"] }, - explanation: { type: "string" }, - }, - required: ["explanation", "sql"], - additionalProperties: false, - }, - strict: true, - }, -} - -const ConversationResponseFormat: ResponseTextConfig = { - format: { - type: "json_schema" as const, - name: "conversation_response_format", - schema: { - type: "object", - properties: { - sql: { type: ["string", "null"] }, - explanation: { type: "string" }, - }, - required: ["sql", "explanation"], - additionalProperties: false, - }, - strict: true, - }, -} - -const inferProviderFromModel = (model: string): Provider => { - const found: ModelOption | undefined = MODEL_OPTIONS.find( - (m) => m.value === model, - ) - if (found) return found.provider - return model.startsWith("claude") ? "anthropic" : "openai" -} - -const createProviderClients = ( - settings: ActiveProviderSettings, -): ProviderClients => { - if (!settings.apiKey) { - throw new Error(`No API key found for ${settings.provider}`) - } - - if (settings.provider === "openai") { - return { - provider: settings.provider, - openai: new OpenAI({ - apiKey: settings.apiKey, - dangerouslyAllowBrowser: true, - }), - } - } - return { - provider: settings.provider, - anthropic: new Anthropic({ - apiKey: settings.apiKey, - dangerouslyAllowBrowser: true, - }), - } -} - -const SCHEMA_TOOLS: Array = [ - { - name: "get_tables", - description: - "Get a list of all tables and materialized views in the QuestDB database", - input_schema: { - type: "object" as const, - properties: {}, - }, - }, - { - name: "get_table_schema", - description: - "Get the full schema definition (DDL) for a specific table or materialized view", - input_schema: { - type: "object" as const, - properties: { - table_name: { - type: "string" as const, - description: - "The name of the table or materialized view to get schema for", - }, - }, - required: ["table_name"], - }, - }, - { - name: "get_table_details", - description: "Get the details of a specific table or materialized view", - input_schema: { - type: "object" as const, - properties: { - table_name: { - type: "string" as const, - description: - "The name of the table or materialized view to get details for", - }, - }, - required: ["table_name"], - }, - }, -] - -const REFERENCE_TOOLS = [ - { - name: "validate_query", - description: - "Validate the syntax correctness of a SQL query using QuestDB's SQL syntax validator. All generated SQL queries should be validated using this tool before responding to the user.", - input_schema: { - type: "object" as const, - properties: { - query: { - type: "string" as const, - description: "The SQL query to validate", - }, - }, - required: ["query"], - }, - }, - { - name: "get_questdb_toc", - description: - "Get a table of contents listing all available QuestDB functions, operators, and SQL keywords. Use this first to see what documentation is available before requesting specific items.", - input_schema: { - type: "object" as const, - properties: {}, - }, - }, - { - name: "get_questdb_documentation", - description: - "Get documentation for specific QuestDB functions, operators, or SQL keywords. This is much more efficient than loading all documentation.", - input_schema: { - type: "object" as const, - properties: { - category: { - type: "string" as const, - enum: [ - "functions", - "operators", - "sql", - "concepts", - "schema", - "cookbook", - ], - description: "The category of documentation to retrieve", - }, - items: { - type: "array" as const, - items: { - type: "string" as const, - }, - description: - "List of specific docs items in the category. IMPORTANT: Category of these items must match the category parameter. Name of these items should exactly match the entry in the table of contents you get with get_questdb_toc.", - }, - }, - required: ["category", "items"], - }, - }, -] - -const ALL_TOOLS = [...SCHEMA_TOOLS, ...REFERENCE_TOOLS] - -const toOpenAIFunctions = ( - tools: Array<{ - name: string - description?: string - input_schema: AnthropicTool["input_schema"] - }>, -) => { - return tools.map((t) => ({ - type: "function" as const, - name: t.name, - description: t.description, - parameters: { ...t.input_schema, additionalProperties: false }, - strict: true, - })) as OpenAI.Responses.Tool[] -} - export const normalizeSql = (sql: string, insertSemicolon: boolean = true) => { if (!sql) return "" let result = sql.trim() @@ -460,179 +246,6 @@ export const createStreamingCallback = ( } } -const DOCS_INSTRUCTION_ANTHROPIC = ` -CRITICAL: Always follow this documentation approach: -1. Use get_questdb_toc to see available functions, operators, SQL syntax, AND cookbook recipes -2. If user's request matches a cookbook recipe description, fetch it FIRST - recipes provide complete, tested SQL patterns -3. Use get_questdb_documentation for specific function/syntax details - -When a cookbook recipe matches the user's intent, ALWAYS use it as the foundation and adapt column/table names and use case to their schema.` - -const getUnifiedPrompt = (grantSchemaAccess?: boolean) => { - const base = `You are a SQL expert coding assistant specializing in QuestDB, a high-performance time-series database. You help users with: -- Generating QuestDB SQL queries from natural language descriptions -- Explaining what QuestDB SQL queries do -- Fixing errors in QuestDB SQL queries -- Refining and modifying existing queries based on user requests - -## CRITICAL: Tool and Response Sequencing -Follow this EXACT sequence for every query generation request: - -**PHASE 1 - INFORMATION GATHERING (NO TEXT OUTPUT)** -1. Call available tools to gather information if you need, including documentation, schema, and validation tools. -2. Complete ALL information gathering before Phase 2. DO NOT CALL any tool after Phase 2. - -**PHASE 2 - FINAL RESPONSE (NO MORE TOOL CALLS)** -3. Return your JSON response with "sql" and "explanation" fields. Always return sql field first, then explanation field. - -NEVER interleave phases. NEVER use any tool after starting to return a response. - -## When Explaining Queries -- Focus on the business logic and what the query achieves, not the SQL syntax itself -- Pay special attention to QuestDB-specific features: - - Time-series operations (SAMPLE BY, LATEST ON, designated timestamp columns) - - Time-based filtering and aggregations - - Real-time data ingestion patterns - - Performance optimizations specific to time-series data - -## When Generating SQL -- DO NOT return any content before completing your tool calls including documentation and validation tools. You should NOT CALL any tool after starting to return a response. -- Always validate the query in "sql" field using the validate_query tool before returning an explanation or a generated SQL query -- Generate only valid QuestDB SQL syntax referring to the documentation about functions, operators, and SQL keywords -- Use appropriate time-series functions (SAMPLE BY, LATEST ON, etc.) and common table expressions when relevant -- Use \`IN\` with \`today()\`, \`tomorrow()\`, \`yesterday()\` interval functions when relevant -- Follow QuestDB best practices for performance referring to the documentation -- Use proper timestamp handling for time-series data -- Use correct data types and functions specific to QuestDB referring to the documentation. Do not use any word that is not in the documentation. - -## When Fixing Queries -- DO NOT return any content before completing your tool calls including documentation and validation tools. You should NOT CALL any tool after starting to return a response. -- Always validate the query in "sql" field using the validate_query tool before returning an explanation or a fixed SQL query -- Analyze the error message carefully to understand what went wrong -- Generate only valid QuestDB SQL syntax by always referring to the documentation about functions, operators, and SQL keywords -- Preserve the original intent of the query while fixing the error -- Follow QuestDB best practices and syntax rules referring to the documentation -- Consider common issues like: - - Missing or incorrect column names - - Invalid syntax for time-series operations - - Data type mismatches - - Incorrect function usage - -## Response Guidelines -- You are working as a coding assistant inside an IDE. Every time you return a query in "sql" field, you provide a suggestion to the user to accept or reject. When the user accepts the suggestion, you are informed and the query in the editor is updated with your suggestion. -- Modify a query by returning "sql" field only if the user asks you to generate, fix, or make changes to the query. If the user does not ask for fixing/changing/generating a query, return null in the "sql" field. Every time you provide a SQL query, the current SQL is updated. -- Provide the "explanation" field if you haven't provided it yet. Explanation should be in GFM (GitHub Flavored Markdown) format. Explanation field is cumulative, every time you provide an explanation, it is added to the previous explanations. - -## Tools -- Use the validate_query tool to validate the query in "sql" field before returning a response only if the user asks you to generate, fix, or make changes to the query. -` - const schemaAccess = grantSchemaAccess - ? `- Use the get_tables tool to retrieve all tables and materialized views in the database instance -- Use the get_table_schema tool to get detailed schema information for a specific table or a materialized view -- Use the get_table_details tool to get detailed information for a specific table or a materialized view. Each property is described in meta functions docs. -` - : "" - return base + schemaAccess + DOCS_INSTRUCTION_ANTHROPIC -} - -export const getExplainSchemaPrompt = ( - tableName: string, - schema: string, - kindLabel: string, -) => `You are a SQL expert assistant specializing in QuestDB, a high-performance time-series database. -Explain the following ${kindLabel} schema. Include: -- The purpose of the ${kindLabel} -- What each column represents and its data type -- Any important properties like WAL enablement, partitioning strategy, designated timestamps -- Any performance or storage considerations - -${kindLabel} Name: ${tableName} - -Schema: -\`\`\`sql -${schema} -\`\`\` - -**IMPORTANT: Format your response in markdown exactly as follows:** - -1. Start with a brief paragraph explaining the purpose and general characteristics of this ${kindLabel}. - -2. Add a "## Columns" section with a markdown table: -| Column | Type | Description | -|--------|------|-------------| -| column_name | \`data_type\` | Brief description | - -3. If this is a table or materialized view (not a view), add a "## Storage Details" section with bullet points about: -- WAL enablement -- Partitioning strategy -- Designated timestamp column -- Any other storage considerations - -For views, skip the Storage Details section.` - -export type HealthIssuePromptData = { - tableName: string - issue: { - id: string - field: string - message: string - currentValue?: string - } - tableDetails: string - monitoringDocs: string - trendSamples?: Array<{ value: number; timestamp: number }> -} - -export const getHealthIssuePrompt = (data: HealthIssuePromptData): string => { - const { tableName, issue, tableDetails, monitoringDocs, trendSamples } = data - - let trendSection = "" - if (trendSamples && trendSamples.length > 0) { - const recentSamples = trendSamples.slice(-30) - trendSection = ` - -### Trend Data (Recent Samples) -| Timestamp | Value | -|-----------|-------| -${recentSamples.map((s) => `| ${new Date(s.timestamp).toISOString()} | ${s.value.toLocaleString()} |`).join("\n")} -` - } - - return `You are a QuestDB expert assistant helping diagnose and resolve table health issues. - -A user is viewing the health monitoring panel for their table and has asked for help with a detected issue. - -## Table: ${tableName} - -## Health Issue Detected -- **Issue ID**: ${issue.id} -- **Field**: ${issue.field} -- **Message**: ${issue.message} -${issue.currentValue ? `- **Current Value**: ${issue.currentValue}` : ""}${trendSection} - -## Table Details (from tables() function) -\`\`\`json -${tableDetails} -\`\`\` - -## QuestDB Monitoring Documentation -${monitoringDocs} - ---- - -**Your Task:** -1. Explain what this health issue means in the context of this specific table -2. Analyze the table details to identify potential root causes -3. Provide specific, actionable recommendations to resolve or mitigate the issue -4. If there is a clear SQL command that can help fix the issue (like \`ALTER TABLE ... RESUME WAL\`), include it in the "sql" field of your response. **Only provide SQL if it directly addresses the root cause** - do not provide SQL just to inspect the problem. - -**IMPORTANT: Be concise and thorough, format your response in markdown with clear sections:** -- Use ## headings for main sections -- Use bullet points for lists -- Use \`code\` for configuration values and SQL -` -} - const MAX_RETRIES = 2 const RETRY_DELAY = 1000 @@ -650,451 +263,9 @@ const handleRateLimit = async () => { lastRequestTime = Date.now() } -const isNonRetryableError = (error: unknown) => { - if (error instanceof StreamingError) { - return error.errorType === "interrupted" || error.errorType === "failed" - } - return ( - error instanceof RefusalError || - error instanceof MaxTokensError || - error instanceof Anthropic.AuthenticationError || - (typeof OpenAI !== "undefined" && - error instanceof OpenAI.AuthenticationError) || - // @ts-expect-error no proper rate limit error type - ("status" in error && error.status === 429) || - error instanceof OpenAI.APIUserAbortError || - error instanceof Anthropic.APIUserAbortError - ) -} - -const executeTool = async ( - toolName: string, - input: unknown, - modelToolsClient: ModelToolsClient, - setStatus: StatusCallback, -): Promise<{ content: string; is_error?: boolean }> => { - try { - switch (toolName) { - case "get_tables": { - setStatus(AIOperationStatus.RetrievingTables) - if (!modelToolsClient.getTables) { - return { - content: - "Error: Schema access is not granted. This tool is not available.", - is_error: true, - } - } - const result = await modelToolsClient.getTables() - const MAX_TABLES = 1000 - if (result.length > MAX_TABLES) { - const truncated = result.slice(0, MAX_TABLES) - return { - content: JSON.stringify( - { - tables: truncated, - total_count: result.length, - truncated: true, - message: `Showing ${MAX_TABLES} of ${result.length} tables. Use get_table_schema with a specific table name to get details if you are interested in a specific table.`, - }, - null, - 2, - ), - } - } - return { content: JSON.stringify(result, null, 2) } - } - case "get_table_schema": { - const tableName = (input as { table_name: string })?.table_name - if (!modelToolsClient.getTableSchema) { - return { - content: - "Error: Schema access is not granted. This tool is not available.", - is_error: true, - } - } - if (!tableName) { - return { - content: "Error: table_name parameter is required", - is_error: true, - } - } - setStatus(AIOperationStatus.InvestigatingTable, { - name: tableName, - tableOpType: "schema", - }) - const result = await modelToolsClient.getTableSchema(tableName) - return { - content: - result || `Table '${tableName}' not found or schema unavailable`, - } - } - case "get_table_details": { - const tableName = (input as { table_name: string })?.table_name - if (!modelToolsClient.getTableDetails) { - return { - content: - "Error: Schema access is not granted. This tool is not available.", - is_error: true, - } - } - if (!tableName) { - return { - content: "Error: table_name parameter is required", - is_error: true, - } - } - setStatus(AIOperationStatus.InvestigatingTable, { - name: tableName, - tableOpType: "details", - }) - const result = await modelToolsClient.getTableDetails(tableName) - return { - content: result - ? JSON.stringify(result, null, 2) - : "Table details not found", - is_error: !result, - } - } - case "validate_query": { - setStatus(AIOperationStatus.ValidatingQuery) - const query = (input as { query: string })?.query - if (!query) { - return { - content: "Error: query parameter is required", - is_error: true, - } - } - const result = await modelToolsClient.validateQuery(query) - const content = { - valid: result.valid, - error: result.valid ? undefined : result.error, - position: result.valid ? undefined : result.position, - } - return { content: JSON.stringify(content, null, 2) } - } - case "get_questdb_toc": { - setStatus(AIOperationStatus.RetrievingDocumentation) - const tocContent = await getQuestDBTableOfContents() - return { content: tocContent } - } - case "get_questdb_documentation": { - const { category, items } = - (input as { category: string; items: string[] }) || {} - if (!category || !items || !Array.isArray(items)) { - return { - content: "Error: category and items parameters are required", - is_error: true, - } - } - const parsedItems = parseDocItems(items) - - if (parsedItems.length > 0) { - setStatus(AIOperationStatus.InvestigatingDocs, { items: parsedItems }) - } else { - setStatus(AIOperationStatus.InvestigatingDocs) - } - const documentation = await getSpecificDocumentation( - category as DocCategory, - items, - ) - return { content: documentation } - } - default: - return { content: `Unknown tool: ${toolName}`, is_error: true } - } - } catch (error) { - return { - content: `Tool execution error: ${error instanceof Error ? error.message : "Unknown error"}`, - is_error: true, - } - } -} - -interface AnthropicToolCallResult { - message: Anthropic.Messages.Message - accumulatedTokens: TokenUsage -} - -async function handleToolCalls( - message: Anthropic.Messages.Message, - anthropic: Anthropic, - modelToolsClient: ModelToolsClient, - conversationHistory: Array, - model: string, - setStatus: StatusCallback, - responseFormat: ResponseTextConfig, - abortSignal?: AbortSignal, - accumulatedTokens: TokenUsage = { inputTokens: 0, outputTokens: 0 }, - streaming?: StreamingCallback, -): Promise { - const toolUseBlocks = message.content.filter( - (block) => block.type === "tool_use", - ) - const toolResults = [] - - if (abortSignal?.aborted) { - return { - type: "aborted", - message: "Operation was cancelled", - } as AiAssistantAPIError - } - - for (const toolUse of toolUseBlocks) { - if ("name" in toolUse) { - const exec = await executeTool( - toolUse.name, - toolUse.input, - modelToolsClient, - setStatus, - ) - toolResults.push({ - type: "tool_result" as const, - tool_use_id: toolUse.id, - content: exec.content, - is_error: exec.is_error, - }) - } - } - - const updatedHistory = [ - ...conversationHistory, - { - role: "assistant" as const, - content: message.content, - }, - { - role: "user" as const, - content: toolResults, - }, - ] - - const criticalTokenUsage = - message.usage.input_tokens >= COMPACTION_THRESHOLDS["anthropic"] && - toolResults.length > 0 - if (criticalTokenUsage) { - updatedHistory.push({ - role: "user" as const, - content: - "**CRITICAL TOKEN USAGE: The conversation is getting too long to fit the context window. If you are planning to use more tools, summarize your findings to the user first, and wait for user confirmation to continue working on the task.**", - }) - } - - const followUpParams: Parameters[1] = { - model, - tools: modelToolsClient.getTables ? ALL_TOOLS : REFERENCE_TOOLS, - messages: updatedHistory, - temperature: 0.3, - } - - const format = responseFormat.format as { type: string; schema?: object } - if (format.type === "json_schema" && format.schema) { - // @ts-expect-error - output_format is a new field not yet in the type definitions - followUpParams.output_format = { - type: "json_schema", - schema: format.schema, - } - } - - const followUpMessage = streaming - ? await createAnthropicMessageStreaming( - anthropic, - followUpParams, - streaming, - abortSignal, - ) - : await createAnthropicMessage(anthropic, followUpParams, abortSignal) - - // Accumulate tokens from this response - const newAccumulatedTokens: TokenUsage = { - inputTokens: - accumulatedTokens.inputTokens + - (followUpMessage.usage?.input_tokens || 0), - outputTokens: - accumulatedTokens.outputTokens + - (followUpMessage.usage?.output_tokens || 0), - } - - if (followUpMessage.stop_reason === "tool_use") { - return handleToolCalls( - followUpMessage, - anthropic, - modelToolsClient, - updatedHistory, - model, - setStatus, - responseFormat, - abortSignal, - newAccumulatedTokens, - streaming, - ) - } - - return { - message: followUpMessage, - accumulatedTokens: newAccumulatedTokens, - } -} - -async function createOpenAIResponseStreaming( - openai: OpenAI, - params: OpenAI.Responses.ResponseCreateParamsNonStreaming, - streamCallback: StreamingCallback, - abortSignal?: AbortSignal, -): Promise { - let accumulatedText = "" - let lastExplanation = "" - let finalResponse: OpenAI.Responses.Response | null = null - - try { - const stream = await openai.responses.create({ - ...params, - stream: true, - } as OpenAI.Responses.ResponseCreateParamsStreaming) - - for await (const event of stream) { - if (abortSignal?.aborted) { - throw new StreamingError("Operation aborted", "interrupted") - } - - if (event.type === "error") { - const errorEvent = event as { error?: { message?: string } } - throw new StreamingError( - errorEvent.error?.message || "Stream error occurred", - "failed", - event, - ) - } - - if (event.type === "response.failed") { - const failedEvent = event as { - response?: { error?: { message?: string } } - } - throw new StreamingError( - failedEvent.response?.error?.message || - "Provider failed to return a response", - "failed", - event, - ) - } - - if (event.type === "response.output_text.delta") { - accumulatedText += event.delta - const explanation = extractPartialExplanation(accumulatedText) - if (explanation !== lastExplanation) { - const chunk = explanation.slice(lastExplanation.length) - lastExplanation = explanation - streamCallback.onTextChunk(chunk, explanation) - } - } - - if (event.type === "response.completed") { - finalResponse = event.response - } - } - } catch (error) { - if (error instanceof StreamingError) { - throw error - } - if (abortSignal?.aborted || error instanceof OpenAI.APIUserAbortError) { - throw new StreamingError("Operation aborted", "interrupted") - } - throw new StreamingError( - error instanceof Error ? error.message : "Stream interrupted", - "network", - error, - ) - } - - if (!finalResponse) { - throw new StreamingError("Provider failed to return a response", "failed") - } - - return finalResponse -} - -const extractOpenAIToolCalls = ( - response: OpenAI.Responses.Response, -): { id?: string; name: string; arguments: unknown; call_id: string }[] => { - const calls = [] - for (const item of response.output) { - if (item?.type === "function_call") { - const args = - typeof item.arguments === "string" - ? safeJsonParse(item.arguments) - : item.arguments || {} - calls.push({ - id: item.id, - name: item.name, - arguments: args, - call_id: item.call_id, - }) - } - } - return calls -} - -const getOpenAIText = ( - response: OpenAI.Responses.Response, -): { type: "refusal" | "text"; message: string } => { - const out = response.output || [] - if ( - out.find( - (item: ResponseOutputItem) => - item.type === "message" && - item.content.some((c) => c.type === "refusal"), - ) - ) { - return { - type: "refusal", - message: "The model refused to generate a response for this request.", - } - } - - for (const item of out) { - if (item.type === "message" && item.content) { - for (const content of item.content) { - if (content.type === "output_text" && "text" in content) { - return { type: "text", message: content.text } - } - } - } - } - - return { type: "text", message: "" } -} - -const safeJsonParse = (text: string): T | object => { - try { - return JSON.parse(text) as T - } catch { - return {} - } -} - -/** - * Extracts partial explanation text from incomplete JSON during streaming. - * Handles JSON escape sequences and partial content. - */ -function extractPartialExplanation(partialJson: string): string { - // Match "explanation": "content... where content may be incomplete - const explanationMatch = partialJson.match( - /"explanation"\s*:\s*"((?:[^"\\]|\\.)*)/, - ) - if (!explanationMatch) { - return "" - } - - // Unescape JSON string escape sequences - return explanationMatch[1] - .replace(/\\n/g, "\n") - .replace(/\\r/g, "\r") - .replace(/\\t/g, "\t") - .replace(/\\"/g, '"') - .replace(/\\\\/g, "\\") -} - const tryWithRetries = async ( fn: () => Promise, + provider: AIProvider, setStatus: StatusCallback, abortSignal?: AbortSignal, ): Promise => { @@ -1113,13 +284,13 @@ const tryWithRetries = async ( console.error( "AI Assistant error:", error instanceof Error ? error.message : String(error), - isNonRetryableError(error) + provider.isNonRetryableError(error) ? "Non-retryable error." : "Remaining retries: " + (MAX_RETRIES - retries) + ".", ) retries++ - if (retries > MAX_RETRIES || isNonRetryableError(error)) { - return handleAiAssistantError(error, setStatus) + if (retries > MAX_RETRIES || provider.isNonRetryableError(error)) { + return provider.classifyError(error, setStatus) } await new Promise((resolve) => setTimeout(resolve, RETRY_DELAY * retries)) @@ -1133,709 +304,34 @@ const tryWithRetries = async ( } } -interface OpenAIFlowConfig { - systemInstructions: string - initialUserContent: string - conversationHistory?: Array<{ role: "user" | "assistant"; content: string }> - responseFormat: ResponseTextConfig - postProcess?: (formatted: T) => T +export const testApiKey = async ( + apiKey: string, + model: string, + providerId: ProviderId, + settings?: AiAssistantSettings, +): Promise<{ valid: boolean; error?: string }> => { + const provider = createProvider(providerId, apiKey, settings) + return provider.testConnection({ apiKey, model }) } -interface AnthropicFlowConfig { - systemInstructions: string - initialUserContent: string - conversationHistory?: Array<{ role: "user" | "assistant"; content: string }> - responseFormat: ResponseTextConfig - postProcess?: (formatted: T) => T -} +export const generateChatTitle = async ({ + firstUserMessage, + settings, +}: { + firstUserMessage: string + settings: ActiveProviderSettings +}): Promise => { + const isCustom = !BUILTIN_PROVIDERS[settings.provider] + if ((!isCustom && !settings.apiKey) || !settings.model) { + return null + } -interface ExecuteAnthropicFlowParams { - anthropic: Anthropic - model: string - config: AnthropicFlowConfig - modelToolsClient: ModelToolsClient - setStatus: StatusCallback - abortSignal?: AbortSignal - streaming?: StreamingCallback -} - -interface ExecuteOpenAIFlowParams { - openai: OpenAI - model: string - config: OpenAIFlowConfig - modelToolsClient: ModelToolsClient - setStatus: StatusCallback - abortSignal?: AbortSignal - streaming?: StreamingCallback -} - -const executeOpenAIFlow = async ({ - openai, - model, - config, - modelToolsClient, - setStatus, - abortSignal, - streaming, -}: ExecuteOpenAIFlowParams): Promise => { - let input: OpenAI.Responses.ResponseInput = [] - if (config.conversationHistory && config.conversationHistory.length > 0) { - const validMessages = config.conversationHistory.filter( - (msg) => msg.content && msg.content.trim() !== "", - ) - for (const msg of validMessages) { - input.push({ - role: msg.role, - content: msg.content, - }) - } - } - - input.push({ - role: "user", - content: config.initialUserContent, - }) - - const grantSchemaAccess = !!modelToolsClient.getTables - const openaiTools = toOpenAIFunctions( - grantSchemaAccess ? ALL_TOOLS : REFERENCE_TOOLS, - ) - - // Accumulate tokens across all iterations - let totalInputTokens = 0 - let totalOutputTokens = 0 - - const requestParams = { - ...getModelProps(model), - instructions: config.systemInstructions, - input, - tools: openaiTools, - text: config.responseFormat, - } as OpenAI.Responses.ResponseCreateParamsNonStreaming - - // Use streaming for the initial call if callback provided - let lastResponse = streaming - ? await createOpenAIResponseStreaming( - openai, - requestParams, - streaming, - abortSignal, - ) - : await openai.responses.create(requestParams) - input = [...input, ...lastResponse.output] - - // Add tokens from first response - totalInputTokens += lastResponse.usage?.input_tokens ?? 0 - totalOutputTokens += lastResponse.usage?.output_tokens ?? 0 - - while (true) { - if (abortSignal?.aborted) { - return { - type: "aborted", - message: "Operation was cancelled", - } as AiAssistantAPIError - } - - const toolCalls = extractOpenAIToolCalls(lastResponse) - if (!toolCalls.length) break - const tool_outputs: OpenAI.Responses.ResponseFunctionToolCallOutputItem[] = - [] - for (const tc of toolCalls) { - const exec = await executeTool( - tc.name, - tc.arguments, - modelToolsClient, - setStatus, - ) - tool_outputs.push({ - type: "function_call_output", - call_id: tc.call_id, - output: exec.content, - } as OpenAI.Responses.ResponseFunctionToolCallOutputItem) - } - input = [...input, ...tool_outputs] - - if ( - (lastResponse.usage?.input_tokens ?? 0) >= - COMPACTION_THRESHOLDS["openai"] && - tool_outputs.length > 0 - ) { - input.push({ - role: "user" as const, - content: - "**CRITICAL TOKEN USAGE: The conversation is getting too long to fit the context window. If you are planning to use more tools, summarize your findings to the user first, and wait for user confirmation to continue working on the task.**", - }) - } - const loopRequestParams = { - ...getModelProps(model), - instructions: config.systemInstructions, - input, - tools: openaiTools, - text: config.responseFormat, - } as OpenAI.Responses.ResponseCreateParamsNonStreaming - - // Use streaming for follow-up calls if callback provided - lastResponse = streaming - ? await createOpenAIResponseStreaming( - openai, - loopRequestParams, - streaming, - abortSignal, - ) - : await openai.responses.create(loopRequestParams) - input = [...input, ...lastResponse.output] - - // Accumulate tokens from each iteration - totalInputTokens += lastResponse.usage?.input_tokens ?? 0 - totalOutputTokens += lastResponse.usage?.output_tokens ?? 0 - } - - if (abortSignal?.aborted) { - return { - type: "aborted", - message: "Operation was cancelled", - } as AiAssistantAPIError - } - - const text = getOpenAIText(lastResponse) - if (text.type === "refusal") { - return { - type: "unknown", - message: text.message, - } as AiAssistantAPIError - } - - const rawOutput = text.message - - try { - const json = JSON.parse(rawOutput) as T - setStatus(null) - - const resultWithTokens = { - ...json, - tokenUsage: { - inputTokens: totalInputTokens, - outputTokens: totalOutputTokens, - }, - } as T & { tokenUsage: TokenUsage } - - if (config.postProcess) { - const processed = config.postProcess(json) - return { - ...processed, - tokenUsage: { - inputTokens: totalInputTokens, - outputTokens: totalOutputTokens, - }, - } as T & { tokenUsage: TokenUsage } - } - return resultWithTokens - } catch (error) { - setStatus(null) - return { - type: "unknown", - message: "Failed to parse assistant response.", - } as AiAssistantAPIError - } -} - -const executeAnthropicFlow = async ({ - anthropic, - model, - config, - modelToolsClient, - setStatus, - abortSignal, - streaming, -}: ExecuteAnthropicFlowParams): Promise => { - const initialMessages: MessageParam[] = [] - if (config.conversationHistory && config.conversationHistory.length > 0) { - const validMessages = config.conversationHistory.filter( - (msg) => msg.content && msg.content.trim() !== "", - ) - for (const msg of validMessages) { - initialMessages.push({ - role: msg.role, - content: msg.content, - }) - } - } - - initialMessages.push({ - role: "user" as const, - content: config.initialUserContent, - }) - - const grantSchemaAccess = !!modelToolsClient.getTables - - const messageParams: Parameters[1] = { - model, - system: config.systemInstructions, - tools: grantSchemaAccess ? ALL_TOOLS : REFERENCE_TOOLS, - messages: initialMessages, - temperature: 0.3, - } - - if (config.responseFormat?.format) { - const format = config.responseFormat.format as { - type: string - schema?: object - } - if (format.type === "json_schema" && format.schema) { - // @ts-expect-error - output_format is a new field not yet in the type definitions - messageParams.output_format = { - type: "json_schema", - schema: format.schema, - } - } - } - - // Use streaming for the initial call if callback provided - const message = streaming - ? await createAnthropicMessageStreaming( - anthropic, - messageParams, - streaming, - abortSignal, - ) - : await createAnthropicMessage(anthropic, messageParams, abortSignal) - - let totalInputTokens = message.usage?.input_tokens || 0 - let totalOutputTokens = message.usage?.output_tokens || 0 - - let responseMessage: Anthropic.Messages.Message - - if (message.stop_reason === "tool_use") { - const toolCallResult = await handleToolCalls( - message, - anthropic, - modelToolsClient, - initialMessages, - model, - setStatus, - config.responseFormat, - abortSignal, - { inputTokens: 0, outputTokens: 0 }, // Start fresh, we already counted initial message - streaming, - ) - - if ("type" in toolCallResult && "message" in toolCallResult) { - return toolCallResult - } - - const result = toolCallResult - responseMessage = result.message - totalInputTokens += result.accumulatedTokens.inputTokens - totalOutputTokens += result.accumulatedTokens.outputTokens - } else { - responseMessage = message - } - - if (abortSignal?.aborted) { - return { - type: "aborted", - message: "Operation was cancelled", - } as AiAssistantAPIError - } - - const textBlock = responseMessage.content.find( - (block) => block.type === "text", - ) - if (!textBlock || !("text" in textBlock)) { - setStatus(null) - return { - type: "unknown", - message: "No text response received from assistant.", - } as AiAssistantAPIError - } - - try { - const json = JSON.parse(textBlock.text) as T - setStatus(null) - - const resultWithTokens = { - ...json, - tokenUsage: { - inputTokens: totalInputTokens, - outputTokens: totalOutputTokens, - }, - } as T & { tokenUsage: TokenUsage } - - if (config.postProcess) { - const processed = config.postProcess(json) - return { - ...processed, - tokenUsage: { - inputTokens: totalInputTokens, - outputTokens: totalOutputTokens, - }, - } as T & { tokenUsage: TokenUsage } - } - return resultWithTokens - } catch (error) { - setStatus(null) - return { - type: "unknown", - message: "Failed to parse assistant response.", - } as AiAssistantAPIError - } -} - -class RefusalError extends Error { - constructor(message: string) { - super(message) - this.name = "RefusalError" - } -} - -class MaxTokensError extends Error { - constructor(message: string) { - super(message) - this.name = "MaxTokensError" - } -} - -class StreamingError extends Error { - constructor( - message: string, - public readonly errorType: "failed" | "network" | "interrupted" | "unknown", - public readonly originalError?: unknown, - ) { - super(message) - this.name = "StreamingError" - } -} - -async function createAnthropicMessage( - anthropic: Anthropic, - params: Omit & { - max_tokens?: number - }, - signal?: AbortSignal, -): Promise { - const message = await anthropic.messages.create( - { - ...params, - stream: false, - max_tokens: params.max_tokens ?? 8192, - }, - { - headers: { - "anthropic-beta": "structured-outputs-2025-11-13", - }, - signal, - }, - ) - - if (message.stop_reason === "refusal") { - throw new RefusalError( - "The model refused to generate a response for this request.", - ) - } - if (message.stop_reason === "max_tokens") { - throw new MaxTokensError( - "The response exceeded the maximum token limit. Please try again with a different prompt or model.", - ) - } - - return message -} - -async function createAnthropicMessageStreaming( - anthropic: Anthropic, - params: Omit & { - max_tokens?: number - }, - streamCallback: StreamingCallback, - abortSignal?: AbortSignal, -): Promise { - let accumulatedText = "" - let lastExplanation = "" - - const stream = anthropic.messages.stream( - { - ...params, - max_tokens: params.max_tokens ?? 8192, - }, - { - headers: { - "anthropic-beta": "structured-outputs-2025-11-13", - }, - signal: abortSignal, - }, - ) - - try { - for await (const event of stream) { - if (abortSignal?.aborted) { - throw new StreamingError("Operation aborted", "interrupted") - } - - const eventWithType = event as { type: string } - if (eventWithType.type === "error") { - const errorEvent = event as { - error?: { type?: string; message?: string } - } - const errorType = errorEvent.error?.type - const errorMessage = errorEvent.error?.message || "Stream error" - - if (errorType === "overloaded_error") { - throw new StreamingError( - "Service is temporarily overloaded. Please try again.", - "failed", - event, - ) - } - throw new StreamingError(errorMessage, "failed", event) - } - - if ( - event.type === "content_block_delta" && - event.delta.type === "text_delta" - ) { - accumulatedText += event.delta.text - const explanation = extractPartialExplanation(accumulatedText) - if (explanation !== lastExplanation) { - const chunk = explanation.slice(lastExplanation.length) - lastExplanation = explanation - streamCallback.onTextChunk(chunk, explanation) - } - } - } - } catch (error) { - if (error instanceof StreamingError) { - throw error - } - if (abortSignal?.aborted) { - throw new StreamingError("Operation aborted", "interrupted") - } - throw new StreamingError( - error instanceof Error ? error.message : "Stream interrupted", - "network", - error, - ) - } - - let finalMessage: Anthropic.Messages.Message - try { - finalMessage = await stream.finalMessage() - } catch (error) { - if (abortSignal?.aborted || error instanceof Anthropic.APIUserAbortError) { - throw new StreamingError("Operation aborted", "interrupted") - } - throw new StreamingError( - "Failed to get final message from the provider", - "network", - error, - ) - } - - if (finalMessage.stop_reason === "refusal") { - throw new RefusalError( - "The model refused to generate a response for this request.", - ) - } - if (finalMessage.stop_reason === "max_tokens") { - throw new MaxTokensError( - "The response exceeded the maximum token limit. Please try again with a different prompt or model.", - ) - } - - return finalMessage -} - -function handleAiAssistantError( - error: unknown, - setStatus: StatusCallback, -): AiAssistantAPIError { - if ( - error instanceof OpenAI.APIUserAbortError || - error instanceof Anthropic.APIUserAbortError || - (error instanceof StreamingError && error.errorType === "interrupted") - ) { - setStatus(AIOperationStatus.Aborted) - return { - type: "aborted", - message: "Operation was cancelled", - } - } - setStatus(null) - - if (error instanceof RefusalError) { - return { - type: "unknown", - message: "The model refused to generate a response for this request.", - details: error.message, - } - } - - if (error instanceof MaxTokensError) { - return { - type: "unknown", - message: - "The response exceeded the maximum token limit for the selected model. Please try again with a different prompt or model.", - details: error.message, - } - } - - if (error instanceof StreamingError) { - switch (error.errorType) { - case "network": - return { - type: "network", - message: - "Network error during streaming. Please check your connection.", - details: error.message, - } - case "failed": - default: - return { - type: "unknown", - message: error.message || "Stream failed unexpectedly.", - details: - error.originalError instanceof Error - ? error.originalError.message - : undefined, - } - } - } - - if (error instanceof Anthropic.AuthenticationError) { - return { - type: "invalid_key", - message: "Invalid API key. Please check your Anthropic API key.", - details: error.message, - } - } - - if (error instanceof Anthropic.RateLimitError) { - return { - type: "rate_limit", - message: "Rate limit exceeded. Please try again later.", - details: error.message, - } - } - - if (error instanceof Anthropic.APIConnectionError) { - return { - type: "network", - message: "Network error. Please check your internet connection.", - details: error.message, - } - } - - if (error instanceof Anthropic.APIError) { - return { - type: "unknown", - message: `Anthropic API error: ${error.message}`, - } - } - - if (error instanceof OpenAI.APIError) { - return { - type: "unknown", - message: `OpenAI API error: ${error.message}`, - } - } - - return { - type: "unknown", - message: "An unexpected error occurred. Please try again.", - details: error as string, - } -} - -export const testApiKey = async ( - apiKey: string, - model: string, -): Promise<{ valid: boolean; error?: string }> => { - try { - if (inferProviderFromModel(model) === "anthropic") { - const anthropic = new Anthropic({ - apiKey, - dangerouslyAllowBrowser: true, - }) - - await createAnthropicMessage(anthropic, { - model, - messages: [ - { - role: "user", - content: "ping", - }, - ], - }) - } else { - const openai = new OpenAI({ apiKey, dangerouslyAllowBrowser: true }) - await openai.responses.create({ - model: getModelProps(model).model, - input: [{ role: "user", content: "ping" }], - max_output_tokens: 16, - }) - } - - return { valid: true } - } catch (error: unknown) { - if (error instanceof Anthropic.AuthenticationError) { - return { - valid: false, - error: "Invalid API key", - } - } - - if (error instanceof Anthropic.RateLimitError) { - return { - valid: true, - } - } - - const status = - (error as { status?: number })?.status || - (error as { error?: { status?: number } })?.error?.status - if (status === 401) { - return { valid: false, error: "Invalid API key" } - } - if (status === 429) { - return { valid: true } - } - - return { - valid: false, - error: - error instanceof Error ? error.message : "Failed to validate API key", - } - } -} - -const ChatTitleFormat: ResponseTextConfig = { - format: { - type: "json_schema" as const, - name: "chat_title_format", - schema: { - type: "object", - properties: { - title: { type: "string" }, - }, - required: ["title"], - additionalProperties: false, - }, - strict: true, - }, -} - -export const generateChatTitle = async ({ - firstUserMessage, - settings, -}: { - firstUserMessage: string - settings: ActiveProviderSettings -}): Promise => { - if (!settings.apiKey || !settings.model) { - return null - } - - try { - const clients = createProviderClients(settings) + try { + const provider = createProvider( + settings.provider, + settings.apiKey, + settings.aiAssistantSettings, + ) const prompt = `Generate a concise chat title (max 30 characters) for this conversation with QuestDB AI Assistant. The title should capture the main topic or intent. @@ -1844,54 +340,12 @@ ${firstUserMessage} Return a JSON object with the following structure: { "title": "Your title here" }` - if (clients.provider === "openai") { - const response = await clients.openai.responses.create({ - ...getModelProps(settings.model), - input: [{ role: "user", content: prompt }], - text: ChatTitleFormat, - max_output_tokens: 100, - }) - try { - const parsed = JSON.parse(response.output_text) as { title: string } - return parsed.title || null - } catch { - return null - } - } - - const messageParams: Parameters[1] = { + return await provider.generateTitle({ model: settings.model, - messages: [{ role: "user", content: prompt }], - max_tokens: 100, - temperature: 0.3, - } - const titleFormat = ChatTitleFormat.format as { - type: string - schema?: object - } - // @ts-expect-error - output_format is a new field not yet in the type definitions - messageParams.output_format = { - type: "json_schema", - schema: titleFormat.schema, - } - - const message = await createAnthropicMessage( - clients.anthropic, - messageParams, - ) - - const textBlock = message.content.find((block) => block.type === "text") - if (textBlock && "text" in textBlock) { - try { - const parsed = JSON.parse(textBlock.text) as { title: string } - return parsed.title?.slice(0, 40) || null - } catch { - return null - } - } - return null + prompt, + responseFormat: ChatTitleFormat, + }) } catch (error) { - // Silently fail - title generation is not critical console.warn("Failed to generate chat title:", error) return null } @@ -1930,7 +384,8 @@ export const continueConversation = async ({ compactedConversationHistory?: Array } > => { - if (!settings.apiKey || !settings.model) { + const isCustom = !BUILTIN_PROVIDERS[settings.provider] + if ((!isCustom && !settings.apiKey) || !settings.model) { return { type: "invalid_key", message: "API key or model is missing", @@ -1953,9 +408,25 @@ export const continueConversation = async ({ health_issue: ConversationResponseFormat, }[operation] + let provider: ReturnType + try { + provider = createProvider( + settings.provider, + settings.apiKey, + settings.aiAssistantSettings, + ) + } catch (error) { + return { + type: "unknown", + message: + error instanceof Error + ? error.message + : "Failed to initialize provider", + } + } + return tryWithRetries( async () => { - const clients = createProviderClients(settings) const grantSchemaAccess = !!modelToolsClient.getTables const systemPrompt = getUnifiedPrompt(grantSchemaAccess) @@ -1966,16 +437,13 @@ export const continueConversation = async ({ if (conversationHistory.length > 0) { const compactionResult = await compactConversationIfNeeded( conversationHistory, - settings.provider, + provider, systemPrompt, userMessage, () => setStatus(AIOperationStatus.Compacting), { - anthropicClient: - clients.provider === "anthropic" ? clients.anthropic : undefined, - openaiClient: - clients.provider === "openai" ? clients.openai : undefined, model: settings.model, + aiAssistantSettings: settings.aiAssistantSettings, }, ) @@ -2025,57 +493,13 @@ export const continueConversation = async ({ } } - if (clients.provider === "openai") { - const result = await executeOpenAIFlow<{ - sql?: string | null - explanation: string - tokenUsage?: TokenUsage - }>({ - openai: clients.openai, - model: settings.model, - config: { - systemInstructions: getUnifiedPrompt(grantSchemaAccess), - initialUserContent: userMessage, - conversationHistory: workingConversationHistory.filter( - (m) => !m.isCompacted, - ), - responseFormat, - postProcess: (formatted) => { - const sql = - formatted?.sql === null - ? null - : formatted?.sql - ? normalizeSql(formatted.sql) - : currentSQL || "" - return { - sql, - explanation: formatted?.explanation || "", - tokenUsage: formatted.tokenUsage, - } - }, - }, - modelToolsClient, - setStatus, - abortSignal, - streaming, - }) - if (isAiAssistantError(result)) { - return result - } - return { - ...postProcess(result), - compactedConversationHistory: isCompacted - ? workingConversationHistory - : undefined, - } - } + const tools = grantSchemaAccess ? ALL_TOOLS : REFERENCE_TOOLS - const result = await executeAnthropicFlow<{ + const result = await provider.executeFlow<{ sql?: string | null explanation: string tokenUsage?: TokenUsage }>({ - anthropic: clients.anthropic, model: settings.model, config: { systemInstructions: getUnifiedPrompt(grantSchemaAccess), @@ -2084,25 +508,14 @@ export const continueConversation = async ({ (m) => !m.isCompacted, ), responseFormat, - postProcess: (formatted) => { - const sql = - formatted?.sql === null - ? null - : formatted?.sql - ? normalizeSql(formatted.sql) - : currentSQL || "" - return { - sql, - explanation: formatted?.explanation || "", - tokenUsage: formatted.tokenUsage, - } - }, }, modelToolsClient, + tools, setStatus, abortSignal, streaming, }) + if (isAiAssistantError(result)) { return result } @@ -2113,6 +526,7 @@ export const continueConversation = async ({ : undefined, } }, + provider, setStatus, abortSignal, ) diff --git a/src/utils/aiAssistantSettings.ts b/src/utils/aiAssistantSettings.ts deleted file mode 100644 index 7892e29b5..000000000 --- a/src/utils/aiAssistantSettings.ts +++ /dev/null @@ -1,193 +0,0 @@ -import { ReasoningEffort } from "openai/resources/shared" -import type { AiAssistantSettings } from "../providers/LocalStorageProvider/types" - -export type Provider = "anthropic" | "openai" - -export type ModelOption = { - label: string - value: string - provider: Provider - isSlow?: boolean - isTestModel?: boolean - default?: boolean - defaultEnabled?: boolean -} - -export const MODEL_OPTIONS: ModelOption[] = [ - { - label: "Claude Sonnet 4.5", - value: "claude-sonnet-4-5", - provider: "anthropic", - default: true, - defaultEnabled: true, - }, - { - label: "Claude Opus 4.5", - value: "claude-opus-4-5", - provider: "anthropic", - isSlow: true, - defaultEnabled: true, - }, - { - label: "Claude Sonnet 4", - value: "claude-sonnet-4", - provider: "anthropic", - }, - { - label: "Claude Haiku 4.5", - value: "claude-haiku-4-5", - provider: "anthropic", - isTestModel: true, - }, - { - label: "GPT-5.1 (High Reasoning)", - value: "gpt-5.1@reasoning=high", - provider: "openai", - isSlow: true, - }, - { - label: "GPT-5.1 (Medium Reasoning)", - value: "gpt-5.1@reasoning=medium", - provider: "openai", - isSlow: true, - defaultEnabled: true, - }, - { - label: "GPT-5.1 (No Reasoning)", - value: "gpt-5.1", - provider: "openai", - defaultEnabled: true, - isTestModel: true, - }, - { - label: "GPT-5", - value: "gpt-5", - provider: "openai", - defaultEnabled: true, - }, - { - label: "GPT-5 mini", - value: "gpt-5-mini", - provider: "openai", - default: true, - defaultEnabled: true, - }, -] - -export const providerForModel = (model: ModelOption["value"]): Provider => { - return MODEL_OPTIONS.find((m) => m.value === model)!.provider -} - -export const getModelProps = ( - model: ModelOption["value"], -): { - model: string - reasoning?: { effort: ReasoningEffort } -} => { - const modelOption = MODEL_OPTIONS.find((m) => m.value === model) - if (!modelOption) { - return { model } - } - const parts = modelOption.value.split("@") - const modelName = parts[0] - const extraParams = parts[1] - if (extraParams) { - const params = extraParams.split("=") - const paramName = params[0] - const paramValue = params[1] - if (paramName === "reasoning" && paramValue) { - return { - model: modelName, - reasoning: { effort: paramValue as ReasoningEffort }, - } - } - } - return { model: modelName } -} - -export const getAllProviders = (): Provider[] => { - const providers = new Set() - MODEL_OPTIONS.forEach((model) => { - providers.add(model.provider) - }) - return Array.from(providers) -} - -export const getSelectedModel = ( - settings: AiAssistantSettings, -): string | null => { - const selectedModel = settings.selectedModel - if ( - selectedModel && - typeof selectedModel === "string" && - MODEL_OPTIONS.find((m) => m.value === selectedModel) - ) { - return selectedModel - } - - return MODEL_OPTIONS.find((m) => m.default)?.value ?? null -} - -export const getNextModel = ( - currentModel: string | undefined, - enabledModels: Record, -): string | null => { - let nextModel: string | null | undefined = currentModel - - const modelProvider = currentModel ? providerForModel(currentModel) : null - if (modelProvider && enabledModels[modelProvider].length > 0) { - // Current model is still enabled, so we can use it - if (currentModel && enabledModels[modelProvider].includes(currentModel)) { - return currentModel - } - // Take the default model of this provider, otherwise the first enabled model of this provider - nextModel = - enabledModels[modelProvider].find( - (m) => MODEL_OPTIONS.find((mo) => mo.value === m)?.default, - ) ?? enabledModels[modelProvider][0] - } else { - // No other enabled models for this provider, we have to choose from another provider if exists - const otherProviderWithEnabledModel = getAllProviders().find( - (p) => enabledModels[p].length > 0, - ) - if (otherProviderWithEnabledModel) { - nextModel = - enabledModels[otherProviderWithEnabledModel].find( - (m) => MODEL_OPTIONS.find((mo) => mo.value === m)?.default, - ) ?? enabledModels[otherProviderWithEnabledModel][0] - } else { - nextModel = null - } - } - return nextModel ?? null -} - -export const isAiAssistantConfigured = ( - settings: AiAssistantSettings, -): boolean => { - return getAllProviders().some( - (provider) => !!settings.providers?.[provider]?.apiKey, - ) -} - -export const canUseAiAssistant = (settings: AiAssistantSettings): boolean => { - return isAiAssistantConfigured(settings) && !!settings.selectedModel -} - -export const hasSchemaAccess = (settings: AiAssistantSettings): boolean => { - const selectedModel = getSelectedModel(settings) - if (!selectedModel) return false - - const anthropicModels = settings.providers?.anthropic?.enabledModels || [] - const openaiModels = settings.providers?.openai?.enabledModels || [] - - if (anthropicModels.includes(selectedModel)) { - return settings.providers?.anthropic?.grantSchemaAccess === true - } - - if (openaiModels.includes(selectedModel)) { - return settings.providers?.openai?.grantSchemaAccess === true - } - - return false -} diff --git a/src/utils/contextCompaction.ts b/src/utils/contextCompaction.ts index 63c5dd90a..c8e36ae41 100644 --- a/src/utils/contextCompaction.ts +++ b/src/utils/contextCompaction.ts @@ -1,16 +1,7 @@ -import Anthropic from "@anthropic-ai/sdk" -import OpenAI from "openai" import type { ConversationMessage } from "../providers/AIConversationProvider/types" -import { - countTokens, - COMPACTION_THRESHOLDS, - type ConversationMessage as TokenConversationMessage, -} from "./tokenCounting" -import { - type Provider, - MODEL_OPTIONS, - getModelProps, -} from "./aiAssistantSettings" +import { getTestModel } from "./ai" +import type { AIProvider } from "./ai" +import type { AiAssistantSettings } from "../providers/LocalStorageProvider/types" type CompactionResultSuccess = { compactedMessage: string @@ -73,7 +64,7 @@ ${summary} function toTokenMessages( messages: [...ConversationMessage[], Omit], -): TokenConversationMessage[] { +): Array<{ role: "user" | "assistant"; content: string }> { return messages .filter((m) => m.content && m.content.trim() !== "") .map((m) => ({ @@ -84,14 +75,11 @@ function toTokenMessages( async function generateSummary( middleMessages: ConversationMessage[], - provider: Provider, - anthropicClient?: Anthropic, - openaiClient?: OpenAI, + aiProvider: AIProvider, + settings?: AiAssistantSettings, ): Promise { - const testModel = MODEL_OPTIONS.find( - (m) => m.provider === provider && m.isTestModel, - ) - if (!testModel) { + const testModelValue = getTestModel(aiProvider.id, settings) + if (!testModelValue) { throw new Error("No test model found for provider") } @@ -101,41 +89,22 @@ async function generateSummary( const userMessage = `Please summarize the following conversation:\n\n${conversationText}` - if (provider === "anthropic" && anthropicClient) { - const response = await anthropicClient.messages.create({ - ...getModelProps(testModel.value), - max_tokens: 8192, - messages: [{ role: "user", content: userMessage }], - system: SUMMARIZATION_PROMPT, - }) - - const textBlock = response.content.find((block) => block.type === "text") - return textBlock?.type === "text" ? textBlock.text : "" - } else if (provider === "openai" && openaiClient) { - const response = await openaiClient.responses.create({ - ...getModelProps(testModel.value), - instructions: SUMMARIZATION_PROMPT, - input: userMessage, - }) - - return response.output_text || "" - } - - throw new Error("No valid client provided for summarization") + return aiProvider.generateSummary({ + model: testModelValue, + systemPrompt: SUMMARIZATION_PROMPT, + userMessage, + }) } export async function compactConversationIfNeeded( conversationHistory: ConversationMessage[], - provider: Provider, + aiProvider: AIProvider, systemPrompt: string, userMessage: string, setStatusCompacting: () => void, - options: { - anthropicClient?: Anthropic - openaiClient?: OpenAI - model?: string - } = {}, + options: { model?: string; aiAssistantSettings?: AiAssistantSettings } = {}, ): Promise { + const compactionThreshold = aiProvider.contextWindow - 50_000 const messages = [ ...conversationHistory, { @@ -144,33 +113,31 @@ export async function compactConversationIfNeeded( timestamp: Date.now(), } as Omit, ] as [...ConversationMessage[], Omit] + const totalChars = systemPrompt.length + messages.reduce((sum, m) => sum + m.content.length, 0) - if (totalChars < COMPACTION_THRESHOLDS[provider]) { + + if (totalChars < compactionThreshold) { return { wasCompacted: false } } const tokenMessages = toTokenMessages(messages) - const estimatedTokens = await countTokens( - provider, - tokenMessages, - systemPrompt, - { - anthropicClient: options.anthropicClient, - model: options.model, - }, - ) - if (estimatedTokens === -1) { + let estimatedTokens: number + try { + estimatedTokens = await aiProvider.countTokens({ + messages: tokenMessages, + systemPrompt, + model: options.model ?? "", + }) + } catch { console.error( "Failed to estimate tokens for conversation, using full messages list.", ) - return { - wasCompacted: false, - } + return { wasCompacted: false } } - if (estimatedTokens <= COMPACTION_THRESHOLDS[provider]) { + if (estimatedTokens <= compactionThreshold) { return { wasCompacted: false } } @@ -184,9 +151,9 @@ export async function compactConversationIfNeeded( const result = await compactConversationInternal( conversationHistory, - provider, + aiProvider, setStatusCompacting, - options, + options.aiAssistantSettings, ) if (!result.wasCompacted) { @@ -202,13 +169,9 @@ export async function compactConversationIfNeeded( async function compactConversationInternal( messages: ConversationMessage[], - provider: Provider, + aiProvider: AIProvider, setStatusCompacting: () => void, - options: { - anthropicClient?: Anthropic - openaiClient?: OpenAI - model?: string - } = {}, + settings?: AiAssistantSettings, ): Promise { if (messages.length === 0) { return { wasCompacted: false } @@ -217,12 +180,7 @@ async function compactConversationInternal( setStatusCompacting() try { - const summary = await generateSummary( - messages, - provider, - options.anthropicClient, - options.openaiClient, - ) + const summary = await generateSummary(messages, aiProvider, settings) return { compactedMessage: buildContinuationPrompt(summary), diff --git a/src/utils/executeAIFlow.ts b/src/utils/executeAIFlow.ts index 5f244165b..8b9b1e786 100644 --- a/src/utils/executeAIFlow.ts +++ b/src/utils/executeAIFlow.ts @@ -18,15 +18,15 @@ import { createStreamingCallback, isAiAssistantError, generateChatTitle, - getExplainSchemaPrompt, - getHealthIssuePrompt, type ActiveProviderSettings, type GeneratedSQL, type AiAssistantExplanation, type AiAssistantAPIError, type AIOperation, } from "./aiAssistant" -import { providerForModel, MODEL_OPTIONS } from "./aiAssistantSettings" +import { getExplainSchemaPrompt, getHealthIssuePrompt } from "./ai" +import { providerForModel, getTestModel } from "./ai" +import type { AiAssistantSettings } from "../providers/LocalStorageProvider/types" import { eventBus } from "../modules/EventBus" import { EventType } from "../modules/EventBus/types" @@ -36,6 +36,7 @@ type BaseFlowConfig = { model: string apiKey: string } + aiAssistantSettings?: AiAssistantSettings questClient: Client tables?: Array hasSchemaAccess: boolean @@ -211,16 +212,10 @@ function buildUserMessage(config: AIFlowConfig): AIFlowUserMessage { } function formatErrorMessage(error: AiAssistantAPIError): string { - switch (error.type) { - case "aborted": - return "Operation has been cancelled" - case "network": - return "Connection interrupted. Please check your network and try again." - case "rate_limit": - return "Rate limit reached. Please wait a moment and try again." - default: - return error.message || "An unexpected error occurred" + if (error.type === "aborted") { + return "Operation has been cancelled" } + return error.message || "An unexpected error occurred" } type ProcessResultConfig = { @@ -332,7 +327,7 @@ function processSQLResult( } } - let assistantContent = result.explanation || "Response received" + let assistantContent = result.explanation || "No explanation received" if (hasSQLInResult) { assistantContent = `SQL Query:\n\`\`\`sql\n${result.sql}\n\`\`\`\n\nExplanation:\n${result.explanation || ""}` } @@ -360,20 +355,23 @@ async function generateChatTitleIfNeeded( return } - const provider = providerForModel(config.settings.model) - const testModel = MODEL_OPTIONS.find( - (m) => m.isTestModel && m.provider === provider, + const provider = providerForModel( + config.settings.model, + config.aiAssistantSettings, ) + if (!provider) return - if (!testModel) return + const testModelValue = getTestModel(provider, config.aiAssistantSettings) + if (!testModelValue) return try { const title = await generateChatTitle({ firstUserMessage: userMessageContent, settings: { - model: testModel.value, + model: testModelValue, provider, apiKey: config.settings.apiKey, + aiAssistantSettings: config.aiAssistantSettings, }, }) @@ -448,11 +446,21 @@ export async function executeAIFlow( eventBus.publish(EventType.AI_QUERY_HIGHLIGHT, conversationId) } - const provider = providerForModel(settings.model) + const provider = providerForModel(settings.model, config.aiAssistantSettings) + if (!provider) { + callbacks.updateMessage(conversationId, assistantMessageId, { + error: `No provider found for model: ${settings.model}`, + }) + return { + success: false, + error: `No provider found for model: ${settings.model}`, + } + } const providerSettings: ActiveProviderSettings = { model: settings.model, provider, apiKey: settings.apiKey, + aiAssistantSettings: config.aiAssistantSettings, } const modelToolsClient = createModelToolsClient( diff --git a/src/utils/tokenCounting.ts b/src/utils/tokenCounting.ts deleted file mode 100644 index 35600b38d..000000000 --- a/src/utils/tokenCounting.ts +++ /dev/null @@ -1,101 +0,0 @@ -import Anthropic from "@anthropic-ai/sdk" -import type { Provider } from "./aiAssistantSettings" -import type { Tiktoken, TiktokenBPE } from "js-tiktoken/lite" - -export interface ConversationMessage { - role: "user" | "assistant" - content: string -} - -export const CONTEXT_LIMITS: Record = { - anthropic: 200_000, - openai: 400_000, -} - -export const COMPACTION_THRESHOLDS: Record = { - anthropic: 150_000, - openai: 350_000, -} - -export async function countTokensAnthropic( - client: Anthropic, - messages: ConversationMessage[], - systemPrompt: string, - model: string, -): Promise { - const anthropicMessages: Anthropic.MessageParam[] = messages.map((m) => ({ - role: m.role, - content: m.content, - })) - - const response = await client.messages.countTokens({ - model, - system: systemPrompt, - messages: anthropicMessages, - }) - - return response.input_tokens -} -let tiktokenEncoder: Tiktoken | null = null - -export async function countTokensOpenAI( - messages: ConversationMessage[], - systemPrompt: string, -): Promise { - if (!tiktokenEncoder) { - const { Tiktoken } = await import("js-tiktoken/lite") - const o200k_base = await import("js-tiktoken/ranks/o200k_base").then( - (module: { default: TiktokenBPE }) => module.default, - ) - tiktokenEncoder = new Tiktoken(o200k_base) - } - - let totalTokens = 0 - - totalTokens += tiktokenEncoder.encode(systemPrompt).length - // Add overhead for system message formatting - totalTokens += 4 // <|start|>system<|end|> overhead - - for (const message of messages) { - // Each message has overhead for role markers - totalTokens += 4 // <|start|>{role}<|end|> overhead - totalTokens += tiktokenEncoder.encode(message.content).length - } - - // Add 2 tokens for assistant reply priming - totalTokens += 2 - - return totalTokens -} - -export async function countTokens( - provider: Provider, - messages: ConversationMessage[], - systemPrompt: string, - options: { - anthropicClient?: Anthropic - model?: string - } = {}, -): Promise { - try { - if (provider === "anthropic") { - if (!options.anthropicClient || !options.model) { - return -1 - } - return await countTokensAnthropic( - options.anthropicClient, - messages, - systemPrompt, - options.model, - ) - } else { - return countTokensOpenAI(messages, systemPrompt) - } - } catch (error) { - console.warn( - "Failed to estimate tokens for conversation, using full messages list.", - error, - ) - return -1 - } -} diff --git a/yarn.lock b/yarn.lock index 5e869060a..0c53075c3 100644 --- a/yarn.lock +++ b/yarn.lock @@ -211,9 +211,9 @@ __metadata: languageName: node linkType: hard -"@anthropic-ai/sdk@npm:^0.71.2": - version: 0.71.2 - resolution: "@anthropic-ai/sdk@npm:0.71.2" +"@anthropic-ai/sdk@npm:^0.78.0": + version: 0.78.0 + resolution: "@anthropic-ai/sdk@npm:0.78.0" dependencies: json-schema-to-ts: "npm:^3.1.1" peerDependencies: @@ -223,7 +223,7 @@ __metadata: optional: true bin: anthropic-ai-sdk: bin/cli - checksum: 10/a8190f9e860079dd97a544a95f36bd4b0b3a9a941610d7e067c431dc47febe03e3e761fc371166b261af9629d832533eeb3d8e72298e9f73dd52994a61881a2c + checksum: 10/7cb34e36d4fc766f0765b2581596825996073b03eec97a1193f07c6ca4ab48a021310dae9df630d61550ae2aa7fb3a6cf54236f7418932b25ea1a0e32624fdf1 languageName: node linkType: hard @@ -2461,7 +2461,7 @@ __metadata: resolution: "@questdb/web-console@workspace:." dependencies: "@4tw/cypress-drag-drop": "npm:^2.2.5" - "@anthropic-ai/sdk": "npm:^0.71.2" + "@anthropic-ai/sdk": "npm:^0.78.0" "@babel/core": "npm:^7.28.5" "@babel/preset-env": "npm:^7.20.2" "@babel/preset-react": "npm:^7.17.12" @@ -2545,6 +2545,7 @@ __metadata: js-base64: "npm:^3.7.7" js-sha256: "npm:^0.11.0" js-tiktoken: "npm:^1.0.21" + jsonrepair: "npm:^3.13.3" lint-staged: "npm:^16.2.6" lodash.isequal: "npm:^4.5.0" lodash.merge: "npm:^4.6.2" @@ -8271,6 +8272,15 @@ __metadata: languageName: node linkType: hard +"jsonrepair@npm:^3.13.3": + version: 3.13.3 + resolution: "jsonrepair@npm:3.13.3" + bin: + jsonrepair: bin/cli.js + checksum: 10/cd1d42516e3e03ccc44498c328f87f4ec05b24afe190becced0babf5d608e81b375e8d2040494142760556c1d6583b395073b5253626907e4df968d8cf01115c + languageName: node + linkType: hard + "jsprim@npm:^2.0.2": version: 2.0.2 resolution: "jsprim@npm:2.0.2"