Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 191 additions & 0 deletions cmd/venat/internal/agentloop/agentloop.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
package agentloop

import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"sync"
"time"

"github.com/google/uuid"
"github.com/openai/openai-go/v3"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)

var (
ErrSentinelAbort = errors.New("agentloop: tool requested the agent loop to abort")
ErrSentinelOkay = errors.New("agentloop: tool requested the agent loop to stop (status okay)")

tokensUsed = promauto.NewGaugeVec(prometheus.GaugeOpts{
Namespace: "venat",
Subsystem: "agentloop",
Name: "tokens_used",
}, []string{"model", "kind"})
)

type Impl struct {
Name, ID string
Tools map[string]Tool
SystemPrompt string

model string
cli openai.Client
lg *slog.Logger

messages []openai.ChatCompletionMessageParamUnion
lock sync.Mutex
}

func New(name, id, systemPrompt, model string, tools []Tool, cli openai.Client, lg *slog.Logger) *Impl {
if id == "" {
id = uuid.Must(uuid.NewV7()).String()
}

toolMap := map[string]Tool{}
for _, tool := range tools {
toolMap[tool.Name()] = tool
}

result := Impl{
Name: name,
ID: id,
Tools: toolMap,
SystemPrompt: systemPrompt,
model: model,
cli: cli,
lg: lg,
messages: []openai.ChatCompletionMessageParamUnion{
openai.SystemMessage(systemPrompt),
},
}

return &result
}

type Result struct {
Messages []openai.ChatCompletionMessageParamUnion
Response string

PromptTokens int64
PromptCachedTokens int64
CompletionTokens int64
CompletionReasoningTokens int64
}

func (i *Impl) Run(ctx context.Context, prompt string, opts ...func(*openai.ChatCompletionNewParams)) (*Result, error) {
i.lock.Lock()
defer i.lock.Unlock()

lg := i.lg.With("component", "agentloop", "name", i.Name, "id", i.ID, "model", i.model)

i.messages = append(i.messages, openai.UserMessage(prompt))

failCount := 0
const failMax = 5

result := Result{}

for {
select {
case <-ctx.Done():
lg.Error("context done", "err", ctx.Err())
return &result, ctx.Err()
default:
}

params := openai.ChatCompletionNewParams{
Messages: i.messages,
Model: openai.ChatModel(i.model),
}

for _, opt := range opts {
opt(&params)
}

for _, tool := range i.Tools {
params.Tools = append(params.Tools, openai.ChatCompletionFunctionTool(tool.Usage()))
}

completion, err := i.cli.Chat.Completions.New(ctx, params)
if err != nil {
failCount++

if failCount == failMax {
return &result, fmt.Errorf("can't reach remote API: %w", err)
}

lg.Error("can't get completion, sleeping and retrying", "err", err, "failCount", failCount, "failMax", failMax)
time.Sleep(time.Duration(failCount) * time.Second)
continue
}

tokensUsed.WithLabelValues(i.model, "input").Add(float64(completion.Usage.PromptTokens))
tokensUsed.WithLabelValues(i.model, "output").Add(float64(completion.Usage.CompletionTokens))
tokensUsed.WithLabelValues(i.model, "cached").Add(float64(completion.Usage.PromptTokensDetails.CachedTokens))
tokensUsed.WithLabelValues(i.model, "reasoning").Add(float64(completion.Usage.CompletionTokensDetails.ReasoningTokens))

result.PromptTokens += completion.Usage.PromptTokens
result.PromptCachedTokens += completion.Usage.PromptTokensDetails.CachedTokens
result.CompletionTokens += completion.Usage.CompletionTokens
result.CompletionReasoningTokens += completion.Usage.CompletionTokensDetails.ReasoningTokens

choice := completion.Choices[0]
resp := choice.Message

i.messages = append(i.messages, resp.ToParam())
result.Messages = i.messages

if resp.Content != "" {
result.Response = resp.Content
}

lg.Debug("got finish reason", "reason", choice.FinishReason)
if choice.FinishReason == "stop" {
return &result, nil
}

toolCalls := completion.Choices[0].Message.ToolCalls

for _, tc := range toolCalls {
lg := lg.With("tool", tc.Function.Name, "toolcall_id", tc.ID)
tool, ok := i.Tools[tc.Function.Name]
if !ok {
lg.Error("AI model chose tool that did not exist, asking it to try again")
i.messages = append(i.messages, openai.UserMessage(fmt.Sprintf("Tool %q does not exist, please try again.", tc.Function.Name)))
continue
}

args := []byte(tc.Function.Arguments)
if err := tool.Valid(args); err != nil {
lg.Error("AI model produced invalid arguments", "err", err)
i.messages = append(i.messages, openai.UserMessage(fmt.Sprintf("When calling tool %q, you got an argument validation error: %v", tool.Name(), err)))
continue
}

lg.Debug("calling tool", "args", json.RawMessage(args))

toolResult, err := tool.Run(ctx, args)
if err != nil {
switch {
case errors.Is(err, ErrSentinelOkay):
lg.Info("tool requested happy exit", "err", err)
return &result, err
case errors.Is(err, ErrSentinelAbort):
lg.Info("tool requested unhappy abort", "err", err)
return &result, err
default:
lg.Error("failed to run tool", "err", err)
i.messages = append(i.messages, openai.ToolMessage(fmt.Sprintf("internal error when running tool %q: %v", tool.Name(), err), tc.ID))
continue
}
}

lg.Debug("got response", "result", string(toolResult))

i.messages = append(i.messages, openai.ToolMessage(string(toolResult), tc.ID))
}
}
}
7 changes: 7 additions & 0 deletions cmd/venat/internal/agentloop/option.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package agentloop

