Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,7 @@ Go Micro’s `ai` package gives every provider the same interface: `Init`, `Gene
|----------|--------|---------------|
| **Anthropic** | `go-micro.dev/v5/ai/anthropic` | `claude-sonnet-4-20250514` |
| **OpenAI** | `go-micro.dev/v5/ai/openai` | `gpt-4o` |
| **Atlas Cloud** | `go-micro.dev/v5/ai/atlascloud` | `llama-3.3-70b` |

Any provider that exposes an OpenAI-compatible API can also be used directly:

Expand Down
14 changes: 14 additions & 0 deletions ai/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,20 @@ m := ai.New("openai",
Default model: `gpt-4o`
Default base URL: `https://api.openai.com`

### Atlas Cloud

```go
m := ai.New("atlascloud",
ai.WithAPIKey("your-key"),
ai.WithModel("llama-3.3-70b"), // default
)
```

Default model: `llama-3.3-70b`
Default base URL: `https://api.atlascloud.ai`

Atlas Cloud is an enterprise AI infrastructure platform offering high-performance LLM APIs. It exposes an OpenAI-compatible chat completions endpoint with tool calling support.

## Auto-Detection

Use `AutoDetectProvider()` to detect the provider from a base URL:
Expand Down
208 changes: 208 additions & 0 deletions ai/atlascloud/atlascloud.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
// Package atlascloud implements the Atlas Cloud model provider.
//
// Atlas Cloud is an enterprise AI infrastructure platform offering
// high-performance LLM APIs. It exposes an OpenAI-compatible
// chat completions endpoint.
//
// Usage:
//
// import _ "go-micro.dev/v5/ai/atlascloud"
//
// m := ai.New("atlascloud",
// ai.WithAPIKey("your-api-key"),
// )
package atlascloud

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"

"go-micro.dev/v5/ai"
)

func init() {
ai.Register("atlascloud", func(opts ...ai.Option) ai.Model {
return NewProvider(opts...)
})
}

// Provider implements the ai.Model interface for Atlas Cloud.
type Provider struct {
opts ai.Options
}

// NewProvider creates a new Atlas Cloud provider.
func NewProvider(opts ...ai.Option) *Provider {
options := ai.NewOptions(opts...)

if options.Model == "" {
options.Model = "llama-3.3-70b"
}
if options.BaseURL == "" {
options.BaseURL = "https://api.atlascloud.ai"
}

return &Provider{opts: options}
}

func (p *Provider) Init(opts ...ai.Option) error {
for _, o := range opts {
o(&p.opts)
}
return nil
}

func (p *Provider) Options() ai.Options { return p.opts }
func (p *Provider) String() string { return "atlascloud" }

func (p *Provider) Generate(ctx context.Context, req *ai.Request, opts ...ai.GenerateOption) (*ai.Response, error) {
var tools []map[string]any
for _, t := range req.Tools {
tools = append(tools, map[string]any{
"type": "function",
"function": map[string]any{
"name": t.Name,
"description": t.Description,
"parameters": map[string]any{
"type": "object",
"properties": t.Properties,
},
},
})
}

messages := []map[string]any{
{"role": "system", "content": req.SystemPrompt},
{"role": "user", "content": req.Prompt},
}

apiReq := map[string]any{
"model": p.opts.Model,
"messages": messages,
}

if len(tools) > 0 {
apiReq["tools"] = tools
}

resp, rawMessage, err := p.callAPI(ctx, apiReq)
if err != nil {
return nil, err
}

if len(resp.ToolCalls) == 0 {
return resp, nil
}

if p.opts.ToolHandler != nil {
followUpMessages := append(messages, map[string]any{
"role": "assistant",
"content": rawMessage["content"],
"tool_calls": rawMessage["tool_calls"],
})

for _, tc := range resp.ToolCalls {
_, content := p.opts.ToolHandler(tc.Name, tc.Input)
followUpMessages = append(followUpMessages, map[string]any{
"role": "tool",
"tool_call_id": tc.ID,
"content": content,
})
}

followUpReq := map[string]any{
"model": p.opts.Model,
"messages": followUpMessages,
}

followUpResp, _, err := p.callAPI(ctx, followUpReq)
if err == nil && followUpResp.Reply != "" {
resp.Answer = followUpResp.Reply
}
}

return resp, nil
}

func (p *Provider) Stream(ctx context.Context, req *ai.Request, opts ...ai.GenerateOption) (ai.Stream, error) {
return nil, fmt.Errorf("streaming not yet implemented for atlascloud provider")
}

