diff --git a/README.md b/README.md index b81bbdbd56..5228c35305 100644 --- a/README.md +++ b/README.md @@ -397,6 +397,7 @@ Go Micro’s `ai` package gives every provider the same interface: `Init`, `Gene |----------|--------|---------------| | **Anthropic** | `go-micro.dev/v5/ai/anthropic` | `claude-sonnet-4-20250514` | | **OpenAI** | `go-micro.dev/v5/ai/openai` | `gpt-4o` | +| **Atlas Cloud** | `go-micro.dev/v5/ai/atlascloud` | `llama-3.3-70b` | Any provider that exposes an OpenAI-compatible API can also be used directly: diff --git a/ai/README.md b/ai/README.md index 21abe3944e..17808c2607 100644 --- a/ai/README.md +++ b/ai/README.md @@ -152,6 +152,20 @@ m := ai.New("openai", Default model: `gpt-4o` Default base URL: `https://api.openai.com` +### Atlas Cloud + +```go +m := ai.New("atlascloud", + ai.WithAPIKey("your-key"), + ai.WithModel("llama-3.3-70b"), // default +) +``` + +Default model: `llama-3.3-70b` +Default base URL: `https://api.atlascloud.ai` + +Atlas Cloud is an enterprise AI infrastructure platform offering high-performance LLM APIs. It exposes an OpenAI-compatible chat completions endpoint with tool calling support. + ## Auto-Detection Use `AutoDetectProvider()` to detect the provider from a base URL: diff --git a/ai/atlascloud/atlascloud.go b/ai/atlascloud/atlascloud.go new file mode 100644 index 0000000000..a12e634911 --- /dev/null +++ b/ai/atlascloud/atlascloud.go @@ -0,0 +1,208 @@ +// Package atlascloud implements the Atlas Cloud model provider. +// +// Atlas Cloud is an enterprise AI infrastructure platform offering +// high-performance LLM APIs. It exposes an OpenAI-compatible +// chat completions endpoint. +// +// Usage: +// +// import _ "go-micro.dev/v5/ai/atlascloud" +// +// m := ai.New("atlascloud", +// ai.WithAPIKey("your-api-key"), +// ) +package atlascloud + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + "go-micro.dev/v5/ai" +) + +func init() { + ai.Register("atlascloud", func(opts ...ai.Option) ai.Model { + return NewProvider(opts...) + }) +} + +// Provider implements the ai.Model interface for Atlas Cloud. +type Provider struct { + opts ai.Options +} + +// NewProvider creates a new Atlas Cloud provider. +func NewProvider(opts ...ai.Option) *Provider { + options := ai.NewOptions(opts...) + + if options.Model == "" { + options.Model = "llama-3.3-70b" + } + if options.BaseURL == "" { + options.BaseURL = "https://api.atlascloud.ai" + } + + return &Provider{opts: options} +} + +func (p *Provider) Init(opts ...ai.Option) error { + for _, o := range opts { + o(&p.opts) + } + return nil +} + +func (p *Provider) Options() ai.Options { return p.opts } +func (p *Provider) String() string { return "atlascloud" } + +func (p *Provider) Generate(ctx context.Context, req *ai.Request, opts ...ai.GenerateOption) (*ai.Response, error) { + var tools []map[string]any + for _, t := range req.Tools { + tools = append(tools, map[string]any{ + "type": "function", + "function": map[string]any{ + "name": t.Name, + "description": t.Description, + "parameters": map[string]any{ + "type": "object", + "properties": t.Properties, + }, + }, + }) + } + + messages := []map[string]any{ + {"role": "system", "content": req.SystemPrompt}, + {"role": "user", "content": req.Prompt}, + } + + apiReq := map[string]any{ + "model": p.opts.Model, + "messages": messages, + } + + if len(tools) > 0 { + apiReq["tools"] = tools + } + + resp, rawMessage, err := p.callAPI(ctx, apiReq) + if err != nil { + return nil, err + } + + if len(resp.ToolCalls) == 0 { + return resp, nil + } + + if p.opts.ToolHandler != nil { + followUpMessages := append(messages, map[string]any{ + "role": "assistant", + "content": rawMessage["content"], + "tool_calls": rawMessage["tool_calls"], + }) + + for _, tc := range resp.ToolCalls { + _, content := p.opts.ToolHandler(tc.Name, tc.Input) + followUpMessages = append(followUpMessages, map[string]any{ + "role": "tool", + "tool_call_id": tc.ID, + "content": content, + }) + } + + followUpReq := map[string]any{ + "model": p.opts.Model, + "messages": followUpMessages, + } + + followUpResp, _, err := p.callAPI(ctx, followUpReq) + if err == nil && followUpResp.Reply != "" { + resp.Answer = followUpResp.Reply + } + } + + return resp, nil +} + +func (p *Provider) Stream(ctx context.Context, req *ai.Request, opts ...ai.GenerateOption) (ai.Stream, error) { + return nil, fmt.Errorf("streaming not yet implemented for atlascloud provider") +} + +func (p *Provider) callAPI(ctx context.Context, req map[string]any) (*ai.Response, map[string]any, error) { + reqBody, err := json.Marshal(req) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal request: %w", err) + } + + apiURL := strings.TrimRight(p.opts.BaseURL, "/") + "/v1/chat/completions" + httpReq, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(reqBody)) + if err != nil { + return nil, nil, fmt.Errorf("failed to create request: %w", err) + } + + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Authorization", "Bearer "+p.opts.APIKey) + + httpResp, err := http.DefaultClient.Do(httpReq) + if err != nil { + return nil, nil, fmt.Errorf("API request failed: %w", err) + } + defer httpResp.Body.Close() + + respBody, _ := io.ReadAll(httpResp.Body) + if httpResp.StatusCode != 200 { + return nil, nil, fmt.Errorf("API error (%s): %s", httpResp.Status, string(respBody)) + } + + var chatResp struct { + Choices []struct { + Message struct { + Content string `json:"content"` + ToolCalls []struct { + ID string `json:"id"` + Function struct { + Name string `json:"name"` + Arguments string `json:"arguments"` + } `json:"function"` + } `json:"tool_calls"` + } `json:"message"` + } `json:"choices"` + } + + if err := json.Unmarshal(respBody, &chatResp); err != nil { + return nil, nil, fmt.Errorf("failed to parse response: %w", err) + } + + if len(chatResp.Choices) == 0 { + return nil, nil, fmt.Errorf("no response from API") + } + + choice := chatResp.Choices[0] + response := &ai.Response{ + Reply: choice.Message.Content, + } + + for _, tc := range choice.Message.ToolCalls { + var input map[string]any + if err := json.Unmarshal([]byte(tc.Function.Arguments), &input); err != nil { + input = map[string]any{} + } + response.ToolCalls = append(response.ToolCalls, ai.ToolCall{ + ID: tc.ID, + Name: tc.Function.Name, + Input: input, + }) + } + + rawMessage := map[string]any{ + "content": choice.Message.Content, + "tool_calls": choice.Message.ToolCalls, + } + + return response, rawMessage, nil +} diff --git a/ai/atlascloud/atlascloud_test.go b/ai/atlascloud/atlascloud_test.go new file mode 100644 index 0000000000..706e296a77 --- /dev/null +++ b/ai/atlascloud/atlascloud_test.go @@ -0,0 +1,104 @@ +package atlascloud + +import ( + "context" + "testing" + + "go-micro.dev/v5/ai" +) + +func TestProvider_String(t *testing.T) { + p := NewProvider() + if p.String() != "atlascloud" { + t.Errorf("Expected provider name 'atlascloud', got '%s'", p.String()) + } +} + +func TestProvider_Init(t *testing.T) { + p := NewProvider() + + err := p.Init( + ai.WithModel("test-model"), + ai.WithAPIKey("test-key"), + ai.WithBaseURL("https://test.com"), + ) + + if err != nil { + t.Fatalf("Init failed: %v", err) + } + + opts := p.Options() + if opts.Model != "test-model" { + t.Errorf("Expected model 'test-model', got '%s'", opts.Model) + } + if opts.APIKey != "test-key" { + t.Errorf("Expected API key 'test-key', got '%s'", opts.APIKey) + } + if opts.BaseURL != "https://test.com" { + t.Errorf("Expected base URL 'https://test.com', got '%s'", opts.BaseURL) + } +} + +func TestProvider_Options(t *testing.T) { + p := NewProvider( + ai.WithModel("custom-model"), + ai.WithAPIKey("my-key"), + ) + + opts := p.Options() + if opts.Model != "custom-model" { + t.Errorf("Expected model 'custom-model', got '%s'", opts.Model) + } + if opts.APIKey != "my-key" { + t.Errorf("Expected API key 'my-key', got '%s'", opts.APIKey) + } +} + +func TestProvider_Defaults(t *testing.T) { + p := NewProvider() + + opts := p.Options() + if opts.Model != "llama-3.3-70b" { + t.Errorf("Expected default model 'llama-3.3-70b', got '%s'", opts.Model) + } + if opts.BaseURL != "https://api.atlascloud.ai" { + t.Errorf("Expected default base URL 'https://api.atlascloud.ai', got '%s'", opts.BaseURL) + } +} + +func TestProvider_Generate_NoAPIKey(t *testing.T) { + p := NewProvider() + + req := &ai.Request{ + Prompt: "Hello", + SystemPrompt: "You are helpful", + } + + _, err := p.Generate(context.Background(), req) + if err == nil { + t.Error("Expected error when API key is missing, got nil") + } +} + +func TestProvider_Stream_NotImplemented(t *testing.T) { + p := NewProvider() + + req := &ai.Request{ + Prompt: "Hello", + } + + _, err := p.Stream(context.Background(), req) + if err == nil { + t.Error("Expected error for unimplemented streaming, got nil") + } +} + +func TestProvider_Registration(t *testing.T) { + m := ai.New("atlascloud", ai.WithAPIKey("test")) + if m == nil { + t.Fatal("ai.New('atlascloud') returned nil — provider not registered") + } + if m.String() != "atlascloud" { + t.Errorf("Expected 'atlascloud', got '%s'", m.String()) + } +}