Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions cli/azd/cmd/container.go
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,9 @@ func registerCommonDependencies(container *ioc.NestedContainer) {
Key: key,
}, nil
})
container.MustRegisterSingleton(func() auth.UserAgent {
return auth.UserAgent(internal.UserAgent())
})
container.MustRegisterScoped(auth.NewManager)
container.MustRegisterSingleton(azapi.NewUserProfileService)
container.MustRegisterScoped(func(authManager *auth.Manager) middleware.CurrentUserAuthManager {
Expand Down
74 changes: 45 additions & 29 deletions cli/azd/pkg/auth/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,13 @@ type Manager struct {
console input.Console
externalAuthCfg ExternalAuthConfiguration
azCli az.AzCli
userAgent string
}

// UserAgent is a typed string for the application user-agent,
// used for dependency injection.
type UserAgent string

type ExternalAuthConfiguration struct {
Endpoint string
Key string
Expand All @@ -114,6 +119,7 @@ func NewManager(
console input.Console,
externalAuthCfg ExternalAuthConfiguration,
azCli az.AzCli,
userAgent UserAgent,
) (*Manager, error) {
cfgRoot, err := config.GetUserConfigDir()
if err != nil {
Expand All @@ -135,10 +141,12 @@ func NewManager(
return nil, fmt.Errorf("joining authority url: %w", err)
}

msalClient := newUserAgentClient(httpClient, string(userAgent))

options := []public.Option{
public.WithCache(newCache(cacheRoot)),
public.WithAuthority(authorityUrl),
public.WithHTTPClient(httpClient),
public.WithHTTPClient(msalClient),
}

publicClientApp, err := public.New(azdClientID, options...)
Expand All @@ -157,9 +165,25 @@ func NewManager(
console: console,
externalAuthCfg: externalAuthCfg,
azCli: azCli,
userAgent: string(userAgent),
}, nil
}

// authClientOptions returns azcore.ClientOptions configured with the custom user-agent policy
// for use with Azure Identity SDK credentials.
func (m *Manager) authClientOptions() azcore.ClientOptions {
opts := azcore.ClientOptions{
Transport: m.httpClient,
Cloud: m.cloud.Configuration,
}
if m.userAgent != "" {
opts.Telemetry = policy.TelemetryOptions{
ApplicationID: m.userAgent,
}
}
return opts
}

// LoginScopes returns the default scopes requested when logging in.
func LoginScopes(cloud *cloud.Cloud) []string {
arm := cloud.Configuration.Services[azcloud.ResourceManager]
Expand Down Expand Up @@ -464,7 +488,9 @@ func (m *Manager) GetLoggedInServicePrincipalTenantID(ctx context.Context) (*str
}

func (m *Manager) newCredentialFromManagedIdentity(clientID string) (azcore.TokenCredential, error) {
options := &azidentity.ManagedIdentityCredentialOptions{}
options := &azidentity.ManagedIdentityCredentialOptions{
ClientOptions: m.authClientOptions(),
}
if clientID != "" {
options.ID = azidentity.ClientID(clientID)
}
Expand All @@ -483,12 +509,7 @@ func (m *Manager) newCredentialFromClientSecret(
clientSecret string,
) (azcore.TokenCredential, error) {
options := &azidentity.ClientSecretCredentialOptions{
ClientOptions: azcore.ClientOptions{
Transport: m.httpClient,
// TODO: Inject client options instead? this can be done if we're OK
// using the default user agent string.
Cloud: m.cloud.Configuration,
},
ClientOptions: m.authClientOptions(),
}
cred, err := azidentity.NewClientSecretCredential(tenantID, clientID, clientSecret, options)
if err != nil {
Expand All @@ -514,12 +535,7 @@ func (m *Manager) newCredentialFromClientCertificate(
}

options := &azidentity.ClientCertificateCredentialOptions{
ClientOptions: azcore.ClientOptions{
Transport: m.httpClient,
// TODO: Inject client options instead? this can be done if we're OK
// using the default user agent string.
Cloud: m.cloud.Configuration,
},
ClientOptions: m.authClientOptions(),
}
cred, err := azidentity.NewClientCertificateCredential(
tenantID, clientID, certs, key, options)
Expand All @@ -537,12 +553,7 @@ func (m *Manager) newCredentialFromFederatedTokenProvider(
provider federatedTokenProvider,
serviceConnectionID *string,
) (azcore.TokenCredential, error) {
clientOptions := azcore.ClientOptions{
Transport: m.httpClient,
// TODO: Inject client options instead? this can be done if we're OK
// using the default user agent string.
Cloud: m.cloud.Configuration,
}
clientOptions := m.authClientOptions()

switch provider {
case gitHubFederatedTokenProvider:
Expand Down Expand Up @@ -845,7 +856,9 @@ func (m *Manager) LoginWithDeviceCode(
}

func (m *Manager) LoginWithManagedIdentity(ctx context.Context, clientID string) (azcore.TokenCredential, error) {
options := &azidentity.ManagedIdentityCredentialOptions{}
options := &azidentity.ManagedIdentityCredentialOptions{
ClientOptions: m.authClientOptions(),
}
if clientID != "" {
options.ID = azidentity.ClientID(clientID)
}
Expand All @@ -865,7 +878,11 @@ func (m *Manager) LoginWithManagedIdentity(ctx context.Context, clientID string)
func (m *Manager) LoginWithServicePrincipalSecret(
ctx context.Context, tenantId, clientId, clientSecret string,
) (azcore.TokenCredential, error) {
cred, err := azidentity.NewClientSecretCredential(tenantId, clientId, clientSecret, nil)
opts := &azidentity.ClientSecretCredentialOptions{
ClientOptions: m.authClientOptions(),
}
cred, err := azidentity.NewClientSecretCredential(
tenantId, clientId, clientSecret, opts)
if err != nil {
return nil, fmt.Errorf("creating credential: %w", err)
}
Expand All @@ -891,7 +908,11 @@ func (m *Manager) LoginWithServicePrincipalCertificate(
return nil, fmt.Errorf("parsing certificate: %w", err)
}

cred, err := azidentity.NewClientCertificateCredential(tenantId, clientId, certs, key, nil)
certOpts := &azidentity.ClientCertificateCredentialOptions{
ClientOptions: m.authClientOptions(),
}
cred, err := azidentity.NewClientCertificateCredential(
tenantId, clientId, certs, key, certOpts)
if err != nil {
return nil, fmt.Errorf("creating credential: %w", err)
}
Expand Down Expand Up @@ -944,12 +965,7 @@ func (m *Manager) LoginWithAzurePipelinesFederatedTokenProvider(
}

options := &azidentity.AzurePipelinesCredentialOptions{
ClientOptions: azcore.ClientOptions{
Transport: m.httpClient,
// TODO: Inject client options instead? this can be done if we're OK
// using the default user agent string.
Cloud: m.cloud.Configuration,
},
ClientOptions: m.authClientOptions(),
}

cred, err := azidentity.NewAzurePipelinesCredential(tenantID, clientID, serviceConnectionID, systemAccessToken, options)
Expand Down
36 changes: 36 additions & 0 deletions cli/azd/pkg/auth/user_agent_client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package auth

import "net/http"

// userAgentClient wraps an HttpClient to inject a User-Agent header on all requests.
type userAgentClient struct {
inner HttpClient
userAgent string
}

func newUserAgentClient(inner HttpClient, userAgent string) HttpClient {
if userAgent == "" {
return inner
}
return &userAgentClient{inner: inner, userAgent: userAgent}
}

func (c *userAgentClient) Do(req *http.Request) (*http.Response, error) {
if req.Header == nil {
req.Header = make(http.Header)
}
existingUA := req.Header.Get("User-Agent")
if existingUA == "" {
req.Header.Set("User-Agent", c.userAgent)
} else {
req.Header.Set("User-Agent", existingUA+","+c.userAgent)
}
return c.inner.Do(req)
}

func (c *userAgentClient) CloseIdleConnections() {
c.inner.CloseIdleConnections()
}
86 changes: 86 additions & 0 deletions cli/azd/pkg/auth/user_agent_client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package auth

import (
"net/http"
"testing"

"github.com/stretchr/testify/require"
)

type mockHttpClient struct {
lastRequest *http.Request
}

func (m *mockHttpClient) Do(req *http.Request) (*http.Response, error) {
m.lastRequest = req
return &http.Response{StatusCode: 200}, nil
}

func (m *mockHttpClient) CloseIdleConnections() {}

func TestUserAgentClient(t *testing.T) {
tests := []struct {
name string
userAgent string
existingUserAgent string
nilHeader bool
expectedUserAgent string
expectWrapped bool
}{
{
name: "SetsUserAgentWhenEmpty",
userAgent: "azdev/1.0.0",
existingUserAgent: "",
expectedUserAgent: "azdev/1.0.0",
expectWrapped: true,
},
{
name: "AppendsToExistingUserAgent",
userAgent: "azdev/1.0.0",
existingUserAgent: "existing-agent/2.0",
expectedUserAgent: "existing-agent/2.0,azdev/1.0.0",
expectWrapped: true,
},
{
name: "EmptyUserAgentReturnsInnerClient",
userAgent: "",
expectWrapped: false,
},
{
name: "HandlesNilHeader",
userAgent: "azdev/1.0.0",
nilHeader: true,
expectedUserAgent: "azdev/1.0.0",
expectWrapped: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
inner := &mockHttpClient{}
client := newUserAgentClient(inner, tt.userAgent)

if !tt.expectWrapped {
// Should return the inner client unchanged
require.Equal(t, inner, client)
return
}

req, err := http.NewRequest("GET", "https://example.com", nil)
require.NoError(t, err)

if tt.nilHeader {
req.Header = nil
} else if tt.existingUserAgent != "" {
req.Header.Set("User-Agent", tt.existingUserAgent)
}

_, err = client.Do(req)
require.NoError(t, err)
require.Equal(t, tt.expectedUserAgent, inner.lastRequest.Header.Get("User-Agent"))
})
}
}
1 change: 1 addition & 0 deletions cli/azd/pkg/devcentersdk/developer_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ func Test_DevCenter_Client(t *testing.T) {
mockContext.Console,
auth.ExternalAuthConfiguration{},
azCli,
"",
)
require.NoError(t, err)

Expand Down
1 change: 1 addition & 0 deletions cli/azd/test/functional/remote_state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ func createBlobClient(
httpClient, mockContext.Console,
auth.ExternalAuthConfiguration{},
azCli,
"",
)
require.NoError(t, err)

Expand Down
Loading