func (p *Provider) callAPI(ctx context.Context, req map[string]any) (*ai.Response, map[string]any, error) {
reqBody, err := json.Marshal(req)
if err != nil {
return nil, nil, fmt.Errorf("failed to marshal request: %w", err)
}

apiURL := strings.TrimRight(p.opts.BaseURL, "/") + "/v1/chat/completions"
httpReq, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(reqBody))
if err != nil {
return nil, nil, fmt.Errorf("failed to create request: %w", err)
}

httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+p.opts.APIKey)

httpResp, err := http.DefaultClient.Do(httpReq)
if err != nil {
return nil, nil, fmt.Errorf("API request failed: %w", err)
}
defer httpResp.Body.Close()

respBody, _ := io.ReadAll(httpResp.Body)
if httpResp.StatusCode != 200 {
return nil, nil, fmt.Errorf("API error (%s): %s", httpResp.Status, string(respBody))
}

var chatResp struct {
Choices []struct {
Message struct {
Content string `json:"content"`
ToolCalls []struct {
ID string `json:"id"`
Function struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
} `json:"function"`
} `json:"tool_calls"`
} `json:"message"`
} `json:"choices"`
}

if err := json.Unmarshal(respBody, &chatResp); err != nil {
return nil, nil, fmt.Errorf("failed to parse response: %w", err)
}

if len(chatResp.Choices) == 0 {
return nil, nil, fmt.Errorf("no response from API")
}

choice := chatResp.Choices[0]
response := &ai.Response{
Reply: choice.Message.Content,
}

for _, tc := range choice.Message.ToolCalls {
var input map[string]any
if err := json.Unmarshal([]byte(tc.Function.Arguments), &input); err != nil {
input = map[string]any{}
}
response.ToolCalls = append(response.ToolCalls, ai.ToolCall{
ID: tc.ID,
Name: tc.Function.Name,
Input: input,
})
}

rawMessage := map[string]any{
"content": choice.Message.Content,
"tool_calls": choice.Message.ToolCalls,
}

return response, rawMessage, nil
}
104 changes: 104 additions & 0 deletions ai/atlascloud/atlascloud_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package atlascloud

import (
"context"
"testing"

"go-micro.dev/v5/ai"
)

func TestProvider_String(t *testing.T) {
p := NewProvider()
if p.String() != "atlascloud" {
t.Errorf("Expected provider name 'atlascloud', got '%s'", p.String())
}
}

func TestProvider_Init(t *testing.T) {
p := NewProvider()

err := p.Init(
ai.WithModel("test-model"),
ai.WithAPIKey("test-key"),
ai.WithBaseURL("https://test.com"),
)

if err != nil {
t.Fatalf("Init failed: %v", err)
}

opts := p.Options()
if opts.Model != "test-model" {
t.Errorf("Expected model 'test-model', got '%s'", opts.Model)
}
if opts.APIKey != "test-key" {
t.Errorf("Expected API key 'test-key', got '%s'", opts.APIKey)
}
if opts.BaseURL != "https://test.com" {
t.Errorf("Expected base URL 'https://test.com', got '%s'", opts.BaseURL)
}
}

func TestProvider_Options(t *testing.T) {
p := NewProvider(
ai.WithModel("custom-model"),
ai.WithAPIKey("my-key"),
)

opts := p.Options()
if opts.Model != "custom-model" {
t.Errorf("Expected model 'custom-model', got '%s'", opts.Model)
}
if opts.APIKey != "my-key" {
t.Errorf("Expected API key 'my-key', got '%s'", opts.APIKey)
}
}

func TestProvider_Defaults(t *testing.T) {
p := NewProvider()

opts := p.Options()
if opts.Model != "llama-3.3-70b" {
t.Errorf("Expected default model 'llama-3.3-70b', got '%s'", opts.Model)
}
if opts.BaseURL != "https://api.atlascloud.ai" {
t.Errorf("Expected default base URL 'https://api.atlascloud.ai', got '%s'", opts.BaseURL)
}
}

func TestProvider_Generate_NoAPIKey(t *testing.T) {
p := NewProvider()

req := &ai.Request{
Prompt: "Hello",
SystemPrompt: "You are helpful",
}

_, err := p.Generate(context.Background(), req)
if err == nil {
t.Error("Expected error when API key is missing, got nil")
}
}

func TestProvider_Stream_NotImplemented(t *testing.T) {
p := NewProvider()

req := &ai.Request{
Prompt: "Hello",
}

_, err := p.Stream(context.Background(), req)
if err == nil {
t.Error("Expected error for unimplemented streaming, got nil")
}
}

func TestProvider_Registration(t *testing.T) {
m := ai.New("atlascloud", ai.WithAPIKey("test"))
if m == nil {
t.Fatal("ai.New('atlascloud') returned nil — provider not registered")
}
if m.String() != "atlascloud" {
t.Errorf("Expected 'atlascloud', got '%s'", m.String())
}
}
Loading