diff --git a/cli/azd/.vscode/cspell.yaml b/cli/azd/.vscode/cspell.yaml index 2545353fe12..a0ce9e4c7f4 100644 --- a/cli/azd/.vscode/cspell.yaml +++ b/cli/azd/.vscode/cspell.yaml @@ -69,6 +69,10 @@ words: - genproto - errdetails - yarnpkg + - azconfig + - hostnames + - seekable + - seekability languageSettings: - languageId: go ignoreRegExpList: @@ -293,6 +297,9 @@ overrides: - filename: extensions/microsoft.azd.concurx/internal/cmd/prompt_model.go words: - textinput + - filename: pkg/azdext/scope_detector.go + words: + - fakeazure ignorePaths: - "**/*_test.go" - "**/mock*.go" diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/monitor.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/monitor.go index 3e11267778d..671f9ae6b16 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/monitor.go +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/monitor.go @@ -66,7 +66,7 @@ This is useful for troubleshooting agent startup issues or monitoring agent beha action := &MonitorAction{ AgentContext: agentContext, - flags: flags, + flags: flags, } return action.Run(ctx) diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/show.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/show.go index d31e8761b1e..3a4c4e649f5 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/show.go +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/show.go @@ -59,7 +59,7 @@ replica configuration, and any error messages.`, action := &ShowAction{ AgentContext: agentContext, - flags: flags, + flags: flags, } return action.Run(ctx) diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_api/models.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_api/models.go index 3290e100a07..2af66252fb3 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_api/models.go +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_api/models.go @@ -287,7 +287,7 @@ type AgentContainerDetails struct { ProvisioningState string `json:"provisioning_state,omitempty"` State string `json:"state,omitempty"` UpdatedOn string `json:"updated_on,omitempty"` - Replicas []AgentContainerReplicaState `json:"replicas,omitempty"` + Replicas []AgentContainerReplicaState `json:"replicas,omitempty"` } // AgentContainerObject represents the details of an agent container diff --git a/cli/azd/pkg/azdext/mcp_security.go b/cli/azd/pkg/azdext/mcp_security.go index c9c2122eaeb..67072372887 100644 --- a/cli/azd/pkg/azdext/mcp_security.go +++ b/cli/azd/pkg/azdext/mcp_security.go @@ -6,6 +6,7 @@ package azdext import ( "fmt" "net" + "net/http" "net/url" "os" "path/filepath" @@ -25,6 +26,10 @@ type MCPSecurityPolicy struct { blockedHosts map[string]bool // lookupHost is used for DNS resolution; override in tests. lookupHost func(string) ([]string, error) + // onBlocked is an optional callback invoked when a URL or path is blocked. + // Parameters: action ("url_blocked", "path_blocked", "redirect_blocked"), + // detail (human-readable explanation). Safe for concurrent use. + onBlocked func(action, detail string) } // NewMCPSecurityPolicy creates an empty security policy. @@ -111,6 +116,20 @@ func (p *MCPSecurityPolicy) ValidatePathsWithinBase(basePaths ...string) *MCPSec return p } +// OnBlocked registers a callback that is invoked whenever a URL, path, or +// redirect is blocked by the security policy. This enables security audit +// logging without coupling the policy to a specific logging framework. +// +// The callback receives an action tag ("url_blocked", "path_blocked", +// "redirect_blocked") and a human-readable detail string. It must be safe +// for concurrent invocation. +func (p *MCPSecurityPolicy) OnBlocked(fn func(action, detail string)) *MCPSecurityPolicy { + p.mu.Lock() + defer p.mu.Unlock() + p.onBlocked = fn + return p +} + // isLocalhostHost returns true if the host is localhost or a loopback address. func isLocalhostHost(host string) bool { h := strings.ToLower(host) @@ -125,8 +144,20 @@ func isLocalhostHost(host string) bool { // Returns an error describing the violation, or nil if allowed. func (p *MCPSecurityPolicy) CheckURL(rawURL string) error { p.mu.RLock() - defer p.mu.RUnlock() + fn := p.onBlocked + err := p.checkURLCore(rawURL) + p.mu.RUnlock() + + if fn != nil && err != nil { + fn("url_blocked", err.Error()) + } + + return err +} +// checkURLCore performs URL validation without acquiring the lock or invoking +// the onBlocked callback. Callers must hold p.mu (at least RLock). +func (p *MCPSecurityPolicy) checkURLCore(rawURL string) error { u, err := url.Parse(rawURL) if err != nil { return fmt.Errorf("invalid URL: %w", err) @@ -193,55 +224,18 @@ func (p *MCPSecurityPolicy) checkIP(ip net.IP, originalHost string) error { return fmt.Errorf("blocked IP %s (private/loopback/link-local) for host %s", ip, originalHost) } - // Handle encoding variants that Go's net.IP methods don't classify, by extracting - // the embedded IPv4 address and re-checking it against all blocked ranges. - if len(ip) == net.IPv6len && ip.To4() == nil { - // IPv4-compatible (::x.x.x.x, RFC 4291 §2.5.5.1): first 12 bytes are zero. - isV4Compatible := true - for i := 0; i < 12; i++ { - if ip[i] != 0 { - isV4Compatible = false - break + // Handle encoding variants that Go's net.IP methods don't classify, + // by extracting the embedded IPv4 and re-checking it. + if v4 := extractEmbeddedIPv4(ip); v4 != nil { + for _, cidr := range p.blockedCIDRs { + if cidr.Contains(v4) { + return fmt.Errorf("blocked IP %s (embedded %s, CIDR %s) for host %s", + ip, v4, cidr, originalHost) } } - if isV4Compatible && (ip[12] != 0 || ip[13] != 0 || ip[14] != 0 || ip[15] != 0) { - v4 := net.IPv4(ip[12], ip[13], ip[14], ip[15]) - for _, cidr := range p.blockedCIDRs { - if cidr.Contains(v4) { - return fmt.Errorf("blocked IP %s (IPv4-compatible %s, CIDR %s) for host %s", - ip, v4, cidr, originalHost) - } - } - if v4.IsLoopback() || v4.IsPrivate() || v4.IsLinkLocalUnicast() || v4.IsUnspecified() { - return fmt.Errorf("blocked IP %s (IPv4-compatible %s, private/loopback) for host %s", - ip, v4, originalHost) - } - } - - // IPv4-translated (::ffff:0:x.x.x.x, RFC 2765 §4.2.1): bytes 0-7 zero, - // bytes 8-9 = 0xFF 0xFF, bytes 10-11 = 0x00 0x00, bytes 12-15 = IPv4. - // Distinct from IPv4-mapped (bytes 10-11 = 0xFF), so To4() returns nil. - isV4Translated := ip[8] == 0xFF && ip[9] == 0xFF && ip[10] == 0x00 && ip[11] == 0x00 - if isV4Translated { - for i := 0; i < 8; i++ { - if ip[i] != 0 { - isV4Translated = false - break - } - } - } - if isV4Translated && (ip[12] != 0 || ip[13] != 0 || ip[14] != 0 || ip[15] != 0) { - v4 := net.IPv4(ip[12], ip[13], ip[14], ip[15]) - for _, cidr := range p.blockedCIDRs { - if cidr.Contains(v4) { - return fmt.Errorf("blocked IP %s (IPv4-translated %s, CIDR %s) for host %s", - ip, v4, cidr, originalHost) - } - } - if v4.IsLoopback() || v4.IsPrivate() || v4.IsLinkLocalUnicast() || v4.IsUnspecified() { - return fmt.Errorf("blocked IP %s (IPv4-translated %s, private/loopback) for host %s", - ip, v4, originalHost) - } + if v4.IsLoopback() || v4.IsPrivate() || v4.IsLinkLocalUnicast() || v4.IsUnspecified() { + return fmt.Errorf("blocked IP %s (embedded %s, private/loopback) for host %s", + ip, v4, originalHost) } } } @@ -251,10 +245,37 @@ func (p *MCPSecurityPolicy) checkIP(ip net.IP, originalHost string) error { // CheckPath validates a file path against the security policy. // Resolves symlinks and checks for directory traversal. +// +// Security note (TOCTOU): There is an inherent time-of-check to time-of-use +// gap between the symlink resolution performed here and the caller's +// subsequent file operation. An adversary with write access to the filesystem +// could create or modify a symlink between the check and the use. This is a +// fundamental limitation of path-based validation on POSIX systems. +// +// Mitigations callers should consider: +// - Use O_NOFOLLOW when opening files after validation (prevents symlink +// following at the final component). +// - Use file-descriptor-based approaches (openat2 with RESOLVE_BENEATH on +// Linux 5.6+) where possible. +// - Avoid writing to directories that untrusted users can modify. +// - Consider validating the opened fd's path post-open via /proc/self/fd/N +// or fstat. func (p *MCPSecurityPolicy) CheckPath(path string) error { p.mu.RLock() - defer p.mu.RUnlock() + fn := p.onBlocked + err := p.checkPathCore(path) + p.mu.RUnlock() + + if fn != nil && err != nil { + fn("path_blocked", err.Error()) + } + + return err +} +// checkPathCore performs path validation without acquiring the lock or invoking +// the onBlocked callback. Callers must hold p.mu (at least RLock). +func (p *MCPSecurityPolicy) checkPathCore(path string) error { if len(p.allowedBasePaths) == 0 { return nil } @@ -348,3 +369,140 @@ func resolveExistingPrefix(p string) string { } } } + +// --------------------------------------------------------------------------- +// Redirect SSRF protection +// --------------------------------------------------------------------------- + +// redirectBlockedHosts lists cloud metadata service endpoints that must never +// be the target of an HTTP redirect. +var redirectBlockedHosts = map[string]bool{ + "169.254.169.254": true, + "fd00:ec2::254": true, + "metadata.google.internal": true, + "100.100.100.200": true, +} + +// SSRFSafeRedirect is an [http.Client] CheckRedirect function that blocks +// redirects to private networks and cloud metadata endpoints. It prevents +// redirect-based SSRF attacks where an attacker-controlled URL redirects to +// an internal service. +// +// Usage: +// +// client := &http.Client{CheckRedirect: azdext.SSRFSafeRedirect} +func SSRFSafeRedirect(req *http.Request, via []*http.Request) error { + const maxRedirects = 10 + if len(via) >= maxRedirects { + return fmt.Errorf("stopped after %d redirects", maxRedirects) + } + + // Block HTTPS → HTTP scheme downgrades to prevent leaking + // Authorization headers (including Bearer tokens) in cleartext. + // Go's net/http preserves headers on same-host redirects regardless + // of scheme change. + if len(via) > 0 && via[len(via)-1].URL.Scheme == "https" && req.URL.Scheme != "https" { + return fmt.Errorf( + "redirect from HTTPS to %s blocked (credential protection)", req.URL.Scheme) + } + + host := req.URL.Hostname() + + // Block redirects to known metadata endpoints. + if redirectBlockedHosts[strings.ToLower(host)] { + return fmt.Errorf("redirect to metadata endpoint %s blocked (SSRF protection)", host) + } + + // Block redirects to localhost hostnames (e.g. "localhost", + // "127.0.0.1") regardless of how they are spelled, preventing + // hostname-based SSRF bypasses of the IP-literal checks below. + if isLocalhostHost(host) { + return fmt.Errorf("redirect to localhost %s blocked (SSRF protection)", host) + } + + // Block redirects to private/loopback IP addresses, including + // IPv4-compatible and IPv4-translated IPv6 encoding variants + // that bypass Go's IsPrivate()/IsLoopback() classification. + if ip := net.ParseIP(host); ip != nil { + if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || ip.IsUnspecified() { + return fmt.Errorf("redirect to private/loopback IP %s blocked (SSRF protection)", ip) + } + + // Check IPv6 encoding variants (IPv4-compatible, IPv4-translated) + // that embed private IPv4 addresses but aren't caught by Go's + // net.IP classifier methods. + if err := checkIPEncodingVariants(ip, host); err != nil { + return err + } + } + + return nil +} + +// checkIPEncodingVariants detects IPv4-compatible (::x.x.x.x) and +// IPv4-translated (::ffff:0:x.x.x.x) IPv6 addresses that embed +// private IPv4 addresses but bypass Go's IsPrivate()/IsLoopback(). +func checkIPEncodingVariants(ip net.IP, originalHost string) error { + v4 := extractEmbeddedIPv4(ip) + if v4 == nil { + return nil + } + + if v4.IsLoopback() || v4.IsPrivate() || v4.IsLinkLocalUnicast() || v4.IsUnspecified() { + return fmt.Errorf( + "redirect to embedded IPv4 address %s (embedded %s) blocked (SSRF protection)", + ip, v4) + } + + return nil +} + +// extractEmbeddedIPv4 returns the embedded IPv4 address from IPv4-compatible +// (::x.x.x.x, RFC 4291 §2.5.5.1) or IPv4-translated (::ffff:0:x.x.x.x, +// RFC 2765 §4.2.1) IPv6 encodings. Returns nil if the address is not one of +// these encoding variants. +// +// This handles addresses that Go's net.IP.To4() does not classify as IPv4 +// (To4 returns nil for these), which means Go's IsPrivate()/IsLoopback() +// methods also return false for them. +func extractEmbeddedIPv4(ip net.IP) net.IP { + if len(ip) != net.IPv6len || ip.To4() != nil { + return nil // Not a pure IPv6 address or already handled as IPv4-mapped + } + + // Check if last 4 bytes are non-zero (otherwise it's just :: which is + // already handled by IsUnspecified). + if ip[12] == 0 && ip[13] == 0 && ip[14] == 0 && ip[15] == 0 { + return nil + } + + // IPv4-compatible (::x.x.x.x): first 12 bytes are zero. + isV4Compatible := true + for i := 0; i < 12; i++ { + if ip[i] != 0 { + isV4Compatible = false + break + } + } + if isV4Compatible { + return net.IPv4(ip[12], ip[13], ip[14], ip[15]) + } + + // IPv4-translated (::ffff:0:x.x.x.x, RFC 2765): bytes 0-7 zero, + // bytes 8-9 = 0xFF 0xFF, bytes 10-11 = 0x00 0x00, bytes 12-15 = IPv4. + // Distinct from IPv4-mapped (bytes 10-11 = 0xFF), so To4() returns nil. + if ip[8] == 0xFF && ip[9] == 0xFF && ip[10] == 0x00 && ip[11] == 0x00 { + allZero := true + for i := 0; i < 8; i++ { + if ip[i] != 0 { + allZero = false + break + } + } + if allZero { + return net.IPv4(ip[12], ip[13], ip[14], ip[15]) + } + } + + return nil +} diff --git a/cli/azd/pkg/azdext/mcp_security_test.go b/cli/azd/pkg/azdext/mcp_security_test.go index 6cc0d1d44da..6507e279dab 100644 --- a/cli/azd/pkg/azdext/mcp_security_test.go +++ b/cli/azd/pkg/azdext/mcp_security_test.go @@ -5,6 +5,9 @@ package azdext import ( "fmt" + "net" + "net/http" + "net/url" "os" "path/filepath" "strings" @@ -335,3 +338,196 @@ func TestMCPSecurityFluentBuilder(t *testing.T) { t.Errorf("expected 1 base path, got %d", len(policy.allowedBasePaths)) } } + +func TestSSRFSafeRedirect_SchemeDowngrade(t *testing.T) { + t.Parallel() + + // Simulate HTTPS → HTTP redirect (credential leak vector). + via := []*http.Request{ + {URL: &url.URL{Scheme: "https", Host: "example.com", Path: "/api"}}, + } + req := &http.Request{ + URL: &url.URL{Scheme: "http", Host: "example.com", Path: "/api"}, + } + + err := SSRFSafeRedirect(req, via) + if err == nil { + t.Fatal("expected error for HTTPS → HTTP redirect (credential protection)") + } + if !strings.Contains(err.Error(), "credential protection") { + t.Errorf("error = %q, want mention of credential protection", err.Error()) + } +} + +func TestSSRFSafeRedirect_HTTPToHTTPAllowed(t *testing.T) { + t.Parallel() + + // HTTP → HTTP redirect (no downgrade) should be allowed. + via := []*http.Request{ + {URL: &url.URL{Scheme: "http", Host: "example.com", Path: "/api"}}, + } + req := &http.Request{ + URL: &url.URL{Scheme: "http", Host: "example.com", Path: "/other"}, + } + + err := SSRFSafeRedirect(req, via) + if err != nil { + t.Errorf("HTTP → HTTP redirect should be allowed, got: %v", err) + } +} + +func TestSSRFSafeRedirect_LocalhostHostname(t *testing.T) { + t.Parallel() + + // Redirect to "localhost" hostname should be blocked. + req := &http.Request{ + URL: &url.URL{Scheme: "http", Host: "localhost:8080", Path: "/steal"}, + } + + err := SSRFSafeRedirect(req, nil) + if err == nil { + t.Fatal("expected error for redirect to localhost hostname") + } + if !strings.Contains(err.Error(), "localhost") { + t.Errorf("error = %q, want mention of localhost", err.Error()) + } +} + +func TestSSRFSafeRedirect_IPv4CompatiblePrivate(t *testing.T) { + t.Parallel() + + // Redirect to IPv4-compatible IPv6 embedding private IP. + req := &http.Request{ + URL: &url.URL{Scheme: "http", Host: "[::10.0.0.1]", Path: "/steal"}, + } + + err := SSRFSafeRedirect(req, nil) + if err == nil { + t.Fatal("expected error for redirect to IPv4-compatible private address") + } + if !strings.Contains(err.Error(), "SSRF") { + t.Errorf("error = %q, want mention of SSRF", err.Error()) + } +} + +func TestMCPSecurityOnBlocked_URLCallback(t *testing.T) { + t.Parallel() + + var ( + gotAction string + gotDetail string + callCount int + ) + + policy := NewMCPSecurityPolicy(). + RequireHTTPS(). + OnBlocked(func(action, detail string) { + gotAction = action + gotDetail = detail + callCount++ + }) + + // This should trigger the callback: HTTP to non-localhost host. + err := policy.CheckURL("http://example.com/api") + if err == nil { + t.Fatal("expected error for HTTP URL with HTTPS required") + } + + if callCount != 1 { + t.Errorf("callCount = %d, want 1", callCount) + } + if gotAction != "url_blocked" { + t.Errorf("action = %q, want %q", gotAction, "url_blocked") + } + if !strings.Contains(gotDetail, "HTTPS required") { + t.Errorf("detail = %q, want to contain %q", gotDetail, "HTTPS required") + } +} + +func TestMCPSecurityOnBlocked_PathCallback(t *testing.T) { + t.Parallel() + + var gotAction string + + base := t.TempDir() + outside := t.TempDir() + outsideFile := filepath.Join(outside, "secret.txt") + if err := os.WriteFile(outsideFile, []byte("secret"), 0o600); err != nil { + t.Fatal(err) + } + + policy := NewMCPSecurityPolicy(). + ValidatePathsWithinBase(base). + OnBlocked(func(action, detail string) { + gotAction = action + }) + + err := policy.CheckPath(outsideFile) + if err == nil { + t.Fatal("expected error for path outside base") + } + + if gotAction != "path_blocked" { + t.Errorf("action = %q, want %q", gotAction, "path_blocked") + } +} + +func TestExtractEmbeddedIPv4(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + ip net.IP + wantV4 net.IP + }{ + { + name: "IPv4-compatible private", + ip: net.ParseIP("::10.0.0.1"), + wantV4: net.IPv4(10, 0, 0, 1), + }, + { + name: "IPv4-compatible loopback", + ip: net.ParseIP("::127.0.0.1"), + wantV4: net.IPv4(127, 0, 0, 1), + }, + { + name: "IPv4-translated private", + ip: net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 0, 0, 10, 0, 0, 1}, + wantV4: net.IPv4(10, 0, 0, 1), + }, + { + name: "IPv4-mapped (handled by To4)", + ip: net.ParseIP("::ffff:10.0.0.1"), + wantV4: nil, // To4() != nil, so extractEmbeddedIPv4 returns nil + }, + { + name: "public IPv6", + ip: net.ParseIP("2607:f8b0:4004:800::200e"), + wantV4: nil, + }, + { + name: "pure IPv4", + ip: net.ParseIP("10.0.0.1"), + wantV4: nil, // len != IPv6len, returns nil + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got := extractEmbeddedIPv4(tt.ip) + if tt.wantV4 == nil { + if got != nil { + t.Errorf("extractEmbeddedIPv4(%s) = %s, want nil", tt.ip, got) + } + } else { + if got == nil { + t.Errorf("extractEmbeddedIPv4(%s) = nil, want %s", tt.ip, tt.wantV4) + } else if !got.Equal(tt.wantV4) { + t.Errorf("extractEmbeddedIPv4(%s) = %s, want %s", tt.ip, got, tt.wantV4) + } + } + }) + } +} diff --git a/cli/azd/pkg/azdext/pagination.go b/cli/azd/pkg/azdext/pagination.go new file mode 100644 index 00000000000..5db6262ef61 --- /dev/null +++ b/cli/azd/pkg/azdext/pagination.go @@ -0,0 +1,342 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strings" +) + +const ( + // defaultMaxPages is the default upper bound on pages fetched by [Pager.Collect]. + // Individual callers can override this via [PagerOptions.MaxPages]. + // A value of 0 means unlimited (no cap), which is the default for manual + // NextPage iteration. Collect uses this default when MaxPages is unset. + defaultMaxPages = 500 +) + +const ( + // maxPageResponseSize limits the maximum size of a single page response + // body to prevent excessive memory consumption from malicious or + // misconfigured servers. 10 MB is intentionally above typical Azure list + // payloads while still bounding memory use. + maxPageResponseSize int64 = 10 << 20 // 10 MB + + // maxErrorBodySize limits the size of error response bodies captured + // for diagnostic purposes. + maxErrorBodySize int64 = 64 << 10 // 64 KB +) + +// Pager provides a generic, lazy iterator over paginated Azure REST API +// responses that use the standard { value, nextLink } pattern. +// +// Usage: +// +// pager := azdext.NewPager[MyItem](client, firstURL, nil) +// for pager.More() { +// page, err := pager.NextPage(ctx) +// if err != nil { ... } +// for _, item := range page.Value { +// // process item +// } +// } +type Pager[T any] struct { + client HTTPDoer + nextURL string + done bool + opts PagerOptions + originHost string // host of the initial URL for SSRF protection + pageCount int // number of pages fetched so far + truncated bool +} + +// PageResponse is a single page returned by [Pager.NextPage]. +type PageResponse[T any] struct { + // Value contains the items for this page. + Value []T `json:"value"` + + // NextLink is the URL to the next page, or empty if this is the last page. + NextLink string `json:"nextLink,omitempty"` +} + +// PagerOptions configures a [Pager]. +type PagerOptions struct { + // Method overrides the HTTP method used for page requests. Defaults to GET. + Method string + + // MaxPages limits the maximum number of pages that [Pager.Collect] will + // fetch. When set to a positive value, Collect stops after fetching that + // many pages. A value of 0 means unlimited (no cap) for manual NextPage + // iteration; Collect applies [defaultMaxPages] when this is 0. + MaxPages int + + // MaxItems limits the maximum total items that [Pager.Collect] will + // accumulate. When the collected items reach this count, Collect stops + // and returns the items gathered so far (truncated to MaxItems). + // A value of 0 means unlimited (no cap). + MaxItems int +} + +// HTTPDoer abstracts the HTTP call so that [ResilientClient] or any +// *http.Client can power pagination. +type HTTPDoer interface { + Do(ctx context.Context, method, url string, body io.Reader) (*http.Response, error) +} + +// stdHTTPDoer wraps *http.Client to satisfy HTTPDoer. +type stdHTTPDoer struct { + client *http.Client +} + +func (s *stdHTTPDoer) Do(ctx context.Context, method, url string, body io.Reader) (*http.Response, error) { + req, err := http.NewRequestWithContext(ctx, method, url, body) + if err != nil { + return nil, err + } + + return s.client.Do(req) +} + +// NewPager creates a [Pager] that iterates over a paginated endpoint. +// +// client may be a [*ResilientClient] or any type satisfying [HTTPDoer]. +// firstURL is the initial page URL. +func NewPager[T any](client HTTPDoer, firstURL string, opts *PagerOptions) *Pager[T] { + if opts == nil { + opts = &PagerOptions{} + } + + if opts.Method == "" { + opts.Method = http.MethodGet + } + + var originHost string + if u, err := url.Parse(firstURL); err == nil { + originHost = strings.ToLower(u.Hostname()) + } + + return &Pager[T]{ + client: client, + nextURL: firstURL, + opts: *opts, + originHost: originHost, + } +} + +// NewPagerFromHTTPClient creates a [Pager] backed by a standard [*http.Client]. +func NewPagerFromHTTPClient[T any](client *http.Client, firstURL string, opts *PagerOptions) *Pager[T] { + return NewPager[T](&stdHTTPDoer{client: client}, firstURL, opts) +} + +// More reports whether there are more pages to fetch. +func (p *Pager[T]) More() bool { + return !p.done && p.nextURL != "" +} + +// Truncated reports whether the last [Collect] call stopped early due to +// MaxPages or MaxItems limits. This allows callers to detect truncation +// without a breaking API change (Collect still returns ([]T, nil) on +// successful truncation). +func (p *Pager[T]) Truncated() bool { + return p.truncated +} + +// NextPage fetches the next page of results. Returns an error if the request +// fails, the response is not 2xx, or the body cannot be decoded. +// +// Response bodies are bounded to [maxPageResponseSize] to prevent +// excessive memory consumption. nextLink URLs are validated to prevent +// SSRF attacks (must stay on the same host with HTTPS). +// +// After the last page is consumed, [More] returns false. +func (p *Pager[T]) NextPage(ctx context.Context) (*PageResponse[T], error) { + if !p.More() { + return nil, errors.New("azdext.Pager.NextPage: no more pages") + } + + if p.client == nil { + return nil, errors.New("azdext.Pager.NextPage: client must not be nil") + } + + resp, err := p.client.Do(ctx, p.opts.Method, p.nextURL, nil) + if err != nil { + return nil, fmt.Errorf("azdext.Pager.NextPage: request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + body, _ := io.ReadAll(io.LimitReader(resp.Body, maxErrorBodySize)) + return nil, &PaginationError{ + StatusCode: resp.StatusCode, + URL: p.nextURL, + Body: sanitizeErrorBody(string(body)), + } + } + + data, err := io.ReadAll(io.LimitReader(resp.Body, maxPageResponseSize)) + if err != nil { + return nil, fmt.Errorf("azdext.Pager.NextPage: failed to read response: %w", err) + } + + var page PageResponse[T] + if err := json.Unmarshal(data, &page); err != nil { + return nil, fmt.Errorf("azdext.Pager.NextPage: failed to decode response: %w", err) + } + + if page.NextLink == "" { + p.done = true + p.nextURL = "" + } else if err := p.validateNextLink(page.NextLink); err != nil { + p.done = true + p.nextURL = "" + return &page, fmt.Errorf("azdext.Pager.NextPage: %w", err) + } else { + p.nextURL = page.NextLink + } + + // Track page count for MaxPages enforcement in Collect. + p.pageCount++ + + return &page, nil +} + +// validateNextLink checks that a nextLink URL is safe to follow. +// It rejects non-HTTPS schemes, URLs with embedded credentials, and +// URLs pointing to a different host than the original request (SSRF protection). +func (p *Pager[T]) validateNextLink(nextLink string) error { + u, err := url.Parse(nextLink) + if err != nil { + return fmt.Errorf("invalid nextLink URL: %w", err) + } + + if u.Scheme != "" && u.Scheme != "https" { + return fmt.Errorf("nextLink must use HTTPS (got %q)", u.Scheme) + } + + if u.User != nil { + return errors.New("nextLink must not contain user credentials") + } + + host := strings.ToLower(u.Hostname()) + if host != "" && p.originHost != "" && host != p.originHost { + return fmt.Errorf("nextLink host %q does not match origin host %q (possible SSRF)", host, p.originHost) + } + + return nil +} + +// Collect is a convenience method that fetches all remaining pages and +// returns all items in a single slice. +// +// To prevent unbounded memory growth from runaway pagination, Collect +// enforces [PagerOptions.MaxPages] (defaults to [defaultMaxPages] when +// unset) and [PagerOptions.MaxItems]. When either limit is reached, +// iteration stops and the items collected so far are returned. +// +// If NextPage returns both page data and an error (e.g. rejected nextLink), +// the page data is included in the returned slice before returning the error. +func (p *Pager[T]) Collect(ctx context.Context) ([]T, error) { + var all []T + p.truncated = false + + maxPages := p.opts.MaxPages + if maxPages <= 0 { + maxPages = defaultMaxPages + } + + for p.More() { + page, err := p.NextPage(ctx) + if page != nil { + all = append(all, page.Value...) + } + if err != nil { + return all, err + } + + // Enforce MaxItems: truncate and stop if exceeded. + if p.opts.MaxItems > 0 && len(all) >= p.opts.MaxItems { + if len(all) > p.opts.MaxItems { + all = all[:p.opts.MaxItems] + } + p.truncated = true + break + } + + // Enforce MaxPages: stop after collecting the configured number of pages. + if p.pageCount >= maxPages { + p.truncated = true + break + } + } + + return all, nil +} + +// maxPaginationErrorBodyLen limits the response body length stored in +// PaginationError to prevent sensitive data leakage through error messages. +// Response bodies from non-2xx pages may contain credentials, tokens, or +// other secrets embedded by the upstream service. +const maxPaginationErrorBodyLen = 1024 + +// PaginationError is returned when a page request receives a non-2xx response. +type PaginationError struct { + StatusCode int + URL string + // Body is a truncated, sanitized excerpt of the error response body for + // diagnostics. It is capped at [maxPaginationErrorBodyLen] bytes and + // stripped of control characters to prevent log forging. + Body string +} + +func (e *PaginationError) Error() string { + return fmt.Sprintf( + "azdext.Pager: page request returned HTTP %d (url=%s)", + e.StatusCode, redactURL(e.URL), + ) +} + +// sanitizeErrorBody truncates and strips control characters from an error +// response body to prevent log forging and sensitive data leakage. +func sanitizeErrorBody(body string) string { + if len(body) > maxPaginationErrorBodyLen { + body = body[:maxPaginationErrorBodyLen] + "...[truncated]" + } + return stripControlChars(body) +} + +// stripControlChars replaces ASCII control characters (except tab) with a +// space to prevent log forging via CR/LF injection or terminal escape +// sequences. Tab (0x09) is preserved as it appears in legitimate JSON. +func stripControlChars(s string) string { + var b strings.Builder + b.Grow(len(s)) + for _, r := range s { + if r < 0x20 && r != '\t' { + b.WriteRune(' ') + } else if r == 0x7F { + b.WriteRune(' ') + } else { + b.WriteRune(r) + } + } + return b.String() +} + +// redactURL strips query parameters and fragments from a URL to avoid leaking +// tokens, SAS signatures, or other secrets in log/error messages. +func redactURL(rawURL string) string { + u, err := url.Parse(rawURL) + if err != nil { + return "" + } + u.RawQuery = "" + u.Fragment = "" + return u.String() +} diff --git a/cli/azd/pkg/azdext/pagination_test.go b/cli/azd/pkg/azdext/pagination_test.go new file mode 100644 index 00000000000..8f8175dcd65 --- /dev/null +++ b/cli/azd/pkg/azdext/pagination_test.go @@ -0,0 +1,659 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + "testing" +) + +// mockDoer is a test double for HTTPDoer. +type mockDoer struct { + responses []*doerResponse + calls int +} + +type doerResponse struct { + resp *http.Response + err error +} + +func (m *mockDoer) Do(_ context.Context, _, _ string, _ io.Reader) (*http.Response, error) { + if m.calls >= len(m.responses) { + return nil, errors.New("no more mock responses") + } + + r := m.responses[m.calls] + m.calls++ + + return r.resp, r.err +} + +// pageJSON builds a PageResponse JSON body. +func pageJSON[T any](value []T, nextLink string) string { + page := PageResponse[T]{Value: value, NextLink: nextLink} + data, _ := json.Marshal(page) + return string(data) +} + +func TestPager_SinglePage(t *testing.T) { + t.Parallel() + + body := pageJSON([]string{"a", "b", "c"}, "") + + doer := &mockDoer{ + responses: []*doerResponse{ + {resp: &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(body)), + Header: http.Header{}, + }}, + }, + } + + pager := NewPager[string](doer, "https://example.com/api?page=1", nil) + + if !pager.More() { + t.Fatal("expected More() = true before first page") + } + + page, err := pager.NextPage(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(page.Value) != 3 { + t.Fatalf("len(Value) = %d, want 3", len(page.Value)) + } + + if page.Value[0] != "a" || page.Value[1] != "b" || page.Value[2] != "c" { + t.Errorf("Value = %v, want [a b c]", page.Value) + } + + if pager.More() { + t.Error("expected More() = false after last page") + } +} + +func TestPager_MultiplePages(t *testing.T) { + t.Parallel() + + page1 := pageJSON([]int{1, 2}, "https://example.com/api?page=2") + page2 := pageJSON([]int{3, 4}, "https://example.com/api?page=3") + page3 := pageJSON([]int{5}, "") + + doer := &mockDoer{ + responses: []*doerResponse{ + {resp: &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(page1)), + Header: http.Header{}, + }}, + {resp: &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(page2)), + Header: http.Header{}, + }}, + {resp: &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(page3)), + Header: http.Header{}, + }}, + }, + } + + pager := NewPager[int](doer, "https://example.com/api?page=1", nil) + + all, err := pager.Collect(context.Background()) + if err != nil { + t.Fatalf("Collect failed: %v", err) + } + + if len(all) != 5 { + t.Fatalf("len(all) = %d, want 5", len(all)) + } + + for i, want := range []int{1, 2, 3, 4, 5} { + if all[i] != want { + t.Errorf("all[%d] = %d, want %d", i, all[i], want) + } + } +} + +func TestPager_EmptyPage(t *testing.T) { + t.Parallel() + + body := pageJSON([]string{}, "") + + doer := &mockDoer{ + responses: []*doerResponse{ + {resp: &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(body)), + Header: http.Header{}, + }}, + }, + } + + pager := NewPager[string](doer, "https://example.com/api", nil) + + page, err := pager.NextPage(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(page.Value) != 0 { + t.Errorf("len(Value) = %d, want 0", len(page.Value)) + } + + if pager.More() { + t.Error("expected More() = false after empty last page") + } +} + +func TestPager_HTTPError(t *testing.T) { + t.Parallel() + + doer := &mockDoer{ + responses: []*doerResponse{ + {resp: &http.Response{ + StatusCode: http.StatusForbidden, + Body: io.NopCloser(strings.NewReader(`{"error":"forbidden"}`)), + Header: http.Header{}, + }}, + }, + } + + pager := NewPager[string](doer, "https://example.com/api", nil) + + _, err := pager.NextPage(context.Background()) + if err == nil { + t.Fatal("expected error for HTTP 403") + } + + var pagErr *PaginationError + if !errors.As(err, &pagErr) { + t.Fatalf("error type = %T, want *PaginationError", err) + } + + if pagErr.StatusCode != http.StatusForbidden { + t.Errorf("StatusCode = %d, want %d", pagErr.StatusCode, http.StatusForbidden) + } +} + +func TestPager_NetworkError(t *testing.T) { + t.Parallel() + + doer := &mockDoer{ + responses: []*doerResponse{ + {err: errors.New("connection reset")}, + }, + } + + pager := NewPager[string](doer, "https://example.com/api", nil) + + _, err := pager.NextPage(context.Background()) + if err == nil { + t.Fatal("expected error for network failure") + } +} + +func TestPager_InvalidJSON(t *testing.T) { + t.Parallel() + + doer := &mockDoer{ + responses: []*doerResponse{ + {resp: &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("not json")), + Header: http.Header{}, + }}, + }, + } + + pager := NewPager[string](doer, "https://example.com/api", nil) + + _, err := pager.NextPage(context.Background()) + if err == nil { + t.Fatal("expected error for invalid JSON") + } +} + +func TestPager_NoMorePages(t *testing.T) { + t.Parallel() + + body := pageJSON([]string{"x"}, "") + + doer := &mockDoer{ + responses: []*doerResponse{ + {resp: &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(body)), + Header: http.Header{}, + }}, + }, + } + + pager := NewPager[string](doer, "https://example.com/api", nil) + + _, _ = pager.NextPage(context.Background()) + + _, err := pager.NextPage(context.Background()) + if err == nil { + t.Fatal("expected error when calling NextPage after last page") + } +} + +func TestPager_EmptyFirstURL(t *testing.T) { + t.Parallel() + + doer := &mockDoer{} + pager := NewPager[string](doer, "", nil) + + if pager.More() { + t.Error("expected More() = false for empty initial URL") + } +} + +type testStruct struct { + Name string `json:"name"` + Count int `json:"count"` +} + +func TestPager_StructType(t *testing.T) { + t.Parallel() + + items := []testStruct{ + {Name: "alpha", Count: 1}, + {Name: "beta", Count: 2}, + } + + body := pageJSON(items, "") + + doer := &mockDoer{ + responses: []*doerResponse{ + {resp: &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(body)), + Header: http.Header{}, + }}, + }, + } + + pager := NewPager[testStruct](doer, "https://example.com/api", nil) + + page, err := pager.NextPage(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(page.Value) != 2 { + t.Fatalf("len(Value) = %d, want 2", len(page.Value)) + } + + if page.Value[0].Name != "alpha" || page.Value[0].Count != 1 { + t.Errorf("Value[0] = %+v, want {alpha 1}", page.Value[0]) + } + + if page.Value[1].Name != "beta" || page.Value[1].Count != 2 { + t.Errorf("Value[1] = %+v, want {beta 2}", page.Value[1]) + } +} + +func TestPager_CollectPartialError(t *testing.T) { + t.Parallel() + + page1 := pageJSON([]string{"a"}, "https://example.com/api?page=2") + + doer := &mockDoer{ + responses: []*doerResponse{ + {resp: &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(page1)), + Header: http.Header{}, + }}, + {err: errors.New("network timeout")}, + }, + } + + pager := NewPager[string](doer, "https://example.com/api?page=1", nil) + + all, err := pager.Collect(context.Background()) + if err == nil { + t.Fatal("expected error from second page") + } + + // Should still return items collected before the error. + if len(all) != 1 { + t.Errorf("len(all) = %d, want 1 (partial results before error)", len(all)) + } + + if all[0] != "a" { + t.Errorf("all[0] = %q, want %q", all[0], "a") + } +} + +func TestNewPagerFromHTTPClient(t *testing.T) { + t.Parallel() + + // Just test that the constructor works; actual HTTP calls tested above. + pager := NewPagerFromHTTPClient[string](http.DefaultClient, "https://example.com/api", nil) + if pager == nil { + t.Fatal("NewPagerFromHTTPClient returned nil") + } +} + +func TestPager_NilClient(t *testing.T) { + t.Parallel() + + pager := NewPager[string](nil, "https://example.com/api", nil) + _, err := pager.NextPage(context.Background()) + if err == nil { + t.Fatal("expected error for nil client") + } + + if !strings.Contains(err.Error(), "client must not be nil") { + t.Errorf("error = %q, want mention of nil client", err.Error()) + } +} + +func TestPager_NextLinkSSRF_DifferentHost(t *testing.T) { + t.Parallel() + + page1 := pageJSON([]string{"a"}, "https://evil.com/steal-data") + + doer := &mockDoer{ + responses: []*doerResponse{ + {resp: &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(page1)), + Header: http.Header{}, + }}, + }, + } + + pager := NewPager[string](doer, "https://example.com/api", nil) + + page, err := pager.NextPage(context.Background()) + if err == nil { + t.Fatal("expected error for nextLink to different host") + } + + // Page data should still be returned despite the nextLink error. + if page == nil || len(page.Value) != 1 || page.Value[0] != "a" { + t.Errorf("expected valid page data despite nextLink error, got %+v", page) + } + + if !strings.Contains(err.Error(), "SSRF") { + t.Errorf("error = %q, want mention of SSRF", err.Error()) + } + + // Pager should be done (won't follow malicious link). + if pager.More() { + t.Error("expected More() = false after nextLink rejection") + } +} + +func TestPager_NextLinkHTTP(t *testing.T) { + t.Parallel() + + page1 := pageJSON([]string{"a"}, "http://example.com/page2") + + doer := &mockDoer{ + responses: []*doerResponse{ + {resp: &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(page1)), + Header: http.Header{}, + }}, + }, + } + + pager := NewPager[string](doer, "https://example.com/api", nil) + + _, err := pager.NextPage(context.Background()) + if err == nil { + t.Fatal("expected error for HTTP nextLink") + } + + if !strings.Contains(err.Error(), "HTTPS") { + t.Errorf("error = %q, want mention of HTTPS", err.Error()) + } +} + +func TestPager_NextLinkUserCredentials(t *testing.T) { + t.Parallel() + + page1 := pageJSON([]string{"a"}, "https://user:pass@example.com/page2") + + doer := &mockDoer{ + responses: []*doerResponse{ + {resp: &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(page1)), + Header: http.Header{}, + }}, + }, + } + + pager := NewPager[string](doer, "https://example.com/api", nil) + + _, err := pager.NextPage(context.Background()) + if err == nil { + t.Fatal("expected error for nextLink with user credentials") + } + + if !strings.Contains(err.Error(), "credentials") { + t.Errorf("error = %q, want mention of credentials", err.Error()) + } +} + +func TestPager_CollectWithSSRFError(t *testing.T) { + t.Parallel() + + page1 := pageJSON([]string{"a", "b"}, "https://evil.com/steal") + + doer := &mockDoer{ + responses: []*doerResponse{ + {resp: &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(page1)), + Header: http.Header{}, + }}, + }, + } + + pager := NewPager[string](doer, "https://example.com/api", nil) + + all, err := pager.Collect(context.Background()) + if err == nil { + t.Fatal("expected error from SSRF nextLink") + } + + // Collect should still return the items from the valid page. + if len(all) != 2 || all[0] != "a" || all[1] != "b" { + t.Errorf("all = %v, want [a b] (partial results before SSRF error)", all) + } +} + +func TestPager_CollectMaxPages(t *testing.T) { + t.Parallel() + + // Build 5 pages; set MaxPages to 3. + var responses []*doerResponse + for i := 1; i <= 5; i++ { + nextLink := "" + if i < 5 { + nextLink = fmt.Sprintf("https://example.com/api?page=%d", i+1) + } + body := pageJSON([]int{i}, nextLink) + responses = append(responses, &doerResponse{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(body)), + Header: http.Header{}, + }, + }) + } + + doer := &mockDoer{responses: responses} + pager := NewPager[int](doer, "https://example.com/api?page=1", &PagerOptions{MaxPages: 3}) + + all, err := pager.Collect(context.Background()) + if err != nil { + t.Fatalf("Collect failed: %v", err) + } + + if len(all) != 3 { + t.Errorf("len(all) = %d, want 3 (MaxPages=3)", len(all)) + } + for i, want := range []int{1, 2, 3} { + if all[i] != want { + t.Errorf("all[%d] = %d, want %d", i, all[i], want) + } + } +} + +func TestPager_CollectMaxItems(t *testing.T) { + t.Parallel() + + // Build 3 pages of 4 items each; set MaxItems to 5. + var responses []*doerResponse + for i := 0; i < 3; i++ { + items := []int{i*4 + 1, i*4 + 2, i*4 + 3, i*4 + 4} + nextLink := "" + if i < 2 { + nextLink = fmt.Sprintf("https://example.com/api?page=%d", i+2) + } + body := pageJSON(items, nextLink) + responses = append(responses, &doerResponse{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(body)), + Header: http.Header{}, + }, + }) + } + + doer := &mockDoer{responses: responses} + pager := NewPager[int](doer, "https://example.com/api?page=1", &PagerOptions{MaxItems: 5}) + + all, err := pager.Collect(context.Background()) + if err != nil { + t.Fatalf("Collect failed: %v", err) + } + + if len(all) != 5 { + t.Errorf("len(all) = %d, want 5 (MaxItems=5)", len(all)) + } +} + +func TestPager_TruncatedByMaxPages(t *testing.T) { + t.Parallel() + + var responses []*doerResponse + for i := 1; i <= 5; i++ { + nextLink := "" + if i < 5 { + nextLink = fmt.Sprintf("https://example.com/api?page=%d", i+1) + } + body := pageJSON([]int{i}, nextLink) + responses = append(responses, &doerResponse{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(body)), + Header: http.Header{}, + }, + }) + } + + doer := &mockDoer{responses: responses} + pager := NewPager[int](doer, "https://example.com/api?page=1", &PagerOptions{MaxPages: 3}) + + all, err := pager.Collect(context.Background()) + if err != nil { + t.Fatalf("Collect failed: %v", err) + } + + if len(all) != 3 { + t.Errorf("len(all) = %d, want 3", len(all)) + } + + if !pager.Truncated() { + t.Error("Truncated() = false, want true (stopped at MaxPages)") + } +} + +func TestPager_TruncatedByMaxItems(t *testing.T) { + t.Parallel() + + var responses []*doerResponse + for i := 0; i < 3; i++ { + items := []int{i*4 + 1, i*4 + 2, i*4 + 3, i*4 + 4} + nextLink := "" + if i < 2 { + nextLink = fmt.Sprintf("https://example.com/api?page=%d", i+2) + } + body := pageJSON(items, nextLink) + responses = append(responses, &doerResponse{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(body)), + Header: http.Header{}, + }, + }) + } + + doer := &mockDoer{responses: responses} + pager := NewPager[int](doer, "https://example.com/api?page=1", &PagerOptions{MaxItems: 5}) + + all, err := pager.Collect(context.Background()) + if err != nil { + t.Fatalf("Collect failed: %v", err) + } + + if len(all) != 5 { + t.Errorf("len(all) = %d, want 5", len(all)) + } + + if !pager.Truncated() { + t.Error("Truncated() = false, want true (stopped at MaxItems)") + } +} + +func TestPager_NotTruncatedOnNaturalEnd(t *testing.T) { + t.Parallel() + + body := pageJSON([]string{"a", "b"}, "") + doer := &mockDoer{ + responses: []*doerResponse{ + {resp: &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(body)), + Header: http.Header{}, + }}, + }, + } + + pager := NewPager[string](doer, "https://example.com/api", nil) + + all, err := pager.Collect(context.Background()) + if err != nil { + t.Fatalf("Collect failed: %v", err) + } + + if len(all) != 2 { + t.Errorf("len(all) = %d, want 2", len(all)) + } + + if pager.Truncated() { + t.Error("Truncated() = true, want false (natural end)") + } +} diff --git a/cli/azd/pkg/azdext/resilient_http_client.go b/cli/azd/pkg/azdext/resilient_http_client.go new file mode 100644 index 00000000000..f00e85d38e3 --- /dev/null +++ b/cli/azd/pkg/azdext/resilient_http_client.go @@ -0,0 +1,301 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "context" + "crypto/rand" + "encoding/binary" + "errors" + "fmt" + "io" + "math" + "net/http" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/azure/azure-dev/cli/azd/pkg/httputil" +) + +const ( + // maxRetryAfterDuration caps the Retry-After header value to prevent + // a malicious or misconfigured server from stalling the client indefinitely. + maxRetryAfterDuration = 120 * time.Second + + // maxRetryBodyDrain limits how many bytes are consumed when draining a + // retryable response body before the next attempt. This prevents a + // malicious or misconfigured server from stalling the client with an + // unbounded response body. + maxRetryBodyDrain int64 = 1 << 20 // 1 MB +) + +// ResilientClient is an HTTP client with built-in retry, exponential backoff, +// timeout, and optional bearer-token injection. It is designed for extension +// authors who need to call Azure REST APIs directly. +// +// Usage: +// +// rc := azdext.NewResilientClient(tokenProvider, nil) +// resp, err := rc.Do(ctx, http.MethodGet, "https://management.azure.com/...", nil) +type ResilientClient struct { + httpClient *http.Client + tokenProvider azcore.TokenCredential + scopeDetector *ScopeDetector + opts ResilientClientOptions +} + +var _ HTTPDoer = (*ResilientClient)(nil) + +// ResilientClientOptions configures a [ResilientClient]. +type ResilientClientOptions struct { + // MaxRetries is the maximum number of retry attempts for transient failures. + // Defaults to 3. + MaxRetries int + + // InitialDelay is the base delay before the first retry. Subsequent retries + // use exponential backoff (delay * 2^attempt) capped at MaxDelay. + // Defaults to 500ms. + InitialDelay time.Duration + + // MaxDelay caps the computed backoff delay. Defaults to 30s. + MaxDelay time.Duration + + // Timeout is the per-request timeout. + // A value of zero or less uses the default of 30s. + Timeout time.Duration + + // Transport overrides the default HTTP transport. Useful for testing. + Transport http.RoundTripper + + // ScopeDetector overrides the default scope detector used for automatic + // scope resolution. When nil, a default detector is created. + ScopeDetector *ScopeDetector +} + +// defaults fills zero-value fields with production defaults. +func (o *ResilientClientOptions) defaults() { + if o.MaxRetries <= 0 { + o.MaxRetries = 3 + } + + if o.InitialDelay <= 0 { + o.InitialDelay = 500 * time.Millisecond + } + + if o.MaxDelay <= 0 { + o.MaxDelay = 30 * time.Second + } + + if o.Timeout <= 0 { + o.Timeout = 30 * time.Second + } +} + +// NewResilientClient creates a [ResilientClient]. +// +// tokenProvider may be nil if the caller handles Authorization headers manually. +// When non-nil, the client automatically injects a Bearer token using scopes +// resolved from the request URL via the [ScopeDetector]. +func NewResilientClient(tokenProvider azcore.TokenCredential, opts *ResilientClientOptions) *ResilientClient { + if opts == nil { + opts = &ResilientClientOptions{} + } + + opts.defaults() + + transport := opts.Transport + if transport == nil { + transport = http.DefaultTransport + } + + sd := opts.ScopeDetector + if sd == nil { + sd = NewScopeDetector(nil) + } + + return &ResilientClient{ + httpClient: &http.Client{ + Transport: transport, + Timeout: opts.Timeout, + CheckRedirect: SSRFSafeRedirect, + }, + tokenProvider: tokenProvider, + scopeDetector: sd, + opts: *opts, + } +} + +// Do executes an HTTP request with retry logic and optional bearer-token injection. +// +// body may be nil for requests without a body (GET, DELETE). +// When body is non-nil and retries are enabled (MaxRetries > 0), the body must +// implement [io.ReadSeeker] so it can be re-read on each attempt. If a retry is +// needed and the body does not implement [io.ReadSeeker], Do returns an error. +func (rc *ResilientClient) Do(ctx context.Context, method, url string, body io.Reader) (*http.Response, error) { + if ctx == nil { + return nil, errors.New("azdext.ResilientClient.Do: context must not be nil") + } + + // Validate body seekability upfront when retries are enabled. + // Fail fast rather than discovering the body is not seekable after the + // first attempt has already consumed it. + if body != nil && rc.opts.MaxRetries > 0 { + if _, ok := body.(io.ReadSeeker); !ok { + return nil, errors.New( + "azdext.ResilientClient.Do: request body does not implement io.ReadSeeker; " + + "retries require a seekable body (use bytes.NewReader or strings.NewReader)") + } + } + + var lastErr error + var retryAfterOverride time.Duration + + for attempt := range rc.opts.MaxRetries + 1 { + if attempt > 0 { + // Use Retry-After from the previous response if available; + // otherwise compute exponential backoff. This avoids a + // double-wait (backoff + Retry-After) on each iteration. + delay := retryAfterOverride + if delay == 0 { + delay = rc.backoff(attempt) + } + retryAfterOverride = 0 + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(delay): + } + + // Reset body for retry — require io.ReadSeeker for non-nil bodies. + if body != nil { + seeker, ok := body.(io.ReadSeeker) + if !ok { + return nil, errors.New( + "azdext.ResilientClient.Do: request body does not implement io.ReadSeeker; " + + "retries require a seekable body (use bytes.NewReader or strings.NewReader)") + } + if _, err := seeker.Seek(0, io.SeekStart); err != nil { + return nil, fmt.Errorf("azdext.ResilientClient.Do: failed to reset request body: %w", err) + } + } + } + + req, err := http.NewRequestWithContext(ctx, method, url, body) + if err != nil { + return nil, fmt.Errorf("azdext.ResilientClient.Do: failed to create request: %w", err) + } + + // Inject bearer token when a token provider is available. + if rc.tokenProvider != nil { + if authErr := rc.applyAuth(ctx, req); authErr != nil { + return nil, fmt.Errorf("azdext.ResilientClient.Do: authorization failed: %w", authErr) + } + } + + resp, err := rc.httpClient.Do(req) + if err != nil { + lastErr = err + + // Don't retry on context cancellation. + if ctx.Err() != nil { + return nil, ctx.Err() + } + + continue // network error → retry + } + + if !isRetryable(resp.StatusCode) { + return resp, nil + } + + // Consume body before retry to release the connection. + // Bound the read to prevent a malicious server from stalling the + // client with an infinitely long response body. + _, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, maxRetryBodyDrain)) + resp.Body.Close() + + // Capture Retry-After for the next iteration's delay, + // capped to prevent indefinite stalling. + if ra := httputil.RetryAfter(resp); ra > 0 { + if ra > maxRetryAfterDuration { + ra = maxRetryAfterDuration + } + retryAfterOverride = ra + } + + lastErr = &RetryableHTTPError{ + StatusCode: resp.StatusCode, + Status: resp.Status, + } + } + + return nil, fmt.Errorf("azdext.ResilientClient.Do: exhausted retries: %w", lastErr) +} + +// applyAuth resolves scopes from the request URL and sets the Authorization header. +// It refuses to send bearer tokens over non-HTTPS connections. +func (rc *ResilientClient) applyAuth(ctx context.Context, req *http.Request) error { + if req.URL.Scheme != "https" { + return fmt.Errorf("bearer token requires HTTPS; refusing to authenticate to %s URL", req.URL.Scheme) + } + + scopes, err := rc.scopeDetector.ScopesForURL(req.URL.String()) + if err != nil { + return err + } + + tok, err := rc.tokenProvider.GetToken(ctx, policy.TokenRequestOptions{Scopes: scopes}) + if err != nil { + return err + } + + req.Header.Set("Authorization", "Bearer "+tok.Token) + + return nil +} + +// backoff computes the delay for a given attempt using exponential backoff. +func (rc *ResilientClient) backoff(attempt int) time.Duration { + delay := time.Duration(float64(rc.opts.InitialDelay) * math.Pow(2, float64(attempt-1))) + if delay > rc.opts.MaxDelay { + delay = rc.opts.MaxDelay + } + + // Add jitter: randomize between [50%, 100%) of computed delay to prevent + // thundering herd when multiple clients retry simultaneously. + var b [8]byte + jitter := 0.75 + if _, err := rand.Read(b[:]); err == nil { + randFloat := float64(binary.BigEndian.Uint64(b[:])) / float64(^uint64(0)) + jitter = 0.5 + randFloat*0.5 + } + return time.Duration(float64(delay) * jitter) +} + +// isRetryable returns true for status codes that indicate a transient failure. +func isRetryable(statusCode int) bool { + switch statusCode { + case http.StatusTooManyRequests, + http.StatusRequestTimeout, + http.StatusInternalServerError, + http.StatusBadGateway, + http.StatusServiceUnavailable, + http.StatusGatewayTimeout: + return true + default: + return false + } +} + +// RetryableHTTPError represents a retryable HTTP failure. +type RetryableHTTPError struct { + StatusCode int + Status string +} + +func (e *RetryableHTTPError) Error() string { + return fmt.Sprintf("azdext.ResilientClient: retryable HTTP error %d (%s)", e.StatusCode, e.Status) +} diff --git a/cli/azd/pkg/azdext/resilient_http_client_test.go b/cli/azd/pkg/azdext/resilient_http_client_test.go new file mode 100644 index 00000000000..9170202a8ca --- /dev/null +++ b/cli/azd/pkg/azdext/resilient_http_client_test.go @@ -0,0 +1,787 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "bytes" + "context" + "errors" + "io" + "net/http" + "strconv" + "strings" + "testing" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/azure/azure-dev/cli/azd/pkg/httputil" +) + +// roundTripFunc is an adapter to allow ordinary functions as http.RoundTripper. +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) } + +// fakeTokenCredential satisfies azcore.TokenCredential for testing. +type fakeTokenCredential struct { + token string + err error +} + +func (f *fakeTokenCredential) GetToken(_ context.Context, _ policy.TokenRequestOptions) (azcore.AccessToken, error) { + return azcore.AccessToken{Token: f.token, ExpiresOn: time.Now().Add(time.Hour)}, f.err +} + +func TestResilientClient_Success(t *testing.T) { + t.Parallel() + + transport := roundTripFunc(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{"ok":true}`)), + Header: http.Header{}, + }, nil + }) + + rc := NewResilientClient(nil, &ResilientClientOptions{Transport: transport}) + + resp, err := rc.Do(context.Background(), http.MethodGet, "https://example.com/api", nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("StatusCode = %d, want %d", resp.StatusCode, http.StatusOK) + } +} + +func TestResilientClient_RetriesTransientFailures(t *testing.T) { + t.Parallel() + + var attempts int + transport := roundTripFunc(func(r *http.Request) (*http.Response, error) { + attempts++ + if attempts < 3 { + return &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Body: io.NopCloser(strings.NewReader("unavailable")), + Header: http.Header{}, + }, nil + } + + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("ok")), + Header: http.Header{}, + }, nil + }) + + rc := NewResilientClient(nil, &ResilientClientOptions{ + Transport: transport, + MaxRetries: 3, + InitialDelay: time.Millisecond, // fast for testing + MaxDelay: 10 * time.Millisecond, + }) + + resp, err := rc.Do(context.Background(), http.MethodGet, "https://example.com/api", nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("StatusCode = %d, want %d", resp.StatusCode, http.StatusOK) + } + + if attempts != 3 { + t.Errorf("attempts = %d, want 3", attempts) + } +} + +func TestResilientClient_ExhaustsRetries(t *testing.T) { + t.Parallel() + + transport := roundTripFunc(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusTooManyRequests, + Body: io.NopCloser(strings.NewReader("throttled")), + Header: http.Header{}, + }, nil + }) + + rc := NewResilientClient(nil, &ResilientClientOptions{ + Transport: transport, + MaxRetries: 2, + InitialDelay: time.Millisecond, + MaxDelay: 5 * time.Millisecond, + }) + + _, err := rc.Do(context.Background(), http.MethodGet, "https://example.com/api", nil) + if err == nil { + t.Fatal("expected error after exhausting retries") + } + + var retryErr *RetryableHTTPError + if !errors.As(err, &retryErr) { + t.Fatalf("error type = %T, want *RetryableHTTPError", err) + } + + if retryErr.StatusCode != http.StatusTooManyRequests { + t.Errorf("StatusCode = %d, want %d", retryErr.StatusCode, http.StatusTooManyRequests) + } +} + +func TestResilientClient_NoRetryOn4xx(t *testing.T) { + t.Parallel() + + var attempts int + transport := roundTripFunc(func(r *http.Request) (*http.Response, error) { + attempts++ + return &http.Response{ + StatusCode: http.StatusNotFound, + Body: io.NopCloser(strings.NewReader("not found")), + Header: http.Header{}, + }, nil + }) + + rc := NewResilientClient(nil, &ResilientClientOptions{ + Transport: transport, + MaxRetries: 3, + InitialDelay: time.Millisecond, + }) + + resp, err := rc.Do(context.Background(), http.MethodGet, "https://example.com/api", nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer resp.Body.Close() + + if attempts != 1 { + t.Errorf("attempts = %d, want 1 (no retry on 404)", attempts) + } +} + +func TestResilientClient_NetworkError(t *testing.T) { + t.Parallel() + + transport := roundTripFunc(func(r *http.Request) (*http.Response, error) { + return nil, errors.New("connection refused") + }) + + rc := NewResilientClient(nil, &ResilientClientOptions{ + Transport: transport, + MaxRetries: 1, + InitialDelay: time.Millisecond, + MaxDelay: 5 * time.Millisecond, + }) + + _, err := rc.Do(context.Background(), http.MethodGet, "https://example.com/api", nil) + if err == nil { + t.Fatal("expected error for network failure") + } +} + +func TestResilientClient_ContextCancelled(t *testing.T) { + t.Parallel() + + transport := roundTripFunc(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Body: io.NopCloser(strings.NewReader("unavailable")), + Header: http.Header{}, + }, nil + }) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // cancel immediately + + rc := NewResilientClient(nil, &ResilientClientOptions{ + Transport: transport, + MaxRetries: 3, + InitialDelay: time.Second, + }) + + _, err := rc.Do(ctx, http.MethodGet, "https://example.com/api", nil) + if !errors.Is(err, context.Canceled) { + t.Errorf("error = %v, want context.Canceled", err) + } +} + +func TestResilientClient_BearerTokenInjection(t *testing.T) { + t.Parallel() + + var capturedAuth string + transport := roundTripFunc(func(r *http.Request) (*http.Response, error) { + capturedAuth = r.Header.Get("Authorization") + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("ok")), + Header: http.Header{}, + }, nil + }) + + cred := &fakeTokenCredential{token: "my-access-token"} + + rc := NewResilientClient(cred, &ResilientClientOptions{Transport: transport}) + + // URL must match a known scope for the detector. + resp, err := rc.Do(context.Background(), http.MethodGet, "https://management.azure.com/subscriptions", nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer resp.Body.Close() + + if capturedAuth != "Bearer my-access-token" { + t.Errorf("Authorization = %q, want %q", capturedAuth, "Bearer my-access-token") + } +} + +func TestResilientClient_TokenProviderError(t *testing.T) { + t.Parallel() + + transport := roundTripFunc(func(r *http.Request) (*http.Response, error) { + t.Fatal("should not reach transport when token fails") + return nil, nil + }) + + cred := &fakeTokenCredential{err: errors.New("token expired")} + + rc := NewResilientClient(cred, &ResilientClientOptions{Transport: transport}) + + _, err := rc.Do(context.Background(), http.MethodGet, "https://management.azure.com/subs", nil) + if err == nil { + t.Fatal("expected error when token provider fails") + } +} + +func TestResilientClient_BodyRewind(t *testing.T) { + t.Parallel() + + var attempts int + transport := roundTripFunc(func(r *http.Request) (*http.Response, error) { + attempts++ + if r.Body != nil { + data, _ := io.ReadAll(r.Body) + if string(data) != "payload" { + t.Errorf("attempt %d: body = %q, want %q", attempts, string(data), "payload") + } + } + + if attempts < 2 { + return &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Body: io.NopCloser(strings.NewReader("")), + Header: http.Header{}, + }, nil + } + + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("ok")), + Header: http.Header{}, + }, nil + }) + + rc := NewResilientClient(nil, &ResilientClientOptions{ + Transport: transport, + MaxRetries: 2, + InitialDelay: time.Millisecond, + }) + + body := bytes.NewReader([]byte("payload")) + resp, err := rc.Do(context.Background(), http.MethodPost, "https://example.com/api", body) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer resp.Body.Close() + + if attempts != 2 { + t.Errorf("attempts = %d, want 2", attempts) + } +} + +func TestResilientClient_RetryAfterHeader(t *testing.T) { + t.Parallel() + + var attempts int + transport := roundTripFunc(func(r *http.Request) (*http.Response, error) { + attempts++ + if attempts == 1 { + h := http.Header{} + h.Set("retry-after-ms", "1") // 1ms + return &http.Response{ + StatusCode: http.StatusTooManyRequests, + Body: io.NopCloser(strings.NewReader("throttled")), + Header: h, + }, nil + } + + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("ok")), + Header: http.Header{}, + }, nil + }) + + rc := NewResilientClient(nil, &ResilientClientOptions{ + Transport: transport, + MaxRetries: 2, + InitialDelay: time.Millisecond, + }) + + resp, err := rc.Do(context.Background(), http.MethodGet, "https://example.com/api", nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer resp.Body.Close() + + if attempts != 2 { + t.Errorf("attempts = %d, want 2", attempts) + } +} + +func TestResilientClient_NilContext(t *testing.T) { + t.Parallel() + + rc := NewResilientClient(nil, nil) + + //lint:ignore SA1012 intentional nil context for test + //nolint:staticcheck // intentional nil context for test + _, err := rc.Do(nil, http.MethodGet, "https://example.com/api", nil) + if err == nil { + t.Fatal("expected error for nil context") + } +} + +func TestResilientClient_DefaultOptions(t *testing.T) { + t.Parallel() + + opts := &ResilientClientOptions{} + opts.defaults() + + if opts.MaxRetries != 3 { + t.Errorf("MaxRetries = %d, want 3", opts.MaxRetries) + } + + if opts.InitialDelay != 500*time.Millisecond { + t.Errorf("InitialDelay = %v, want 500ms", opts.InitialDelay) + } + + if opts.MaxDelay != 30*time.Second { + t.Errorf("MaxDelay = %v, want 30s", opts.MaxDelay) + } + + if opts.Timeout != 30*time.Second { + t.Errorf("Timeout = %v, want 30s", opts.Timeout) + } +} + +func TestRetryAfterFromResponse(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + header string + value string + want time.Duration + }{ + {name: "retry-after-ms", header: "retry-after-ms", value: "500", want: 500 * time.Millisecond}, + {name: "x-ms-retry-after-ms", header: "x-ms-retry-after-ms", value: "1000", want: time.Second}, + {name: "retry-after seconds", header: "retry-after", value: "2", want: 2 * time.Second}, + {name: "empty header", header: "retry-after", value: "", want: 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + h := http.Header{} + if tt.value != "" { + h.Set(tt.header, tt.value) + } + + resp := &http.Response{Header: h} + got := httputil.RetryAfter(resp) + + if got != tt.want { + t.Errorf("retryAfterFromResponse() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestRetryAfterFromResponse_Nil(t *testing.T) { + t.Parallel() + + got := httputil.RetryAfter(nil) + if got != 0 { + t.Errorf("retryAfterFromResponse(nil) = %v, want 0", got) + } +} + +func TestIsRetryable(t *testing.T) { + t.Parallel() + + retryable := []int{ + http.StatusTooManyRequests, + http.StatusRequestTimeout, + http.StatusInternalServerError, + http.StatusBadGateway, + http.StatusServiceUnavailable, + http.StatusGatewayTimeout, + } + + for _, code := range retryable { + if !isRetryable(code) { + t.Errorf("isRetryable(%d) = false, want true", code) + } + } + + notRetryable := []int{ + http.StatusOK, + http.StatusCreated, + http.StatusBadRequest, + http.StatusUnauthorized, + http.StatusForbidden, + http.StatusNotFound, + } + + for _, code := range notRetryable { + if isRetryable(code) { + t.Errorf("isRetryable(%d) = true, want false", code) + } + } +} + +func TestResilientClient_AllRetryableStatusCodes(t *testing.T) { + t.Parallel() + + retryableCodes := []int{ + http.StatusTooManyRequests, + http.StatusRequestTimeout, + http.StatusInternalServerError, + http.StatusBadGateway, + http.StatusServiceUnavailable, + http.StatusGatewayTimeout, + } + + for _, code := range retryableCodes { + t.Run(strconv.Itoa(code), func(t *testing.T) { + t.Parallel() + + var attempts int + transport := roundTripFunc(func(r *http.Request) (*http.Response, error) { + attempts++ + return &http.Response{ + StatusCode: code, + Body: io.NopCloser(strings.NewReader("")), + Header: http.Header{}, + }, nil + }) + + rc := NewResilientClient(nil, &ResilientClientOptions{ + Transport: transport, + MaxRetries: 1, + InitialDelay: time.Millisecond, + }) + + _, err := rc.Do(context.Background(), http.MethodGet, "https://example.com/api", nil) + if err == nil { + t.Fatal("expected error after retries exhausted") + } + + // 1 initial + 1 retry = 2 + if attempts != 2 { + t.Errorf("attempts = %d, want 2 for status %d", attempts, code) + } + }) + } +} + +func TestResilientClient_NonSeekableBodyRetryError(t *testing.T) { + t.Parallel() + + var attempts int + transport := roundTripFunc(func(r *http.Request) (*http.Response, error) { + attempts++ + return &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Body: io.NopCloser(strings.NewReader("")), + Header: http.Header{}, + }, nil + }) + + rc := NewResilientClient(nil, &ResilientClientOptions{ + Transport: transport, + MaxRetries: 2, + InitialDelay: time.Millisecond, + }) + + // io.NopCloser wrapping strings.NewReader is NOT an io.ReadSeeker. + // With upfront validation, the error is caught before any HTTP call. + body := io.NopCloser(strings.NewReader("payload")) + _, err := rc.Do(context.Background(), http.MethodPost, "https://example.com/api", body) + if err == nil { + t.Fatal("expected error for non-seekable body with retries enabled") + } + + if !strings.Contains(err.Error(), "io.ReadSeeker") { + t.Errorf("error = %q, want mention of io.ReadSeeker", err.Error()) + } + + // Should have made zero attempts — upfront check rejects before any HTTP call. + if attempts != 0 { + t.Errorf("attempts = %d, want 0 (fail fast before any request)", attempts) + } +} + +func TestResilientClient_TokenOverHTTP(t *testing.T) { + t.Parallel() + + transport := roundTripFunc(func(r *http.Request) (*http.Response, error) { + t.Fatal("should not reach transport for HTTP URL with token provider") + return nil, nil + }) + + cred := &fakeTokenCredential{token: "secret-token"} + rc := NewResilientClient(cred, &ResilientClientOptions{Transport: transport}) + + _, err := rc.Do(context.Background(), http.MethodGet, "http://example.com/api", nil) + if err == nil { + t.Fatal("expected error for HTTP URL with token provider") + } + + if !strings.Contains(err.Error(), "HTTPS") { + t.Errorf("error = %q, want mention of HTTPS", err.Error()) + } +} + +func TestResilientClient_RetryAfterReplacesBackoff(t *testing.T) { + t.Parallel() + + var attempts int + transport := roundTripFunc(func(r *http.Request) (*http.Response, error) { + attempts++ + if attempts == 1 { + h := http.Header{} + h.Set("retry-after-ms", "1") // 1ms retry-after + return &http.Response{ + StatusCode: http.StatusTooManyRequests, + Body: io.NopCloser(strings.NewReader("throttled")), + Header: h, + }, nil + } + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("ok")), + Header: http.Header{}, + }, nil + }) + + rc := NewResilientClient(nil, &ResilientClientOptions{ + Transport: transport, + MaxRetries: 2, + InitialDelay: 5 * time.Second, // large backoff — should NOT be used + MaxDelay: 10 * time.Second, + }) + + start := time.Now() + resp, err := rc.Do(context.Background(), http.MethodGet, "https://example.com/api", nil) + elapsed := time.Since(start) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer resp.Body.Close() + + // If Retry-After replaces backoff, total time should be ~1ms, not 5s. + if elapsed > 2*time.Second { + t.Errorf("elapsed = %v, want < 2s (Retry-After should replace backoff, not add to it)", elapsed) + } +} + +func TestResilientClient_RetryAfterCapped(t *testing.T) { + t.Parallel() + + // Verify the cap constant is reasonable. + if maxRetryAfterDuration > 5*time.Minute { + t.Errorf("maxRetryAfterDuration = %v, should be <= 5m", maxRetryAfterDuration) + } + + // A large Retry-After value should be capped in Do(). + h := http.Header{} + h.Set("retry-after", "999999") + resp := &http.Response{Header: h} + + got := httputil.RetryAfter(resp) + // RetryAfter parser itself doesn't cap (pure parser), but Do() caps it. + if got != 999999*time.Second { + t.Errorf("RetryAfter() = %v, want %v (capping happens in Do)", got, 999999*time.Second) + } +} + +func TestResilientClient_RetryBodyDrainBounded(t *testing.T) { + t.Parallel() + + // Verify the constant used for bounded retry body drain is set + // and reasonable: it should prevent memory exhaustion but allow + // realistic retryable response bodies to be fully drained. + if maxRetryBodyDrain <= 0 { + t.Fatal("maxRetryBodyDrain must be positive") + } + if maxRetryBodyDrain > 10<<20 { // 10 MB + t.Errorf("maxRetryBodyDrain = %d, should be <= 10 MB", maxRetryBodyDrain) + } + + // Simulate a retry scenario where the retryable response body is larger + // than the drain limit. The client should not hang or OOM. + var attempts int + transport := roundTripFunc(func(r *http.Request) (*http.Response, error) { + attempts++ + if attempts == 1 { + // Return a retryable status with a body larger than the drain limit. + // Use a LimitedReader to simulate a large body without allocating. + bigBody := io.LimitReader(infiniteReader{}, maxRetryBodyDrain+1024) + return &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Body: io.NopCloser(bigBody), + Header: http.Header{}, + }, nil + } + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("ok")), + Header: http.Header{}, + }, nil + }) + + rc := NewResilientClient(nil, &ResilientClientOptions{ + Transport: transport, + MaxRetries: 1, + InitialDelay: time.Millisecond, + }) + + resp, err := rc.Do(context.Background(), http.MethodGet, "https://example.com/api", nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer resp.Body.Close() + + if attempts != 2 { + t.Errorf("attempts = %d, want 2", attempts) + } +} + +// infiniteReader is an io.Reader that produces zero bytes forever. +type infiniteReader struct{} + +func (infiniteReader) Read(p []byte) (int, error) { + clear(p) + return len(p), nil +} + +func TestResilientClient_BackoffJitter(t *testing.T) { + t.Parallel() + + rc := NewResilientClient(nil, &ResilientClientOptions{ + InitialDelay: 100 * time.Millisecond, + MaxDelay: 10 * time.Second, + }) + + // Run backoff multiple times for the same attempt and verify results + // vary (jitter produces different values). + seen := make(map[time.Duration]bool) + for range 20 { + d := rc.backoff(1) + seen[d] = true + // With jitter in [50%, 100%), delay should be in [50ms, 100ms). + if d < 50*time.Millisecond || d >= 100*time.Millisecond { + t.Errorf("backoff(1) = %v, want in [50ms, 100ms)", d) + } + } + if len(seen) < 2 { + t.Error("backoff jitter produced identical values across 20 calls") + } +} + +func TestResilientClient_NonSeekableBodyFailsFast(t *testing.T) { + t.Parallel() + + var attempts int + transport := roundTripFunc(func(r *http.Request) (*http.Response, error) { + attempts++ + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("ok")), + Header: http.Header{}, + }, nil + }) + + rc := NewResilientClient(nil, &ResilientClientOptions{ + Transport: transport, + MaxRetries: 2, + InitialDelay: time.Millisecond, + }) + + // Non-seekable body with retries enabled should fail before any request. + body := io.NopCloser(strings.NewReader("payload")) + _, err := rc.Do(context.Background(), http.MethodPost, "https://example.com/api", body) + if err == nil { + t.Fatal("expected error for non-seekable body with retries enabled") + } + + if !strings.Contains(err.Error(), "io.ReadSeeker") { + t.Errorf("error = %q, want mention of io.ReadSeeker", err.Error()) + } + + // Should NOT have made any HTTP request. + if attempts != 0 { + t.Errorf("attempts = %d, want 0 (fail fast before any request)", attempts) + } +} + +func TestResilientClient_RetryAfterCappedInDo(t *testing.T) { + t.Parallel() + + // A huge Retry-After should be capped to maxRetryAfterDuration. + // We verify this by using a very short context timeout: if the raw + // value (999999s) were used, the context would expire instantly + // rather than letting the retry proceed. With capping, the context + // timeout (250ms here) is less than the cap, so we expect the + // context to cancel — proving the delay is finite and capped. + var attempts int + transport := roundTripFunc(func(r *http.Request) (*http.Response, error) { + attempts++ + h := http.Header{} + h.Set("retry-after", "999999") + return &http.Response{ + StatusCode: http.StatusTooManyRequests, + Body: io.NopCloser(strings.NewReader("throttled")), + Header: h, + }, nil + }) + + rc := NewResilientClient(nil, &ResilientClientOptions{ + Transport: transport, + MaxRetries: 1, + InitialDelay: time.Millisecond, + }) + + ctx, cancel := context.WithTimeout(context.Background(), 250*time.Millisecond) + defer cancel() + + _, err := rc.Do(ctx, http.MethodGet, "https://example.com/api", nil) + // The context should cancel during the capped delay (120s > 250ms), + // which means the raw 999999s was replaced by the cap. + if !errors.Is(err, context.DeadlineExceeded) { + t.Errorf("expected context.DeadlineExceeded (proving cap was applied), got: %v", err) + } + + // Only 1 attempt — the retry wait for the capped delay gets canceled. + if attempts != 1 { + t.Errorf("attempts = %d, want 1", attempts) + } +} diff --git a/cli/azd/pkg/azdext/scope_detector.go b/cli/azd/pkg/azdext/scope_detector.go new file mode 100644 index 00000000000..9025cb36a2a --- /dev/null +++ b/cli/azd/pkg/azdext/scope_detector.go @@ -0,0 +1,178 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "errors" + "net/url" + "sort" + "strings" +) + +// ScopeDetector maps Azure resource endpoint URLs to the OAuth 2.0 scopes +// required for token acquisition. Extensions use this to automatically +// determine the correct scope for a given API call without hard-coding values. +// +// Usage: +// +// sd := azdext.NewScopeDetector(nil) +// scopes, err := sd.ScopesForURL("https://management.azure.com/subscriptions/...") +// // scopes = []string{"https://management.azure.com/.default"} +type ScopeDetector struct { + rules []scopeRule +} + +// scopeRule binds a host-matching function to a scope. +type scopeRule struct { + match func(host string) bool + scope string +} + +// ScopeDetectorOptions allows adding custom endpoint-to-scope mappings. +type ScopeDetectorOptions struct { + // CustomRules appends additional host → scope mappings. + // Each entry maps a host suffix (e.g. ".openai.azure.com") to a scope + // (e.g. "https://cognitiveservices.azure.com/.default"). + // Suffixes should start with a dot for subdomain matching; entries + // without a leading dot are treated as exact host matches to prevent + // unintended partial-host matching (e.g. "azure.com" matching "fakeazure.com"). + // Empty keys are ignored. + CustomRules map[string]string +} + +// defaultRules contains well-known Azure endpoint → scope mappings. +// Order does not matter; rules are evaluated until a match is found. +func defaultRules() []scopeRule { + suffix := func(s string) func(string) bool { + return func(host string) bool { return strings.HasSuffix(host, s) } + } + exact := func(s string) func(string) bool { + return func(host string) bool { return host == s } + } + + return []scopeRule{ + // Azure Resource Manager + {match: exact("management.azure.com"), scope: "https://management.azure.com/.default"}, + + // Microsoft Graph + {match: exact("graph.microsoft.com"), scope: "https://graph.microsoft.com/.default"}, + + // Azure Key Vault + {match: suffix(".vault.azure.net"), scope: "https://vault.azure.net/.default"}, + + // Azure Storage (Blob, Queue, Table, File, Data Lake) + {match: suffix(".blob.core.windows.net"), scope: "https://storage.azure.com/.default"}, + {match: suffix(".queue.core.windows.net"), scope: "https://storage.azure.com/.default"}, + {match: suffix(".table.core.windows.net"), scope: "https://storage.azure.com/.default"}, + {match: suffix(".file.core.windows.net"), scope: "https://storage.azure.com/.default"}, + {match: suffix(".dfs.core.windows.net"), scope: "https://storage.azure.com/.default"}, + + // Azure Container Registry + {match: suffix(".azurecr.io"), scope: "https://management.azure.com/.default"}, + + // Azure Cognitive Services / OpenAI + {match: suffix(".openai.azure.com"), scope: "https://cognitiveservices.azure.com/.default"}, + {match: suffix(".cognitiveservices.azure.com"), scope: "https://cognitiveservices.azure.com/.default"}, + + // Azure AI Services + {match: suffix(".services.ai.azure.com"), scope: "https://cognitiveservices.azure.com/.default"}, + + // Azure DevOps + {match: exact("dev.azure.com"), scope: "499b84ac-1321-427f-aa17-267ca6975798/.default"}, + {match: suffix(".visualstudio.com"), scope: "499b84ac-1321-427f-aa17-267ca6975798/.default"}, + + // Azure Database for PostgreSQL + {match: suffix(".postgres.database.azure.com"), scope: "https://ossrdbms-aad.database.windows.net/.default"}, + + // Azure Database for MySQL + {match: suffix(".mysql.database.azure.com"), scope: "https://ossrdbms-aad.database.windows.net/.default"}, + + // Azure Cosmos DB + {match: suffix(".documents.azure.com"), scope: "https://cosmos.azure.com/.default"}, + + // Azure Event Hubs + {match: suffix(".servicebus.windows.net"), scope: "https://eventhubs.azure.net/.default"}, + + // Azure App Configuration + {match: suffix(".azconfig.io"), scope: "https://azconfig.io/.default"}, + } +} + +// NewScopeDetector creates a [ScopeDetector] with the built-in Azure endpoint +// mappings. Additional custom rules can be supplied via opts. +func NewScopeDetector(opts *ScopeDetectorOptions) *ScopeDetector { + rules := defaultRules() + + if opts != nil { + // Sort keys for deterministic rule evaluation order. + keys := make([]string, 0, len(opts.CustomRules)) + for k := range opts.CustomRules { + keys = append(keys, k) + } + sort.Strings(keys) + + for _, hostSuffix := range keys { + if hostSuffix == "" { + continue // ignore empty keys + } + + scope := opts.CustomRules[hostSuffix] + hs := hostSuffix + + if strings.HasPrefix(hs, ".") { + // Dot-prefixed: suffix match (subdomain matching). + rules = append(rules, scopeRule{ + match: func(host string) bool { return strings.HasSuffix(host, hs) }, + scope: scope, + }) + } else { + // No dot prefix: exact host match to prevent partial-host + // matching (e.g. "azure.com" matching "fakeazure.com"). + rules = append(rules, scopeRule{ + match: func(host string) bool { return host == hs }, + scope: scope, + }) + } + } + } + + return &ScopeDetector{rules: rules} +} + +// ScopesForURL returns the OAuth 2.0 scopes required to access the given URL. +// Returns an error if the URL is malformed or no matching scope is found. +func (sd *ScopeDetector) ScopesForURL(rawURL string) ([]string, error) { + if rawURL == "" { + return nil, errors.New("azdext.ScopeDetector.ScopesForURL: URL must not be empty") + } + + u, err := url.Parse(rawURL) + if err != nil { + return nil, &ScopeDetectorError{URL: rawURL, Reason: "malformed URL: " + err.Error()} + } + + host := strings.ToLower(u.Hostname()) + if host == "" { + return nil, &ScopeDetectorError{URL: rawURL, Reason: "URL has no host"} + } + + for _, rule := range sd.rules { + if rule.match(host) { + return []string{rule.scope}, nil + } + } + + return nil, &ScopeDetectorError{URL: rawURL, Reason: "no scope mapping found for host: " + host} +} + +// ScopeDetectorError is returned when [ScopeDetector.ScopesForURL] cannot +// resolve a scope for the given URL. +type ScopeDetectorError struct { + URL string + Reason string +} + +func (e *ScopeDetectorError) Error() string { + return "azdext.ScopeDetector: " + e.Reason + " (url=" + e.URL + ")" +} diff --git a/cli/azd/pkg/azdext/scope_detector_test.go b/cli/azd/pkg/azdext/scope_detector_test.go new file mode 100644 index 00000000000..5e24e56df6a --- /dev/null +++ b/cli/azd/pkg/azdext/scope_detector_test.go @@ -0,0 +1,312 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "errors" + "testing" +) + +func TestScopeDetector_KnownEndpoints(t *testing.T) { + t.Parallel() + + sd := NewScopeDetector(nil) + + tests := []struct { + name string + url string + wantScope string + }{ + // ARM + { + name: "ARM subscription list", + url: "https://management.azure.com/subscriptions?api-version=2022-01-01", + wantScope: "https://management.azure.com/.default", + }, + // Graph + { + name: "Graph users", + url: "https://graph.microsoft.com/v1.0/me", + wantScope: "https://graph.microsoft.com/.default", + }, + // Key Vault + { + name: "Key Vault secret", + url: "https://myvault.vault.azure.net/secrets/mysecret", + wantScope: "https://vault.azure.net/.default", + }, + // Storage - Blob + { + name: "Blob storage", + url: "https://myaccount.blob.core.windows.net/container/blob", + wantScope: "https://storage.azure.com/.default", + }, + // Storage - Queue + { + name: "Queue storage", + url: "https://myaccount.queue.core.windows.net/myqueue", + wantScope: "https://storage.azure.com/.default", + }, + // Storage - Table + { + name: "Table storage", + url: "https://myaccount.table.core.windows.net/mytable", + wantScope: "https://storage.azure.com/.default", + }, + // Storage - File + { + name: "File storage", + url: "https://myaccount.file.core.windows.net/myshare", + wantScope: "https://storage.azure.com/.default", + }, + // Data Lake + { + name: "Data Lake", + url: "https://myaccount.dfs.core.windows.net/filesystem/path", + wantScope: "https://storage.azure.com/.default", + }, + // ACR + { + name: "Container Registry", + url: "https://myregistry.azurecr.io/v2/repo/tags/list", + wantScope: "https://management.azure.com/.default", + }, + // Azure OpenAI + { + name: "Azure OpenAI", + url: "https://myoai.openai.azure.com/openai/deployments/gpt4/chat/completions", + wantScope: "https://cognitiveservices.azure.com/.default", + }, + // Cognitive Services + { + name: "Cognitive Services", + url: "https://mycs.cognitiveservices.azure.com/vision/v3.1/analyze", + wantScope: "https://cognitiveservices.azure.com/.default", + }, + // AI Services + { + name: "AI Services", + url: "https://myai.services.ai.azure.com/api/projects/myproj", + wantScope: "https://cognitiveservices.azure.com/.default", + }, + // Azure DevOps + { + name: "Azure DevOps", + url: "https://dev.azure.com/myorg/myproject/_apis/git/repos", + wantScope: "499b84ac-1321-427f-aa17-267ca6975798/.default", + }, + // PostgreSQL + { + name: "PostgreSQL", + url: "https://myserver.postgres.database.azure.com:5432", + wantScope: "https://ossrdbms-aad.database.windows.net/.default", + }, + // MySQL + { + name: "MySQL", + url: "https://myserver.mysql.database.azure.com:3306", + wantScope: "https://ossrdbms-aad.database.windows.net/.default", + }, + // Cosmos DB + { + name: "Cosmos DB", + url: "https://myaccount.documents.azure.com:443/", + wantScope: "https://cosmos.azure.com/.default", + }, + // Event Hubs / Service Bus + { + name: "Event Hubs", + url: "https://myns.servicebus.windows.net/myhub", + wantScope: "https://eventhubs.azure.net/.default", + }, + // App Configuration + { + name: "App Configuration", + url: "https://myconfig.azconfig.io/kv/mykey", + wantScope: "https://azconfig.io/.default", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + scopes, err := sd.ScopesForURL(tt.url) + if err != nil { + t.Fatalf("ScopesForURL(%q) error: %v", tt.url, err) + } + + if len(scopes) != 1 { + t.Fatalf("ScopesForURL(%q) returned %d scopes, want 1", tt.url, len(scopes)) + } + + if scopes[0] != tt.wantScope { + t.Errorf("ScopesForURL(%q) = %q, want %q", tt.url, scopes[0], tt.wantScope) + } + }) + } +} + +func TestScopeDetector_EmptyURL(t *testing.T) { + t.Parallel() + + sd := NewScopeDetector(nil) + _, err := sd.ScopesForURL("") + if err == nil { + t.Fatal("expected error for empty URL") + } +} + +func TestScopeDetector_NoMatch(t *testing.T) { + t.Parallel() + + sd := NewScopeDetector(nil) + _, err := sd.ScopesForURL("https://example.com/foo") + if err == nil { + t.Fatal("expected error for unknown host") + } + + var scopeErr *ScopeDetectorError + if !errors.As(err, &scopeErr) { + t.Fatalf("error type = %T, want *ScopeDetectorError", err) + } + + if scopeErr.URL != "https://example.com/foo" { + t.Errorf("ScopeDetectorError.URL = %q, want %q", scopeErr.URL, "https://example.com/foo") + } +} + +func TestScopeDetector_MalformedURL(t *testing.T) { + t.Parallel() + + sd := NewScopeDetector(nil) + _, err := sd.ScopesForURL("://bad") + if err == nil { + t.Fatal("expected error for malformed URL") + } +} + +func TestScopeDetector_NoHost(t *testing.T) { + t.Parallel() + + sd := NewScopeDetector(nil) + _, err := sd.ScopesForURL("/relative/path") + if err == nil { + t.Fatal("expected error for URL without host") + } +} + +func TestScopeDetector_CustomRules(t *testing.T) { + t.Parallel() + + sd := NewScopeDetector(&ScopeDetectorOptions{ + CustomRules: map[string]string{ + ".custom.example.com": "https://custom.example.com/.default", + }, + }) + + scopes, err := sd.ScopesForURL("https://api.custom.example.com/v1/data") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(scopes) != 1 || scopes[0] != "https://custom.example.com/.default" { + t.Errorf("scopes = %v, want [https://custom.example.com/.default]", scopes) + } +} + +func TestScopeDetector_CaseInsensitive(t *testing.T) { + t.Parallel() + + sd := NewScopeDetector(nil) + + scopes, err := sd.ScopesForURL("https://MANAGEMENT.AZURE.COM/subscriptions") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(scopes) != 1 || scopes[0] != "https://management.azure.com/.default" { + t.Errorf("scopes = %v, want [https://management.azure.com/.default]", scopes) + } +} + +func TestScopeDetector_URLWithPort(t *testing.T) { + t.Parallel() + + sd := NewScopeDetector(nil) + + scopes, err := sd.ScopesForURL("https://myvault.vault.azure.net:443/secrets/key") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(scopes) != 1 || scopes[0] != "https://vault.azure.net/.default" { + t.Errorf("scopes = %v, want [https://vault.azure.net/.default]", scopes) + } +} + +func TestScopeDetector_CustomRuleWithoutDotPrefix(t *testing.T) { + t.Parallel() + + sd := NewScopeDetector(&ScopeDetectorOptions{ + CustomRules: map[string]string{ + "api.example.com": "https://example.com/.default", + }, + }) + + // Exact match should work. + scopes, err := sd.ScopesForURL("https://api.example.com/v1/data") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(scopes) != 1 || scopes[0] != "https://example.com/.default" { + t.Errorf("scopes = %v, want [https://example.com/.default]", scopes) + } + + // Should NOT match a different host that merely ends with the same string. + _, err = sd.ScopesForURL("https://evil-api.example.com/v1/data") + if err == nil { + t.Fatal("expected error: exact match should not match different host") + } +} + +func TestScopeDetector_EmptyCustomRuleIgnored(t *testing.T) { + t.Parallel() + + sd := NewScopeDetector(&ScopeDetectorOptions{ + CustomRules: map[string]string{ + "": "https://catch-all.example.com/.default", + }, + }) + + // Empty key should be ignored, so unknown host still errors. + _, err := sd.ScopesForURL("https://unknown.example.com/data") + if err == nil { + t.Fatal("expected error: empty custom rule should be ignored") + } +} + +func TestScopeDetector_DeterministicCustomRules(t *testing.T) { + t.Parallel() + + // Create a detector with multiple custom rules and verify + // that results are deterministic across multiple invocations. + rules := map[string]string{ + ".alpha.example.com": "https://alpha.example.com/.default", + ".beta.example.com": "https://beta.example.com/.default", + ".gamma.example.com": "https://gamma.example.com/.default", + } + + for i := 0; i < 10; i++ { + sd := NewScopeDetector(&ScopeDetectorOptions{CustomRules: rules}) + + scopes, err := sd.ScopesForURL("https://api.alpha.example.com/data") + if err != nil { + t.Fatalf("iteration %d: unexpected error: %v", i, err) + } + if scopes[0] != "https://alpha.example.com/.default" { + t.Fatalf("iteration %d: scope = %q, want %q", i, scopes[0], "https://alpha.example.com/.default") + } + } +} diff --git a/cli/azd/pkg/azdext/token_provider.go b/cli/azd/pkg/azdext/token_provider.go new file mode 100644 index 00000000000..5bd2ead1dbb --- /dev/null +++ b/cli/azd/pkg/azdext/token_provider.go @@ -0,0 +1,112 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "context" + "errors" + "fmt" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" +) + +// TokenProvider implements [azcore.TokenCredential] so that extensions can +// obtain Azure tokens without manual credential construction. +// +// It uses the AZD deployment context (tenant/subscription) retrieved via +// gRPC and delegates to [azidentity.AzureDeveloperCLICredential] for the +// actual token acquisition flow. +// +// Usage: +// +// tp, err := azdext.NewTokenProvider(client, nil) +// // use tp as azcore.TokenCredential with any Azure SDK client +type TokenProvider struct { + credential azcore.TokenCredential + tenantID string +} + +// Compile-time interface check. +var _ azcore.TokenCredential = (*TokenProvider)(nil) + +// TokenProviderOptions configures a [TokenProvider]. +type TokenProviderOptions struct { + // TenantID overrides the tenant obtained from the AZD deployment context. + // When empty, the provider queries the AZD gRPC server for the current tenant. + TenantID string + + // Credential overrides the default credential chain. + // When nil, [azidentity.AzureDeveloperCLICredential] is used. + Credential azcore.TokenCredential +} + +// NewTokenProvider creates a [TokenProvider] for the given AZD client. +// +// If opts is nil, the provider discovers the current tenant from the AZD +// deployment context and constructs an [azidentity.AzureDeveloperCLICredential]. +func NewTokenProvider(ctx context.Context, client *AzdClient, opts *TokenProviderOptions) (*TokenProvider, error) { + if client == nil { + return nil, errors.New("azdext.NewTokenProvider: client must not be nil") + } + + if opts == nil { + opts = &TokenProviderOptions{} + } + + tenantID := opts.TenantID + + // Resolve tenant from deployment context when not explicitly supplied. + if tenantID == "" { + resp, err := client.Deployment().GetDeploymentContext(ctx, &EmptyRequest{}) + if err != nil { + return nil, fmt.Errorf("azdext.NewTokenProvider: failed to retrieve deployment context: %w", err) + } + + if resp.GetAzureContext() != nil && resp.GetAzureContext().GetScope() != nil { + tenantID = resp.GetAzureContext().GetScope().GetTenantId() + } + + if tenantID == "" { + return nil, errors.New( + "azdext.NewTokenProvider: deployment context returned no tenant ID; " + + "set TenantID explicitly", + ) + } + } + + cred := opts.Credential + if cred == nil { + azdCred, err := azidentity.NewAzureDeveloperCLICredential( + &azidentity.AzureDeveloperCLICredentialOptions{ + TenantID: tenantID, + }, + ) + if err != nil { + return nil, fmt.Errorf("azdext.NewTokenProvider: failed to create credential: %w", err) + } + + cred = azdCred + } + + return &TokenProvider{ + credential: cred, + tenantID: tenantID, + }, nil +} + +// GetToken satisfies [azcore.TokenCredential]. +func (tp *TokenProvider) GetToken(ctx context.Context, options policy.TokenRequestOptions) (azcore.AccessToken, error) { + if len(options.Scopes) == 0 { + return azcore.AccessToken{}, errors.New("azdext.TokenProvider.GetToken: at least one scope is required") + } + + return tp.credential.GetToken(ctx, options) +} + +// TenantID returns the Azure tenant ID that was resolved or configured for this provider. +func (tp *TokenProvider) TenantID() string { + return tp.tenantID +} diff --git a/cli/azd/pkg/azdext/token_provider_test.go b/cli/azd/pkg/azdext/token_provider_test.go new file mode 100644 index 00000000000..f51d85097be --- /dev/null +++ b/cli/azd/pkg/azdext/token_provider_test.go @@ -0,0 +1,114 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" +) + +// stubCredential is a test double for azcore.TokenCredential. +type stubCredential struct { + token azcore.AccessToken + err error +} + +func (s *stubCredential) GetToken(_ context.Context, _ policy.TokenRequestOptions) (azcore.AccessToken, error) { + return s.token, s.err +} + +func TestNewTokenProvider_NilClient(t *testing.T) { + t.Parallel() + + _, err := NewTokenProvider(context.Background(), nil, nil) + if err == nil { + t.Fatal("expected error for nil client") + } +} + +func TestNewTokenProvider_ExplicitTenantAndCredential(t *testing.T) { + t.Parallel() + + cred := &stubCredential{ + token: azcore.AccessToken{Token: "test-token", ExpiresOn: time.Now().Add(time.Hour)}, + } + + // Use a minimal AzdClient (no gRPC connection needed since we supply tenant+cred). + client := &AzdClient{} + tp, err := NewTokenProvider(context.Background(), client, &TokenProviderOptions{ + TenantID: "test-tenant-id", + Credential: cred, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if tp.TenantID() != "test-tenant-id" { + t.Errorf("TenantID() = %q, want %q", tp.TenantID(), "test-tenant-id") + } + + tok, err := tp.GetToken(context.Background(), policy.TokenRequestOptions{ + Scopes: []string{"https://management.azure.com/.default"}, + }) + if err != nil { + t.Fatalf("GetToken failed: %v", err) + } + + if tok.Token != "test-token" { + t.Errorf("Token = %q, want %q", tok.Token, "test-token") + } +} + +func TestTokenProvider_GetToken_NoScopes(t *testing.T) { + t.Parallel() + + cred := &stubCredential{ + token: azcore.AccessToken{Token: "test-token", ExpiresOn: time.Now().Add(time.Hour)}, + } + + tp := &TokenProvider{ + credential: cred, + tenantID: "tenant", + } + + _, err := tp.GetToken(context.Background(), policy.TokenRequestOptions{}) + if err == nil { + t.Fatal("expected error when no scopes provided") + } +} + +func TestTokenProvider_GetToken_CredentialError(t *testing.T) { + t.Parallel() + + credErr := errors.New("credential unavailable") + cred := &stubCredential{err: credErr} + + tp := &TokenProvider{ + credential: cred, + tenantID: "tenant", + } + + _, err := tp.GetToken(context.Background(), policy.TokenRequestOptions{ + Scopes: []string{"https://management.azure.com/.default"}, + }) + if err == nil { + t.Fatal("expected error when credential fails") + } + + if !errors.Is(err, credErr) { + t.Errorf("error = %v, want wrapping %v", err, credErr) + } +} + +func TestTokenProvider_ImplementsTokenCredential(t *testing.T) { + t.Parallel() + + // Compile-time check is in the production file; this is a runtime confirmation. + var _ azcore.TokenCredential = &TokenProvider{} +}