diff --git a/.changeset/webview-current-model-indicator.md b/.changeset/webview-current-model-indicator.md new file mode 100644 index 00000000000..a93b79d488f --- /dev/null +++ b/.changeset/webview-current-model-indicator.md @@ -0,0 +1,5 @@ +--- +"roo-cline": patch +--- + +Show the current model in the Roo Code chat input footer. diff --git a/packages/types/src/vscode-extension-host.ts b/packages/types/src/vscode-extension-host.ts index b20539afe49..2a0becac22c 100644 --- a/packages/types/src/vscode-extension-host.ts +++ b/packages/types/src/vscode-extension-host.ts @@ -314,6 +314,7 @@ export type ExtensionState = Pick< currentTaskItem?: HistoryItem currentTaskTodos?: TodoItem[] // Initial todos for the current task apiConfiguration: ProviderSettings + currentModelId?: string uriScheme?: string shouldShowAnnouncement: boolean diff --git a/src/core/webview/ClineProvider.ts b/src/core/webview/ClineProvider.ts index 1106d340050..42289af273e 100644 --- a/src/core/webview/ClineProvider.ts +++ b/src/core/webview/ClineProvider.ts @@ -2129,6 +2129,7 @@ export class ClineProvider const { apiConfiguration, + currentModelId, lastShownAnnouncementId, customInstructions, alwaysAllowReadOnly, @@ -2242,6 +2243,7 @@ export class ClineProvider return { version: this.context.extension?.packageJSON?.version ?? "", apiConfiguration, + currentModelId, customInstructions, alwaysAllowReadOnly: alwaysAllowReadOnly ?? false, alwaysAllowReadOnlyOutsideWorkspace: alwaysAllowReadOnlyOutsideWorkspace ?? false, @@ -2395,6 +2397,8 @@ export class ClineProvider providerSettings.apiProvider = apiProvider } + const currentModelId = getModelId(providerSettings) + let organizationAllowList = ORGANIZATION_ALLOW_ALL try { @@ -2471,6 +2475,7 @@ export class ClineProvider // Return the same structure as before. return { apiConfiguration: providerSettings, + currentModelId, lastShownAnnouncementId: stateValues.lastShownAnnouncementId, customInstructions: stateValues.customInstructions, apiModelId: stateValues.apiModelId, diff --git a/webview-ui/src/components/chat/ChatTextArea.tsx b/webview-ui/src/components/chat/ChatTextArea.tsx index e72c1726f35..b7e6869eaf2 100644 --- a/webview-ui/src/components/chat/ChatTextArea.tsx +++ b/webview-ui/src/components/chat/ChatTextArea.tsx @@ -1,9 +1,20 @@ import React, { forwardRef, useCallback, useEffect, useLayoutEffect, useMemo, useRef, useState } from "react" import { useEvent } from "react-use" import DynamicTextArea from "react-textarea-autosize" -import { VolumeX, Image, WandSparkles, SendHorizontal, X, ListEnd, Square } from "lucide-react" +import { VolumeX, Image, WandSparkles, SendHorizontal, X, ListEnd, Square, Check, ChevronsUpDown } from "lucide-react" -import type { ExtensionMessage } from "@roo-code/types" +import { + isDynamicProvider, + isRetiredProvider, + modelIdKeysByProvider, + openAiModelInfoSaneDefaults, + type ExtensionMessage, + type ModelIdKey, + type ModelRecord, + type OrganizationAllowList, + type ProviderName, + type ProviderSettings, +} from "@roo-code/types" import { mentionRegex, mentionRegexGlobal, commandRegexGlobal, unescapeSpaces } from "@roo/context-mentions" import { WebviewMessage } from "@roo/WebviewMessage" @@ -22,9 +33,25 @@ import { } from "@src/utils/context-mentions" import { cn } from "@src/lib/utils" import { convertToMentionPath } from "@src/utils/path-mentions" -import { StandardTooltip } from "@src/components/ui" +import { + Command, + CommandEmpty, + CommandGroup, + CommandInput, + CommandItem, + CommandList, + Popover, + PopoverContent, + PopoverTrigger, + StandardTooltip, +} from "@src/components/ui" +import { useRouterModels } from "@src/components/ui/hooks/useRouterModels" +import { useLmStudioModels } from "@src/components/ui/hooks/useLmStudioModels" +import { useOllamaModels } from "@src/components/ui/hooks/useOllamaModels" import Thumbnails from "../common/Thumbnails" +import { MODELS_BY_PROVIDER as STATIC_MODELS_BY_PROVIDER } from "../settings/constants" +import { filterModels } from "../settings/utils/organizationFilters" import { ModeSelector } from "./ModeSelector" import { ApiConfigSelector } from "./ApiConfigSelector" import { AutoApproveDropdown } from "./AutoApproveDropdown" @@ -34,6 +61,243 @@ import { IndexingStatusBadge } from "./IndexingStatusBadge" import { usePromptHistory } from "./hooks/usePromptHistory" import { CloudAccountSwitcher } from "../cloud/CloudAccountSwitcher" +const QUICK_MODEL_ID_KEYS: Partial> = { + openai: "openAiModelId", + openrouter: "openRouterModelId", + requesty: "requestyModelId", + unbound: "unboundModelId", + litellm: "litellmModelId", + "vercel-ai-gateway": "vercelAiGatewayModelId", + ollama: "ollamaModelId", + lmstudio: "lmStudioModelId", + "openai-native": "apiModelId", +} + +interface CurrentModelSelectorProps { + apiConfiguration?: ProviderSettings + currentApiConfigName?: string + currentModelId?: string + currentModelDisplayName?: string + disabled?: boolean + organizationAllowList?: OrganizationAllowList + setApiConfiguration: (config: ProviderSettings) => void +} + +const CurrentModelSelector = ({ + apiConfiguration, + currentApiConfigName, + currentModelId, + currentModelDisplayName, + disabled, + organizationAllowList, + setApiConfiguration, +}: CurrentModelSelectorProps) => { + const { t } = useAppTranslation() + const [open, setOpen] = useState(false) + const [searchValue, setSearchValue] = useState("") + const [openAiModels, setOpenAiModels] = useState(null) + + const provider = apiConfiguration?.apiProvider + const activeProvider: ProviderName | undefined = + provider && !isRetiredProvider(provider) ? (provider as ProviderName) : undefined + const dynamicProvider = activeProvider && isDynamicProvider(activeProvider) ? activeProvider : undefined + + const routerModels = useRouterModels({ provider: dynamicProvider, enabled: !!dynamicProvider }) + const lmStudioModels = useLmStudioModels( + activeProvider === "lmstudio" ? apiConfiguration?.lmStudioModelId : undefined, + ) + const ollamaModels = useOllamaModels(activeProvider === "ollama" ? apiConfiguration?.ollamaModelId : undefined) + + const onMessage = useCallback((event: MessageEvent) => { + const message: ExtensionMessage = event.data + + if (message.type === "openAiModels") { + setOpenAiModels( + Object.fromEntries( + (message.openAiModels ?? []).map((modelId) => [modelId, openAiModelInfoSaneDefaults]), + ), + ) + } + }, []) + + useEvent("message", onMessage) + + useEffect(() => { + if (open && activeProvider === "openai") { + vscode.postMessage({ + type: "requestOpenAiModels", + values: { + baseUrl: apiConfiguration?.openAiBaseUrl, + apiKey: apiConfiguration?.openAiApiKey, + customHeaders: apiConfiguration?.openAiHeaders ?? {}, + }, + }) + } + }, [ + activeProvider, + apiConfiguration?.openAiApiKey, + apiConfiguration?.openAiBaseUrl, + apiConfiguration?.openAiHeaders, + open, + ]) + + const models = useMemo(() => { + if (!activeProvider) { + return null + } + + if (activeProvider === "openai") { + return openAiModels + } + + if (activeProvider === "lmstudio") { + return lmStudioModels.data ?? null + } + + if (activeProvider === "ollama") { + return ollamaModels.data ?? null + } + + if (dynamicProvider) { + return routerModels.data?.[dynamicProvider] ?? null + } + + return STATIC_MODELS_BY_PROVIDER[activeProvider] ?? null + }, [activeProvider, dynamicProvider, lmStudioModels.data, ollamaModels.data, openAiModels, routerModels.data]) + + const modelIds = useMemo(() => { + const filteredModels = filterModels(models, activeProvider, organizationAllowList) + + return Object.entries(filteredModels ?? {}) + .filter(([modelId, modelInfo]) => modelId === currentModelId || !modelInfo.deprecated) + .map(([modelId]) => modelId) + .sort((a, b) => a.localeCompare(b)) + }, [activeProvider, currentModelId, models, organizationAllowList]) + + const modelIdKey = useMemo(() => { + if (!activeProvider) { + return undefined + } + + return ( + QUICK_MODEL_ID_KEYS[activeProvider] ?? + modelIdKeysByProvider[activeProvider as keyof typeof modelIdKeysByProvider] + ) + }, [activeProvider]) + + const handleModelSelect = useCallback( + (modelId: string) => { + if (!apiConfiguration || !currentApiConfigName || !modelIdKey) { + return + } + + const updatedConfiguration = { + ...apiConfiguration, + [modelIdKey]: modelId, + } as ProviderSettings + + setOpen(false) + setSearchValue("") + setApiConfiguration(updatedConfiguration) + vscode.postMessage({ + type: "upsertApiConfiguration", + text: currentApiConfigName, + apiConfiguration: updatedConfiguration, + }) + }, + [apiConfiguration, currentApiConfigName, modelIdKey, setApiConfiguration], + ) + + if (!currentModelDisplayName || !currentModelId) { + return null + } + + const isLoading = + (activeProvider === "openai" && openAiModels === null) || + (!!dynamicProvider && routerModels.isLoading) || + (activeProvider === "lmstudio" && lmStudioModels.isLoading) || + (activeProvider === "ollama" && ollamaModels.isLoading) + + const canSelectModel = !disabled && !!modelIdKey && (modelIds.length > 0 || activeProvider === "openai") + + return ( + + + + + + + + + +
+ {isLoading ? "Loading..." : t("settings:modelPicker.noMatchFound")} +
+
+ + {modelIds.map((modelId) => ( + + + {formatModelDisplayName(modelId)} + + + + ))} + +
+
+
+
+ ) +} + +const formatModelDisplayName = (modelId: string) => { + return modelId + .split(/[-_\s]+/) + .filter(Boolean) + .map((part) => { + if (part.toLowerCase() === "deepseek") { + return "DeepSeek" + } + + if (/^v\d+$/i.test(part)) { + return part.toUpperCase() + } + + return part.charAt(0).toUpperCase() + part.slice(1) + }) + .join(" ") +} + interface ChatTextAreaProps { inputValue: string setInputValue: (value: string) => void @@ -99,6 +363,10 @@ export const ChatTextArea = forwardRef( cloudUserInfo, enterBehavior, lockApiConfigAcrossModes, + apiConfiguration, + currentModelId, + organizationAllowList, + setApiConfiguration, } = useExtensionState() // Find the ID and display text for the currently selected API configuration. @@ -110,6 +378,14 @@ export const ChatTextArea = forwardRef( } }, [listApiConfigMeta, currentApiConfigName]) + const currentModelDisplayName = useMemo(() => { + if (!currentModelId) { + return undefined + } + + return formatModelDisplayName(currentModelId) + }, [currentModelId]) + const [gitCommits, setGitCommits] = useState([]) const [showDropdown, setShowDropdown] = useState(false) const [fileSearchResults, setFileSearchResults] = useState([]) @@ -1320,6 +1596,15 @@ export const ChatTextArea = forwardRef( lockApiConfigAcrossModes={!!lockApiConfigAcrossModes} onToggleLockApiConfig={handleToggleLockApiConfig} /> +
{ }, taskHistory: [], cwd: "/test/workspace", + setApiConfiguration: vi.fn(), }) }) @@ -89,6 +90,64 @@ describe("ChatTextArea", () => { }) }) + describe("current model indicator", () => { + it("shows the model id from extension state in the footer", () => { + ;(useExtensionState as ReturnType).mockReturnValue({ + filePaths: [], + openedTabs: [], + apiConfiguration: { + apiProvider: "deepseek", + }, + currentModelId: "deepseek-v4-pro", + currentApiConfigName: "DeepSeek", + taskHistory: [], + cwd: "/test/workspace", + setApiConfiguration: vi.fn(), + }) + + render() + + const indicator = screen.getByTestId("current-model-indicator") + expect(indicator).toHaveTextContent("DeepSeek V4 Pro") + expect(indicator).toHaveAttribute("title", "deepseek-v4-pro") + }) + + it("switches the current provider model from the footer picker", async () => { + const setApiConfiguration = vi.fn() + const apiConfiguration = { + apiProvider: "deepseek", + apiModelId: "deepseek-chat", + } + + ;(useExtensionState as ReturnType).mockReturnValue({ + filePaths: [], + openedTabs: [], + apiConfiguration, + currentModelId: "deepseek-chat", + currentApiConfigName: "DeepSeek", + taskHistory: [], + cwd: "/test/workspace", + setApiConfiguration, + }) + + render() + + fireEvent.click(screen.getByTestId("current-model-indicator")) + fireEvent.click(await screen.findByTestId("quick-model-option-deepseek-reasoner")) + + const updatedConfiguration = { + apiProvider: "deepseek", + apiModelId: "deepseek-reasoner", + } + expect(setApiConfiguration).toHaveBeenCalledWith(updatedConfiguration) + expect(mockPostMessage).toHaveBeenCalledWith({ + type: "upsertApiConfiguration", + text: "DeepSeek", + apiConfiguration: updatedConfiguration, + }) + }) + }) + describe("handleEnhancePrompt", () => { it("should send message with correct configuration when clicked", () => { const apiConfiguration = { diff --git a/webview-ui/src/context/ExtensionStateContext.tsx b/webview-ui/src/context/ExtensionStateContext.tsx index ce7a607d9a8..6cb0a5d71d4 100644 --- a/webview-ui/src/context/ExtensionStateContext.tsx +++ b/webview-ui/src/context/ExtensionStateContext.tsx @@ -192,6 +192,7 @@ export const mergeExtensionState = (prevState: ExtensionState, newState: Partial export const ExtensionStateContextProvider: React.FC<{ children: React.ReactNode }> = ({ children }) => { const [state, setState] = useState({ apiConfiguration: {}, + currentModelId: undefined, version: "", clineMessages: [], taskHistory: [],