diff --git a/providers/openai/computer_use.go b/providers/openai/computer_use.go new file mode 100644 index 000000000..c4d35aa51 --- /dev/null +++ b/providers/openai/computer_use.go @@ -0,0 +1,563 @@ +package openai + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "fmt" + "math" + + "charm.land/fantasy" + "github.com/charmbracelet/openai-go/packages/param" + "github.com/charmbracelet/openai-go/responses" +) + +const ( + computerUseToolID = "openai.computer" + computerUseAPIName = "computer" + computerUseStoreError = "openai computer use requires store to be true in openai responses provider options" + maxExactIntFloat64 = float64(1<<53 - 1) +) + +// ComputerUseToolOptions configures the OpenAI computer use tool. +type ComputerUseToolOptions struct { + DisplayWidthPx int64 + DisplayHeightPx int64 + Environment responses.ComputerUsePreviewToolEnvironment +} + +// NewComputerUseTool creates a new executable OpenAI computer use tool. +func NewComputerUseTool( + opts ComputerUseToolOptions, + run func(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error), +) fantasy.ExecutableProviderTool { + args := map[string]any{ + "display_width_px": opts.DisplayWidthPx, + "display_height_px": opts.DisplayHeightPx, + "environment": opts.Environment, + } + return fantasy.NewExecutableProviderTool(fantasy.ProviderDefinedTool{ + ID: computerUseToolID, + Name: computerUseAPIName, + Args: args, + }, run) +} + +// ComputerUseActionType identifies a single OpenAI computer action. +type ComputerUseActionType string + +const ( + ComputerUseActionTypeClick ComputerUseActionType = "click" + ComputerUseActionTypeDoubleClick ComputerUseActionType = "double_click" + ComputerUseActionTypeDrag ComputerUseActionType = "drag" + ComputerUseActionTypeKeypress ComputerUseActionType = "keypress" + ComputerUseActionTypeMove ComputerUseActionType = "move" + ComputerUseActionTypeScreenshot ComputerUseActionType = "screenshot" + ComputerUseActionTypeScroll ComputerUseActionType = "scroll" + ComputerUseActionTypeType ComputerUseActionType = "type" + ComputerUseActionTypeWait ComputerUseActionType = "wait" +) + +// ComputerUseAction represents a parsed OpenAI computer action. +type ComputerUseAction interface { + Type() ComputerUseActionType + isComputerUseAction() +} + +// ComputerUseInput is the parsed representation of a computer tool call. +// Single-action payloads populate Action. Batched payloads populate Actions. +type ComputerUseInput struct { + Action ComputerUseAction + Actions []ComputerUseAction +} + +// ComputerUsePoint represents an x/y point. +type ComputerUsePoint struct { + X int64 `json:"x"` + Y int64 `json:"y"` +} + +// ComputerUseClickAction represents a click action. +type ComputerUseClickAction struct { + Button string `json:"button"` + X int64 `json:"x"` + Y int64 `json:"y"` +} + +func (ComputerUseClickAction) isComputerUseAction() {} + +// Type returns the action discriminator. +func (ComputerUseClickAction) Type() ComputerUseActionType { return ComputerUseActionTypeClick } + +// ComputerUseDoubleClickAction represents a double-click action. +type ComputerUseDoubleClickAction struct { + X int64 `json:"x"` + Y int64 `json:"y"` +} + +func (ComputerUseDoubleClickAction) isComputerUseAction() {} + +// Type returns the action discriminator. +func (ComputerUseDoubleClickAction) Type() ComputerUseActionType { + return ComputerUseActionTypeDoubleClick +} + +// ComputerUseDragAction represents a drag action. +type ComputerUseDragAction struct { + Path []ComputerUsePoint `json:"path"` +} + +func (ComputerUseDragAction) isComputerUseAction() {} + +// Type returns the action discriminator. +func (ComputerUseDragAction) Type() ComputerUseActionType { return ComputerUseActionTypeDrag } + +// ComputerUseKeypressAction represents a keypress action. +type ComputerUseKeypressAction struct { + Keys []string `json:"keys"` +} + +func (ComputerUseKeypressAction) isComputerUseAction() {} + +// Type returns the action discriminator. +func (ComputerUseKeypressAction) Type() ComputerUseActionType { + return ComputerUseActionTypeKeypress +} + +// ComputerUseMoveAction represents a move action. +type ComputerUseMoveAction struct { + X int64 `json:"x"` + Y int64 `json:"y"` +} + +func (ComputerUseMoveAction) isComputerUseAction() {} + +// Type returns the action discriminator. +func (ComputerUseMoveAction) Type() ComputerUseActionType { return ComputerUseActionTypeMove } + +// ComputerUseScreenshotAction represents a screenshot action. +type ComputerUseScreenshotAction struct{} + +func (ComputerUseScreenshotAction) isComputerUseAction() {} + +// Type returns the action discriminator. +func (ComputerUseScreenshotAction) Type() ComputerUseActionType { + return ComputerUseActionTypeScreenshot +} + +// ComputerUseScrollAction represents a scroll action. +type ComputerUseScrollAction struct { + X int64 `json:"x"` + Y int64 `json:"y"` + ScrollX int64 `json:"scroll_x"` + ScrollY int64 `json:"scroll_y"` +} + +func (ComputerUseScrollAction) isComputerUseAction() {} + +// Type returns the action discriminator. +func (ComputerUseScrollAction) Type() ComputerUseActionType { return ComputerUseActionTypeScroll } + +// ComputerUseTypeAction represents a type action. +type ComputerUseTypeAction struct { + Text string `json:"text"` +} + +func (ComputerUseTypeAction) isComputerUseAction() {} + +// Type returns the action discriminator. +func (ComputerUseTypeAction) Type() ComputerUseActionType { return ComputerUseActionTypeType } + +// ComputerUseWaitAction represents a wait action. +type ComputerUseWaitAction struct{} + +func (ComputerUseWaitAction) isComputerUseAction() {} + +// Type returns the action discriminator. +func (ComputerUseWaitAction) Type() ComputerUseActionType { return ComputerUseActionTypeWait } + +// ParseComputerUseInput parses a raw OpenAI computer-use input payload. +// Single actions are encoded as a JSON object. Batched actions are encoded as a +// JSON array. +func ParseComputerUseInput(input []byte) (ComputerUseInput, error) { + trimmed := bytes.TrimSpace(input) + if len(trimmed) == 0 { + return ComputerUseInput{}, fmt.Errorf("computer use input is empty") + } + + switch trimmed[0] { + case '{': + action, err := parseComputerUseAction(trimmed) + if err != nil { + return ComputerUseInput{}, err + } + return ComputerUseInput{Action: action}, nil + case '[': + var rawActions []json.RawMessage + if err := json.Unmarshal(trimmed, &rawActions); err != nil { + return ComputerUseInput{}, err + } + if len(rawActions) == 0 { + return ComputerUseInput{}, fmt.Errorf("openai computer use input requires at least one action") + } + actions := make([]ComputerUseAction, 0, len(rawActions)) + for _, rawAction := range rawActions { + action, err := parseComputerUseAction(rawAction) + if err != nil { + return ComputerUseInput{}, err + } + actions = append(actions, action) + } + return ComputerUseInput{Actions: actions}, nil + default: + return ComputerUseInput{}, fmt.Errorf("computer use input must be a JSON object or array") + } +} + +func parseComputerUseAction(input []byte) (ComputerUseAction, error) { + var header struct { + Type ComputerUseActionType `json:"type"` + } + if err := json.Unmarshal(input, &header); err != nil { + return nil, err + } + + switch header.Type { + case ComputerUseActionTypeClick: + var action ComputerUseClickAction + return action, json.Unmarshal(input, &action) + case ComputerUseActionTypeDoubleClick: + var action ComputerUseDoubleClickAction + return action, json.Unmarshal(input, &action) + case ComputerUseActionTypeDrag: + var action ComputerUseDragAction + if err := json.Unmarshal(input, &action); err != nil { + return nil, err + } + if len(action.Path) == 0 { + return nil, fmt.Errorf("computer use drag action requires a non-empty path") + } + return action, nil + case ComputerUseActionTypeKeypress: + var action ComputerUseKeypressAction + return action, json.Unmarshal(input, &action) + case ComputerUseActionTypeMove: + var action ComputerUseMoveAction + return action, json.Unmarshal(input, &action) + case ComputerUseActionTypeScreenshot: + var action ComputerUseScreenshotAction + return action, json.Unmarshal(input, &action) + case ComputerUseActionTypeScroll: + var action ComputerUseScrollAction + return action, json.Unmarshal(input, &action) + case ComputerUseActionTypeType: + var action ComputerUseTypeAction + return action, json.Unmarshal(input, &action) + case ComputerUseActionTypeWait: + var action ComputerUseWaitAction + return action, json.Unmarshal(input, &action) + default: + return nil, fmt.Errorf("unsupported computer use action type %q", header.Type) + } +} + +// NewComputerUseScreenshotResult returns a screenshot tool result with PNG data. +// Screenshot and media outputs are the only computer-use tool results that +// round-trip through the OpenAI Responses API. Use this helper for OpenAI +// computer-use submissions. +func NewComputerUseScreenshotResult(toolCallID string, screenshotPNG []byte) fantasy.ToolResultPart { + return fantasy.ToolResultPart{ + ToolCallID: toolCallID, + Output: fantasy.ToolResultOutputContentMedia{ + Data: base64.StdEncoding.EncodeToString(screenshotPNG), + MediaType: "image/png", + }, + } +} + +// NewComputerUseScreenshotResultWithMediaType returns a screenshot tool result +// with caller-provided base64 data and media type. Like +// NewComputerUseScreenshotResult, this round-trips through OpenAI Responses +// computer-use because it produces media output. +func NewComputerUseScreenshotResultWithMediaType( + toolCallID string, + base64Data string, + mediaType string, +) fantasy.ToolResultPart { + return fantasy.ToolResultPart{ + ToolCallID: toolCallID, + Output: fantasy.ToolResultOutputContentMedia{ + Data: base64Data, + MediaType: mediaType, + }, + } +} + +// NewComputerUseErrorResult returns an error tool result. +// OpenAI Responses computer-use does not accept this output format on replay. +// Use it for local display, logging, or other non-OpenAI flows that want to +// preserve the failure alongside the tool call ID. +func NewComputerUseErrorResult(toolCallID string, err error) fantasy.ToolResultPart { + return fantasy.ToolResultPart{ + ToolCallID: toolCallID, + Output: fantasy.ToolResultOutputContentError{ + Error: err, + }, + } +} + +// NewComputerUseTextResult returns a text tool result. +// OpenAI Responses computer-use does not accept this output format on replay. +// Use it for local display, logging, or tests that need to keep a textual +// association with the original tool call ID. +func NewComputerUseTextResult(toolCallID string, text string) fantasy.ToolResultPart { + return fantasy.ToolResultPart{ + ToolCallID: toolCallID, + Output: fantasy.ToolResultOutputContentText{ + Text: text, + }, + } +} + +func asProviderDefinedTool(tool fantasy.Tool) (fantasy.ProviderDefinedTool, bool) { + if pdt, ok := tool.(fantasy.ProviderDefinedTool); ok { + return pdt, true + } + if ept, ok := tool.(fantasy.ExecutableProviderTool); ok { + return ept.Definition(), true + } + return fantasy.ProviderDefinedTool{}, false +} + +func hasComputerUseTool(tools []fantasy.Tool) bool { + for _, tool := range tools { + pt, ok := asProviderDefinedTool(tool) + if ok && pt.ID == computerUseToolID { + return true + } + } + return false +} + +func anyToInt64(v any) (int64, bool) { + switch typed := v.(type) { + case int: + return int64(typed), true + case int8: + return int64(typed), true + case int16: + return int64(typed), true + case int32: + return int64(typed), true + case int64: + return typed, true + case uint: + u64 := uint64(typed) + if u64 > math.MaxInt64 { + return 0, false + } + return int64(u64), true + case uint8: + return int64(typed), true + case uint16: + return int64(typed), true + case uint32: + return int64(typed), true + case uint64: + if typed > math.MaxInt64 { + return 0, false + } + return int64(typed), true + case float32: + f := float64(typed) + if math.Trunc(f) != f || math.IsNaN(f) || math.IsInf(f, 0) || f < -maxExactIntFloat64 || f > maxExactIntFloat64 { + return 0, false + } + return int64(f), true + case float64: + if math.Trunc(typed) != typed || math.IsNaN(typed) || math.IsInf(typed, 0) || typed < -maxExactIntFloat64 || typed > maxExactIntFloat64 { + return 0, false + } + return int64(typed), true + case json.Number: + parsed, err := typed.Int64() + if err != nil { + return 0, false + } + return parsed, true + default: + return 0, false + } +} + +func anyToComputerUseEnvironment(v any) (responses.ComputerUsePreviewToolEnvironment, bool) { + switch typed := v.(type) { + case responses.ComputerUsePreviewToolEnvironment: + return typed, true + case string: + return responses.ComputerUsePreviewToolEnvironment(typed), typed != "" + default: + return "", false + } +} + +func toComputerUseToolParam(pt fantasy.ProviderDefinedTool) (responses.ToolUnionParam, error) { + displayHeight, ok := anyToInt64(pt.Args["display_height_px"]) + if !ok || displayHeight <= 0 { + return responses.ToolUnionParam{}, fmt.Errorf("computer use tool has invalid display_height_px") + } + displayWidth, ok := anyToInt64(pt.Args["display_width_px"]) + if !ok || displayWidth <= 0 { + return responses.ToolUnionParam{}, fmt.Errorf("computer use tool has invalid display_width_px") + } + environment, ok := anyToComputerUseEnvironment(pt.Args["environment"]) + if !ok { + return responses.ToolUnionParam{}, fmt.Errorf("computer use tool has invalid environment") + } + return responses.ToolParamOfComputerUsePreview(displayHeight, displayWidth, environment), nil +} + +func getComputerUseCallMetadata(options fantasy.ProviderOptions) *OpenAIComputerUseCallMetadata { + if options == nil { + return nil + } + if providerOptions, ok := options[Name]; ok { + if metadata, ok := providerOptions.(*OpenAIComputerUseCallMetadata); ok { + return metadata + } + } + return nil +} + +func computerUseToolCallMetadata(call responses.ResponseComputerToolCall) *OpenAIComputerUseCallMetadata { + metadata := &OpenAIComputerUseCallMetadata{ + CallID: call.CallID, + } + if len(call.PendingSafetyChecks) > 0 { + metadata.PendingSafetyChecks = make([]OpenAIComputerUsePendingSafetyCheck, 0, len(call.PendingSafetyChecks)) + for _, check := range call.PendingSafetyChecks { + metadata.PendingSafetyChecks = append(metadata.PendingSafetyChecks, OpenAIComputerUsePendingSafetyCheck{ + ID: check.ID, + Code: check.Code, + Message: check.Message, + }) + } + } + return metadata +} + +func computerUseSafetyChecksToAcknowledgedParams(checks []OpenAIComputerUsePendingSafetyCheck) []responses.ResponseInputItemComputerCallOutputAcknowledgedSafetyCheckParam { + if len(checks) == 0 { + return nil + } + paramsList := make([]responses.ResponseInputItemComputerCallOutputAcknowledgedSafetyCheckParam, 0, len(checks)) + for _, check := range checks { + ack := responses.ResponseInputItemComputerCallOutputAcknowledgedSafetyCheckParam{ + ID: check.ID, + } + if check.Code != "" { + ack.Code = param.NewOpt(check.Code) + } + if check.Message != "" { + ack.Message = param.NewOpt(check.Message) + } + paramsList = append(paramsList, ack) + } + return paramsList +} + +func computerUseToolResultInput(toolResult fantasy.ToolResultPart, metadata *OpenAIComputerUseCallMetadata) (responses.ResponseInputItemUnionParam, error) { + if metadata == nil || metadata.CallID == "" { + return responses.ResponseInputItemUnionParam{}, fmt.Errorf("openai computer tool call metadata is missing call_id") + } + output, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output) + if !ok { + return responses.ResponseInputItemUnionParam{}, fmt.Errorf("openai computer tool results must use media output") + } + if output.MediaType == "" { + return responses.ResponseInputItemUnionParam{}, fmt.Errorf("openai computer tool results must include a media type") + } + item := responses.ResponseInputItemParamOfComputerCallOutput(metadata.CallID, responses.ResponseComputerToolCallOutputScreenshotParam{ + ImageURL: param.NewOpt(fmt.Sprintf("data:%s;base64,%s", output.MediaType, output.Data)), + }) + item.OfComputerCallOutput.AcknowledgedSafetyChecks = computerUseSafetyChecksToAcknowledgedParams(metadata.PendingSafetyChecks) + return item, nil +} + +func computerUseToolCallContent(call responses.ResponseComputerToolCall) (fantasy.ToolCallContent, error) { + input, err := computerUseToolCallInput(call) + if err != nil { + return fantasy.ToolCallContent{}, err + } + return fantasy.ToolCallContent{ + ToolCallID: call.ID, + ToolName: computerUseAPIName, + Input: input, + ProviderExecuted: false, + ProviderMetadata: fantasy.ProviderMetadata{ + Name: computerUseToolCallMetadata(call), + }, + }, nil +} + +func computerUseToolCallInput(call responses.ResponseComputerToolCall) (string, error) { + if len(call.Actions) > 0 { + payload, err := computerUseActionsJSON(call.Actions) + if err != nil { + return "", err + } + return string(payload), nil + } + + payload, err := computerUseResponseActionJSON(call.Action) + if err != nil { + return "", err + } + return string(payload), nil +} + +func computerUseActionsJSON(actions responses.ComputerActionList) ([]byte, error) { + if len(actions) == 0 { + return nil, fmt.Errorf("computer use tool call is missing actions") + } + rawActions := make([]json.RawMessage, 0, len(actions)) + for _, action := range actions { + payload, err := computerUseActionJSON(action) + if err != nil { + return nil, err + } + rawActions = append(rawActions, payload) + } + return json.Marshal(rawActions) +} + +func computerUseActionJSON(action responses.ComputerActionUnion) (json.RawMessage, error) { + if raw := action.RawJSON(); raw != "" { + return json.RawMessage(raw), nil + } + variant := action.AsAny() + if variant == nil { + return nil, fmt.Errorf("computer use tool call is missing action payload") + } + payload, err := json.Marshal(variant) + if err != nil { + return nil, err + } + return json.RawMessage(payload), nil +} + +func computerUseResponseActionJSON(action responses.ResponseComputerToolCallActionUnion) (json.RawMessage, error) { + if raw := action.RawJSON(); raw != "" { + return json.RawMessage(raw), nil + } + variant := action.AsAny() + if variant == nil { + return nil, fmt.Errorf("computer use tool call is missing action payload") + } + payload, err := json.Marshal(variant) + if err != nil { + return nil, err + } + return json.RawMessage(payload), nil +} diff --git a/providers/openai/computer_use_test.go b/providers/openai/computer_use_test.go new file mode 100644 index 000000000..d656e48ff --- /dev/null +++ b/providers/openai/computer_use_test.go @@ -0,0 +1,237 @@ +package openai + +import ( + "context" + "encoding/base64" + "errors" + "reflect" + "testing" + + "charm.land/fantasy" + "github.com/charmbracelet/openai-go/responses" + "github.com/stretchr/testify/require" +) + +func TestNewComputerUseTool(t *testing.T) { + t.Parallel() + + tool := NewComputerUseTool(ComputerUseToolOptions{ + DisplayWidthPx: 1920, + DisplayHeightPx: 1080, + Environment: responses.ComputerUsePreviewToolEnvironmentUbuntu, + }, func(context.Context, fantasy.ToolCall) (fantasy.ToolResponse, error) { + return fantasy.ToolResponse{}, nil + }) + + definition := tool.Definition() + require.Equal(t, computerUseToolID, definition.ID) + require.Equal(t, computerUseAPIName, definition.Name) + require.Len(t, definition.Args, 3) + require.Equal(t, int64(1920), definition.Args["display_width_px"]) + require.Equal(t, int64(1080), definition.Args["display_height_px"]) + require.Equal(t, responses.ComputerUsePreviewToolEnvironmentUbuntu, definition.Args["environment"]) + _, hasDisplayNumber := definition.Args["display_number"] + require.False(t, hasDisplayNumber) +} + +func TestComputerUseToolOptions_DoesNotExposeDisplayNumber(t *testing.T) { + t.Parallel() + + _, exists := reflect.TypeOf(ComputerUseToolOptions{}).FieldByName("DisplayNumber") + require.False(t, exists) +} + +func TestParseComputerUseInput(t *testing.T) { + t.Parallel() + + t.Run("click", func(t *testing.T) { + t.Parallel() + + input, err := ParseComputerUseInput([]byte(`{"type":"click","button":"left","x":100,"y":200}`)) + require.NoError(t, err) + require.Nil(t, input.Actions) + + action, ok := input.Action.(ComputerUseClickAction) + require.True(t, ok) + require.Equal(t, ComputerUseActionTypeClick, action.Type()) + require.Equal(t, "left", action.Button) + require.Equal(t, int64(100), action.X) + require.Equal(t, int64(200), action.Y) + }) + + t.Run("double click", func(t *testing.T) { + t.Parallel() + + input, err := ParseComputerUseInput([]byte(`{"type":"double_click","x":10,"y":20}`)) + require.NoError(t, err) + + action, ok := input.Action.(ComputerUseDoubleClickAction) + require.True(t, ok) + require.Equal(t, ComputerUseActionTypeDoubleClick, action.Type()) + require.Equal(t, int64(10), action.X) + require.Equal(t, int64(20), action.Y) + }) + + t.Run("drag", func(t *testing.T) { + t.Parallel() + + input, err := ParseComputerUseInput([]byte(`{"type":"drag","path":[{"x":1,"y":2},{"x":3,"y":4},{"x":5,"y":6}]}`)) + require.NoError(t, err) + + action, ok := input.Action.(ComputerUseDragAction) + require.True(t, ok) + require.Equal(t, ComputerUseActionTypeDrag, action.Type()) + require.Equal(t, []ComputerUsePoint{{X: 1, Y: 2}, {X: 3, Y: 4}, {X: 5, Y: 6}}, action.Path) + }) + + t.Run("keypress", func(t *testing.T) { + t.Parallel() + + input, err := ParseComputerUseInput([]byte(`{"type":"keypress","keys":["CTRL","L"]}`)) + require.NoError(t, err) + + action, ok := input.Action.(ComputerUseKeypressAction) + require.True(t, ok) + require.Equal(t, ComputerUseActionTypeKeypress, action.Type()) + require.Equal(t, []string{"CTRL", "L"}, action.Keys) + }) + + t.Run("move", func(t *testing.T) { + t.Parallel() + + input, err := ParseComputerUseInput([]byte(`{"type":"move","x":320,"y":240}`)) + require.NoError(t, err) + + action, ok := input.Action.(ComputerUseMoveAction) + require.True(t, ok) + require.Equal(t, ComputerUseActionTypeMove, action.Type()) + require.Equal(t, int64(320), action.X) + require.Equal(t, int64(240), action.Y) + }) + + t.Run("screenshot", func(t *testing.T) { + t.Parallel() + + input, err := ParseComputerUseInput([]byte(`{"type":"screenshot"}`)) + require.NoError(t, err) + + _, ok := input.Action.(ComputerUseScreenshotAction) + require.True(t, ok) + }) + + t.Run("scroll", func(t *testing.T) { + t.Parallel() + + input, err := ParseComputerUseInput([]byte(`{"type":"scroll","x":10,"y":20,"scroll_x":0,"scroll_y":600}`)) + require.NoError(t, err) + + action, ok := input.Action.(ComputerUseScrollAction) + require.True(t, ok) + require.Equal(t, ComputerUseActionTypeScroll, action.Type()) + require.Equal(t, int64(10), action.X) + require.Equal(t, int64(20), action.Y) + require.Equal(t, int64(0), action.ScrollX) + require.Equal(t, int64(600), action.ScrollY) + }) + + t.Run("type", func(t *testing.T) { + t.Parallel() + + input, err := ParseComputerUseInput([]byte(`{"type":"type","text":"hello"}`)) + require.NoError(t, err) + + action, ok := input.Action.(ComputerUseTypeAction) + require.True(t, ok) + require.Equal(t, ComputerUseActionTypeType, action.Type()) + require.Equal(t, "hello", action.Text) + }) + + t.Run("wait", func(t *testing.T) { + t.Parallel() + + input, err := ParseComputerUseInput([]byte(`{"type":"wait"}`)) + require.NoError(t, err) + + _, ok := input.Action.(ComputerUseWaitAction) + require.True(t, ok) + }) + + t.Run("batched actions", func(t *testing.T) { + t.Parallel() + + input, err := ParseComputerUseInput([]byte(`[{"type":"move","x":10,"y":20},{"type":"click","button":"left","x":10,"y":20}]`)) + require.NoError(t, err) + require.Nil(t, input.Action) + require.Len(t, input.Actions, 2) + + moveAction, ok := input.Actions[0].(ComputerUseMoveAction) + require.True(t, ok) + require.Equal(t, int64(10), moveAction.X) + require.Equal(t, int64(20), moveAction.Y) + + clickAction, ok := input.Actions[1].(ComputerUseClickAction) + require.True(t, ok) + require.Equal(t, "left", clickAction.Button) + }) + + t.Run("empty batched actions error", func(t *testing.T) { + t.Parallel() + + _, err := ParseComputerUseInput([]byte(`[]`)) + require.EqualError(t, err, "openai computer use input requires at least one action") + }) + + t.Run("unknown action errors", func(t *testing.T) { + t.Parallel() + + _, err := ParseComputerUseInput([]byte(`{"type":"future_action"}`)) + require.Error(t, err) + }) +} + +func TestNewComputerUseScreenshotResult(t *testing.T) { + t.Parallel() + + pngData := []byte{0x89, 0x50, 0x4E, 0x47} + result := NewComputerUseScreenshotResult("tool_123", pngData) + + require.Equal(t, "tool_123", result.ToolCallID) + media, ok := result.Output.(fantasy.ToolResultOutputContentMedia) + require.True(t, ok) + require.Equal(t, "image/png", media.MediaType) + require.Equal(t, base64.StdEncoding.EncodeToString(pngData), media.Data) +} + +func TestNewComputerUseScreenshotResultWithMediaType(t *testing.T) { + t.Parallel() + + result := NewComputerUseScreenshotResultWithMediaType("tool_123", "ZmFrZQ==", "image/jpeg") + + require.Equal(t, "tool_123", result.ToolCallID) + media, ok := result.Output.(fantasy.ToolResultOutputContentMedia) + require.True(t, ok) + require.Equal(t, "image/jpeg", media.MediaType) + require.Equal(t, "ZmFrZQ==", media.Data) +} + +func TestNewComputerUseErrorResult(t *testing.T) { + t.Parallel() + + result := NewComputerUseErrorResult("tool_123", errors.New("boom")) + + require.Equal(t, "tool_123", result.ToolCallID) + errOutput, ok := result.Output.(fantasy.ToolResultOutputContentError) + require.True(t, ok) + require.EqualError(t, errOutput.Error, "boom") +} + +func TestNewComputerUseTextResult(t *testing.T) { + t.Parallel() + + result := NewComputerUseTextResult("tool_123", "done") + + require.Equal(t, "tool_123", result.ToolCallID) + textOutput, ok := result.Output.(fantasy.ToolResultOutputContentText) + require.True(t, ok) + require.Equal(t, "done", textOutput.Text) +} diff --git a/providers/openai/openai_test.go b/providers/openai/openai_test.go index 07bcdc981..0a26adff7 100644 --- a/providers/openai/openai_test.go +++ b/providers/openai/openai_test.go @@ -3183,7 +3183,9 @@ func TestResponsesToPrompt_DropsEmptyMessages(t *testing.T) { }, } - input, warnings := toResponsesPrompt(prompt, "system", false) + input, warnings, err := toResponsesPrompt(prompt, "system", false) + + require.NoError(t, err) require.Len(t, input, 1, "should only have user message") require.Len(t, warnings, 1) @@ -3209,7 +3211,9 @@ func TestResponsesToPrompt_DropsEmptyMessages(t *testing.T) { }, } - input, warnings := toResponsesPrompt(prompt, "system", false) + input, warnings, err := toResponsesPrompt(prompt, "system", false) + + require.NoError(t, err) require.Len(t, input, 2, "should have both user and assistant messages") require.Empty(t, warnings) @@ -3237,7 +3241,9 @@ func TestResponsesToPrompt_DropsEmptyMessages(t *testing.T) { }, } - input, warnings := toResponsesPrompt(prompt, "system", false) + input, warnings, err := toResponsesPrompt(prompt, "system", false) + + require.NoError(t, err) require.Len(t, input, 2, "should have both user and assistant messages") require.Empty(t, warnings) @@ -3258,7 +3264,9 @@ func TestResponsesToPrompt_DropsEmptyMessages(t *testing.T) { }, } - input, warnings := toResponsesPrompt(prompt, "system", false) + input, warnings, err := toResponsesPrompt(prompt, "system", false) + + require.NoError(t, err) require.Empty(t, input) require.Len(t, warnings, 2) // One for unsupported type, one for empty message @@ -3280,7 +3288,9 @@ func TestResponsesToPrompt_DropsEmptyMessages(t *testing.T) { }, } - input, warnings := toResponsesPrompt(prompt, "system", false) + input, warnings, err := toResponsesPrompt(prompt, "system", false) + + require.NoError(t, err) require.Len(t, input, 1) require.Empty(t, warnings) @@ -3301,7 +3311,9 @@ func TestResponsesToPrompt_DropsEmptyMessages(t *testing.T) { }, } - input, warnings := toResponsesPrompt(prompt, "system", false) + input, warnings, err := toResponsesPrompt(prompt, "system", false) + + require.NoError(t, err) require.Len(t, input, 1) require.Empty(t, warnings) @@ -3322,7 +3334,9 @@ func TestResponsesToPrompt_DropsEmptyMessages(t *testing.T) { }, } - input, warnings := toResponsesPrompt(prompt, "system", false) + input, warnings, err := toResponsesPrompt(prompt, "system", false) + + require.NoError(t, err) require.Len(t, input, 1) require.Empty(t, warnings) @@ -3955,7 +3969,9 @@ func TestResponsesToPrompt_WebSearchProviderExecutedToolResults(t *testing.T) { t.Run("store false skips item reference", func(t *testing.T) { t.Parallel() - input, warnings := toResponsesPrompt(prompt, "system instructions", false) + input, warnings, err := toResponsesPrompt(prompt, "system instructions", false) + + require.NoError(t, err) require.Empty(t, warnings) require.Len(t, input, 2, @@ -3967,7 +3983,9 @@ func TestResponsesToPrompt_WebSearchProviderExecutedToolResults(t *testing.T) { t.Run("store true uses item reference", func(t *testing.T) { t.Parallel() - input, warnings := toResponsesPrompt(prompt, "system instructions", true) + input, warnings, err := toResponsesPrompt(prompt, "system instructions", true) + + require.NoError(t, err) require.Empty(t, warnings) require.Len(t, input, 3, @@ -4019,7 +4037,9 @@ func TestResponsesToPrompt_ReasoningWithStore(t *testing.T) { t.Run("store true skips reasoning", func(t *testing.T) { t.Parallel() - input, warnings := toResponsesPrompt(prompt, "system", true) + input, warnings, err := toResponsesPrompt(prompt, "system", true) + + require.NoError(t, err) require.Empty(t, warnings) // With store=true: user, assistant text (reasoning @@ -4036,7 +4056,9 @@ func TestResponsesToPrompt_ReasoningWithStore(t *testing.T) { t.Run("store false skips reasoning", func(t *testing.T) { t.Parallel() - input, warnings := toResponsesPrompt(prompt, "system", false) + input, warnings, err := toResponsesPrompt(prompt, "system", false) + + require.NoError(t, err) require.Empty(t, warnings) // With store=false: user, assistant text, follow-up user. diff --git a/providers/openai/responses_computer_use_test.go b/providers/openai/responses_computer_use_test.go new file mode 100644 index 000000000..aa6287b38 --- /dev/null +++ b/providers/openai/responses_computer_use_test.go @@ -0,0 +1,607 @@ +package openai + +import ( + "context" + "encoding/json" + "path/filepath" + "testing" + + "charm.land/fantasy" + "github.com/charmbracelet/openai-go/responses" + "github.com/stretchr/testify/require" +) + +func TestIsResponsesModel_ComputerUse(t *testing.T) { + t.Parallel() + + require.True(t, IsResponsesModel("gpt-4o")) + require.True(t, IsResponsesModel("computer-use-preview")) + require.True(t, IsResponsesModel("computer-use-preview-2025-03-11")) + require.False(t, IsResponsesModel("acme-computer-use-preview")) + require.False(t, IsResponsesModel("not-a-computer-use-model")) + require.False(t, IsResponsesModel("not-a-responses-model")) +} + +func TestGetResponsesModelConfig_ComputerUseAllowlist(t *testing.T) { + t.Parallel() + + for _, tc := range []struct { + name string + modelID string + wantReasoning bool + wantSystemMode string + }{ + { + name: "official computer use model", + modelID: string(responses.ResponsesModelComputerUsePreview), + wantReasoning: true, + wantSystemMode: "developer", + }, + { + name: "unknown computer use style model", + modelID: "acme-computer-use-preview", + wantReasoning: false, + wantSystemMode: "system", + }, + } { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + config := getResponsesModelConfig(tc.modelID) + require.Equal(t, tc.wantReasoning, config.isReasoningModel) + require.Equal(t, tc.wantSystemMode, config.systemMessageMode) + }) + } +} + +func TestPrepareParams_ComputerUseRequiresStore(t *testing.T) { + t.Parallel() + + lm := responsesLanguageModel{ + provider: Name, + modelID: string(responses.ResponsesModelComputerUsePreview), + } + tool := NewComputerUseTool(ComputerUseToolOptions{ + DisplayWidthPx: 1024, + DisplayHeightPx: 768, + Environment: responses.ComputerUsePreviewToolEnvironmentBrowser, + }, func(context.Context, fantasy.ToolCall) (fantasy.ToolResponse, error) { + return fantasy.ToolResponse{}, nil + }) + prompt := fantasy.Prompt{testTextMessage(fantasy.MessageRoleUser, "take a screenshot")} + + t.Run("missing provider options", func(t *testing.T) { + t.Parallel() + + _, _, err := lm.prepareParams(fantasy.Call{Prompt: prompt, Tools: []fantasy.Tool{tool}}) + require.EqualError(t, err, computerUseStoreError) + }) + + t.Run("store false", func(t *testing.T) { + t.Parallel() + + _, _, err := lm.prepareParams(fantasy.Call{ + Prompt: prompt, + Tools: []fantasy.Tool{tool}, + ProviderOptions: fantasy.ProviderOptions{ + Name: &ResponsesProviderOptions{Store: fantasy.Opt(false)}, + }, + }) + require.EqualError(t, err, computerUseStoreError) + }) + + t.Run("store true", func(t *testing.T) { + t.Parallel() + + params, warnings, err := lm.prepareParams(fantasy.Call{ + Prompt: prompt, + Tools: []fantasy.Tool{tool}, + ProviderOptions: fantasy.ProviderOptions{ + Name: &ResponsesProviderOptions{Store: fantasy.Opt(true)}, + }, + }) + require.NoError(t, err) + require.Empty(t, warnings) + require.True(t, params.Store.Valid()) + require.True(t, params.Store.Value) + require.Len(t, params.Tools, 1) + require.NotNil(t, params.Tools[0].OfComputerUsePreview) + }) + + for _, tc := range []struct { + name string + displayWidth int64 + displayHeight int64 + wantErr string + }{ + { + name: "zero display height", + displayWidth: 1024, + displayHeight: 0, + wantErr: "computer use tool has invalid display_height_px", + }, + { + name: "zero display width", + displayWidth: 0, + displayHeight: 768, + wantErr: "computer use tool has invalid display_width_px", + }, + { + name: "negative display width", + displayWidth: -100, + displayHeight: 768, + wantErr: "computer use tool has invalid display_width_px", + }, + { + name: "negative display height", + displayWidth: 1024, + displayHeight: -100, + wantErr: "computer use tool has invalid display_height_px", + }, + } { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + invalidTool := NewComputerUseTool(ComputerUseToolOptions{ + DisplayWidthPx: tc.displayWidth, + DisplayHeightPx: tc.displayHeight, + Environment: responses.ComputerUsePreviewToolEnvironmentBrowser, + }, func(context.Context, fantasy.ToolCall) (fantasy.ToolResponse, error) { + return fantasy.ToolResponse{}, nil + }) + + _, _, err := lm.prepareParams(fantasy.Call{ + Prompt: prompt, + Tools: []fantasy.Tool{invalidTool}, + ProviderOptions: fantasy.ProviderOptions{ + Name: &ResponsesProviderOptions{Store: fantasy.Opt(true)}, + }, + }) + require.EqualError(t, err, tc.wantErr) + }) + } +} + +func TestToResponsesTools_ComputerUsePreview(t *testing.T) { + t.Parallel() + + tool := NewComputerUseTool(ComputerUseToolOptions{ + DisplayWidthPx: 1440, + DisplayHeightPx: 900, + Environment: responses.ComputerUsePreviewToolEnvironmentUbuntu, + }, func(context.Context, fantasy.ToolCall) (fantasy.ToolResponse, error) { + return fantasy.ToolResponse{}, nil + }) + + definition := tool.Definition() + argsJSON, err := json.Marshal(definition.Args) + require.NoError(t, err) + require.NoError(t, json.Unmarshal(argsJSON, &definition.Args)) + require.Len(t, definition.Args, 3) + _, hasDisplayNumber := definition.Args["display_number"] + require.False(t, hasDisplayNumber) + + tools, toolChoice, warnings, err := toResponsesTools([]fantasy.Tool{definition}, nil, nil) + require.NoError(t, err) + require.Empty(t, warnings) + require.Empty(t, toolChoice) + require.Len(t, tools, 1) + require.NotNil(t, tools[0].OfComputerUsePreview) + require.Equal(t, int64(900), tools[0].OfComputerUsePreview.DisplayHeight) + require.Equal(t, int64(1440), tools[0].OfComputerUsePreview.DisplayWidth) + require.Equal(t, responses.ComputerUsePreviewToolEnvironmentUbuntu, tools[0].OfComputerUsePreview.Environment) + + toolJSON, err := json.Marshal(tools[0]) + require.NoError(t, err) + require.JSONEq(t, `{"display_height":900,"display_width":1440,"environment":"ubuntu","type":"computer_use_preview"}`, string(toolJSON)) + + for _, tc := range []struct { + name string + displayWidth int64 + displayHeight int64 + wantErr string + }{ + { + name: "zero display height", + displayWidth: 1440, + displayHeight: 0, + wantErr: "computer use tool has invalid display_height_px", + }, + { + name: "zero display width", + displayWidth: 0, + displayHeight: 900, + wantErr: "computer use tool has invalid display_width_px", + }, + { + name: "negative display width", + displayWidth: -100, + displayHeight: 900, + wantErr: "computer use tool has invalid display_width_px", + }, + { + name: "negative display height", + displayWidth: 1440, + displayHeight: -100, + wantErr: "computer use tool has invalid display_height_px", + }, + } { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + invalidTool := NewComputerUseTool(ComputerUseToolOptions{ + DisplayWidthPx: tc.displayWidth, + DisplayHeightPx: tc.displayHeight, + Environment: responses.ComputerUsePreviewToolEnvironmentUbuntu, + }, func(context.Context, fantasy.ToolCall) (fantasy.ToolResponse, error) { + return fantasy.ToolResponse{}, nil + }) + + invalidDefinition := invalidTool.Definition() + invalidArgsJSON, err := json.Marshal(invalidDefinition.Args) + require.NoError(t, err) + require.NoError(t, json.Unmarshal(invalidArgsJSON, &invalidDefinition.Args)) + + _, _, warnings, err := toResponsesTools([]fantasy.Tool{invalidDefinition}, nil, nil) + require.EqualError(t, err, tc.wantErr) + require.Empty(t, warnings) + }) + } +} + +func TestResponsesToPrompt_ComputerUseWithStore(t *testing.T) { + t.Parallel() + + prompt := fantasy.Prompt{ + { + Role: fantasy.MessageRoleUser, + Content: []fantasy.MessagePart{ + fantasy.TextPart{Text: "take a screenshot"}, + }, + }, + { + Role: fantasy.MessageRoleAssistant, + Content: []fantasy.MessagePart{ + fantasy.ToolCallPart{ + ToolCallID: "comp_item_01", + ToolName: computerUseAPIName, + Input: `{"type":"screenshot"}`, + ProviderOptions: fantasy.ProviderOptions{ + Name: &OpenAIComputerUseCallMetadata{ + CallID: "call_01", + PendingSafetyChecks: []OpenAIComputerUsePendingSafetyCheck{ + {ID: "safe_01", Code: "account_access", Message: "Confirm access."}, + }, + }, + }, + }, + }, + }, + { + Role: fantasy.MessageRoleTool, + Content: []fantasy.MessagePart{ + NewComputerUseScreenshotResultWithMediaType("comp_item_01", "ZmFrZQ==", "image/jpeg"), + }, + }, + } + + input, warnings, err := toResponsesPrompt(prompt, "system", true) + require.NoError(t, err) + require.Empty(t, warnings) + require.Len(t, input, 3) + require.NotNil(t, input[1].OfItemReference) + require.Equal(t, "comp_item_01", input[1].OfItemReference.ID) + + computerOutput := input[2].OfComputerCallOutput + require.NotNil(t, computerOutput) + require.Equal(t, "call_01", computerOutput.CallID) + require.True(t, computerOutput.Output.ImageURL.Valid()) + require.Equal(t, "data:image/jpeg;base64,ZmFrZQ==", computerOutput.Output.ImageURL.Value) + require.Len(t, computerOutput.AcknowledgedSafetyChecks, 1) + require.Equal(t, "safe_01", computerOutput.AcknowledgedSafetyChecks[0].ID) + require.True(t, computerOutput.AcknowledgedSafetyChecks[0].Code.Valid()) + require.Equal(t, "account_access", computerOutput.AcknowledgedSafetyChecks[0].Code.Value) + require.True(t, computerOutput.AcknowledgedSafetyChecks[0].Message.Valid()) + require.Equal(t, "Confirm access.", computerOutput.AcknowledgedSafetyChecks[0].Message.Value) +} + +func TestResponsesToPrompt_ComputerUseMissingMetadata(t *testing.T) { + t.Parallel() + + t.Run("media output without metadata errors", func(t *testing.T) { + t.Parallel() + + prompt := fantasy.Prompt{ + testTextMessage(fantasy.MessageRoleUser, "take a screenshot"), + { + Role: fantasy.MessageRoleAssistant, + Content: []fantasy.MessagePart{ + fantasy.ToolCallPart{ + ToolCallID: "comp_item_02", + ToolName: computerUseAPIName, + Input: `{"type":"screenshot"}`, + }, + }, + }, + { + Role: fantasy.MessageRoleTool, + Content: []fantasy.MessagePart{ + NewComputerUseScreenshotResultWithMediaType("comp_item_02", "ZmFrZQ==", "image/png"), + }, + }, + } + + _, _, err := toResponsesPrompt(prompt, "system", true) + require.EqualError(t, err, "malformed prompt: openai computer tool result for tool_call_id \"comp_item_02\" is missing matching call metadata") + }) + + for _, tc := range []struct { + name string + output fantasy.ToolResultOutputContent + wantText string + }{ + { + name: "text output without metadata stays generic", + output: fantasy.ToolResultOutputContentText{ + Text: "done", + }, + wantText: "done", + }, + { + name: "error output without metadata stays generic", + output: fantasy.ToolResultOutputContentError{ + Error: context.DeadlineExceeded, + }, + wantText: context.DeadlineExceeded.Error(), + }, + } { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + prompt := fantasy.Prompt{ + testTextMessage(fantasy.MessageRoleUser, "run the tool"), + { + Role: fantasy.MessageRoleAssistant, + Content: []fantasy.MessagePart{ + fantasy.ToolCallPart{ + ToolCallID: "call_123", + ToolName: "echo", + Input: `{"text":"hello"}`, + }, + }, + }, + { + Role: fantasy.MessageRoleTool, + Content: []fantasy.MessagePart{ + fantasy.ToolResultPart{ + ToolCallID: "call_123", + Output: tc.output, + }, + }, + }, + } + + input, warnings, err := toResponsesPrompt(prompt, "system", false) + require.NoError(t, err) + require.Empty(t, warnings) + require.Len(t, input, 3) + require.NotNil(t, input[2].OfFunctionCallOutput) + require.True(t, input[2].OfFunctionCallOutput.Output.OfString.Valid()) + require.Equal(t, tc.wantText, input[2].OfFunctionCallOutput.Output.OfString.Value) + }) + } +} + +func TestOpenAIComputerUseCallMetadata_JSON(t *testing.T) { + t.Parallel() + + encoded, err := json.Marshal(OpenAIComputerUseCallMetadata{ + CallID: "call_01", + PendingSafetyChecks: []OpenAIComputerUsePendingSafetyCheck{ + {ID: "safe_01", Code: "account_access", Message: "Confirm access."}, + }, + }) + require.NoError(t, err) + + decoded, err := fantasy.UnmarshalProviderMetadata(map[string]json.RawMessage{ + Name: encoded, + }) + require.NoError(t, err) + + metadata, ok := decoded[Name].(*OpenAIComputerUseCallMetadata) + require.True(t, ok) + require.Equal(t, "call_01", metadata.CallID) + require.Len(t, metadata.PendingSafetyChecks, 1) + require.Equal(t, "safe_01", metadata.PendingSafetyChecks[0].ID) + require.Equal(t, "account_access", metadata.PendingSafetyChecks[0].Code) + require.Equal(t, "Confirm access.", metadata.PendingSafetyChecks[0].Message) +} + +func TestResponsesGenerate_ComputerUseResponse(t *testing.T) { + t.Parallel() + + server := newMockServer() + defer server.close() + server.response = mockResponsesComputerUseResponse() + + provider, err := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), + WithUseResponsesAPI(), + ) + require.NoError(t, err) + + model, err := provider.LanguageModel(context.Background(), string(responses.ResponsesModelComputerUsePreview)) + require.NoError(t, err) + + resp, err := model.Generate(context.Background(), fantasy.Call{ + Prompt: testPrompt, + ProviderOptions: fantasy.ProviderOptions{ + Name: &ResponsesProviderOptions{Store: fantasy.Opt(true)}, + }, + Tools: []fantasy.Tool{testComputerUseToolDefinition()}, + }) + require.NoError(t, err) + require.Equal(t, "/responses", server.calls[0].path) + require.Equal(t, fantasy.FinishReasonToolCalls, resp.FinishReason) + + toolCalls := resp.Content.ToolCalls() + require.Len(t, toolCalls, 1) + assertComputerUseToolCall(t, toolCalls[0], `[ + {"type":"move","x":320,"y":240}, + {"type":"click","button":"left","x":320,"y":240} + ]`) +} + +func TestResponsesStream_ComputerUseResponse(t *testing.T) { + t.Parallel() + + chunks := []string{ + "event: response.output_item.added\n" + + `data: {"type":"response.output_item.added","output_index":0,"item":{"type":"computer_call","id":"comp_item_01","status":"in_progress"}}` + "\n\n", + "event: response.output_item.done\n" + + `data: {"type":"response.output_item.done","output_index":0,"item":{"type":"computer_call","id":"comp_item_01","call_id":"call_01","status":"completed","pending_safety_checks":[{"id":"safe_01","code":"account_access","message":"Confirm access."}],"actions":[{"type":"move","x":320,"y":240},{"type":"click","button":"left","x":320,"y":240}]}}` + "\n\n", + "event: response.completed\n" + + `data: {"type":"response.completed","response":{"id":"resp_01","status":"completed","output":[],"usage":{"input_tokens":10,"output_tokens":4,"total_tokens":14}}}` + "\n\n", + } + + sms := newStreamingMockServer() + defer sms.close() + sms.chunks = chunks + + provider, err := New( + WithAPIKey("test-api-key"), + WithBaseURL(sms.server.URL), + WithUseResponsesAPI(), + ) + require.NoError(t, err) + + model, err := provider.LanguageModel(context.Background(), string(responses.ResponsesModelComputerUsePreview)) + require.NoError(t, err) + + stream, err := model.Stream(context.Background(), fantasy.Call{ + Prompt: testPrompt, + ProviderOptions: fantasy.ProviderOptions{ + Name: &ResponsesProviderOptions{Store: fantasy.Opt(true)}, + }, + Tools: []fantasy.Tool{testComputerUseToolDefinition()}, + }) + require.NoError(t, err) + + var ( + toolInputStarts []fantasy.StreamPart + toolInputEnds []fantasy.StreamPart + toolCalls []fantasy.StreamPart + finishes []fantasy.StreamPart + ) + stream(func(part fantasy.StreamPart) bool { + switch part.Type { + case fantasy.StreamPartTypeToolInputStart: + toolInputStarts = append(toolInputStarts, part) + case fantasy.StreamPartTypeToolInputEnd: + toolInputEnds = append(toolInputEnds, part) + case fantasy.StreamPartTypeToolCall: + toolCalls = append(toolCalls, part) + case fantasy.StreamPartTypeFinish: + finishes = append(finishes, part) + } + return true + }) + + require.Len(t, toolInputStarts, 1) + require.Equal(t, "comp_item_01", toolInputStarts[0].ID) + require.Equal(t, computerUseAPIName, toolInputStarts[0].ToolCallName) + + require.Len(t, toolInputEnds, 1) + require.Equal(t, "comp_item_01", toolInputEnds[0].ID) + + require.Len(t, toolCalls, 1) + require.Equal(t, "comp_item_01", toolCalls[0].ID) + require.Equal(t, computerUseAPIName, toolCalls[0].ToolCallName) + require.JSONEq(t, `[ + {"type":"move","x":320,"y":240}, + {"type":"click","button":"left","x":320,"y":240} + ]`, toolCalls[0].ToolCallInput) + + metadata, ok := toolCalls[0].ProviderMetadata[Name].(*OpenAIComputerUseCallMetadata) + require.True(t, ok) + require.Equal(t, "call_01", metadata.CallID) + require.Len(t, metadata.PendingSafetyChecks, 1) + require.Equal(t, "safe_01", metadata.PendingSafetyChecks[0].ID) + + require.Len(t, finishes, 1) + require.Equal(t, fantasy.FinishReasonToolCalls, finishes[0].FinishReason) +} + +func mockResponsesComputerUseResponse() map[string]any { + return map[string]any{ + "id": "resp_01", + "object": "response", + "model": string(responses.ResponsesModelComputerUsePreview), + "output": []any{ + map[string]any{ + "type": "computer_call", + "id": "comp_item_01", + "call_id": "call_01", + "status": "completed", + "pending_safety_checks": []any{ + map[string]any{ + "id": "safe_01", + "code": "account_access", + "message": "Confirm access.", + }, + }, + "actions": []any{ + map[string]any{"type": "move", "x": 320, "y": 240}, + map[string]any{"type": "click", "button": "left", "x": 320, "y": 240}, + }, + }, + }, + "status": "completed", + "usage": map[string]any{ + "input_tokens": 10, + "output_tokens": 4, + "total_tokens": 14, + }, + } +} + +func assertComputerUseToolCall(t *testing.T, toolCall fantasy.ToolCallContent, wantInput string) { + t.Helper() + + require.False(t, toolCall.ProviderExecuted) + require.Equal(t, computerUseAPIName, toolCall.ToolName) + require.Equal(t, "comp_item_01", toolCall.ToolCallID) + require.JSONEq(t, wantInput, toolCall.Input) + + metadata, ok := toolCall.ProviderMetadata[Name].(*OpenAIComputerUseCallMetadata) + require.True(t, ok) + require.Equal(t, "call_01", metadata.CallID) + require.Len(t, metadata.PendingSafetyChecks, 1) + require.Equal(t, "safe_01", metadata.PendingSafetyChecks[0].ID) + require.Equal(t, "account_access", metadata.PendingSafetyChecks[0].Code) + require.Equal(t, "Confirm access.", metadata.PendingSafetyChecks[0].Message) +} + +func testComputerUseToolDefinition() fantasy.ProviderDefinedTool { + return fantasy.ProviderDefinedTool{ + ID: computerUseToolID, + Name: computerUseAPIName, + Args: map[string]any{ + "display_width_px": int64(1024), + "display_height_px": int64(768), + "environment": responses.ComputerUsePreviewToolEnvironmentBrowser, + }, + } +} + +func computerUseCassettePaths(t *testing.T, modelName string) []string { + base := filepath.Join("testdata", t.Name(), modelName) + return []string{ + filepath.Join(base, "computer_use.yaml"), + filepath.Join(base, "computer_use_streaming.yaml"), + } +} diff --git a/providers/openai/responses_language_model.go b/providers/openai/responses_language_model.go index eb027109e..2bd037a5d 100644 --- a/providers/openai/responses_language_model.go +++ b/providers/openai/responses_language_model.go @@ -90,7 +90,7 @@ func getResponsesModelConfig(modelID string) responsesModelConfig { strings.HasPrefix(modelID, "o4") || strings.Contains(modelID, "-o4") || strings.HasPrefix(modelID, "oss") || strings.Contains(modelID, "-oss") || strings.Contains(modelID, "gpt-5") || strings.Contains(modelID, "codex-") || - strings.Contains(modelID, "computer-use") { + isOpenAIComputerUseModel(modelID) { if strings.Contains(modelID, "o1-mini") || strings.Contains(modelID, "o1-preview") { return responsesModelConfig{ isReasoningModel: true, @@ -158,6 +158,11 @@ func (o responsesLanguageModel) prepareParams(call fantasy.Call) (*responses.Res } } + storeEnabled := openaiOptions != nil && openaiOptions.Store != nil && *openaiOptions.Store + if hasComputerUseTool(call.Tools) && !storeEnabled { + return nil, warnings, errors.New(computerUseStoreError) + } + if openaiOptions != nil && openaiOptions.Store != nil { params.Store = param.NewOpt(*openaiOptions.Store) } else { @@ -168,14 +173,16 @@ func (o responsesLanguageModel) prepareParams(call fantasy.Call) (*responses.Res if err := validatePreviousResponseIDPrompt(call.Prompt); err != nil { return nil, warnings, err } - if openaiOptions.Store == nil || !*openaiOptions.Store { + if !storeEnabled { return nil, warnings, errors.New(previousResponseIDStoreError) } params.PreviousResponseID = param.NewOpt(*openaiOptions.PreviousResponseID) } - storeEnabled := openaiOptions != nil && openaiOptions.Store != nil && *openaiOptions.Store - input, inputWarnings := toResponsesPrompt(call.Prompt, modelConfig.systemMessageMode, storeEnabled) + input, inputWarnings, err := toResponsesPrompt(call.Prompt, modelConfig.systemMessageMode, storeEnabled) + if err != nil { + return nil, warnings, err + } warnings = append(warnings, inputWarnings...) var include []IncludeType @@ -338,7 +345,10 @@ func (o responsesLanguageModel) prepareParams(call fantasy.Call) (*responses.Res } } - tools, toolChoice, toolWarnings := toResponsesTools(call.Tools, call.ToolChoice, openaiOptions) + tools, toolChoice, toolWarnings, err := toResponsesTools(call.Tools, call.ToolChoice, openaiOptions) + if err != nil { + return nil, warnings, err + } warnings = append(warnings, toolWarnings...) if len(tools) > 0 { @@ -390,9 +400,10 @@ func responsesUsage(resp responses.Response) fantasy.Usage { return usage } -func toResponsesPrompt(prompt fantasy.Prompt, systemMessageMode string, store bool) (responses.ResponseInputParam, []fantasy.CallWarning) { +func toResponsesPrompt(prompt fantasy.Prompt, systemMessageMode string, store bool) (responses.ResponseInputParam, []fantasy.CallWarning, error) { var input responses.ResponseInputParam var warnings []fantasy.CallWarning + computerToolCalls := make(map[string]*OpenAIComputerUseCallMetadata) for _, msg := range prompt { switch msg.Role { @@ -536,6 +547,14 @@ func toResponsesPrompt(prompt fantasy.Prompt, systemMessageMode string, store bo continue } + if metadata := getComputerUseCallMetadata(toolCallPart.ProviderOptions); metadata != nil { + computerToolCalls[toolCallPart.ToolCallID] = metadata + if store { + input = append(input, responses.ResponseInputItemParamOfItemReference(toolCallPart.ToolCallID)) + } + continue + } + if toolCallPart.ProviderExecuted { if store { // Round-trip provider-executed tools via @@ -610,6 +629,22 @@ func toResponsesPrompt(prompt fantasy.Prompt, systemMessageMode string, store bo continue } + metadata := computerToolCalls[toolResultPart.ToolCallID] + if metadata == nil { + metadata = getComputerUseCallMetadata(toolResultPart.ProviderOptions) + } + if metadata != nil { + computerOutput, err := computerUseToolResultInput(toolResultPart, metadata) + if err != nil { + return nil, warnings, fmt.Errorf("malformed prompt: failed to build openai computer tool result for tool_call_id %q: %w", toolResultPart.ToolCallID, err) + } + input = append(input, computerOutput) + continue + } + if _, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResultPart.Output); ok { + return nil, warnings, fmt.Errorf("malformed prompt: openai computer tool result for tool_call_id %q is missing matching call metadata", toolResultPart.ToolCallID) + } + var outputStr string switch toolResultPart.Output.GetType() { @@ -640,7 +675,7 @@ func toResponsesPrompt(prompt fantasy.Prompt, systemMessageMode string, store bo } } - return input, warnings + return input, warnings, nil } func hasVisibleResponsesUserContent(content responses.ResponseInputMessageContentListParam) bool { @@ -657,12 +692,12 @@ func hasVisibleResponsesAssistantContent(items []responses.ResponseInputItemUnio return false } -func toResponsesTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice, options *ResponsesProviderOptions) ([]responses.ToolUnionParam, responses.ResponseNewParamsToolChoiceUnion, []fantasy.CallWarning) { +func toResponsesTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice, options *ResponsesProviderOptions) ([]responses.ToolUnionParam, responses.ResponseNewParamsToolChoiceUnion, []fantasy.CallWarning, error) { warnings := make([]fantasy.CallWarning, 0) var openaiTools []responses.ToolUnionParam if len(tools) == 0 { - return nil, responses.ResponseNewParamsToolChoiceUnion{}, nil + return nil, responses.ResponseNewParamsToolChoiceUnion{}, nil, nil } strictJSONSchema := false @@ -688,7 +723,7 @@ func toResponsesTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice, opti continue } if tool.GetType() == fantasy.ToolTypeProviderDefined { - pt, ok := tool.(fantasy.ProviderDefinedTool) + pt, ok := asProviderDefinedTool(tool) if !ok { continue } @@ -696,6 +731,13 @@ func toResponsesTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice, opti case "web_search": openaiTools = append(openaiTools, toWebSearchToolParam(pt)) continue + case computerUseToolID: + computerTool, err := toComputerUseToolParam(pt) + if err != nil { + return nil, responses.ResponseNewParamsToolChoiceUnion{}, warnings, err + } + openaiTools = append(openaiTools, computerTool) + continue } } @@ -707,7 +749,7 @@ func toResponsesTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice, opti } if toolChoice == nil { - return openaiTools, responses.ResponseNewParamsToolChoiceUnion{}, warnings + return openaiTools, responses.ResponseNewParamsToolChoiceUnion{}, warnings, nil } var openaiToolChoice responses.ResponseNewParamsToolChoiceUnion @@ -734,7 +776,7 @@ func toResponsesTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice, opti } } - return openaiTools, openaiToolChoice, warnings + return openaiTools, openaiToolChoice, warnings, nil } func (o responsesLanguageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) { @@ -806,6 +848,14 @@ func (o responsesLanguageModel) Generate(ctx context.Context, call fantasy.Call) Input: outputItem.Arguments.OfString, }) + case "computer_call": + hasFunctionCall = true + computerCall, err := computerUseToolCallContent(outputItem.AsComputerCall()) + if err != nil { + return nil, fmt.Errorf("failed to build computer tool call content: %w", err) + } + content = append(content, computerCall) + case "web_search_call": // Provider-executed web search tool call. Emit both // a ToolCallContent and ToolResultContent as a pair, @@ -942,6 +992,15 @@ func (o responsesLanguageModel) Stream(ctx context.Context, call fantasy.Call) ( return } + case "computer_call": + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeToolInputStart, + ID: added.Item.ID, + ToolCallName: computerUseAPIName, + }) { + return + } + case "web_search_call": // Provider-executed web search; emit start. if !yield(fantasy.StreamPart{ @@ -1009,6 +1068,35 @@ func (o responsesLanguageModel) Stream(ctx context.Context, call fantasy.Call) ( } } + case "computer_call": + hasFunctionCall = true + computerCall, err := computerUseToolCallContent(done.Item.AsComputerCall()) + if err != nil { + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeError, + Error: fmt.Errorf("failed to build computer tool call content: %w", err), + }) { + return + } + return + } + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeToolInputEnd, + ID: computerCall.ToolCallID, + }) { + return + } + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeToolCall, + ID: computerCall.ToolCallID, + ToolCallName: computerCall.ToolName, + ToolCallInput: computerCall.Input, + ProviderExecuted: computerCall.ProviderExecuted, + ProviderMetadata: computerCall.ProviderMetadata, + }) { + return + } + case "web_search_call": // Provider-executed web search completed. // Source citations come from url_citation annotations diff --git a/providers/openai/responses_options.go b/providers/openai/responses_options.go index a80a5d310..ab822bca4 100644 --- a/providers/openai/responses_options.go +++ b/providers/openai/responses_options.go @@ -14,6 +14,7 @@ const ( TypeResponsesProviderOptions = Name + ".responses.options" TypeResponsesReasoningMetadata = Name + ".responses.reasoning_metadata" TypeWebSearchCallMetadata = Name + ".responses.web_search_call_metadata" + TypeComputerUseCallMetadata = Name + ".responses.computer_use_call_metadata" ) // Register OpenAI Responses API-specific types with the global registry. @@ -46,6 +47,13 @@ func init() { } return &v, nil }) + fantasy.RegisterProviderType(TypeComputerUseCallMetadata, func(data []byte) (fantasy.ProviderOptionsData, error) { + var v OpenAIComputerUseCallMetadata + if err := json.Unmarshal(data, &v); err != nil { + return nil, err + } + return &v, nil + }) } // ResponsesProviderMetadata contains response-level metadata from the OpenAI Responses API. @@ -222,35 +230,44 @@ var responsesReasoningModelIDs = []string{ "gpt-oss-120b", } +// computerUseModelIDs lists model IDs that support OpenAI's computer-use tool. +var computerUseModelIDs = []string{ + "computer-use-preview", + "computer-use-preview-2025-03-11", +} + // responsesModelIds lists all model IDs for OpenAI Responses API. -var responsesModelIDs = append([]string{ - "gpt-4.1", - "gpt-4.1-2025-04-14", - "gpt-4.1-mini", - "gpt-4.1-mini-2025-04-14", - "gpt-4.1-nano", - "gpt-4.1-nano-2025-04-14", - "gpt-4o", - "gpt-4o-2024-05-13", - "gpt-4o-2024-08-06", - "gpt-4o-2024-11-20", - "gpt-4o-mini", - "gpt-4o-mini-2024-07-18", - "gpt-4-turbo", - "gpt-4-turbo-2024-04-09", - "gpt-4-turbo-preview", - "gpt-4-0125-preview", - "gpt-4-1106-preview", - "gpt-4", - "gpt-4-0613", - "gpt-4.5-preview", - "gpt-4.5-preview-2025-02-27", - "gpt-3.5-turbo-0125", - "gpt-3.5-turbo", - "gpt-3.5-turbo-1106", - "chatgpt-4o-latest", - "gpt-5-chat-latest", -}, responsesReasoningModelIDs...) +var responsesModelIDs = append( + append([]string{ + "gpt-4.1", + "gpt-4.1-2025-04-14", + "gpt-4.1-mini", + "gpt-4.1-mini-2025-04-14", + "gpt-4.1-nano", + "gpt-4.1-nano-2025-04-14", + "gpt-4o", + "gpt-4o-2024-05-13", + "gpt-4o-2024-08-06", + "gpt-4o-2024-11-20", + "gpt-4o-mini", + "gpt-4o-mini-2024-07-18", + "gpt-4-turbo", + "gpt-4-turbo-2024-04-09", + "gpt-4-turbo-preview", + "gpt-4-0125-preview", + "gpt-4-1106-preview", + "gpt-4", + "gpt-4-0613", + "gpt-4.5-preview", + "gpt-4.5-preview-2025-02-27", + "gpt-3.5-turbo-0125", + "gpt-3.5-turbo", + "gpt-3.5-turbo-1106", + "chatgpt-4o-latest", + "gpt-5-chat-latest", + }, computerUseModelIDs...), + responsesReasoningModelIDs..., +) // NewResponsesProviderOptions creates new provider options for OpenAI Responses API. func NewResponsesProviderOptions(opts *ResponsesProviderOptions) fantasy.ProviderOptions { @@ -268,6 +285,12 @@ func ParseResponsesOptions(data map[string]any) (*ResponsesProviderOptions, erro return &options, nil } +// isOpenAIComputerUseModel checks if a model ID is an OpenAI computer-use +// Responses API model. +func isOpenAIComputerUseModel(modelID string) bool { + return slices.Contains(computerUseModelIDs, modelID) +} + // IsResponsesModel checks if a model ID is a Responses API model for OpenAI. func IsResponsesModel(modelID string) bool { return slices.Contains(responsesModelIDs, modelID) @@ -386,3 +409,39 @@ func (m *WebSearchCallMetadata) UnmarshalJSON(data []byte) error { *m = WebSearchCallMetadata(p) return nil } + +// OpenAIComputerUsePendingSafetyCheck stores a pending safety check from an +// OpenAI computer tool call using local types suitable for JSON round-tripping. +type OpenAIComputerUsePendingSafetyCheck struct { + ID string `json:"id"` + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` +} + +// OpenAIComputerUseCallMetadata stores structured metadata for an OpenAI +// computer tool call. CallID is required to submit the matching +// computer_call_output on replay. +type OpenAIComputerUseCallMetadata struct { + CallID string `json:"call_id"` + PendingSafetyChecks []OpenAIComputerUsePendingSafetyCheck `json:"pending_safety_checks,omitempty"` +} + +// Options implements the ProviderOptionsData interface. +func (*OpenAIComputerUseCallMetadata) Options() {} + +// MarshalJSON implements custom JSON marshaling with type info. +func (m OpenAIComputerUseCallMetadata) MarshalJSON() ([]byte, error) { + type plain OpenAIComputerUseCallMetadata + return fantasy.MarshalProviderType(TypeComputerUseCallMetadata, plain(m)) +} + +// UnmarshalJSON implements custom JSON unmarshaling with type info. +func (m *OpenAIComputerUseCallMetadata) UnmarshalJSON(data []byte) error { + type plain OpenAIComputerUseCallMetadata + var p plain + if err := fantasy.UnmarshalProviderType(data, &p); err != nil { + return err + } + *m = OpenAIComputerUseCallMetadata(p) + return nil +} diff --git a/providertests/openai_responses_test.go b/providertests/openai_responses_test.go index db85f73d4..9cf52db89 100644 --- a/providertests/openai_responses_test.go +++ b/providertests/openai_responses_test.go @@ -3,11 +3,13 @@ package providertests import ( "net/http" "os" + "path/filepath" "testing" "charm.land/fantasy" "charm.land/fantasy/providers/openai" "charm.land/x/vcr" + openairesponses "github.com/charmbracelet/openai-go/responses" "github.com/stretchr/testify/require" ) @@ -94,3 +96,184 @@ func testOpenAIResponsesThinkingWithSummaryThinking(t *testing.T, result *fantas require.Greater(t, encryptedData, 0) require.Equal(t, reasoningContentCount, encryptedData) } + +func TestOpenAIResponsesComputerUse(t *testing.T) { + modelName := "openai-computer-use-preview" + for _, cassettePath := range computerUseCassettePaths(t, modelName) { + if _, err := os.Stat(cassettePath); err != nil { + t.Skip("requires vcr cassette") + } + } + + modelID := string(openairesponses.ResponsesModelComputerUsePreview) + providerOptions := fantasy.ProviderOptions{ + openai.Name: &openai.ResponsesProviderOptions{ + Store: fantasy.Opt(true), + }, + } + + t.Run(modelName, func(t *testing.T) { + t.Run("computer use", func(t *testing.T) { + r := vcr.NewRecorder(t) + + model, err := openAIReasoningBuilder(modelID)(t, r) + require.NoError(t, err) + + cuTool := jsonRoundTripTool(t, openai.NewComputerUseTool(openai.ComputerUseToolOptions{ + DisplayWidthPx: 1920, + DisplayHeightPx: 1080, + Environment: openairesponses.ComputerUsePreviewToolEnvironmentBrowser, + }, noopComputerRun)) + + resp, err := model.Generate(t.Context(), fantasy.Call{ + Prompt: fantasy.Prompt{ + {Role: fantasy.MessageRoleSystem, Content: []fantasy.MessagePart{fantasy.TextPart{Text: "You are a helpful assistant"}}}, + {Role: fantasy.MessageRoleUser, Content: []fantasy.MessagePart{fantasy.TextPart{Text: "Take a screenshot of the desktop"}}}, + }, + ProviderOptions: providerOptions, + Tools: []fantasy.Tool{cuTool}, + }) + require.NoError(t, err) + require.Equal(t, fantasy.FinishReasonToolCalls, resp.FinishReason) + + toolCalls := resp.Content.ToolCalls() + require.Len(t, toolCalls, 1) + require.Equal(t, "computer", toolCalls[0].ToolName) + require.Contains(t, toolCalls[0].Input, "screenshot") + + resp2, err := model.Generate(t.Context(), fantasy.Call{ + Prompt: fantasy.Prompt{ + {Role: fantasy.MessageRoleSystem, Content: []fantasy.MessagePart{fantasy.TextPart{Text: "You are a helpful assistant"}}}, + {Role: fantasy.MessageRoleUser, Content: []fantasy.MessagePart{fantasy.TextPart{Text: "Take a screenshot of the desktop"}}}, + { + Role: fantasy.MessageRoleAssistant, + Content: []fantasy.MessagePart{ + fantasy.ToolCallPart{ + ToolCallID: toolCalls[0].ToolCallID, + ToolName: toolCalls[0].ToolName, + Input: toolCalls[0].Input, + ProviderOptions: fantasy.ProviderOptions(toolCalls[0].ProviderMetadata), + ProviderExecuted: toolCalls[0].ProviderExecuted, + }, + }, + }, + { + Role: fantasy.MessageRoleTool, + Content: []fantasy.MessagePart{ + fantasy.ToolResultPart{ + ToolCallID: toolCalls[0].ToolCallID, + Output: fantasy.ToolResultOutputContentMedia{ + Data: screenshotBase64, + MediaType: "image/png", + }, + }, + }, + }, + }, + ProviderOptions: providerOptions, + Tools: []fantasy.Tool{cuTool}, + }) + require.NoError(t, err) + require.NotEmpty(t, resp2.Content.Text()) + require.Contains(t, resp2.Content.Text(), "desktop") + }) + + t.Run("computer use streaming", func(t *testing.T) { + r := vcr.NewRecorder(t) + + model, err := openAIReasoningBuilder(modelID)(t, r) + require.NoError(t, err) + + cuTool := jsonRoundTripTool(t, openai.NewComputerUseTool(openai.ComputerUseToolOptions{ + DisplayWidthPx: 1920, + DisplayHeightPx: 1080, + Environment: openairesponses.ComputerUsePreviewToolEnvironmentBrowser, + }, noopComputerRun)) + + stream, err := model.Stream(t.Context(), fantasy.Call{ + Prompt: fantasy.Prompt{ + {Role: fantasy.MessageRoleSystem, Content: []fantasy.MessagePart{fantasy.TextPart{Text: "You are a helpful assistant"}}}, + {Role: fantasy.MessageRoleUser, Content: []fantasy.MessagePart{fantasy.TextPart{Text: "Take a screenshot of the desktop"}}}, + }, + ProviderOptions: providerOptions, + Tools: []fantasy.Tool{cuTool}, + }) + require.NoError(t, err) + + var toolCallID, toolCallName, toolCallInput string + var toolCallMetadata fantasy.ProviderMetadata + var toolCallProviderExecuted bool + var finishReason fantasy.FinishReason + stream(func(part fantasy.StreamPart) bool { + switch part.Type { + case fantasy.StreamPartTypeToolCall: + toolCallID = part.ID + toolCallName = part.ToolCallName + toolCallInput = part.ToolCallInput + toolCallMetadata = part.ProviderMetadata + toolCallProviderExecuted = part.ProviderExecuted + case fantasy.StreamPartTypeFinish: + finishReason = part.FinishReason + } + return true + }) + + require.Equal(t, fantasy.FinishReasonToolCalls, finishReason) + require.Equal(t, "computer", toolCallName) + require.Contains(t, toolCallInput, "screenshot") + require.NotEmpty(t, toolCallMetadata) + + stream2, err := model.Stream(t.Context(), fantasy.Call{ + Prompt: fantasy.Prompt{ + {Role: fantasy.MessageRoleSystem, Content: []fantasy.MessagePart{fantasy.TextPart{Text: "You are a helpful assistant"}}}, + {Role: fantasy.MessageRoleUser, Content: []fantasy.MessagePart{fantasy.TextPart{Text: "Take a screenshot of the desktop"}}}, + { + Role: fantasy.MessageRoleAssistant, + Content: []fantasy.MessagePart{ + fantasy.ToolCallPart{ + ToolCallID: toolCallID, + ToolName: toolCallName, + Input: toolCallInput, + ProviderOptions: fantasy.ProviderOptions(toolCallMetadata), + ProviderExecuted: toolCallProviderExecuted, + }, + }, + }, + { + Role: fantasy.MessageRoleTool, + Content: []fantasy.MessagePart{ + fantasy.ToolResultPart{ + ToolCallID: toolCallID, + Output: fantasy.ToolResultOutputContentMedia{ + Data: screenshotBase64, + MediaType: "image/png", + }, + }, + }, + }, + }, + ProviderOptions: providerOptions, + Tools: []fantasy.Tool{cuTool}, + }) + require.NoError(t, err) + + var text string + stream2(func(part fantasy.StreamPart) bool { + if part.Type == fantasy.StreamPartTypeTextDelta { + text += part.Delta + } + return true + }) + require.NotEmpty(t, text) + require.Contains(t, text, "desktop") + }) + }) +} + +func computerUseCassettePaths(t *testing.T, modelName string) []string { + base := filepath.Join("testdata", t.Name(), modelName) + return []string{ + filepath.Join(base, "computer_use.yaml"), + filepath.Join(base, "computer_use_streaming.yaml"), + } +}