diff --git a/api.go b/api.go index f8fc60f1..3669ae24 100644 --- a/api.go +++ b/api.go @@ -19,6 +19,7 @@ const ( ProviderAnthropic = config.ProviderAnthropic ProviderOpenAI = config.ProviderOpenAI ProviderCopilot = config.ProviderCopilot + ProviderBedrock = config.ProviderBedrock ) type ( @@ -37,6 +38,7 @@ type ( AnthropicConfig = config.Anthropic AWSBedrockConfig = config.AWSBedrock + BedrockConfig = config.Bedrock OpenAIConfig = config.OpenAI CopilotConfig = config.Copilot ) @@ -57,6 +59,10 @@ func NewCopilotProvider(cfg config.Copilot) provider.Provider { return provider.NewCopilot(cfg) } +func NewBedrockProvider(cfg config.Bedrock) provider.Provider { + return provider.NewBedrock(cfg) +} + func NewMetrics(reg prometheus.Registerer) *metrics.Metrics { return metrics.NewMetrics(reg) } diff --git a/config/config.go b/config/config.go index 17ce01e4..1e3e9651 100644 --- a/config/config.go +++ b/config/config.go @@ -6,6 +6,7 @@ const ( ProviderAnthropic = "anthropic" ProviderOpenAI = "openai" ProviderCopilot = "copilot" + ProviderBedrock = "bedrock" ) type Anthropic struct { @@ -34,6 +35,17 @@ type AWSBedrock struct { BaseURL string } +// Bedrock is a standalone Bedrock provider configuration. It acts as a +// SigV4-signing reverse proxy, forwarding native Bedrock API requests +// to AWS and adding centralized AWS credentials. +type Bedrock struct { + // Name is the provider instance name. If empty, defaults to "bedrock". + Name string + APIDumpDir string + CircuitBreaker *CircuitBreaker + AWSBedrock // Region, AccessKey, AccessKeySecret, SessionToken, Model, SmallFastModel, BaseURL +} + type OpenAI struct { // Name is the provider instance name. If empty, defaults to "openai". Name string diff --git a/fixtures/bedrock/parse_eventstream.go b/fixtures/bedrock/parse_eventstream.go new file mode 100644 index 00000000..51ca8529 --- /dev/null +++ b/fixtures/bedrock/parse_eventstream.go @@ -0,0 +1,88 @@ +//go:build ignore + +// Usage: go run parse_eventstream.go +// +// Decodes an AWS EventStream binary file and prints each frame's +// decoded JSON body. Use this to inspect captured Bedrock responses +// and verify fixture contents. +package main + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "fmt" + "os" + + "github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream" + "github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream/eventstreamapi" +) + +func main() { + if len(os.Args) < 2 { + fmt.Fprintf(os.Stderr, "usage: go run parse_eventstream.go \n") + os.Exit(1) + } + + data, err := os.ReadFile(os.Args[1]) + if err != nil { + fmt.Fprintf(os.Stderr, "read file: %v\n", err) + os.Exit(1) + } + + decoder := eventstream.NewDecoder() + reader := bytes.NewReader(data) + frameNum := 0 + + for { + msg, err := decoder.Decode(reader, nil) + if err != nil { + break + } + frameNum++ + + messageType := msg.Headers.Get(eventstreamapi.MessageTypeHeader) + eventType := msg.Headers.Get(eventstreamapi.EventTypeHeader) + + fmt.Printf("=== Frame %d ===\n", frameNum) + fmt.Printf(" message-type: %s\n", headerStr(messageType)) + fmt.Printf(" event-type: %s\n", headerStr(eventType)) + + if headerStr(eventType) != "chunk" { + fmt.Printf(" payload: %s\n\n", string(msg.Payload)) + continue + } + + var chunk struct { + Bytes string `json:"bytes"` + } + if err := json.Unmarshal(msg.Payload, &chunk); err != nil { + fmt.Printf(" unmarshal error: %v\n\n", err) + continue + } + + decoded, err := base64.StdEncoding.DecodeString(chunk.Bytes) + if err != nil { + fmt.Printf(" base64 decode error: %v\n\n", err) + continue + } + + var pretty json.RawMessage + if err := json.Unmarshal(decoded, &pretty); err != nil { + fmt.Printf(" json: %s\n\n", string(decoded)) + continue + } + + indented, _ := json.MarshalIndent(pretty, " ", " ") + fmt.Printf(" body:\n %s\n\n", string(indented)) + } + + fmt.Printf("Total frames: %d\n", frameNum) +} + +func headerStr(h eventstream.Value) string { + if h == nil { + return "" + } + return h.String() +} diff --git a/fixtures/bedrock/simple.req.json b/fixtures/bedrock/simple.req.json new file mode 100644 index 00000000..b61b2ac2 --- /dev/null +++ b/fixtures/bedrock/simple.req.json @@ -0,0 +1,10 @@ +{ + "anthropic_version": "bedrock-2023-05-31", + "max_tokens": 50, + "messages": [ + { + "role": "user", + "content": "how many angels can dance on the head of a pin" + } + ] +} diff --git a/fixtures/bedrock/simple.resp.bin b/fixtures/bedrock/simple.resp.bin new file mode 100644 index 00000000..056c54a0 Binary files /dev/null and b/fixtures/bedrock/simple.resp.bin differ diff --git a/fixtures/bedrock/simple.resp.decoded b/fixtures/bedrock/simple.resp.decoded new file mode 100644 index 00000000..49cb0689 --- /dev/null +++ b/fixtures/bedrock/simple.resp.decoded @@ -0,0 +1,511 @@ +=== Frame 1 === + message-type: event + event-type: chunk + body: + { + "type": "message_start", + "message": { + "model": "claude-sonnet-4-5-20250929", + "id": "msg_bdrk_01V7qMAYt4GEtvfD9w8MJsGm", + "type": "message", + "role": "assistant", + "content": [], + "stop_reason": null, + "stop_sequence": null, + "usage": { + "input_tokens": 18, + "cache_creation_input_tokens": 0, + "cache_read_input_tokens": 0, + "cache_creation": { + "ephemeral_5m_input_tokens": 0, + "ephemeral_1h_input_tokens": 0 + }, + "output_tokens": 1 + } + } + } + +=== Frame 2 === + message-type: event + event-type: chunk + body: + { + "type": "content_block_start", + "index": 0, + "content_block": { + "type": "text", + "text": "" + } + } + +=== Frame 3 === + message-type: event + event-type: chunk + body: + { + "type": "content_block_delta", + "index": 0, + "delta": { + "type": "text_delta", + "text": "This" + } + } + +=== Frame 4 === + message-type: event + event-type: chunk + body: + { + "type": "content_block_delta", + "index": 0, + "delta": { + "type": "text_delta", + "text": " famous" + } + } + +=== Frame 5 === + message-type: event + event-type: chunk + body: + { + "type": "content_block_delta", + "index": 0, + "delta": { + "type": "text_delta", + "text": " question" + } + } + +=== Frame 6 === + message-type: event + event-type: chunk + body: + { + "type": "content_block_delta", + "index": 0, + "delta": { + "type": "text_delta", + "text": " is often" + } + } + +=== Frame 7 === + message-type: event + event-type: chunk + body: + { + "type": "content_block_delta", + "index": 0, + "delta": { + "type": "text_delta", + "text": " used" + } + } + +=== Frame 8 === + message-type: event + event-type: chunk + body: + { + "type": "content_block_delta", + "index": 0, + "delta": { + "type": "text_delta", + "text": " to" + } + } + +=== Frame 9 === + message-type: event + event-type: chunk + body: + { + "type": "content_block_delta", + "index": 0, + "delta": { + "type": "text_delta", + "text": " mock" + } + } + +=== Frame 10 === + message-type: event + event-type: chunk + body: + { + "type": "content_block_delta", + "index": 0, + "delta": { + "type": "text_delta", + "text": " medieval" + } + } + +=== Frame 11 === + message-type: event + event-type: chunk + body: + { + "type": "content_block_delta", + "index": 0, + "delta": { + "type": "text_delta", + "text": " schol" + } + } + +=== Frame 12 === + message-type: event + event-type: chunk + body: + { + "type": "content_block_delta", + "index": 0, + "delta": { + "type": "text_delta", + "text": "astic philosophy" + } + } + +=== Frame 13 === + message-type: event + event-type: chunk + body: + { + "type": "content_block_delta", + "index": 0, + "delta": { + "type": "text_delta", + "text": " as" + } + } + +=== Frame 14 === + message-type: event + event-type: chunk + body: + { + "type": "content_block_delta", + "index": 0, + "delta": { + "type": "text_delta", + "text": " abs" + } + } + +=== Frame 15 === + message-type: event + event-type: chunk + body: + { + "type": "content_block_delta", + "index": 0, + "delta": { + "type": "text_delta", + "text": "urd" + } + } + +=== Frame 16 === + message-type: event + event-type: chunk + body: + { + "type": "content_block_delta", + "index": 0, + "delta": { + "type": "text_delta", + "text": "ly abstract" + } + } + +=== Frame 17 === + message-type: event + event-type: chunk + body: + { + "type": "content_block_delta", + "index": 0, + "delta": { + "type": "text_delta", + "text": " and" + } + } + +=== Frame 18 === + message-type: event + event-type: chunk + body: + { + "type": "content_block_delta", + "index": 0, + "delta": { + "type": "text_delta", + "text": " disconn" + } + } + +=== Frame 19 === + message-type: event + event-type: chunk + body: + { + "type": "content_block_delta", + "index": 0, + "delta": { + "type": "text_delta", + "text": "ected from reality." + } + } + +=== Frame 20 === + message-type: event + event-type: chunk + body: + { + "type": "content_block_delta", + "index": 0, + "delta": { + "type": "text_delta", + "text": " However" + } + } + +=== Frame 21 === + message-type: event + event-type: chunk + body: + { + "type": "content_block_delta", + "index": 0, + "delta": { + "type": "text_delta", + "text": ", it's" + } + } + +=== Frame 22 === + message-type: event + event-type: chunk + body: + { + "type": "content_block_delta", + "index": 0, + "delta": { + "type": "text_delta", + "text": " actually" + } + } + +=== Frame 23 === + message-type: event + event-type: chunk + body: + { + "type": "content_block_delta", + "index": 0, + "delta": { + "type": "text_delta", + "text": " a **" + } + } + +=== Frame 24 === + message-type: event + event-type: chunk + body: + { + "type": "content_block_delta", + "index": 0, + "delta": { + "type": "text_delta", + "text": "miscon" + } + } + +=== Frame 25 === + message-type: event + event-type: chunk + body: + { + "type": "content_block_delta", + "index": 0, + "delta": { + "type": "text_delta", + "text": "ception** —" + } + } + +=== Frame 26 === + message-type: event + event-type: chunk + body: + { + "type": "content_block_delta", + "index": 0, + "delta": { + "type": "text_delta", + "text": " there" + } + } + +=== Frame 27 === + message-type: event + event-type: chunk + body: + { + "type": "content_block_delta", + "index": 0, + "delta": { + "type": "text_delta", + "text": "'s" + } + } + +=== Frame 28 === + message-type: event + event-type: chunk + body: + { + "type": "content_block_delta", + "index": 0, + "delta": { + "type": "text_delta", + "text": " no evidence" + } + } + +=== Frame 29 === + message-type: event + event-type: chunk + body: + { + "type": "content_block_delta", + "index": 0, + "delta": { + "type": "text_delta", + "text": " medieval" + } + } + +=== Frame 30 === + message-type: event + event-type: chunk + body: + { + "type": "content_block_delta", + "index": 0, + "delta": { + "type": "text_delta", + "text": " philosophers" + } + } + +=== Frame 31 === + message-type: event + event-type: chunk + body: + { + "type": "content_block_delta", + "index": 0, + "delta": { + "type": "text_delta", + "text": " seriously" + } + } + +=== Frame 32 === + message-type: event + event-type: chunk + body: + { + "type": "content_block_delta", + "index": 0, + "delta": { + "type": "text_delta", + "text": " deb" + } + } + +=== Frame 33 === + message-type: event + event-type: chunk + body: + { + "type": "content_block_delta", + "index": 0, + "delta": { + "type": "text_delta", + "text": "ated this." + } + } + +=== Frame 34 === + message-type: event + event-type: chunk + body: + { + "type": "content_block_delta", + "index": 0, + "delta": { + "type": "text_delta", + "text": "\n\nThe question likely" + } + } + +=== Frame 35 === + message-type: event + event-type: chunk + body: + { + "type": "content_block_delta", + "index": 0, + "delta": { + "type": "text_delta", + "text": " originated" + } + } + +=== Frame 36 === + message-type: event + event-type: chunk + body: + { + "type": "content_block_stop", + "index": 0 + } + +=== Frame 37 === + message-type: event + event-type: chunk + body: + { + "type": "message_delta", + "delta": { + "stop_reason": "max_tokens", + "stop_sequence": null + }, + "usage": { + "input_tokens": 18, + "cache_creation_input_tokens": 0, + "cache_read_input_tokens": 0, + "output_tokens": 50 + } + } + +=== Frame 38 === + message-type: event + event-type: chunk + body: + { + "type": "message_stop", + "amazon-bedrock-invocationMetrics": { + "inputTokenCount": 18, + "outputTokenCount": 50, + "invocationLatency": 3403, + "firstByteLatency": 1550 + } + } + +Total frames: 38 diff --git a/fixtures/fixtures.go b/fixtures/fixtures.go index c731e0fb..18b57a30 100644 --- a/fixtures/fixtures.go +++ b/fixtures/fixtures.go @@ -163,6 +163,23 @@ var ( OaiResponsesStreamingWrongResponseFormat []byte ) +// Bedrock fixtures are binary (eventstream) and cannot use the txtar format. +// Each fixture is a pair: .req.json (request body) and .resp.bin (raw +// eventstream response). +var ( + //go:embed bedrock/simple.req.json + BedrockSimpleReq []byte + + //go:embed bedrock/simple.resp.bin + BedrockSimpleResp []byte +) + +// BedrockFixture holds a request/response pair for a Bedrock test case. +type BedrockFixture struct { + Request []byte + Response []byte +} + // Section name constants matching the file names used in txtar fixtures. const ( fileRequest = "request" diff --git a/intercept/bedrock/interceptor.go b/intercept/bedrock/interceptor.go new file mode 100644 index 00000000..1783c217 --- /dev/null +++ b/intercept/bedrock/interceptor.go @@ -0,0 +1,461 @@ +// Package bedrock provides a SigV4-signing reverse proxy interceptor +// for native Bedrock API requests. It forwards requests to AWS Bedrock +// with centralized AWS credentials and extracts audit metadata from +// the response stream. +package bedrock + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream" + "github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream/eventstreamapi" + v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + awsconfig "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/google/uuid" + "github.com/tidwall/gjson" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + "golang.org/x/xerrors" + + "cdr.dev/slog/v3" + "github.com/coder/aibridge/config" + "github.com/coder/aibridge/intercept" + "github.com/coder/aibridge/intercept/messages" + "github.com/coder/aibridge/mcp" + "github.com/coder/aibridge/recorder" + "github.com/coder/aibridge/tracing" +) + +var _ intercept.Interceptor = &Interceptor{} + +// Interceptor is a SigV4-signing reverse proxy for native Bedrock API +// requests. It forwards the request body as-is, signs with centralized +// AWS credentials, and extracts audit metadata from the response. +type Interceptor struct { + id uuid.UUID + modelID string + streaming bool + reqBody []byte + originalPath string + bedrockCfg config.AWSBedrock + providerName string + dumpDir string + tracer trace.Tracer + credential intercept.CredentialInfo + + logger slog.Logger + recorder recorder.Recorder +} + +func NewInterceptor( + id uuid.UUID, + modelID string, + streaming bool, + reqBody []byte, + originalPath string, + bedrockCfg config.AWSBedrock, + providerName string, + dumpDir string, + tracer trace.Tracer, + cred intercept.CredentialInfo, +) *Interceptor { + return &Interceptor{ + id: id, + modelID: modelID, + streaming: streaming, + reqBody: reqBody, + originalPath: originalPath, + bedrockCfg: bedrockCfg, + providerName: providerName, + dumpDir: dumpDir, + tracer: tracer, + credential: cred, + } +} + +func (i *Interceptor) ID() uuid.UUID { return i.id } +func (i *Interceptor) Model() string { return i.modelID } + +func (i *Interceptor) Setup(logger slog.Logger, rec recorder.Recorder, _ mcp.ServerProxier) { + i.logger = logger + i.recorder = rec +} + +func (i *Interceptor) Streaming() bool { return i.streaming } +func (i *Interceptor) Credential() intercept.CredentialInfo { return i.credential } +func (i *Interceptor) CorrelatingToolCallID() *string { return nil } + +func (i *Interceptor) TraceAttributes(r *http.Request) []attribute.KeyValue { + return []attribute.KeyValue{ + attribute.String(tracing.Provider, i.providerName), + attribute.String(tracing.Model, i.modelID), + attribute.Bool(tracing.Streaming, i.streaming), + attribute.Bool(tracing.IsBedrock, true), + } +} + +func (i *Interceptor) ProcessRequest(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + _, span := i.tracer.Start(ctx, "bedrock.ProcessRequest") + defer span.End() + + // Extract user prompt before sending the request. + var promptText string + var promptFound bool + if reqPayload, err := messages.NewRequestPayload(i.reqBody); err == nil { + promptText, promptFound, _ = reqPayload.LastUserPrompt() + } + + baseURL := i.bedrockCfg.BaseURL + if baseURL == "" { + baseURL = "https://bedrock-runtime." + i.bedrockCfg.Region + ".amazonaws.com" + } + + outReq, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL+i.originalPath, bytes.NewReader(i.reqBody)) + if err != nil { + return xerrors.Errorf("create outbound request: %w", err) + } + + outReq.Header = intercept.PrepareClientHeaders(r.Header) + outReq.Header.Set("Content-Type", "application/json") + + awsCreds, err := i.loadCredentials(ctx) + if err != nil { + return xerrors.Errorf("load AWS credentials: %w", err) + } + + hash := sha256.Sum256(i.reqBody) + signer := v4.NewSigner() + if err = signer.SignHTTP(ctx, awsCreds, outReq, hex.EncodeToString(hash[:]), "bedrock", i.bedrockCfg.Region, time.Now()); err != nil { + return xerrors.Errorf("sign request: %w", err) + } + + resp, err := http.DefaultClient.Do(outReq) + if err != nil { + return xerrors.Errorf("send request to bedrock: %w", err) + } + defer resp.Body.Close() + + for key, values := range resp.Header { + for _, val := range values { + w.Header().Add(key, val) + } + } + w.WriteHeader(resp.StatusCode) + + if resp.StatusCode != http.StatusOK { + _, _ = io.Copy(w, resp.Body) + return nil + } + + // Buffer the response while streaming to the client so we can + // parse it for audit data after the stream completes. + var auditBuf bytes.Buffer + tee := io.TeeReader(resp.Body, &auditBuf) + + if i.streaming { + flusher, ok := w.(http.Flusher) + buf := make([]byte, 32*1024) + for { + n, readErr := tee.Read(buf) + if n > 0 { + if _, writeErr := w.Write(buf[:n]); writeErr != nil { + return xerrors.Errorf("write streaming chunk: %w", writeErr) + } + if ok { + flusher.Flush() + } + } + if readErr != nil { + if readErr == io.EOF { + break + } + return xerrors.Errorf("read streaming chunk: %w", readErr) + } + } + } else { + if _, err = io.Copy(w, tee); err != nil { + return xerrors.Errorf("copy response body: %w", err) + } + } + + respBytes := auditBuf.Bytes() + + // Dump request and response for debugging/fixture generation. + i.dumpRequestResponse(ctx, respBytes) + + // Extract audit metadata from the buffered response. + if i.streaming { + i.extractStreamingAudit(ctx, respBytes, promptText, promptFound) + } else { + i.extractBlockingAudit(ctx, respBytes, promptText, promptFound) + } + + return nil +} + +// dumpRequestResponse writes the raw request body and response bytes +// to files for debugging and test fixture generation. +func (i *Interceptor) dumpRequestResponse(ctx context.Context, respBytes []byte) { + if i.dumpDir == "" { + return + } + + safeModel := strings.ReplaceAll(i.modelID, "/", "-") + dir := filepath.Join(i.dumpDir, i.providerName, safeModel) + if err := os.MkdirAll(dir, 0o755); err != nil { + i.logger.Warn(ctx, "failed to create dump dir", slog.Error(err)) + return + } + + base := filepath.Join(dir, fmt.Sprintf("%d-%s", time.Now().UTC().UnixMilli(), i.id)) + + if err := os.WriteFile(base+".req.json", i.reqBody, 0o644); err != nil { + i.logger.Warn(ctx, "failed to dump request", slog.Error(err)) + } + + suffix := ".resp.json" + if i.streaming { + suffix = ".resp.bin" + } + if err := os.WriteFile(base+suffix, respBytes, 0o644); err != nil { + i.logger.Warn(ctx, "failed to dump response", slog.Error(err)) + } +} + +// extractStreamingAudit parses a buffered AWS EventStream response +// and records audit metadata. +func (i *Interceptor) extractStreamingAudit(ctx context.Context, data []byte, promptText string, promptFound bool) { + decoder := eventstream.NewDecoder() + reader := bytes.NewReader(data) + + var msgID string + + // Accumulators for content blocks indexed by block position. + type toolBlock struct { + id string + name string + args bytes.Buffer + } + var toolBlocks []toolBlock + thinkingBlocks := map[int]*bytes.Buffer{} + blockTypes := map[int]string{} + + for { + msg, err := decoder.Decode(reader, nil) + if err != nil { + break + } + + messageType := msg.Headers.Get(eventstreamapi.MessageTypeHeader) + if messageType == nil || messageType.String() != eventstreamapi.EventMessageType { + continue + } + eventType := msg.Headers.Get(eventstreamapi.EventTypeHeader) + if eventType == nil || eventType.String() != "chunk" { + continue + } + + var chunk struct { + Bytes string `json:"bytes"` + } + if err := json.Unmarshal(msg.Payload, &chunk); err != nil { + continue + } + decoded, err := base64.StdEncoding.DecodeString(chunk.Bytes) + if err != nil { + continue + } + + eventKind := gjson.GetBytes(decoded, "type").String() + + switch eventKind { + case "message_start": + msgID = gjson.GetBytes(decoded, "message.id").String() + usage := gjson.GetBytes(decoded, "message.usage") + if usage.Exists() { + _ = i.recorder.RecordTokenUsage(ctx, &recorder.TokenUsageRecord{ + InterceptionID: i.id.String(), + MsgID: msgID, + Input: usage.Get("input_tokens").Int(), + Output: usage.Get("output_tokens").Int(), + CacheReadInputTokens: usage.Get("cache_read_input_tokens").Int(), + CacheWriteInputTokens: usage.Get("cache_creation_input_tokens").Int(), + }) + } + if promptFound { + _ = i.recorder.RecordPromptUsage(ctx, &recorder.PromptUsageRecord{ + InterceptionID: i.id.String(), + MsgID: msgID, + Prompt: promptText, + }) + promptFound = false + } + + case "message_delta": + usage := gjson.GetBytes(decoded, "usage") + if usage.Exists() { + _ = i.recorder.RecordTokenUsage(ctx, &recorder.TokenUsageRecord{ + InterceptionID: i.id.String(), + MsgID: msgID, + Output: usage.Get("output_tokens").Int(), + }) + } + + case "content_block_start": + idx := int(gjson.GetBytes(decoded, "index").Int()) + blockType := gjson.GetBytes(decoded, "content_block.type").String() + blockTypes[idx] = blockType + + if blockType == "tool_use" { + toolBlocks = append(toolBlocks, toolBlock{ + id: gjson.GetBytes(decoded, "content_block.id").String(), + name: gjson.GetBytes(decoded, "content_block.name").String(), + }) + } + if blockType == "thinking" { + thinkingBlocks[idx] = &bytes.Buffer{} + } + + case "content_block_delta": + idx := int(gjson.GetBytes(decoded, "index").Int()) + switch blockTypes[idx] { + case "tool_use": + partialJSON := gjson.GetBytes(decoded, "delta.partial_json").String() + for ti := range toolBlocks { + if toolBlocks[ti].id != "" { + toolBlocks[len(toolBlocks)-1].args.WriteString(partialJSON) + break + } + } + case "thinking": + if buf, ok := thinkingBlocks[idx]; ok { + buf.WriteString(gjson.GetBytes(decoded, "delta.thinking").String()) + } + } + + case "message_stop": + for _, tb := range toolBlocks { + var args json.RawMessage + if tb.args.Len() > 0 { + args = json.RawMessage(tb.args.Bytes()) + } + _ = i.recorder.RecordToolUsage(ctx, &recorder.ToolUsageRecord{ + InterceptionID: i.id.String(), + MsgID: msgID, + ToolCallID: tb.id, + Tool: tb.name, + Args: args, + Injected: false, + }) + } + for _, buf := range thinkingBlocks { + if buf.Len() > 0 { + _ = i.recorder.RecordModelThought(ctx, &recorder.ModelThoughtRecord{ + InterceptionID: i.id.String(), + Content: buf.String(), + Metadata: recorder.Metadata{"source": recorder.ThoughtSourceThinking}, + }) + } + } + } + } +} + +// extractBlockingAudit parses a JSON response body and records audit +// metadata. +func (i *Interceptor) extractBlockingAudit(ctx context.Context, data []byte, promptText string, promptFound bool) { + msgID := gjson.GetBytes(data, "id").String() + + if promptFound { + _ = i.recorder.RecordPromptUsage(ctx, &recorder.PromptUsageRecord{ + InterceptionID: i.id.String(), + MsgID: msgID, + Prompt: promptText, + }) + } + + usage := gjson.GetBytes(data, "usage") + if usage.Exists() { + _ = i.recorder.RecordTokenUsage(ctx, &recorder.TokenUsageRecord{ + InterceptionID: i.id.String(), + MsgID: msgID, + Input: usage.Get("input_tokens").Int(), + Output: usage.Get("output_tokens").Int(), + CacheReadInputTokens: usage.Get("cache_read_input_tokens").Int(), + CacheWriteInputTokens: usage.Get("cache_creation_input_tokens").Int(), + }) + } + + content := gjson.GetBytes(data, "content") + if content.IsArray() { + content.ForEach(func(_, block gjson.Result) bool { + switch block.Get("type").String() { + case "tool_use": + _ = i.recorder.RecordToolUsage(ctx, &recorder.ToolUsageRecord{ + InterceptionID: i.id.String(), + MsgID: msgID, + ToolCallID: block.Get("id").String(), + Tool: block.Get("name").String(), + Args: json.RawMessage(block.Get("input").Raw), + Injected: false, + }) + case "thinking": + thinking := block.Get("thinking").String() + if thinking != "" { + _ = i.recorder.RecordModelThought(ctx, &recorder.ModelThoughtRecord{ + InterceptionID: i.id.String(), + Content: thinking, + Metadata: recorder.Metadata{"source": recorder.ThoughtSourceThinking}, + }) + } + } + return true + }) + } +} + +func (i *Interceptor) loadCredentials(ctx context.Context) (aws.Credentials, error) { + loadOpts := []func(*awsconfig.LoadOptions) error{ + awsconfig.WithRegion(i.bedrockCfg.Region), + } + + switch { + case i.bedrockCfg.AccessKey != "" && i.bedrockCfg.AccessKeySecret != "": + loadOpts = append(loadOpts, awsconfig.WithCredentialsProvider( + credentials.NewStaticCredentialsProvider( + i.bedrockCfg.AccessKey, + i.bedrockCfg.AccessKeySecret, + i.bedrockCfg.SessionToken, + ), + )) + case i.bedrockCfg.AccessKey != "" || i.bedrockCfg.AccessKeySecret != "": + return aws.Credentials{}, xerrors.New("both access key and access key secret must be provided together") + } + + cfg, err := awsconfig.LoadDefaultConfig(ctx, loadOpts...) + if err != nil { + return aws.Credentials{}, xerrors.Errorf("load AWS config: %w", err) + } + + creds, err := cfg.Credentials.Retrieve(ctx) + if err != nil { + return aws.Credentials{}, xerrors.Errorf("retrieve AWS credentials: %w", err) + } + + return creds, nil +} diff --git a/intercept/messages/blocking.go b/intercept/messages/blocking.go index 7fb3f562..277e769f 100644 --- a/intercept/messages/blocking.go +++ b/intercept/messages/blocking.go @@ -77,7 +77,7 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req i.injectTools() var prompt *string - promptText, promptFound, promptErr := i.reqPayload.lastUserPrompt() + promptText, promptFound, promptErr := i.reqPayload.LastUserPrompt() if promptErr != nil { i.logger.Warn(ctx, "failed to retrieve last user prompt", slog.Error(promptErr)) } else if promptFound { diff --git a/intercept/messages/reqpayload.go b/intercept/messages/reqpayload.go index dfe52fc8..6374ae92 100644 --- a/intercept/messages/reqpayload.go +++ b/intercept/messages/reqpayload.go @@ -144,10 +144,10 @@ func (p RequestPayload) correlatingToolCallID() *string { return nil } -// lastUserPrompt returns the prompt text from the last user message. If no prompt +// LastUserPrompt returns the prompt text from the last user message. If no prompt // is found, it returns empty string, false, nil. Unexpected shapes are treated as // unsupported and do not fail the request path. -func (p RequestPayload) lastUserPrompt() (string, bool, error) { +func (p RequestPayload) LastUserPrompt() (string, bool, error) { messages := gjson.GetBytes(p, messagesReqPathMessages) if !messages.Exists() || messages.Type == gjson.Null { return "", false, nil diff --git a/intercept/messages/reqpayload_test.go b/intercept/messages/reqpayload_test.go index a5de61f8..d73eaeca 100644 --- a/intercept/messages/reqpayload_test.go +++ b/intercept/messages/reqpayload_test.go @@ -216,7 +216,7 @@ func TestRequestPayloadLastUserPrompt(t *testing.T) { t.Parallel() payload := mustMessagesPayload(t, testCase.requestBody) - prompt, found, err := payload.lastUserPrompt() + prompt, found, err := payload.LastUserPrompt() if testCase.expectError { require.Error(t, err) return diff --git a/intercept/messages/streaming.go b/intercept/messages/streaming.go index d62441a9..0c4d0dbe 100644 --- a/intercept/messages/streaming.go +++ b/intercept/messages/streaming.go @@ -111,7 +111,7 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re err error ) - prompt, promptFound, err = i.reqPayload.lastUserPrompt() + prompt, promptFound, err = i.reqPayload.LastUserPrompt() if err != nil { logger.Warn(ctx, "failed to determine last user prompt", slog.Error(err)) } diff --git a/internal/integrationtest/bridge_test.go b/internal/integrationtest/bridge_test.go index d0fdff16..4332fb50 100644 --- a/internal/integrationtest/bridge_test.go +++ b/internal/integrationtest/bridge_test.go @@ -2132,3 +2132,57 @@ func TestActorHeaders(t *testing.T) { } } } + +// Native Bedrock tests use the AWS EventStream binary protocol +// (application/vnd.amazon.eventstream) instead of Anthropic's SSE format. +// The provider acts as a SigV4-signing reverse proxy. +func TestNativeBedrockSimple(t *testing.T) { + t.Parallel() + + const bedrockModel = "us.anthropic.claude-sonnet-4-5-20250929-v1:0" + + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) + t.Cleanup(cancel) + + fix := fixtures.BedrockFixture{ + Request: fixtures.BedrockSimpleReq, + Response: fixtures.BedrockSimpleResp, + } + + upstream := newMockUpstream(ctx, t, upstreamResponse{ + Streaming: fix.Response, + }) + + bridgeServer := newBridgeTestServer(ctx, t, upstream.URL) + + path := "/bedrock/model/" + bedrockModel + "/invoke-with-response-stream" + resp, err := bridgeServer.makeRequest(t, http.MethodPost, path, fix.Request) + require.NoError(t, err) + defer resp.Body.Close() + require.Equal(t, http.StatusOK, resp.StatusCode) + + // Verify non-empty response was forwarded to client. + bodyBytes, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.NotEmpty(t, bodyBytes, "should have received response body") + + // Verify the upstream received the request at the correct path. + received := upstream.receivedRequests() + require.Len(t, received, 1) + assert.Equal(t, "/model/"+bedrockModel+"/invoke-with-response-stream", received[0].Path) + + // Verify prompt was recorded. + promptUsages := bridgeServer.Recorder.RecordedPromptUsages() + require.NotEmpty(t, promptUsages, "no prompts tracked") + assert.Contains(t, promptUsages[0].Prompt, "how many angels can dance on the head of a pin") + + // Streaming produces 2 token records: message_start + message_delta. + tokenUsages := bridgeServer.Recorder.RecordedTokenUsages() + require.Len(t, tokenUsages, 2) + assert.EqualValues(t, 18, bridgeServer.Recorder.TotalInputTokens(), "input tokens") + // message_start reports output_tokens=1, message_delta reports output_tokens=50. + assert.EqualValues(t, 51, bridgeServer.Recorder.TotalOutputTokens(), "output tokens") + + // Verify interception lifecycle. + bridgeServer.Recorder.VerifyAllInterceptionsEnded(t) +} diff --git a/internal/integrationtest/helpers.go b/internal/integrationtest/helpers.go index 038e6335..bead5a8b 100644 --- a/internal/integrationtest/helpers.go +++ b/internal/integrationtest/helpers.go @@ -35,6 +35,20 @@ func bedrockCfg(url string) *config.AWSBedrock { } } +// standaloneBedrockCfg creates a Bedrock provider config for testing. +// The BaseURL points at the mock upstream so no real AWS calls are made. +func standaloneBedrockCfg(url string) config.Bedrock { + return config.Bedrock{ + Name: config.ProviderBedrock, + AWSBedrock: config.AWSBedrock{ + Region: "us-east-1", + AccessKey: "test-access-key", + AccessKeySecret: "test-secret-key", + BaseURL: url, + }, + } +} + // openAICfg creates a minimal OpenAI config for testing. func openAICfg(url string, key string) config.OpenAI { return config.OpenAI{ diff --git a/internal/integrationtest/mockupstream.go b/internal/integrationtest/mockupstream.go index faef0488..7a0ee4ff 100644 --- a/internal/integrationtest/mockupstream.go +++ b/internal/integrationtest/mockupstream.go @@ -171,6 +171,10 @@ func (ms *mockUpstream) handle(w http.ResponseWriter, r *http.Request) { ms.writeRawHTTPResponse(w, r, resp.Streaming) return } + if isBedrockPath(r.URL.Path) { + ms.writeEventStream(w, resp.Streaming) + return + } ms.writeSSE(w, resp.Streaming) return } @@ -228,6 +232,20 @@ func (ms *mockUpstream) writeSSE(w http.ResponseWriter, data []byte) { require.NoError(ms.t, scanner.Err()) } +// isBedrockPath returns true if the URL path looks like a Bedrock invoke path. +func isBedrockPath(path string) bool { + return strings.Contains(path, "/model/") && (strings.HasSuffix(path, "/invoke") || strings.HasSuffix(path, "/invoke-with-response-stream")) +} + +// writeEventStream writes raw binary eventstream data as-is. +func (ms *mockUpstream) writeEventStream(w http.ResponseWriter, data []byte) { + ms.t.Helper() + w.Header().Set("Content-Type", "application/vnd.amazon.eventstream") + w.WriteHeader(http.StatusOK) + _, err := w.Write(data) + require.NoError(ms.t, err) +} + // isRawHTTPResponse returns true if data starts with "HTTP/", indicating // it contains a complete HTTP response (status line + headers + body) rather // than just a response body. diff --git a/internal/integrationtest/setupbridge.go b/internal/integrationtest/setupbridge.go index 5bfa01bc..fd017f1e 100644 --- a/internal/integrationtest/setupbridge.go +++ b/internal/integrationtest/setupbridge.go @@ -32,9 +32,8 @@ const ( pathCopilotChatCompletions = "/copilot/chat/completions" pathCopilotResponses = "/copilot/responses" - // providerBedrock identifies a Bedrock provider in [withProvider]. - // other providers use config.Provider* constants. - providerBedrock = "bedrock" + // providerBedrockAnthropic identifies a Bedrock-via-Anthropic provider. + providerBedrockAnthropic = "bedrock-anthropic" // defaults apiKey = "api-key" @@ -163,6 +162,7 @@ func newBridgeTestServer( providers = []aibridge.Provider{ newDefaultProvider(config.ProviderAnthropic, upstreamURL), newDefaultProvider(config.ProviderOpenAI, upstreamURL), + newDefaultProvider(config.ProviderBedrock, upstreamURL), } } @@ -253,8 +253,10 @@ func newDefaultProvider(providerType string, addr string) aibridge.Provider { return provider.NewAnthropic(anthropicCfg(addr, apiKey), nil) case config.ProviderOpenAI: return provider.NewOpenAI(openAICfg(addr, apiKey)) - case providerBedrock: + case providerBedrockAnthropic: return provider.NewAnthropic(anthropicCfg(addr, apiKey), bedrockCfg(addr)) + case config.ProviderBedrock: + return provider.NewBedrock(standaloneBedrockCfg(addr)) default: panic("unknown provider type: " + providerType) } diff --git a/internal/integrationtest/trace_test.go b/internal/integrationtest/trace_test.go index dc86815f..09326294 100644 --- a/internal/integrationtest/trace_test.go +++ b/internal/integrationtest/trace_test.go @@ -153,7 +153,7 @@ func TestTraceAnthropic(t *testing.T) { withTracer(tracer), } if tc.bedrock { - opts = append(opts, withProvider(providerBedrock)) + opts = append(opts, withProvider(providerBedrockAnthropic)) } bridgeServer := newBridgeTestServer(ctx, t, upstream.URL, opts...) @@ -269,7 +269,7 @@ func TestTraceAnthropicErr(t *testing.T) { withTracer(tracer), } if tc.bedrock { - opts = append(opts, withProvider(providerBedrock)) + opts = append(opts, withProvider(providerBedrockAnthropic)) } bridgeServer := newBridgeTestServer(ctx, t, upstream.URL, opts...) @@ -354,7 +354,7 @@ func TestInjectedToolsTrace(t *testing.T) { path: pathAnthropicMessages, expectModel: "beddel", expectProvider: config.ProviderAnthropic, - opts: []bridgeOption{withProvider(providerBedrock)}, + opts: []bridgeOption{withProvider(providerBedrockAnthropic)}, }, { name: "bedrock_streaming", @@ -364,7 +364,7 @@ func TestInjectedToolsTrace(t *testing.T) { path: pathAnthropicMessages, expectModel: "beddel", expectProvider: config.ProviderAnthropic, - opts: []bridgeOption{withProvider(providerBedrock)}, + opts: []bridgeOption{withProvider(providerBedrockAnthropic)}, }, { name: "openai_blocking", diff --git a/provider/bedrock.go b/provider/bedrock.go new file mode 100644 index 00000000..86de2063 --- /dev/null +++ b/provider/bedrock.go @@ -0,0 +1,163 @@ +package provider + +import ( + "fmt" + "io" + "net/http" + "os" + "strings" + + "github.com/google/uuid" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" + "golang.org/x/xerrors" + + "github.com/coder/aibridge/config" + "github.com/coder/aibridge/intercept" + bedrockintercept "github.com/coder/aibridge/intercept/bedrock" + "github.com/coder/aibridge/tracing" +) + +var _ Provider = &Bedrock{} + +// Bedrock is a standalone Bedrock provider that accepts native Bedrock +// API requests and proxies them to AWS with centralized SigV4 signing. +type Bedrock struct { + cfg config.Bedrock +} + +func NewBedrock(cfg config.Bedrock) *Bedrock { + if cfg.Name == "" { + cfg.Name = config.ProviderBedrock + } + if cfg.APIDumpDir == "" { + cfg.APIDumpDir = os.Getenv("BRIDGE_DUMP_DIR") + } + if cfg.CircuitBreaker != nil { + // Bedrock returns Anthropic-compatible errors for Claude models, + // and also returns 429/503 for rate limiting across all models. + cfg.CircuitBreaker.IsFailure = anthropicIsFailure + cfg.CircuitBreaker.OpenErrorResponse = anthropicOpenErrorResponse + } + + return &Bedrock{cfg: cfg} +} + +func (*Bedrock) Type() string { + return config.ProviderBedrock +} + +func (p *Bedrock) Name() string { + return p.cfg.Name +} + +func (p *Bedrock) BaseURL() string { + if p.cfg.AWSBedrock.BaseURL != "" { + return p.cfg.AWSBedrock.BaseURL + } + return fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com", p.cfg.Region) +} + +func (p *Bedrock) RoutePrefix() string { + return fmt.Sprintf("/%s", p.Name()) +} + +// BridgedRoutes returns a prefix pattern that catches all Bedrock +// invoke paths: /model/{modelId}/invoke and +// /model/{modelId}/invoke-with-response-stream. +func (*Bedrock) BridgedRoutes() []string { + return []string{"/model/"} +} + +// PassthroughRoutes returns an empty slice. All Bedrock requests +// require SigV4 signing which cannot be done via simple header +// injection. +func (*Bedrock) PassthroughRoutes() []string { + return nil +} + +func (*Bedrock) AuthHeader() string { + return "Authorization" +} + +// InjectAuthHeader is a no-op for Bedrock. Authentication is handled +// by SigV4 signing inside the interceptor, not via header injection. +func (*Bedrock) InjectAuthHeader(_ *http.Header) {} + +func (p *Bedrock) CircuitBreakerConfig() *config.CircuitBreaker { + return p.cfg.CircuitBreaker +} + +func (p *Bedrock) APIDumpDir() string { + return p.cfg.APIDumpDir +} + +// parseBedrockPath extracts the model ID and streaming flag from a +// Bedrock invoke path. Expected format: +// +// /model/{modelId}/invoke +// /model/{modelId}/invoke-with-response-stream +func parseBedrockPath(path string) (modelID string, streaming bool, err error) { + // path should be like /model/us.anthropic.claude-sonnet-4-6/invoke-with-response-stream + const modelPrefix = "/model/" + if !strings.HasPrefix(path, modelPrefix) { + return "", false, xerrors.Errorf("path does not start with %s: %s", modelPrefix, path) + } + + rest := path[len(modelPrefix):] + // rest = "us.anthropic.claude-sonnet-4-6/invoke-with-response-stream" + + switch { + case strings.HasSuffix(rest, "/invoke-with-response-stream"): + modelID = strings.TrimSuffix(rest, "/invoke-with-response-stream") + streaming = true + case strings.HasSuffix(rest, "/invoke"): + modelID = strings.TrimSuffix(rest, "/invoke") + streaming = false + default: + return "", false, xerrors.Errorf("path does not end with /invoke or /invoke-with-response-stream: %s", path) + } + + if modelID == "" { + return "", false, xerrors.Errorf("empty model ID in path: %s", path) + } + + return modelID, streaming, nil +} + +func (p *Bedrock) CreateInterceptor(_ http.ResponseWriter, r *http.Request, tracer trace.Tracer) (_ intercept.Interceptor, outErr error) { + id := uuid.New() + _, span := tracer.Start(r.Context(), "Intercept.CreateInterceptor") + defer tracing.EndSpanErr(span, &outErr) + + path := strings.TrimPrefix(r.URL.Path, p.RoutePrefix()) + + modelID, streaming, err := parseBedrockPath(path) + if err != nil { + span.SetStatus(codes.Error, "unknown route: "+r.URL.Path) + return nil, ErrUnknownRoute + } + + body, err := io.ReadAll(r.Body) + if err != nil { + return nil, xerrors.Errorf("read body: %w", err) + } + + cred := intercept.NewCredentialInfo(intercept.CredentialKindCentralized, "") + + interceptor := bedrockintercept.NewInterceptor( + id, + modelID, + streaming, + body, + path, + p.cfg.AWSBedrock, + p.Name(), + p.cfg.APIDumpDir, + tracer, + cred, + ) + + span.SetAttributes(interceptor.TraceAttributes(r)...) + return interceptor, nil +}