import "github.com/openai/openai-go/v3"

func EnableParallelToolCalling(params *openai.ChatCompletionNewParams) {
params.ParallelToolCalls = openai.Bool(true)
}
14 changes: 14 additions & 0 deletions cmd/venat/internal/agentloop/tool.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package agentloop

import (
"context"

"github.com/openai/openai-go/v3"
)

type Tool interface {
Name() string
Usage() openai.FunctionDefinitionParam
Valid(data []byte) (err error)
Run(ctx context.Context, data []byte) ([]byte, error)
}
70 changes: 70 additions & 0 deletions cmd/venat/internal/models/backup.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package models

import (
"context"
"fmt"
"log/slog"
"time"

"github.com/ncruces/go-sqlite3"
)

func (d *DAO) Backup() {
slog.Info("starting backup")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
err := d.backup(ctx, d.backupDBLoc)
if err != nil {
slog.Error("failed to backup database", "err", err)
}
slog.Info("backup done")
}

func (d *DAO) backup(ctx context.Context, to string) error {
db, err := d.db.DB()
if err != nil {
return fmt.Errorf("failed to get database connection: %w", err)
}

if err := db.Ping(); err != nil {
return fmt.Errorf("failed to ping database: %w", err)
}

conn, err := db.Conn(ctx)
if err != nil {
return fmt.Errorf("failed to get database connection: %w", err)
}

defer conn.Close()

if err := conn.Raw(func(dca any) error {
conn, ok := dca.(sqlite3.DriverConn)
if !ok {
return fmt.Errorf("db connection is not a sqlite3 connection, it is %T", dca)
}

bu, err := conn.Raw().BackupInit("main", to)
if err != nil {
return fmt.Errorf("failed to initialize backup: %w", err)
}
defer bu.Close()

var done bool
for !done {
done, err = bu.Step(bu.Remaining())
if err != nil {
return fmt.Errorf("failed to backup database: %w", err)
}
}

if err := bu.Close(); err != nil {
return fmt.Errorf("failed to close backup: %w", err)
}

return nil
}); err != nil {
return fmt.Errorf("failed to backup database: %w", err)
}

return nil
}
51 changes: 51 additions & 0 deletions cmd/venat/internal/models/dao.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package models

import (
"context"
"fmt"

_ "github.com/ncruces/go-sqlite3/embed"
"github.com/ncruces/go-sqlite3/gormlite"
slogGorm "github.com/orandin/slog-gorm"
"gorm.io/gorm"
gormPrometheus "gorm.io/plugin/prometheus"
)

type DAO struct {
db *gorm.DB
backupDBLoc string
}

func (d *DAO) DB() *gorm.DB {
return d.db
}

func (d *DAO) Ping(ctx context.Context) error {
if err := d.db.WithContext(ctx).Exec("select 1+1").Error; err != nil {
return err
}

return nil
}

func New(dbLoc, backupDBLoc string) (*DAO, error) {
db, err := gorm.Open(gormlite.Open(dbLoc), &gorm.Config{
Logger: slogGorm.New(
slogGorm.WithErrorField("err"),
slogGorm.WithRecordNotFoundError(),
),
})
if err != nil {
return nil, fmt.Errorf("failed to connect to database: %w", err)
}

if err := db.AutoMigrate(); err != nil {
return nil, fmt.Errorf("failed to migrate schema: %w", err)
}

db.Use(gormPrometheus.New(gormPrometheus.Config{
DBName: "venat",
}))

return &DAO{db: db, backupDBLoc: backupDBLoc}, nil
}
Loading
Loading