diff --git a/docs/models_api.md b/docs/models_api.md new file mode 100644 index 00000000..e2310695 --- /dev/null +++ b/docs/models_api.md @@ -0,0 +1,56 @@ +# /v1/models API 使用说明 + +## 功能概述 +ccNexus 现已支持 OpenAI 兼容的 `/v1/models` API,聚合所有后端端点的模型列表。 + +## 快速开始 + +### 启动服务 +```bash +go run ./cmd/server +``` + +### 获取模型列表 +```bash +curl http://localhost:3000/v1/models +``` + +### 强制刷新缓存 +```bash +curl http://localhost:3000/v1/models?refresh=true +``` + +## 响应示例 +```json +{ + "object": "list", + "data": [ + { + "id": "claude-sonnet-4-20250514", + "object": "model", + "created": 1700000000, + "owned_by": "anthropic", + "endpoint_id": "Claude Official" + } + ] +} +``` + +## 配置项 +在 `config.json` 中添加: +```json +{ + "modelsCacheTTL": 30 // 缓存时间(分钟),默认30 +} +``` + +## 支持的端点 +- **openai/openai2**: 自动查询后端 /v1/models +- **gemini**: 自动查询后端 /v1beta/models +- **claude**: 使用配置的 model 字段(无API查询) + +## 特性 +- ✅ 聚合多后端模型列表 +- ✅ 自动缓存(30分钟,可配置) +- ✅ 支持刷新参数 +- ✅ 失败降级(返回默认模型) diff --git a/internal/config/config.go b/internal/config/config.go index eb01be81..d9d9ad2d 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -172,6 +172,8 @@ type Config struct { CloseWindowBehavior string `json:"closeWindowBehavior,omitempty"` // "quit", "minimize", "ask" ClaudeNotificationEnabled bool `json:"claudeNotificationEnabled"` // Enable Claude Code task completion notification ClaudeNotificationType string `json:"claudeNotificationType"` // Notification type: toast, dialog, disabled + ModelsCacheTTL int `json:"modelsCacheTTL,omitempty"` // /v1/models cache TTL in minutes, default 30 + ModelsCacheRefreshEnabled bool `json:"modelsCacheRefreshEnabled,omitempty"` // Enable ?refresh=true parameter, default false WebDAV *WebDAVConfig `json:"webdav,omitempty"` // WebDAV synchronization config Backup *BackupConfig `json:"backup,omitempty"` // Backup/sync configuration Update *UpdateConfig `json:"update,omitempty"` // Update configuration @@ -189,6 +191,8 @@ func DefaultConfig() *Config { Language: "zh-CN", // Default to Chinese WindowWidth: 1024, // Default window width WindowHeight: 768, // Default window height + ModelsCacheTTL: 30, // Default 30 minutes + ModelsCacheRefreshEnabled: false, // Default disabled Endpoints: []Endpoint{ { Name: "Claude Official", @@ -575,6 +579,19 @@ func LoadFromStorage(storage StorageAdapter) (*Config, error) { } } + if modelsCacheTTLStr, err := storage.GetConfig("modelsCacheTTL"); err == nil && modelsCacheTTLStr != "" { + if modelsCacheTTL, err := strconv.Atoi(modelsCacheTTLStr); err == nil { + config.ModelsCacheTTL = modelsCacheTTL + } + } + if config.ModelsCacheTTL == 0 { + config.ModelsCacheTTL = 30 // Default 30 minutes + } + + if modelsCacheRefreshEnabledStr, err := storage.GetConfig("modelsCacheRefreshEnabled"); err == nil && modelsCacheRefreshEnabledStr != "" { + config.ModelsCacheRefreshEnabled = modelsCacheRefreshEnabledStr == "true" + } + if lang, err := storage.GetConfig("language"); err == nil { config.Language = lang } @@ -820,6 +837,12 @@ func (c *Config) SaveToStorage(storage StorageAdapter) error { if err := storage.SetConfig("logLevel", strconv.Itoa(c.LogLevel)); err != nil { return fmt.Errorf("failed to save logLevel config: %w", err) } + if err := storage.SetConfig("modelsCacheTTL", strconv.Itoa(c.ModelsCacheTTL)); err != nil { + return fmt.Errorf("failed to save modelsCacheTTL config: %w", err) + } + if err := storage.SetConfig("modelsCacheRefreshEnabled", strconv.FormatBool(c.ModelsCacheRefreshEnabled)); err != nil { + return fmt.Errorf("failed to save modelsCacheRefreshEnabled config: %w", err) + } if err := storage.SetConfig("language", c.Language); err != nil { return fmt.Errorf("failed to save language config: %w", err) } diff --git a/internal/proxy/handler.go b/internal/proxy/handler.go index b9ff4d63..5ccc1912 100644 --- a/internal/proxy/handler.go +++ b/internal/proxy/handler.go @@ -160,6 +160,12 @@ func (p *Proxy) UpdateConfig(cfg *config.Config) error { p.currentIndex = 0 } + // Clear models cache to force refresh with new endpoints + if p.modelsCache != nil { + p.modelsCache.Clear() + logger.Debug("[CONFIG UPDATE] Cleared models cache") + } + logger.Info("Configuration updated: %d endpoints configured", len(cfg.GetEndpoints())) return nil } diff --git a/internal/proxy/models.go b/internal/proxy/models.go new file mode 100644 index 00000000..7d4b5026 --- /dev/null +++ b/internal/proxy/models.go @@ -0,0 +1,315 @@ +package proxy + +import ( + "encoding/json" + "fmt" + "net/http" + "strings" + "sync" + "time" + + "github.com/lich0821/ccNexus/internal/config" + "github.com/lich0821/ccNexus/internal/logger" +) + +// ModelInfo represents a single model information +type ModelInfo struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + OwnedBy string `json:"owned_by"` + EndpointID string `json:"endpoint_id"` // Source endpoint identifier +} + +// ModelsCache represents cached models data with TTL +type ModelsCache struct { + data []ModelInfo + updatedAt time.Time + ttl time.Duration + mu sync.RWMutex +} + +// NewModelsCache creates a new models cache +func NewModelsCache(ttlMinutes int) *ModelsCache { + if ttlMinutes <= 0 { + ttlMinutes = 30 // Default 30 minutes + } + return &ModelsCache{ + data: []ModelInfo{}, + updatedAt: time.Time{}, + ttl: time.Duration(ttlMinutes) * time.Minute, + } +} + +// Get returns cached data if valid +func (c *ModelsCache) Get() ([]ModelInfo, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + + if time.Since(c.updatedAt) > c.ttl { + return nil, false + } + return c.data, true +} + +// Set updates cached data +func (c *ModelsCache) Set(data []ModelInfo) { + c.mu.Lock() + defer c.mu.Unlock() + + c.data = data + c.updatedAt = time.Now() +} + +// Clear clears the cache +func (c *ModelsCache) Clear() { + c.mu.Lock() + defer c.mu.Unlock() + + c.data = []ModelInfo{} + c.updatedAt = time.Time{} +} + +// fetchModelsFromEndpoint fetches models from a specific endpoint +func (p *Proxy) fetchModelsFromEndpoint(ep config.Endpoint) ([]ModelInfo, error) { + var modelsURL string + var req *http.Request + var err error + + switch strings.ToLower(ep.Transformer) { + case "openai", "openai2": + // OpenAI compatible endpoints + baseURL := strings.TrimSuffix(ep.APIUrl, "/") + if strings.Contains(baseURL, "/v1") { + modelsURL = baseURL + "/models" + } else { + modelsURL = baseURL + "/v1/models" + } + req, err = http.NewRequest("GET", modelsURL, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + // Add authorization header + if ep.AuthMode == config.AuthModeAPIKey && ep.APIKey != "" { + req.Header.Set("Authorization", "Bearer "+ep.APIKey) + } + + case "gemini": + // Google Gemini endpoints + baseURL := strings.TrimSuffix(ep.APIUrl, "/") + if strings.Contains(baseURL, "/v1") { + modelsURL = baseURL + "/models" + } else { + modelsURL = baseURL + "/v1beta/models" + } + // Add API key as query parameter + if ep.AuthMode == config.AuthModeAPIKey && ep.APIKey != "" { + modelsURL = modelsURL + "?key=" + ep.APIKey + } + req, err = http.NewRequest("GET", modelsURL, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + default: + // For transformers without /v1/models support (claude, codex) + return nil, fmt.Errorf("transformer %s does not support /v1/models", ep.Transformer) + } + + // Set User-Agent + req.Header.Set("User-Agent", "ccNexus/1.0") + + // Execute request + resp, err := p.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to fetch models: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + // Parse response + var result struct { + Data []struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + OwnedBy string `json:"owned_by"` + } `json:"data"` + } + + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + // Convert to ModelInfo with endpoint_id + models := make([]ModelInfo, len(result.Data)) + for i, m := range result.Data { + models[i] = ModelInfo{ + ID: m.ID, + Object: m.Object, + Created: m.Created, + OwnedBy: m.OwnedBy, + EndpointID: ep.Name, + } + } + + return models, nil +} + +// getDefaultModels returns default models for endpoints that don't support /v1/models +func (p *Proxy) getDefaultModels(ep config.Endpoint) []ModelInfo { + var modelID string + var ownedBy string + + switch strings.ToLower(ep.Transformer) { + case "claude": + // Claude endpoints + if ep.Model != "" { + modelID = ep.Model + } else { + modelID = "claude-sonnet-4-20250514" // Default Claude model + } + ownedBy = "anthropic" + + case "openai2": + // Codex endpoints + if ep.Model != "" { + modelID = ep.Model + } else if ep.AuthMode == config.AuthModeCodexTokenPool { + modelID = "gpt-5-codex" // Default Codex model + } else { + modelID = "gpt-4o" // Default OpenAI model + } + ownedBy = "openai" + + default: + // Fallback for any other transformer + if ep.Model != "" { + modelID = ep.Model + } else { + modelID = "unknown-model" + } + ownedBy = strings.ToLower(ep.Transformer) + } + + return []ModelInfo{ + { + ID: modelID, + Object: "model", + Created: time.Now().Unix(), + OwnedBy: ownedBy, + EndpointID: ep.Name, + }, + } +} + +// handleModels handles GET /v1/models requests +func (p *Proxy) handleModels(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Check for refresh parameter + refresh := r.URL.Query().Get("refresh") == "true" + refreshEnabled := p.config.ModelsCacheRefreshEnabled + + if refresh && !refreshEnabled { + http.Error(w, "Refresh is disabled in configuration", http.StatusForbidden) + return + } + + // Try to get from cache if not refreshing + if !refresh { + if cached, ok := p.modelsCache.Get(); ok { + p.writeModelsResponse(w, cached) + return + } + } + + // Fetch from endpoints + endpoints := p.config.GetEndpoints() + allModels := []ModelInfo{} + allFailed := true + + for _, ep := range endpoints { + if !ep.Enabled { + continue + } + + var models []ModelInfo + var err error + + // Try to fetch from endpoint's /v1/models API + models, err = p.fetchModelsFromEndpoint(ep) + if err != nil { + // If fetch fails, use default models for this endpoint + logger.Debug("Failed to fetch models from %s: %v", ep.Name, err) + models = p.getDefaultModels(ep) + } else { + allFailed = false + } + + allModels = append(allModels, models...) + } + + // If all endpoints failed, still return the aggregated default models + if allFailed { + logger.Debug("All endpoints failed to fetch models, returning default models") + } + + // Cache the result + p.modelsCache.Set(allModels) + + // Write response + p.writeModelsResponse(w, allModels) +} + +// writeModelsResponse writes the models list response +func (p *Proxy) writeModelsResponse(w http.ResponseWriter, models []ModelInfo) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + + response := struct { + Object string `json:"object"` + Data []ModelInfo `json:"data"` + }{ + Object: "list", + Data: models, + } + + if err := json.NewEncoder(w).Encode(response); err != nil { + logger.Debug("Failed to encode models response: %v", err) + } +} + +// refreshModelsCache refreshes the models cache in background +func (p *Proxy) refreshModelsCache() { + logger.Debug("Refreshing models cache in background") + + endpoints := p.config.GetEndpoints() + allModels := []ModelInfo{} + + for _, ep := range endpoints { + if !ep.Enabled { + continue + } + + var models []ModelInfo + var err error + + models, err = p.fetchModelsFromEndpoint(ep) + if err != nil { + logger.Debug("Background refresh: failed to fetch models from %s: %v", ep.Name, err) + models = p.getDefaultModels(ep) + } + + allModels = append(allModels, models...) + } + + p.modelsCache.Set(allModels) + logger.Debug("Models cache refreshed, total models: %d", len(allModels)) +} diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index bf4d7c0a..e04b3c19 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -47,6 +47,7 @@ type Proxy struct { endpointCancel map[string]context.CancelFunc // cancel functions per endpoint ctxMu sync.RWMutex // protects context maps onEndpointSuccess func(endpointName string) // callback when endpoint request succeeds + modelsCache *ModelsCache // Cache for /v1/models endpoint } // New creates a new Proxy instance @@ -79,6 +80,7 @@ func New(cfg *config.Config, statsStorage StatsStorage, sqliteStorage *storage.S activeRequests: make(map[string]bool), endpointCtx: make(map[string]context.Context), endpointCancel: make(map[string]context.CancelFunc), + modelsCache: NewModelsCache(cfg.ModelsCacheTTL), } } @@ -106,6 +108,7 @@ func (p *Proxy) StartWithMux(customMux *http.ServeMux) error { // Register proxy routes mux.HandleFunc("/", p.handleProxy) mux.HandleFunc("/v1/messages/count_tokens", p.handleCountTokens) + mux.HandleFunc("/v1/models", p.handleModels) mux.HandleFunc("/health", p.handleHealth) mux.HandleFunc("/stats", p.handleStats)