diff --git a/cli/azd/cmd/container.go b/cli/azd/cmd/container.go index c5a373068a9..956cf1319e1 100644 --- a/cli/azd/cmd/container.go +++ b/cli/azd/cmd/container.go @@ -634,6 +634,9 @@ func registerCommonDependencies(container *ioc.NestedContainer) { Key: key, }, nil }) + container.MustRegisterSingleton(func() auth.UserAgent { + return auth.UserAgent(internal.UserAgent()) + }) container.MustRegisterScoped(auth.NewManager) container.MustRegisterSingleton(azapi.NewUserProfileService) container.MustRegisterScoped(func(authManager *auth.Manager) middleware.CurrentUserAuthManager { diff --git a/cli/azd/pkg/auth/manager.go b/cli/azd/pkg/auth/manager.go index 194d2313bfe..71fa9ed63e5 100644 --- a/cli/azd/pkg/auth/manager.go +++ b/cli/azd/pkg/auth/manager.go @@ -98,8 +98,13 @@ type Manager struct { console input.Console externalAuthCfg ExternalAuthConfiguration azCli az.AzCli + userAgent string } +// UserAgent is a typed string for the application user-agent, +// used for dependency injection. +type UserAgent string + type ExternalAuthConfiguration struct { Endpoint string Key string @@ -114,6 +119,7 @@ func NewManager( console input.Console, externalAuthCfg ExternalAuthConfiguration, azCli az.AzCli, + userAgent UserAgent, ) (*Manager, error) { cfgRoot, err := config.GetUserConfigDir() if err != nil { @@ -135,10 +141,12 @@ func NewManager( return nil, fmt.Errorf("joining authority url: %w", err) } + msalClient := newUserAgentClient(httpClient, string(userAgent)) + options := []public.Option{ public.WithCache(newCache(cacheRoot)), public.WithAuthority(authorityUrl), - public.WithHTTPClient(httpClient), + public.WithHTTPClient(msalClient), } publicClientApp, err := public.New(azdClientID, options...) @@ -157,9 +165,25 @@ func NewManager( console: console, externalAuthCfg: externalAuthCfg, azCli: azCli, + userAgent: string(userAgent), }, nil } +// authClientOptions returns azcore.ClientOptions configured with the custom user-agent policy +// for use with Azure Identity SDK credentials. +func (m *Manager) authClientOptions() azcore.ClientOptions { + opts := azcore.ClientOptions{ + Transport: m.httpClient, + Cloud: m.cloud.Configuration, + } + if m.userAgent != "" { + opts.Telemetry = policy.TelemetryOptions{ + ApplicationID: m.userAgent, + } + } + return opts +} + // LoginScopes returns the default scopes requested when logging in. func LoginScopes(cloud *cloud.Cloud) []string { arm := cloud.Configuration.Services[azcloud.ResourceManager] @@ -464,7 +488,9 @@ func (m *Manager) GetLoggedInServicePrincipalTenantID(ctx context.Context) (*str } func (m *Manager) newCredentialFromManagedIdentity(clientID string) (azcore.TokenCredential, error) { - options := &azidentity.ManagedIdentityCredentialOptions{} + options := &azidentity.ManagedIdentityCredentialOptions{ + ClientOptions: m.authClientOptions(), + } if clientID != "" { options.ID = azidentity.ClientID(clientID) } @@ -483,12 +509,7 @@ func (m *Manager) newCredentialFromClientSecret( clientSecret string, ) (azcore.TokenCredential, error) { options := &azidentity.ClientSecretCredentialOptions{ - ClientOptions: azcore.ClientOptions{ - Transport: m.httpClient, - // TODO: Inject client options instead? this can be done if we're OK - // using the default user agent string. - Cloud: m.cloud.Configuration, - }, + ClientOptions: m.authClientOptions(), } cred, err := azidentity.NewClientSecretCredential(tenantID, clientID, clientSecret, options) if err != nil { @@ -514,12 +535,7 @@ func (m *Manager) newCredentialFromClientCertificate( } options := &azidentity.ClientCertificateCredentialOptions{ - ClientOptions: azcore.ClientOptions{ - Transport: m.httpClient, - // TODO: Inject client options instead? this can be done if we're OK - // using the default user agent string. - Cloud: m.cloud.Configuration, - }, + ClientOptions: m.authClientOptions(), } cred, err := azidentity.NewClientCertificateCredential( tenantID, clientID, certs, key, options) @@ -537,12 +553,7 @@ func (m *Manager) newCredentialFromFederatedTokenProvider( provider federatedTokenProvider, serviceConnectionID *string, ) (azcore.TokenCredential, error) { - clientOptions := azcore.ClientOptions{ - Transport: m.httpClient, - // TODO: Inject client options instead? this can be done if we're OK - // using the default user agent string. - Cloud: m.cloud.Configuration, - } + clientOptions := m.authClientOptions() switch provider { case gitHubFederatedTokenProvider: @@ -845,7 +856,9 @@ func (m *Manager) LoginWithDeviceCode( } func (m *Manager) LoginWithManagedIdentity(ctx context.Context, clientID string) (azcore.TokenCredential, error) { - options := &azidentity.ManagedIdentityCredentialOptions{} + options := &azidentity.ManagedIdentityCredentialOptions{ + ClientOptions: m.authClientOptions(), + } if clientID != "" { options.ID = azidentity.ClientID(clientID) } @@ -865,7 +878,11 @@ func (m *Manager) LoginWithManagedIdentity(ctx context.Context, clientID string) func (m *Manager) LoginWithServicePrincipalSecret( ctx context.Context, tenantId, clientId, clientSecret string, ) (azcore.TokenCredential, error) { - cred, err := azidentity.NewClientSecretCredential(tenantId, clientId, clientSecret, nil) + opts := &azidentity.ClientSecretCredentialOptions{ + ClientOptions: m.authClientOptions(), + } + cred, err := azidentity.NewClientSecretCredential( + tenantId, clientId, clientSecret, opts) if err != nil { return nil, fmt.Errorf("creating credential: %w", err) } @@ -891,7 +908,11 @@ func (m *Manager) LoginWithServicePrincipalCertificate( return nil, fmt.Errorf("parsing certificate: %w", err) } - cred, err := azidentity.NewClientCertificateCredential(tenantId, clientId, certs, key, nil) + certOpts := &azidentity.ClientCertificateCredentialOptions{ + ClientOptions: m.authClientOptions(), + } + cred, err := azidentity.NewClientCertificateCredential( + tenantId, clientId, certs, key, certOpts) if err != nil { return nil, fmt.Errorf("creating credential: %w", err) } @@ -944,12 +965,7 @@ func (m *Manager) LoginWithAzurePipelinesFederatedTokenProvider( } options := &azidentity.AzurePipelinesCredentialOptions{ - ClientOptions: azcore.ClientOptions{ - Transport: m.httpClient, - // TODO: Inject client options instead? this can be done if we're OK - // using the default user agent string. - Cloud: m.cloud.Configuration, - }, + ClientOptions: m.authClientOptions(), } cred, err := azidentity.NewAzurePipelinesCredential(tenantID, clientID, serviceConnectionID, systemAccessToken, options) diff --git a/cli/azd/pkg/auth/user_agent_client.go b/cli/azd/pkg/auth/user_agent_client.go new file mode 100644 index 00000000000..4a3be541ea5 --- /dev/null +++ b/cli/azd/pkg/auth/user_agent_client.go @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package auth + +import "net/http" + +// userAgentClient wraps an HttpClient to inject a User-Agent header on all requests. +type userAgentClient struct { + inner HttpClient + userAgent string +} + +func newUserAgentClient(inner HttpClient, userAgent string) HttpClient { + if userAgent == "" { + return inner + } + return &userAgentClient{inner: inner, userAgent: userAgent} +} + +func (c *userAgentClient) Do(req *http.Request) (*http.Response, error) { + if req.Header == nil { + req.Header = make(http.Header) + } + existingUA := req.Header.Get("User-Agent") + if existingUA == "" { + req.Header.Set("User-Agent", c.userAgent) + } else { + req.Header.Set("User-Agent", existingUA+","+c.userAgent) + } + return c.inner.Do(req) +} + +func (c *userAgentClient) CloseIdleConnections() { + c.inner.CloseIdleConnections() +} diff --git a/cli/azd/pkg/auth/user_agent_client_test.go b/cli/azd/pkg/auth/user_agent_client_test.go new file mode 100644 index 00000000000..1c36b1f1cb4 --- /dev/null +++ b/cli/azd/pkg/auth/user_agent_client_test.go @@ -0,0 +1,86 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package auth + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +type mockHttpClient struct { + lastRequest *http.Request +} + +func (m *mockHttpClient) Do(req *http.Request) (*http.Response, error) { + m.lastRequest = req + return &http.Response{StatusCode: 200}, nil +} + +func (m *mockHttpClient) CloseIdleConnections() {} + +func TestUserAgentClient(t *testing.T) { + tests := []struct { + name string + userAgent string + existingUserAgent string + nilHeader bool + expectedUserAgent string + expectWrapped bool + }{ + { + name: "SetsUserAgentWhenEmpty", + userAgent: "azdev/1.0.0", + existingUserAgent: "", + expectedUserAgent: "azdev/1.0.0", + expectWrapped: true, + }, + { + name: "AppendsToExistingUserAgent", + userAgent: "azdev/1.0.0", + existingUserAgent: "existing-agent/2.0", + expectedUserAgent: "existing-agent/2.0,azdev/1.0.0", + expectWrapped: true, + }, + { + name: "EmptyUserAgentReturnsInnerClient", + userAgent: "", + expectWrapped: false, + }, + { + name: "HandlesNilHeader", + userAgent: "azdev/1.0.0", + nilHeader: true, + expectedUserAgent: "azdev/1.0.0", + expectWrapped: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + inner := &mockHttpClient{} + client := newUserAgentClient(inner, tt.userAgent) + + if !tt.expectWrapped { + // Should return the inner client unchanged + require.Equal(t, inner, client) + return + } + + req, err := http.NewRequest("GET", "https://example.com", nil) + require.NoError(t, err) + + if tt.nilHeader { + req.Header = nil + } else if tt.existingUserAgent != "" { + req.Header.Set("User-Agent", tt.existingUserAgent) + } + + _, err = client.Do(req) + require.NoError(t, err) + require.Equal(t, tt.expectedUserAgent, inner.lastRequest.Header.Get("User-Agent")) + }) + } +} diff --git a/cli/azd/pkg/devcentersdk/developer_client_test.go b/cli/azd/pkg/devcentersdk/developer_client_test.go index 92faa224187..56139fffa26 100644 --- a/cli/azd/pkg/devcentersdk/developer_client_test.go +++ b/cli/azd/pkg/devcentersdk/developer_client_test.go @@ -36,6 +36,7 @@ func Test_DevCenter_Client(t *testing.T) { mockContext.Console, auth.ExternalAuthConfiguration{}, azCli, + "", ) require.NoError(t, err) diff --git a/cli/azd/test/functional/remote_state_test.go b/cli/azd/test/functional/remote_state_test.go index 3aba4ee5cac..091edb5cb94 100644 --- a/cli/azd/test/functional/remote_state_test.go +++ b/cli/azd/test/functional/remote_state_test.go @@ -89,6 +89,7 @@ func createBlobClient( httpClient, mockContext.Console, auth.ExternalAuthConfiguration{}, azCli, + "", ) require.NoError(t, err)