diff --git a/cli/azd/.vscode/cspell.yaml b/cli/azd/.vscode/cspell.yaml index fc5d6118f45..99b5744d722 100644 --- a/cli/azd/.vscode/cspell.yaml +++ b/cli/azd/.vscode/cspell.yaml @@ -1,6 +1,7 @@ import: ../../../.vscode/cspell.global.yaml words: - agentdetect + - Authenticode - azcloud - azdext - azurefd @@ -13,6 +14,7 @@ words: - cmds - Codespace - Codespaces + - codesign - cooldown - customtype - devcontainers @@ -27,6 +29,7 @@ words: - OPENCODE - opencode - grpcbroker + - msiexec - nosec - oneof - idxs diff --git a/cli/azd/cmd/root.go b/cli/azd/cmd/root.go index 9abf0cff426..80119b55a59 100644 --- a/cli/azd/cmd/root.go +++ b/cli/azd/cmd/root.go @@ -193,6 +193,14 @@ func NewRootCmd( }, }) + root.Add("update", &actions.ActionDescriptorOptions{ + Command: newUpdateCmd(), + FlagsResolver: newUpdateFlags, + ActionResolver: newUpdateAction, + OutputFormats: []output.Format{output.NoneFormat}, + DefaultFormat: output.NoneFormat, + }) + root.Add("vs-server", &actions.ActionDescriptorOptions{ Command: newVsServerCmd(), FlagsResolver: newVsServerFlags, diff --git a/cli/azd/cmd/update.go b/cli/azd/cmd/update.go new file mode 100644 index 00000000000..6e96b08c6d9 --- /dev/null +++ b/cli/azd/cmd/update.go @@ -0,0 +1,342 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "context" + "errors" + "fmt" + "io" + + "github.com/azure/azure-dev/cli/azd/cmd/actions" + "github.com/azure/azure-dev/cli/azd/internal" + "github.com/azure/azure-dev/cli/azd/internal/tracing" + "github.com/azure/azure-dev/cli/azd/internal/tracing/fields" + "github.com/azure/azure-dev/cli/azd/internal/tracing/resource" + "github.com/azure/azure-dev/cli/azd/pkg/alpha" + "github.com/azure/azure-dev/cli/azd/pkg/config" + "github.com/azure/azure-dev/cli/azd/pkg/exec" + "github.com/azure/azure-dev/cli/azd/pkg/input" + "github.com/azure/azure-dev/cli/azd/pkg/installer" + "github.com/azure/azure-dev/cli/azd/pkg/output" + "github.com/azure/azure-dev/cli/azd/pkg/output/ux" + "github.com/azure/azure-dev/cli/azd/pkg/update" + "github.com/spf13/cobra" + "github.com/spf13/pflag" +) + +type updateFlags struct { + channel string + autoUpdate string + checkIntervalHours int + global *internal.GlobalCommandOptions +} + +func newUpdateFlags(cmd *cobra.Command, global *internal.GlobalCommandOptions) *updateFlags { + flags := &updateFlags{} + flags.Bind(cmd.Flags(), global) + return flags +} + +func (f *updateFlags) Bind(local *pflag.FlagSet, global *internal.GlobalCommandOptions) { + f.global = global + + local.StringVar( + &f.channel, + "channel", + "", + "Update channel: stable or daily.", + ) + local.StringVar( + &f.autoUpdate, + "auto-update", + "", + "Enable or disable auto-update: on or off.", + ) + local.IntVar( + &f.checkIntervalHours, + "check-interval-hours", + 0, + "Override the update check interval in hours.", + ) +} + +func newUpdateCmd() *cobra.Command { + return &cobra.Command{ + Use: "update", + Short: "Updates azd to the latest version.", + Hidden: true, + } +} + +type updateAction struct { + flags *updateFlags + console input.Console + formatter output.Formatter + writer io.Writer + configManager config.UserConfigManager + commandRunner exec.CommandRunner + alphaFeatureManager *alpha.FeatureManager +} + +func newUpdateAction( + flags *updateFlags, + console input.Console, + formatter output.Formatter, + writer io.Writer, + configManager config.UserConfigManager, + commandRunner exec.CommandRunner, + alphaFeatureManager *alpha.FeatureManager, +) actions.Action { + return &updateAction{ + flags: flags, + console: console, + formatter: formatter, + writer: writer, + configManager: configManager, + commandRunner: commandRunner, + alphaFeatureManager: alphaFeatureManager, + } +} + +func (a *updateAction) Run(ctx context.Context) (*actions.ActionResult, error) { + // Auto-enable the alpha feature if not already enabled. + // The user's intent is clear by running `azd update` directly. + if !a.alphaFeatureManager.IsEnabled(update.FeatureUpdate) { + userCfg, err := a.configManager.Load() + if err != nil { + userCfg = config.NewEmptyConfig() + } + + if err := userCfg.Set(fmt.Sprintf("alpha.%s", update.FeatureUpdate), "on"); err != nil { + return nil, fmt.Errorf("failed to enable update feature: %w", err) + } + + if err := a.configManager.Save(userCfg); err != nil { + return nil, fmt.Errorf("failed to save config: %w", err) + } + + a.console.MessageUxItem(ctx, &ux.MessageTitle{ + Title: fmt.Sprintf("azd update is in alpha. "+ + "To turn off in the future, run `azd config unset alpha.%s`.\n", + update.FeatureUpdate), + }) + } + + // Track install method for telemetry + installedBy := installer.InstalledBy() + tracing.SetUsageAttributes( + fields.UpdateInstallMethod.String(string(installedBy)), + ) + + userConfig, err := a.configManager.Load() + if err != nil { + userConfig = config.NewEmptyConfig() + } + + // Determine current channel BEFORE persisting any flags + currentCfg := update.LoadUpdateConfig(userConfig) + switchingChannels := a.flags.channel != "" && update.Channel(a.flags.channel) != currentCfg.Channel + + // Persist non-channel config flags immediately (auto-update, check-interval) + configChanged, err := a.persistNonChannelFlags(userConfig) + if err != nil { + return nil, err + } + + // If switching channels, persist channel to a temporary config for the version check + // but don't save to disk until after confirmation + if switchingChannels { + newChannel, err := update.ParseChannel(a.flags.channel) + if err != nil { + return nil, err + } + _ = update.SaveChannel(userConfig, newChannel) + configChanged = true + } else if a.flags.channel != "" { + // Same channel explicitly set — just persist it + if err := update.SaveChannel(userConfig, update.Channel(a.flags.channel)); err != nil { + return nil, err + } + configChanged = true + } + + cfg := update.LoadUpdateConfig(userConfig) + + // Track channel for telemetry + tracing.SetUsageAttributes( + fields.UpdateChannel.String(string(cfg.Channel)), + fields.UpdateFromVersion.String(internal.VersionInfo().Version.String()), + ) + + mgr := update.NewManager(a.commandRunner) + + // Block update in CI/CD environments + if resource.IsRunningOnCI() { + tracing.SetUsageAttributes(fields.UpdateResult.String(update.CodeSkippedCI)) + return nil, &update.UpdateError{ + Code: update.CodeSkippedCI, + Err: &internal.ErrorWithSuggestion{ + Err: fmt.Errorf("azd update is not supported in CI/CD environments"), + Suggestion: "Use your pipeline to install the desired version directly.", + }, + } + } + + // Check if the user is trying to switch to daily via a package manager + if a.flags.channel == string(update.ChannelDaily) && update.IsPackageManagerInstall() { + tracing.SetUsageAttributes(fields.UpdateResult.String(update.CodePackageManagerFailed)) + + uninstallCmd := update.PackageManagerUninstallCmd(installedBy) + return nil, &update.UpdateError{ + Code: update.CodePackageManagerFailed, + Err: &internal.ErrorWithSuggestion{ + Err: fmt.Errorf("daily builds aren't available via %s", installedBy), + Suggestion: fmt.Sprintf( + "Uninstall first with: %s\nThen install daily with: "+ + "curl -fsSL https://aka.ms/install-azd.sh | bash -s -- --version daily", + uninstallCmd), + }, + } + } + + // If only config flags were set (no channel change, no update needed), just confirm + if a.onlyConfigFlagsSet() { + if configChanged { + if err := a.configManager.Save(userConfig); err != nil { + return nil, fmt.Errorf("failed to save config: %w", err) + } + } + tracing.SetUsageAttributes(fields.UpdateResult.String(update.CodeSuccess)) + return &actions.ActionResult{ + Message: &actions.ResultMessage{ + Header: "Update preferences saved.", + }, + }, nil + } + + // Check for updates (always fresh for manual invocation) + a.console.ShowSpinner(ctx, "Checking for updates...", input.Step) + versionInfo, err := mgr.CheckForUpdate(ctx, cfg, true) + a.console.StopSpinner(ctx, "", input.StepDone) + + if err != nil { + tracing.SetUsageAttributes(fields.UpdateResult.String(update.CodeVersionCheckFailed)) + return nil, &update.UpdateError{ + Code: update.CodeVersionCheckFailed, Err: err, + } + } + + // Track target version + tracing.SetUsageAttributes( + fields.UpdateToVersion.String(versionInfo.Version), + ) + + if !versionInfo.HasUpdate && !switchingChannels { + currentVersion := internal.VersionInfo().Version.String() + tracing.SetUsageAttributes(fields.UpdateResult.String(update.CodeAlreadyUpToDate)) + + header := fmt.Sprintf("azd is up to date (version %s) on the %s channel.", currentVersion, cfg.Channel) + if cfg.Channel == update.ChannelDaily { + header += " To check for stable updates, run: azd update --channel stable" + } + + return &actions.ActionResult{ + Message: &actions.ResultMessage{ + Header: header, + }, + }, nil + } + + // Confirm channel switch with version details + if switchingChannels { + currentVersion := internal.VersionInfo().Version.String() + confirmMsg := fmt.Sprintf( + "Switch from %s channel (%s) to %s channel (%s)?", + currentCfg.Channel, currentVersion, + cfg.Channel, versionInfo.Version, + ) + + confirm, err := a.console.Confirm(ctx, input.ConsoleOptions{ + Message: confirmMsg, + DefaultValue: true, + }) + + if err != nil || !confirm { + tracing.SetUsageAttributes(fields.UpdateResult.String(update.CodeChannelSwitchDecline)) + a.console.Message(ctx, "Channel switch cancelled.") + return nil, nil + } + } + + // Now persist all config changes (including channel) after confirmation + if configChanged { + if err := a.configManager.Save(userConfig); err != nil { + return nil, fmt.Errorf("failed to save config: %w", err) + } + } + + // Perform the update + a.console.MessageUxItem(ctx, &ux.MessageTitle{ + Title: fmt.Sprintf("Updating azd to %s (%s)", versionInfo.Version, cfg.Channel), + }) + + stdout := a.console.Handles().Stdout + if err := mgr.Update(ctx, cfg, stdout); err != nil { + // UpdateError already has the right code, just track it + var updateErr *update.UpdateError + if errors.As(err, &updateErr) { + tracing.SetUsageAttributes(fields.UpdateResult.String(updateErr.Code)) + } else { + tracing.SetUsageAttributes(fields.UpdateResult.String(update.CodeReplaceFailed)) + } + return nil, err + } + + tracing.SetUsageAttributes(fields.UpdateResult.String(update.CodeSuccess)) + + // Clean up any staged binary now that a manual update succeeded + update.CleanStagedUpdate() + + return &actions.ActionResult{ + Message: &actions.ResultMessage{ + Header: fmt.Sprintf( + "Successfully updated azd to version %s. Changes take effect on next invocation.", + versionInfo.Version, + ), + }, + }, nil +} + +// persistNonChannelFlags saves auto-update and check-interval flags to config. +// Channel is handled separately to allow confirmation before persisting. +func (a *updateAction) persistNonChannelFlags(cfg config.Config) (bool, error) { + changed := false + + if a.flags.autoUpdate != "" { + enabled := a.flags.autoUpdate == "on" + if a.flags.autoUpdate != "on" && a.flags.autoUpdate != "off" { + return false, fmt.Errorf("invalid auto-update value %q, must be \"on\" or \"off\"", a.flags.autoUpdate) + } + if err := update.SaveAutoUpdate(cfg, enabled); err != nil { + return false, err + } + changed = true + } + + if a.flags.checkIntervalHours > 0 { + if err := update.SaveCheckIntervalHours(cfg, a.flags.checkIntervalHours); err != nil { + return false, err + } + changed = true + } + + return changed, nil +} + +// onlyConfigFlagsSet returns true if only config flags were provided (no channel that requires an update). +func (a *updateAction) onlyConfigFlagsSet() bool { + return a.flags.channel == "" && + (a.flags.autoUpdate != "" || a.flags.checkIntervalHours > 0) +} diff --git a/cli/azd/cmd/version.go b/cli/azd/cmd/version.go index fb464321fc2..8298ebc9dc0 100644 --- a/cli/azd/cmd/version.go +++ b/cli/azd/cmd/version.go @@ -10,9 +10,11 @@ import ( "github.com/azure/azure-dev/cli/azd/cmd/actions" "github.com/azure/azure-dev/cli/azd/internal" + "github.com/azure/azure-dev/cli/azd/pkg/alpha" "github.com/azure/azure-dev/cli/azd/pkg/contracts" "github.com/azure/azure-dev/cli/azd/pkg/input" "github.com/azure/azure-dev/cli/azd/pkg/output" + "github.com/azure/azure-dev/cli/azd/pkg/update" "github.com/spf13/cobra" "github.com/spf13/pflag" ) @@ -33,10 +35,11 @@ func newVersionFlags(cmd *cobra.Command, global *internal.GlobalCommandOptions) } type versionAction struct { - flags *versionFlags - formatter output.Formatter - writer io.Writer - console input.Console + flags *versionFlags + formatter output.Formatter + writer io.Writer + console input.Console + alphaFeatureManager *alpha.FeatureManager } func newVersionAction( @@ -44,19 +47,22 @@ func newVersionAction( formatter output.Formatter, writer io.Writer, console input.Console, + alphaFeatureManager *alpha.FeatureManager, ) actions.Action { return &versionAction{ - flags: flags, - formatter: formatter, - writer: writer, - console: console, + flags: flags, + formatter: formatter, + writer: writer, + console: console, + alphaFeatureManager: alphaFeatureManager, } } func (v *versionAction) Run(ctx context.Context) (*actions.ActionResult, error) { switch v.formatter.Kind() { case output.NoneFormat: - fmt.Fprintf(v.console.Handles().Stdout, "azd version %s\n", internal.Version) + channelSuffix := v.channelSuffix() + fmt.Fprintf(v.console.Handles().Stdout, "azd version %s%s\n", internal.Version, channelSuffix) case output.JsonFormat: var result contracts.VersionResult versionSpec := internal.VersionInfo() @@ -72,3 +78,19 @@ func (v *versionAction) Run(ctx context.Context) (*actions.ActionResult, error) return nil, nil } + +// channelSuffix returns a display suffix like " (stable)" or " (daily)". +// Based on the running binary's version string, not the configured channel. +// Only shown when the update alpha feature is enabled. +func (v *versionAction) channelSuffix() string { + if !v.alphaFeatureManager.IsEnabled(update.FeatureUpdate) { + return "" + } + + // Detect from the binary itself: if the version contains "daily.", it's a daily build. + if _, err := update.ParseDailyBuildNumber(internal.Version); err == nil { + return " (daily)" + } + + return " (stable)" +} diff --git a/cli/azd/internal/cmd/errors.go b/cli/azd/internal/cmd/errors.go index 8cce7b9f9cf..5eab4fb7d26 100644 --- a/cli/azd/internal/cmd/errors.go +++ b/cli/azd/internal/cmd/errors.go @@ -32,6 +32,7 @@ import ( "github.com/azure/azure-dev/cli/azd/pkg/pipeline" "github.com/azure/azure-dev/cli/azd/pkg/tools" "github.com/azure/azure-dev/cli/azd/pkg/tools/git" + "github.com/azure/azure-dev/cli/azd/pkg/update" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/codes" ) @@ -56,8 +57,11 @@ func MapError(err error, span tracing.Span) { // internal errors var errWithSuggestion *internal.ErrorWithSuggestion var loginErr *auth.ReLoginRequiredError + var updateErr *update.UpdateError - if errors.As(err, &loginErr) { + if errors.As(err, &updateErr) { + errCode = updateErr.Code + } else if errors.As(err, &loginErr) { errCode = "auth.login_required" } else if errors.As(err, &errWithSuggestion) { errCode = "error.suggestion" @@ -197,6 +201,8 @@ func MapError(err error, span tracing.Span) { errCode = "internal.preview_not_supported" } else if errors.Is(err, provisioning.ErrBindMountOperationDisabled) { errCode = "internal.bind_mount_disabled" + } else if errors.Is(err, update.ErrNeedsElevation) { + errCode = "update.elevationRequired" } else if errors.Is(err, pipeline.ErrRemoteHostIsNotAzDo) { errCode = "internal.remote_not_azdo" } else if isNetworkError(err) { diff --git a/cli/azd/internal/tracing/fields/fields.go b/cli/azd/internal/tracing/fields/fields.go index f347c5788d6..16021d98c2f 100644 --- a/cli/azd/internal/tracing/fields/fields.go +++ b/cli/azd/internal/tracing/fields/fields.go @@ -581,3 +581,37 @@ var ( Purpose: FeatureInsight, } ) + +// Update related fields +var ( + // UpdateChannel is the update channel (stable, daily). + UpdateChannel = AttributeKey{ + Key: attribute.Key("update.channel"), + Classification: SystemMetadata, + Purpose: FeatureInsight, + } + // UpdateInstallMethod is the install method (brew, winget, choco, script, etc.). + UpdateInstallMethod = AttributeKey{ + Key: attribute.Key("update.installMethod"), + Classification: SystemMetadata, + Purpose: FeatureInsight, + } + // UpdateFromVersion is the version before the update. + UpdateFromVersion = AttributeKey{ + Key: attribute.Key("update.fromVersion"), + Classification: SystemMetadata, + Purpose: FeatureInsight, + } + // UpdateToVersion is the target version for the update. + UpdateToVersion = AttributeKey{ + Key: attribute.Key("update.toVersion"), + Classification: SystemMetadata, + Purpose: FeatureInsight, + } + // UpdateResult is the outcome of the update operation. + UpdateResult = AttributeKey{ + Key: attribute.Key("update.result"), + Classification: SystemMetadata, + Purpose: FeatureInsight, + } +) diff --git a/cli/azd/main.go b/cli/azd/main.go index a3ea8f23542..1b2643cad6d 100644 --- a/cli/azd/main.go +++ b/cli/azd/main.go @@ -7,19 +7,16 @@ package main import ( "context" - "encoding/json" "errors" "fmt" "io" - "io/fs" "log" - "net/http" "os" "os/exec" "path/filepath" "runtime" "strconv" - "strings" + "syscall" "time" azcorelog "github.com/Azure/azure-sdk-for-go/sdk/azcore/log" @@ -27,14 +24,15 @@ import ( "github.com/azure/azure-dev/cli/azd/internal" "github.com/azure/azure-dev/cli/azd/internal/telemetry" "github.com/azure/azure-dev/cli/azd/internal/tracing" + "github.com/azure/azure-dev/cli/azd/internal/tracing/resource" + "github.com/azure/azure-dev/cli/azd/pkg/alpha" "github.com/azure/azure-dev/cli/azd/pkg/config" "github.com/azure/azure-dev/cli/azd/pkg/installer" "github.com/azure/azure-dev/cli/azd/pkg/ioc" "github.com/azure/azure-dev/cli/azd/pkg/oneauth" - "github.com/azure/azure-dev/cli/azd/pkg/osutil" "github.com/azure/azure-dev/cli/azd/pkg/output" "github.com/azure/azure-dev/cli/azd/pkg/tools" - "github.com/blang/semver/v4" + "github.com/azure/azure-dev/cli/azd/pkg/update" "github.com/mattn/go-colorable" "github.com/spf13/pflag" ) @@ -64,7 +62,59 @@ func main() { ctx = tracing.ContextFromEnv(ctx) } - latest := make(chan semver.Version) + // Auto-update: check for applied update marker and display banner + if !internal.IsDevVersion() { + if fromVersion, err := update.ReadAppliedMarker(); err == nil && fromVersion != "" { + update.RemoveAppliedMarker() + fmt.Fprintln( + os.Stderr, + output.WithSuccessFormat( + "azd has been auto-updated from %s to %s", fromVersion, internal.Version)) + } + } + + // Auto-update: apply staged binary if one exists (before anything else) + showedElevationWarning := false + if !internal.IsDevVersion() && update.HasStagedUpdate() { + applyConfigMgr := config.NewUserConfigManager(config.NewFileConfigManager(config.NewManager())) + applyCfg, cfgErr := applyConfigMgr.Load() + if cfgErr != nil { + applyCfg = config.NewEmptyConfig() + } + + applyFeatures := alpha.NewFeaturesManagerWithConfig(applyCfg) + updateCfg := update.LoadUpdateConfig(applyCfg) + + if applyFeatures.IsEnabled(update.FeatureUpdate) && updateCfg.AutoUpdate { + appliedPath, err := update.ApplyStagedUpdate() + if errors.Is(err, update.ErrNeedsElevation) { + versionStr := "a new version" + if cache, cacheErr := update.LoadCache(); cacheErr == nil && cache != nil && cache.Version != "" { + versionStr = "version " + cache.Version + } + fmt.Fprintln( + os.Stderr, + output.WithWarningFormat( + "WARNING: azd %s has been downloaded. "+ + "Run 'azd update' to apply it.", versionStr)) + showedElevationWarning = true + } else if err != nil { + log.Printf("failed to apply staged update: %v", err) + } else if appliedPath != "" { + log.Printf("applied staged update, re-executing: %s", appliedPath) + update.WriteAppliedMarker(internal.Version) + if err := reExec(appliedPath); err != nil { + log.Printf("re-exec failed: %v, continuing with current binary", err) + } + // reExec replaces the process; if we get here it failed + } + } else { + // Feature or auto-update was disabled after staging — clean up + update.CleanStagedUpdate() + } + } + + latest := make(chan *update.VersionInfo) go fetchLatestVersion(latest) rootContainer := ioc.NewNestedContainer(nil) @@ -86,7 +136,7 @@ func main() { } } - latestVersion, ok := <-latest + versionInfo, ok := <-latest // If we were able to fetch a latest version, check to see if we are up to date and // print a warning if we are not. Note that we don't print this warning when the CLI version @@ -95,65 +145,41 @@ func main() { // // Don't write this message when JSON output is enabled, since in that case we use stderr to return structured // information about command progress. - if !isJsonOutput() && ok { + if !isJsonOutput() && ok && !suppressUpdateBanner() && !showedElevationWarning { if internal.IsDevVersion() { - // This is a dev build (i.e. built using `go install without setting a version`) - don't print a warning in this - // case log.Printf("eliding update message for dev build") - } else if latestVersion.GT(internal.VersionInfo().Version) { - var upgradeText string - - installedBy := installer.InstalledBy() - if runtime.GOOS == "windows" { - switch installedBy { - case installer.InstallTypePs: - //nolint:lll - upgradeText = "run:\npowershell -ex AllSigned -c \"Invoke-RestMethod 'https://aka.ms/install-azd.ps1' | Invoke-Expression\"\n\nIf the install script was run with custom parameters, ensure that the same parameters are used for the upgrade. For advanced install instructions, see: https://aka.ms/azd/upgrade/windows" - case installer.InstallTypeWinget: - upgradeText = "run:\nwinget upgrade Microsoft.Azd" - case installer.InstallTypeChoco: - upgradeText = "run:\nchoco upgrade azd" - default: - // Also covers "msi" case where the user installed directly - // via MSI - upgradeText = "visit https://aka.ms/azd/upgrade/windows" - } - } else if runtime.GOOS == "linux" { - switch installedBy { - case installer.InstallTypeSh: - //nolint:lll - upgradeText = "run:\ncurl -fsSL https://aka.ms/install-azd.sh | bash\n\nIf the install script was run with custom parameters, ensure that the same parameters are used for the upgrade. For advanced install instructions, see: https://aka.ms/azd/upgrade/linux" - default: - // Also covers "deb" and "rpm" cases which are currently - // documented. When package manager distribution support is - // added, this will need to be updated. - upgradeText = "visit https://aka.ms/azd/upgrade/linux" - } - } else if runtime.GOOS == "darwin" { - switch installedBy { - case installer.InstallTypeBrew: - upgradeText = "run:\nbrew update && brew upgrade azd" - case installer.InstallTypeSh: - //nolint:lll - upgradeText = "run:\ncurl -fsSL https://aka.ms/install-azd.sh | bash\n\nIf the install script was run with custom parameters, ensure that the same parameters are used for the upgrade. For advanced install instructions, see: https://aka.ms/azd/upgrade/mac" - default: - upgradeText = "visit https://aka.ms/azd/upgrade/mac" - } - } else { - // Platform is not recognized, use the generic install link - upgradeText = "visit https://aka.ms/azd/upgrade" + } else if versionInfo.HasUpdate { + currentVersionStr := internal.VersionInfo().Version.String() + latestVersionStr := versionInfo.Version + if versionInfo.BuildNumber > 0 { + latestVersionStr = fmt.Sprintf("%s (build %d)", versionInfo.Version, versionInfo.BuildNumber) } fmt.Fprintln( os.Stderr, output.WithWarningFormat( - "WARNING: your version of azd is out of date, you have %s and the latest version is %s", - internal.VersionInfo().Version.String(), latestVersion.String())) + "WARNING: your version of azd is out of date, you have %s and the latest %s version is %s", + currentVersionStr, versionInfo.Channel, latestVersionStr)) fmt.Fprintln(os.Stderr) - fmt.Fprintln( - os.Stderr, - output.WithWarningFormat(`To update to the latest version, %s`, - upgradeText)) + + // Show "azd update" hint only if the update feature is enabled, + // otherwise show the original platform-specific upgrade instructions. + configMgr := config.NewUserConfigManager(config.NewFileConfigManager(config.NewManager())) + userCfg, cfgErr := configMgr.Load() + if cfgErr != nil { + userCfg = config.NewEmptyConfig() + } + featureManager := alpha.NewFeaturesManagerWithConfig(userCfg) + if featureManager.IsEnabled(update.FeatureUpdate) { + fmt.Fprintln( + os.Stderr, + output.WithWarningFormat("To update to the latest version, run: azd update")) + } else { + upgradeText := platformUpgradeText() + fmt.Fprintln( + os.Stderr, + output.WithWarningFormat("To update to the latest version, %s", upgradeText)) + } } } @@ -176,15 +202,11 @@ func main() { } } -// updateCheckCacheFileName is the name of the file created in the azd configuration directory -// which is used to cache version information for our up to date check. -const updateCheckCacheFileName = "update-check.json" - -// fetchLatestVersion fetches the latest version of the CLI and sends the result -// across the version channel, which it then closes. If the latest version can not -// be determined, the channel is closed without writing a value. -func fetchLatestVersion(version chan<- semver.Version) { - defer close(version) +// fetchLatestVersion checks for a newer version of the CLI using the user's +// configured channel and sends the result across the channel, which it then closes. +// If the latest version can not be determined, the channel is closed without writing a value. +func fetchLatestVersion(result chan<- *update.VersionInfo) { + defer close(result) // Allow the user to skip the update check if they wish, by setting AZD_SKIP_UPDATE_CHECK to // a truthy value. @@ -198,129 +220,36 @@ func fetchLatestVersion(version chan<- semver.Version) { } } - // To avoid fetching the latest version of the CLI on every invocation, we cache the result for a period - // of time, in the user's home directory. - configDir, err := config.GetUserConfigDir() + // Load user config to determine channel + configMgr := config.NewUserConfigManager(config.NewFileConfigManager(config.NewManager())) + userConfig, err := configMgr.Load() if err != nil { - log.Printf("could not determine config directory: %v, skipping update check", err) - return + userConfig = config.NewEmptyConfig() } - cacheFilePath := filepath.Join(configDir, updateCheckCacheFileName) - cacheFile, err := os.ReadFile(cacheFilePath) - if err != nil && !errors.Is(err, fs.ErrNotExist) { - log.Printf("error reading update cache file: %v, skipping update check", err) + cfg := update.LoadUpdateConfig(userConfig) + + mgr := update.NewManager(nil) + versionInfo, err := mgr.CheckForUpdate(context.Background(), cfg, false) + if err != nil { + log.Printf("failed to check for updates: %v, skipping update check", err) return } - // If we were able to read the update file, try to interpret it and use the cached - // value if it is still valid. Note the `err == nil` guard here ensures we don't run - // this logic when the cache file did not exist (since err will be a form of fs.ErrNotExist) - var cachedLatestVersion *semver.Version - if err == nil { - var cache updateCacheFile - if err := json.Unmarshal(cacheFile, &cache); err == nil { - parsedVersion, parseVersionErr := semver.Parse(cache.Version) - parsedExpiresOn, parseExpiresOnErr := time.Parse(time.RFC3339, cache.ExpiresOn) - - if parseVersionErr == nil && parseExpiresOnErr == nil { - if time.Now().UTC().Before(parsedExpiresOn) { - log.Printf("using cached latest version: %s (expires on: %s)", cache.Version, cache.ExpiresOn) - cachedLatestVersion = &parsedVersion - } else { - log.Printf("ignoring cached latest version, it is out of date") - } - } else { - if parseVersionErr != nil { - log.Printf("failed to parse cached version '%s' as a semver: %v,"+ - " ignoring cached value", cache.Version, parseVersionErr) - } - if parseExpiresOnErr != nil { - log.Printf( - "failed to parse cached version expiration time '%s' as a RFC3339"+ - " timestamp: %v, ignoring cached value", - cache.ExpiresOn, - parseExpiresOnErr) - } + // Auto-update: if enabled and an update is available, stage the new binary in the background. + // Skip in CI environments and package manager installs. + if cfg.AutoUpdate && versionInfo.HasUpdate && !update.IsPackageManagerInstall() && + !resource.IsRunningOnCI() { + featureManager := alpha.NewFeaturesManagerWithConfig(userConfig) + if featureManager.IsEnabled(update.FeatureUpdate) { + log.Printf("auto-update: staging update to %s", versionInfo.Version) + if stageErr := mgr.StageUpdate(context.Background(), cfg); stageErr != nil { + log.Printf("auto-update: staging failed: %v", stageErr) } - } else { - log.Printf("could not unmarshal cache file: %v, ignoring cache", err) } } - // If we don't have a cached version we can use, fetch one (and cache it) - if cachedLatestVersion == nil { - log.Print("fetching latest version information for update check") - req, err := http.NewRequest(http.MethodGet, "https://aka.ms/azure-dev/versions/cli/latest", nil) - if err != nil { - log.Printf("failed to create request object: %v, skipping update check", err) - } - - req.Header.Set("User-Agent", internal.UserAgent()) - - res, err := http.DefaultClient.Do(req) - if err != nil { - log.Printf("failed to fetch latest version: %v, skipping update check", err) - return - } - body, err := readToEndAndClose(res.Body) - if err != nil { - log.Printf("failed to read response body: %v, skipping update check", err) - return - } - - if res.StatusCode != http.StatusOK { - log.Printf( - "failed to refresh latest version, http status: %v, body: %v, skipping update check", - res.StatusCode, - body, - ) - return - } - - // Parse the body of the response as a semver, and if it's valid, cache it. - fetchedVersionText := strings.TrimSpace(body) - fetchedVersion, err := semver.Parse(fetchedVersionText) - if err != nil { - log.Printf("failed to parse latest version '%s' as a semver: %v, skipping update check", fetchedVersionText, err) - return - } - - cachedLatestVersion = &fetchedVersion - - // Write the value back to the cache. Note that on these logging paths for errors we do not return - // eagerly, since we have not yet sent the latest versions across the channel (and we don't want to do that until - // we've updated the cache since reader on the other end of the channel will exit the process after it receives this - // value and finishes - // the up to date check, possibly while this go-routine is still running) - if err := os.MkdirAll(filepath.Dir(cacheFilePath), osutil.PermissionFile); err != nil { - log.Printf("failed to create cache folder '%s': %v", filepath.Dir(cacheFilePath), err) - } else { - cacheObject := updateCacheFile{ - Version: fetchedVersionText, - ExpiresOn: time.Now().UTC().Add(24 * time.Hour).Format(time.RFC3339), - } - - // The marshal call can not fail, so we ignore the error. - cacheContents, _ := json.Marshal(cacheObject) - - if err := os.WriteFile(cacheFilePath, cacheContents, osutil.PermissionDirectory); err != nil { - log.Printf("failed to write update cache file: %v", err) - } else { - log.Printf("updated cache file to version %s (expires on: %s)", cacheObject.Version, cacheObject.ExpiresOn) - } - } - } - - // Publish our value, the defer above will close the channel. - version <- *cachedLatestVersion -} - -type updateCacheFile struct { - // The semver of the latest version the CLI - Version string `json:"version"` - // A time at which this cached value expires, stored as an RFC3339 timestamp - ExpiresOn string `json:"expiresOn"` + result <- versionInfo } // isDebugEnabled checks to see if `--debug` was passed with a truthy @@ -346,6 +275,15 @@ func isDebugEnabled() bool { } // isJsonOutput checks to see if `--output` was passed with the value `json` +// suppressUpdateBanner returns true for commands where the "out of date" banner +// adds no value: azd update (stale version in-process), azd config (managing settings). +func suppressUpdateBanner() bool { + if len(os.Args) < 2 { + return false + } + return os.Args[1] == "update" || os.Args[1] == "config" +} + func isJsonOutput() bool { output := "" flags := pflag.NewFlagSet("", pflag.ContinueOnError) @@ -367,13 +305,6 @@ func isJsonOutput() bool { return output == "json" } -func readToEndAndClose(r io.ReadCloser) (string, error) { - defer r.Close() - var buf strings.Builder - _, err := io.Copy(&buf, r) - return buf.String(), err -} - // setupLogging configures log output based on AZD_DEBUG_LOG environment variable // Returns a cleanup function that should be called when the program exits func setupLogging(debugEnabled bool) func() { @@ -437,6 +368,73 @@ func createDailyLogFile() (*os.File, error) { return logFile, nil } +// platformUpgradeText returns the original platform-specific upgrade instructions. +func platformUpgradeText() string { + installedBy := installer.InstalledBy() + + if runtime.GOOS == "windows" { + switch installedBy { + case installer.InstallTypePs: + //nolint:lll + return "run:\npowershell -ex AllSigned -c \"Invoke-RestMethod 'https://aka.ms/install-azd.ps1' | Invoke-Expression\"\n\nIf the install script was run with custom parameters, ensure that the same parameters are used for the upgrade. For advanced install instructions, see: https://aka.ms/azd/upgrade/windows" + case installer.InstallTypeWinget: + return "run:\nwinget upgrade Microsoft.Azd" + case installer.InstallTypeChoco: + return "run:\nchoco upgrade azd" + default: + return "visit https://aka.ms/azd/upgrade/windows" + } + } else if runtime.GOOS == "linux" { + switch installedBy { + case installer.InstallTypeSh: + //nolint:lll + return "run:\ncurl -fsSL https://aka.ms/install-azd.sh | bash\n\nIf the install script was run with custom parameters, ensure that the same parameters are used for the upgrade. For advanced install instructions, see: https://aka.ms/azd/upgrade/linux" + default: + return "visit https://aka.ms/azd/upgrade/linux" + } + } else if runtime.GOOS == "darwin" { + switch installedBy { + case installer.InstallTypeBrew: + return "run:\nbrew update && brew upgrade azd" + case installer.InstallTypeSh: + //nolint:lll + return "run:\ncurl -fsSL https://aka.ms/install-azd.sh | bash\n\nIf the install script was run with custom parameters, ensure that the same parameters are used for the upgrade. For advanced install instructions, see: https://aka.ms/azd/upgrade/mac" + default: + return "visit https://aka.ms/azd/upgrade/mac" + } + } + + return "visit https://aka.ms/azd/upgrade" +} + +// reExec replaces the current process with the binary at the given path, +// passing the same arguments. On Unix, this uses syscall.Exec to replace +// the process in-place. On Windows, it spawns a new process and exits. +func reExec(binaryPath string) error { + args := os.Args + args[0] = binaryPath + + if runtime.GOOS == "windows" { + // Windows doesn't support exec-style process replacement. + // Spawn the new binary and exit. + // #nosec G204 -- binaryPath is the staged azd binary we just verified + cmd := exec.Command(binaryPath, args[1:]...) //nolint:gosec + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + var exitErr *exec.ExitError + if errors.As(err, &exitErr) { + os.Exit(exitErr.ExitCode()) + } + return err + } + os.Exit(0) + } + + return syscall.Exec(binaryPath, args, os.Environ()) //nolint:gosec +} + func startBackgroundUploadProcess() error { // The background upload process executable is ourself execPath, err := os.Executable() diff --git a/cli/azd/pkg/update/config.go b/cli/azd/pkg/update/config.go new file mode 100644 index 00000000000..97e8d5a3b3a --- /dev/null +++ b/cli/azd/pkg/update/config.go @@ -0,0 +1,215 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package update + +import ( + "encoding/json" + "errors" + "fmt" + "io/fs" + "log" + "os" + "path/filepath" + "strconv" + "time" + + "github.com/azure/azure-dev/cli/azd/pkg/alpha" + "github.com/azure/azure-dev/cli/azd/pkg/config" + "github.com/azure/azure-dev/cli/azd/pkg/osutil" +) + +// FeatureUpdate is the alpha feature key for the azd update command. +var FeatureUpdate = alpha.MustFeatureKey("update") + +// Channel represents the update channel for azd builds. +type Channel string + +const ( + // ChannelStable represents the stable release channel. + ChannelStable Channel = "stable" + // ChannelDaily represents the daily build channel. + ChannelDaily Channel = "daily" +) + +// ParseChannel parses a string into a Channel value. +func ParseChannel(s string) (Channel, error) { + switch Channel(s) { + case ChannelStable: + return ChannelStable, nil + case ChannelDaily: + return ChannelDaily, nil + default: + return "", fmt.Errorf("invalid channel %q, must be %q or %q", s, ChannelStable, ChannelDaily) + } +} + +const ( + // configKeyChannel is the config key for the update channel. + configKeyChannel = "updates.channel" + // configKeyAutoUpdate is the config key for auto-update. + configKeyAutoUpdate = "updates.autoUpdate" + // configKeyCheckIntervalHours is the config key for the check interval. + configKeyCheckIntervalHours = "updates.checkIntervalHours" +) + +const ( + // DefaultCheckIntervalStable is the default check interval for the stable channel. + DefaultCheckIntervalStable = 24 * time.Hour + // DefaultCheckIntervalDaily is the default check interval for the daily channel. + DefaultCheckIntervalDaily = 4 * time.Hour +) + +// UpdateConfig holds the user's update preferences. +type UpdateConfig struct { + Channel Channel + AutoUpdate bool + CheckIntervalHours int +} + +// DefaultCheckInterval returns the default check interval for the configured channel. +func (c *UpdateConfig) DefaultCheckInterval() time.Duration { + if c.CheckIntervalHours > 0 { + return time.Duration(c.CheckIntervalHours) * time.Hour + } + + if c.Channel == ChannelDaily { + return DefaultCheckIntervalDaily + } + + return DefaultCheckIntervalStable +} + +// LoadUpdateConfig reads update configuration from the user config. +func LoadUpdateConfig(cfg config.Config) *UpdateConfig { + uc := &UpdateConfig{ + Channel: ChannelStable, + } + + if ch, ok := cfg.GetString(configKeyChannel); ok { + if parsed, err := ParseChannel(ch); err == nil { + uc.Channel = parsed + } + } + + if au, ok := cfg.GetString(configKeyAutoUpdate); ok { + uc.AutoUpdate = au == "on" + } + + if interval, ok := cfg.Get(configKeyCheckIntervalHours); ok { + switch v := interval.(type) { + case float64: + uc.CheckIntervalHours = int(v) + case int: + uc.CheckIntervalHours = v + case string: + if n, err := strconv.Atoi(v); err == nil { + uc.CheckIntervalHours = n + } + } + } + + return uc +} + +// SaveChannel persists the channel to user config. +func SaveChannel(cfg config.Config, channel Channel) error { + return cfg.Set(configKeyChannel, string(channel)) +} + +// SaveAutoUpdate persists the auto-update setting to user config. +func SaveAutoUpdate(cfg config.Config, enabled bool) error { + value := "off" + if enabled { + value = "on" + } + return cfg.Set(configKeyAutoUpdate, value) +} + +// SaveCheckIntervalHours persists the check interval to user config. +func SaveCheckIntervalHours(cfg config.Config, hours int) error { + return cfg.Set(configKeyCheckIntervalHours, hours) +} + +// CacheFile represents the cached version check result. +type CacheFile struct { + // Channel is the update channel this cache entry is for. + Channel string `json:"channel,omitempty"` + // Version is the semver of the latest version. + Version string `json:"version"` + // BuildNumber is the Azure DevOps build ID (used for daily builds). + BuildNumber int `json:"buildNumber,omitempty"` + // ExpiresOn is the time at which this cached value expires, stored as an RFC3339 timestamp. + ExpiresOn string `json:"expiresOn"` +} + +const cacheFileName = "update-check.json" + +// LoadCache reads the cached version check result. +func LoadCache() (*CacheFile, error) { + configDir, err := config.GetUserConfigDir() + if err != nil { + return nil, fmt.Errorf("could not determine config directory: %w", err) + } + + cacheFilePath := filepath.Join(configDir, cacheFileName) + data, err := os.ReadFile(cacheFilePath) + if err != nil { + if errors.Is(err, fs.ErrNotExist) { + return nil, nil + } + return nil, fmt.Errorf("error reading update cache file: %w", err) + } + + var cache CacheFile + if err := json.Unmarshal(data, &cache); err != nil { + return nil, fmt.Errorf("could not unmarshal cache file: %w", err) + } + + return &cache, nil +} + +// SaveCache writes the cached version check result. +func SaveCache(cache *CacheFile) error { + configDir, err := config.GetUserConfigDir() + if err != nil { + return fmt.Errorf("could not determine config directory: %w", err) + } + + cacheFilePath := filepath.Join(configDir, cacheFileName) + if err := os.MkdirAll(filepath.Dir(cacheFilePath), osutil.PermissionDirectory); err != nil { + return fmt.Errorf("failed to create cache folder: %w", err) + } + + data, _ := json.Marshal(cache) + if err := os.WriteFile(cacheFilePath, data, osutil.PermissionFile); err != nil { + return fmt.Errorf("failed to write update cache file: %w", err) + } + + log.Printf("updated cache file to version %s (expires on: %s)", cache.Version, cache.ExpiresOn) + return nil +} + +// IsCacheValid checks if the cache is still valid (not expired) and matches the given channel. +func IsCacheValid(cache *CacheFile, channel Channel) bool { + if cache == nil { + return false + } + + // If cache has no channel, treat as stable (backward compatibility) + cacheChannel := Channel(cache.Channel) + if cacheChannel == "" { + cacheChannel = ChannelStable + } + + if cacheChannel != channel { + return false + } + + expiresOn, err := time.Parse(time.RFC3339, cache.ExpiresOn) + if err != nil { + return false + } + + return time.Now().UTC().Before(expiresOn) +} diff --git a/cli/azd/pkg/update/config_test.go b/cli/azd/pkg/update/config_test.go new file mode 100644 index 00000000000..2905b554805 --- /dev/null +++ b/cli/azd/pkg/update/config_test.go @@ -0,0 +1,278 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package update + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" + "time" + + "github.com/azure/azure-dev/cli/azd/pkg/config" + "github.com/stretchr/testify/require" +) + +func TestParseChannel(t *testing.T) { + tests := []struct { + name string + input string + want Channel + wantErr bool + }{ + {"stable", "stable", ChannelStable, false}, + {"daily", "daily", ChannelDaily, false}, + {"invalid", "nightly", "", true}, + {"empty", "", "", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ParseChannel(tt.input) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, tt.want, got) + } + }) + } +} + +func TestLoadUpdateConfig(t *testing.T) { + tests := []struct { + name string + config map[string]any + expected *UpdateConfig + }{ + { + name: "defaults", + config: map[string]any{}, + expected: &UpdateConfig{ + Channel: ChannelStable, + }, + }, + { + name: "daily channel with auto-update", + config: map[string]any{ + "updates": map[string]any{ + "channel": "daily", + "autoUpdate": "on", + }, + }, + expected: &UpdateConfig{ + Channel: ChannelDaily, + AutoUpdate: true, + }, + }, + { + name: "custom check interval", + config: map[string]any{ + "updates": map[string]any{ + "checkIntervalHours": float64(8), + }, + }, + expected: &UpdateConfig{ + Channel: ChannelStable, + CheckIntervalHours: 8, + }, + }, + { + name: "invalid channel falls back to stable", + config: map[string]any{ + "updates": map[string]any{ + "channel": "nightly", + }, + }, + expected: &UpdateConfig{ + Channel: ChannelStable, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := config.NewConfig(tt.config) + got := LoadUpdateConfig(cfg) + require.Equal(t, tt.expected, got) + }) + } +} + +func TestUpdateConfigDefaultCheckInterval(t *testing.T) { + tests := []struct { + name string + config UpdateConfig + expected time.Duration + }{ + { + name: "stable default", + config: UpdateConfig{Channel: ChannelStable}, + expected: 24 * time.Hour, + }, + { + name: "daily default", + config: UpdateConfig{Channel: ChannelDaily}, + expected: 4 * time.Hour, + }, + { + name: "custom override", + config: UpdateConfig{Channel: ChannelStable, CheckIntervalHours: 12}, + expected: 12 * time.Hour, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.expected, tt.config.DefaultCheckInterval()) + }) + } +} + +func TestSaveAndLoadConfig(t *testing.T) { + cfg := config.NewConfig(map[string]any{}) + + require.NoError(t, SaveChannel(cfg, ChannelDaily)) + require.NoError(t, SaveAutoUpdate(cfg, true)) + require.NoError(t, SaveCheckIntervalHours(cfg, 6)) + + loaded := LoadUpdateConfig(cfg) + require.Equal(t, ChannelDaily, loaded.Channel) + require.True(t, loaded.AutoUpdate) + require.Equal(t, 6, loaded.CheckIntervalHours) +} + +func TestIsCacheValid(t *testing.T) { + future := time.Now().UTC().Add(1 * time.Hour).Format(time.RFC3339) + past := time.Now().UTC().Add(-1 * time.Hour).Format(time.RFC3339) + + tests := []struct { + name string + cache *CacheFile + channel Channel + want bool + }{ + { + name: "nil cache", + cache: nil, + channel: ChannelStable, + want: false, + }, + { + name: "valid stable cache", + cache: &CacheFile{ + Channel: "stable", + Version: "1.23.6", + ExpiresOn: future, + }, + channel: ChannelStable, + want: true, + }, + { + name: "expired cache", + cache: &CacheFile{ + Channel: "stable", + Version: "1.23.6", + ExpiresOn: past, + }, + channel: ChannelStable, + want: false, + }, + { + name: "channel mismatch", + cache: &CacheFile{ + Channel: "stable", + Version: "1.23.6", + ExpiresOn: future, + }, + channel: ChannelDaily, + want: false, + }, + { + name: "missing channel defaults to stable", + cache: &CacheFile{ + Version: "1.23.6", + ExpiresOn: future, + }, + channel: ChannelStable, + want: true, + }, + { + name: "missing channel, requesting daily", + cache: &CacheFile{ + Version: "1.23.6", + ExpiresOn: future, + }, + channel: ChannelDaily, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, IsCacheValid(tt.cache, tt.channel)) + }) + } +} + +func TestCacheFileJSON(t *testing.T) { + t.Run("new format round-trip", func(t *testing.T) { + cache := &CacheFile{ + Channel: "daily", + Version: "1.24.0-beta.1", + BuildNumber: 98770, + ExpiresOn: "2026-02-26T08:00:00Z", + } + + data, err := json.Marshal(cache) + require.NoError(t, err) + + var loaded CacheFile + require.NoError(t, json.Unmarshal(data, &loaded)) + require.Equal(t, cache, &loaded) + }) + + t.Run("old format backward compatible", func(t *testing.T) { + // Old format without channel or buildNumber + oldJSON := `{"version":"1.23.6","expiresOn":"2026-02-26T01:24:50Z"}` + + var cache CacheFile + require.NoError(t, json.Unmarshal([]byte(oldJSON), &cache)) + require.Equal(t, "1.23.6", cache.Version) + require.Equal(t, "", cache.Channel) // zero value + require.Equal(t, 0, cache.BuildNumber) // zero value + require.Equal(t, "2026-02-26T01:24:50Z", cache.ExpiresOn) + }) +} + +func TestSaveAndLoadCache(t *testing.T) { + // Use a temp dir for config + tempDir := t.TempDir() + t.Setenv("AZD_CONFIG_DIR", tempDir) + + cache := &CacheFile{ + Channel: "daily", + Version: "1.24.0-beta.1", + BuildNumber: 12345, + ExpiresOn: time.Now().UTC().Add(4 * time.Hour).Format(time.RFC3339), + } + + require.NoError(t, SaveCache(cache)) + + // Verify file exists + _, err := os.Stat(filepath.Join(tempDir, cacheFileName)) + require.NoError(t, err) + + loaded, err := LoadCache() + require.NoError(t, err) + require.Equal(t, cache.Channel, loaded.Channel) + require.Equal(t, cache.Version, loaded.Version) + require.Equal(t, cache.BuildNumber, loaded.BuildNumber) +} + +func TestIsPackageManagerInstall(t *testing.T) { + // This test just ensures the function doesn't panic. + // The actual result depends on the install method of the test runner. + _ = IsPackageManagerInstall() +} diff --git a/cli/azd/pkg/update/errors.go b/cli/azd/pkg/update/errors.go new file mode 100644 index 00000000000..19a72eda995 --- /dev/null +++ b/cli/azd/pkg/update/errors.go @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package update + +import "fmt" + +// UpdateError represents a typed update error with a result code for telemetry. +type UpdateError struct { + // Code is the telemetry result code (e.g. "update.downloadFailed"). + Code string + // Err is the underlying error. + Err error +} + +func (e *UpdateError) Error() string { + return e.Err.Error() +} + +func (e *UpdateError) Unwrap() error { + return e.Err +} + +// Result codes matching the design doc. +const ( + CodeSuccess = "update.success" + CodeAlreadyUpToDate = "update.alreadyUpToDate" + CodeDownloadFailed = "update.downloadFailed" + CodeReplaceFailed = "update.replaceFailed" + CodeElevationFailed = "update.elevationFailed" + CodePackageManagerFailed = "update.packageManagerFailed" + CodeVersionCheckFailed = "update.versionCheckFailed" + CodeChannelSwitchDecline = "update.channelSwitchDowngrade" + CodeSkippedCI = "update.skippedCI" + CodeSignatureInvalid = "update.signatureInvalid" + CodeElevationRequired = "update.elevationRequired" + CodeUnsupportedInstallMethod = "update.unsupportedInstallMethod" +) + +func newUpdateError(code string, err error) *UpdateError { + return &UpdateError{Code: code, Err: err} +} + +func newUpdateErrorf(code, format string, args ...any) *UpdateError { + return &UpdateError{Code: code, Err: fmt.Errorf(format, args...)} +} diff --git a/cli/azd/pkg/update/manager.go b/cli/azd/pkg/update/manager.go new file mode 100644 index 00000000000..3b126ccbcf1 --- /dev/null +++ b/cli/azd/pkg/update/manager.go @@ -0,0 +1,995 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package update + +import ( + "archive/tar" + "archive/zip" + "compress/gzip" + "context" + "fmt" + "io" + "log" + "net/http" + "os" + osexec "os/exec" + "path/filepath" + "runtime" + "strconv" + "strings" + "time" + + "github.com/azure/azure-dev/cli/azd/internal" + "github.com/azure/azure-dev/cli/azd/pkg/config" + "github.com/azure/azure-dev/cli/azd/pkg/exec" + "github.com/azure/azure-dev/cli/azd/pkg/installer" + "github.com/azure/azure-dev/cli/azd/pkg/osutil" + "github.com/blang/semver/v4" +) + +const ( + // stableVersionURL is the endpoint that returns the latest stable version. + stableVersionURL = "https://aka.ms/azure-dev/versions/cli/latest" + // blobBaseURL is the base URL for Azure Blob Storage where azd binaries are hosted. + blobBaseURL = "https://azuresdkartifacts.z5.web.core.windows.net/azd/standalone/release" +) + +// VersionInfo holds the result of a version check. +type VersionInfo struct { + Version string + BuildNumber int + Channel Channel + HasUpdate bool +} + +// Manager handles checking for and applying azd updates. +type Manager struct { + commandRunner exec.CommandRunner +} + +// NewManager creates a new update Manager. +func NewManager(commandRunner exec.CommandRunner) *Manager { + return &Manager{ + commandRunner: commandRunner, + } +} + +// CheckForUpdate checks whether a newer version of azd is available. +func (m *Manager) CheckForUpdate(ctx context.Context, cfg *UpdateConfig, ignoreCache bool) (*VersionInfo, error) { + if !ignoreCache { + cache, err := LoadCache() + if err != nil { + log.Printf("error loading update cache: %v", err) + } + + if IsCacheValid(cache, cfg.Channel) { + return m.buildVersionInfoFromCache(cache, cfg.Channel) + } + } + + var info *VersionInfo + var err error + + switch cfg.Channel { + case ChannelStable: + info, err = m.checkStableVersion(ctx) + case ChannelDaily: + info, err = m.checkDailyVersion(ctx) + default: + return nil, fmt.Errorf("unsupported channel: %s", cfg.Channel) + } + + if err != nil { + return nil, err + } + + // Update cache + cacheEntry := &CacheFile{ + Channel: string(cfg.Channel), + Version: info.Version, + BuildNumber: info.BuildNumber, + ExpiresOn: time.Now().UTC().Add(cfg.DefaultCheckInterval()).Format(time.RFC3339), + } + + if err := SaveCache(cacheEntry); err != nil { + log.Printf("failed to save update cache: %v", err) + } + + return info, nil +} + +func (m *Manager) buildVersionInfoFromCache(cache *CacheFile, channel Channel) (*VersionInfo, error) { + info := &VersionInfo{ + Version: cache.Version, + BuildNumber: cache.BuildNumber, + Channel: channel, + } + + if channel == ChannelDaily { + // For daily builds, compare cached build number against the running binary's build number. + // Azure DevOps build IDs are globally monotonically increasing, so a higher build number + // always means a newer build regardless of the semver prefix. + currentBuild, currentErr := parseDailyBuildNumber(internal.Version) + if currentErr == nil && currentBuild > 0 { + info.HasUpdate = cache.BuildNumber > currentBuild + } else { + // Current binary is not a daily build (e.g. stable or dev). + // Fall back to semver comparison to avoid suggesting a downgrade + // (e.g. stable 1.23.5 should not "update" to daily 1.5.0). + dailyVersion, parseErr := semver.Parse(cache.Version) + currentVersion := internal.VersionInfo().Version + if parseErr == nil { + info.HasUpdate = dailyVersion.GT(currentVersion) + } else { + info.HasUpdate = true + } + } + } else { + latestVersion, err := semver.Parse(cache.Version) + if err != nil { + return nil, fmt.Errorf("failed to parse cached version: %w", err) + } + currentVersion := internal.VersionInfo().Version + info.HasUpdate = latestVersion.GT(currentVersion) + } + + return info, nil +} + +func (m *Manager) checkStableVersion(ctx context.Context) (*VersionInfo, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, stableVersionURL, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("User-Agent", internal.UserAgent()) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to fetch latest stable version: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to fetch latest stable version, status: %d", resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + versionText := strings.TrimSpace(string(body)) + latestVersion, err := semver.Parse(versionText) + if err != nil { + return nil, fmt.Errorf("failed to parse version %q: %w", versionText, err) + } + + currentVersion := internal.VersionInfo().Version + return &VersionInfo{ + Version: versionText, + Channel: ChannelStable, + HasUpdate: latestVersion.GT(currentVersion), + }, nil +} + +func (m *Manager) checkDailyVersion(ctx context.Context) (*VersionInfo, error) { + versionURL := blobBaseURL + "/daily/version.txt" + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, versionURL, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("User-Agent", internal.UserAgent()) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to fetch daily version info: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to fetch daily version info, status: %d", resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + version := strings.TrimSpace(string(body)) + buildNumber, err := parseDailyBuildNumber(version) + if err != nil { + return nil, fmt.Errorf("failed to parse daily version %q: %w", version, err) + } + + // Compare build numbers to determine if an update is available. + // Azure DevOps build IDs are globally monotonically increasing, so a higher build number + // always means a newer build regardless of the semver prefix (e.g. daily.5935787 > daily.5935780). + // Extract current build number from the running binary's version string. + currentBuild, currentErr := parseDailyBuildNumber(internal.Version) + hasUpdate := true + if currentErr == nil && currentBuild > 0 { + hasUpdate = buildNumber > currentBuild + } else { + // Current binary is not a daily build (e.g. stable or dev). + // Fall back to semver comparison to avoid suggesting a downgrade + // (e.g. stable 1.23.5 should not "update" to daily 1.5.0). + dailyVersion, parseErr := semver.Parse(version) + currentVersion := internal.VersionInfo().Version + if parseErr == nil { + hasUpdate = dailyVersion.GT(currentVersion) + } + } + + return &VersionInfo{ + Version: version, + BuildNumber: buildNumber, + Channel: ChannelDaily, + HasUpdate: hasUpdate, + }, nil +} + +// ParseDailyBuildNumber extracts the build number from a daily version string. +// e.g. "1.24.0-beta.1-daily.5935787" → 5935787 +// Also handles internal.Version format: "1.24.0-beta.1-daily.5935787 (commit ...)" → 5935787 +func ParseDailyBuildNumber(version string) (int, error) { + return parseDailyBuildNumber(version) +} + +func parseDailyBuildNumber(version string) (int, error) { + const prefix = "daily." + idx := strings.LastIndex(version, prefix) + if idx == -1 { + return 0, fmt.Errorf("version %q does not contain %q suffix", version, prefix) + } + + numStr := version[idx+len(prefix):] + // Trim anything after the build number (e.g. " (commit ...)" from internal.Version) + if spaceIdx := strings.IndexByte(numStr, ' '); spaceIdx != -1 { + numStr = numStr[:spaceIdx] + } + + buildNumber, err := strconv.Atoi(numStr) + if err != nil { + return 0, fmt.Errorf("invalid build number %q in version %q: %w", numStr, version, err) + } + + return buildNumber, nil +} + +// Update performs the update based on the install method. +func (m *Manager) Update(ctx context.Context, cfg *UpdateConfig, writer io.Writer) error { + installedBy := installer.InstalledBy() + + switch installedBy { + case installer.InstallTypeBrew: + return m.updateViaPackageManager(ctx, "brew", []string{"upgrade", "azd"}, writer) + case installer.InstallTypeWinget: + return m.updateViaPackageManager(ctx, "winget", []string{"upgrade", "Microsoft.Azd"}, writer) + case installer.InstallTypeChoco: + return m.updateViaPackageManager(ctx, "choco", []string{"upgrade", "azd"}, writer) + case installer.InstallTypePs, installer.InstallTypeSh, installer.InstallTypeDeb, + installer.InstallTypeRpm, installer.InstallTypeUnknown: + if runtime.GOOS == "windows" { + return m.updateViaMSI(ctx, cfg, writer) + } + return m.updateViaBinaryDownload(ctx, cfg, writer) + default: + if runtime.GOOS == "windows" { + return m.updateViaMSI(ctx, cfg, writer) + } + return m.updateViaBinaryDownload(ctx, cfg, writer) + } +} + +func (m *Manager) updateViaPackageManager( + ctx context.Context, + command string, + args []string, + writer io.Writer, +) error { + fmt.Fprintf(writer, "Updating azd via %s...\n", command) + + runArgs := exec.NewRunArgs(command, args...) + runArgs = runArgs.WithStdOut(writer).WithStdErr(writer).WithInteractive(true) + + result, err := m.commandRunner.Run(ctx, runArgs) + if err != nil { + return newUpdateError(CodePackageManagerFailed, err) + } + + if result.ExitCode != 0 { + return newUpdateErrorf(CodePackageManagerFailed, + "package manager update failed with exit code %d", result.ExitCode) + } + + return nil +} + +func (m *Manager) updateViaMSI(ctx context.Context, cfg *UpdateConfig, writer io.Writer) error { + msiURL, err := m.buildMSIDownloadURL(cfg.Channel) + if err != nil { + return err + } + + fmt.Fprintf(writer, "Downloading MSI from %s...\n", msiURL) + + tempDir := os.TempDir() + msiPath := filepath.Join(tempDir, "azd-windows-amd64.msi") + + if err := m.downloadFile(ctx, msiURL, msiPath, writer); err != nil { + return newUpdateError(CodeDownloadFailed, err) + } + defer os.Remove(msiPath) + + fmt.Fprintf(writer, "Installing update via MSI...\n") + runArgs := exec.NewRunArgs("msiexec", "/i", msiPath, "/qn") + runArgs = runArgs.WithStdOut(writer).WithStdErr(writer) + + result, err := m.commandRunner.Run(ctx, runArgs) + if err != nil { + return newUpdateError(CodeReplaceFailed, err) + } + + if result.ExitCode != 0 { + return newUpdateErrorf(CodeReplaceFailed, + "MSI installation failed with exit code %d", result.ExitCode) + } + + return nil +} + +func (m *Manager) updateViaBinaryDownload(ctx context.Context, cfg *UpdateConfig, writer io.Writer) error { + downloadURL, err := m.buildDownloadURL(cfg.Channel) + if err != nil { + return err + } + + fmt.Fprintf(writer, "Downloading azd from %s...\n", downloadURL) + + // Download to a temp file + tempDir := os.TempDir() + archiveName := fmt.Sprintf("azd-%s-%s%s", runtime.GOOS, runtime.GOARCH, archiveExtension()) + tempArchivePath := filepath.Join(tempDir, archiveName) + + if err := m.downloadFile(ctx, downloadURL, tempArchivePath, writer); err != nil { + return newUpdateError(CodeDownloadFailed, err) + } + defer os.Remove(tempArchivePath) + + // Extract the binary + binaryName := "azd" + if runtime.GOOS == "windows" { + binaryName = "azd.exe" + } + + tempBinaryPath := filepath.Join(tempDir, "azd-update-"+binaryName) + if err := extractBinary(tempArchivePath, binaryName, tempBinaryPath); err != nil { + return fmt.Errorf("extraction failed: %w", err) + } + defer os.Remove(tempBinaryPath) + + // Make executable on unix + if runtime.GOOS != "windows" { + if err := os.Chmod(tempBinaryPath, 0o755); err != nil { + return fmt.Errorf("failed to set permissions: %w", err) + } + } + + // Verify code signature (macOS and Windows only) + if err := m.verifyCodeSignature(ctx, tempBinaryPath, writer); err != nil { + return newUpdateError(CodeSignatureInvalid, err) + } + + // Determine current binary location + currentBinaryPath, err := currentExePath() + if err != nil { + return fmt.Errorf("failed to determine current binary path: %w", err) + } + + // Replace the binary (may need elevation) + fmt.Fprintf(writer, "Installing update...\n") + if err := m.replaceBinary(ctx, tempBinaryPath, currentBinaryPath); err != nil { + return newUpdateError(CodeReplaceFailed, err) + } + + return nil +} + +func (m *Manager) buildDownloadURL(channel Channel) (string, error) { + platform := runtime.GOOS + arch := runtime.GOARCH + ext := archiveExtension() + + var folder string + switch channel { + case ChannelStable: + folder = "stable" + case ChannelDaily: + folder = "daily" + default: + return "", fmt.Errorf("unsupported channel: %s", channel) + } + + return fmt.Sprintf("%s/%s/azd-%s-%s%s", blobBaseURL, folder, platform, arch, ext), nil +} + +func (m *Manager) buildMSIDownloadURL(channel Channel) (string, error) { + var folder string + switch channel { + case ChannelStable: + folder = "stable" + case ChannelDaily: + folder = "daily" + default: + return "", fmt.Errorf("unsupported channel: %s", channel) + } + + return fmt.Sprintf("%s/%s/azd-windows-amd64.msi", blobBaseURL, folder), nil +} + +func archiveExtension() string { + if runtime.GOOS == "linux" { + return ".tar.gz" + } + return ".zip" +} + +func (m *Manager) downloadFile(ctx context.Context, url string, destPath string, writer io.Writer) error { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return err + } + req.Header.Set("User-Agent", internal.UserAgent()) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("download failed with status %d", resp.StatusCode) + } + + out, err := os.Create(destPath) + if err != nil { + return err + } + defer out.Close() + + // Show progress + contentLength := resp.ContentLength + var src io.Reader = resp.Body + if contentLength > 0 { + src = &progressReader{ + reader: resp.Body, + total: contentLength, + writer: writer, + } + } + + _, err = io.Copy(out, src) + if contentLength > 0 { + fmt.Fprintln(writer) // newline after progress + } + + return err +} + +// verifyCodeSignature checks the code signature of the downloaded binary. +// On macOS, it uses codesign -v. On Windows, it uses Get-AuthenticodeSignature. +// On Linux or if the command runner is nil, verification is skipped gracefully. +func (m *Manager) verifyCodeSignature(ctx context.Context, binaryPath string, writer io.Writer) error { + if m.commandRunner == nil { + log.Printf("no command runner available, skipping code signature verification") + return nil + } + + switch runtime.GOOS { + case "darwin": + return m.verifyCodesignMac(ctx, binaryPath, writer) + case "windows": + return m.verifyAuthenticode(ctx, binaryPath, writer) + default: + // Linux has no standard code signing verification tool + log.Printf("code signing verification not available on %s, skipping", runtime.GOOS) + return nil + } +} + +func (m *Manager) verifyCodesignMac(ctx context.Context, binaryPath string, writer io.Writer) error { + runArgs := exec.NewRunArgs("codesign", "-v", "--strict", binaryPath) + result, err := m.commandRunner.Run(ctx, runArgs) + if err != nil { + log.Printf("codesign verification failed: %v, skipping", err) + return nil + } + + if result.ExitCode != 0 { + return fmt.Errorf( + "code signature verification failed for %s (exit code %d): %s", + binaryPath, result.ExitCode, result.Stderr, + ) + } + + fmt.Fprintf(writer, "Code signature verified.\n") + return nil +} + +func (m *Manager) verifyAuthenticode(ctx context.Context, binaryPath string, writer io.Writer) error { + // PowerShell script to check Authenticode signature status + script := fmt.Sprintf( + `$sig = Get-AuthenticodeSignature -FilePath '%s'; if ($sig.Status -ne 'Valid') { `+ + `Write-Error "Signature status: $($sig.Status)"; exit 1 }`, + binaryPath, + ) + + runArgs := exec.NewRunArgs("powershell", "-NoProfile", "-Command", script) + result, err := m.commandRunner.Run(ctx, runArgs) + if err != nil { + log.Printf("Authenticode verification failed: %v, skipping", err) + return nil + } + + if result.ExitCode != 0 { + return fmt.Errorf( + "Authenticode signature verification failed for %s: %s", + binaryPath, result.Stderr, + ) + } + + fmt.Fprintf(writer, "Code signature verified.\n") + return nil +} + +func (m *Manager) replaceBinary(ctx context.Context, newBinaryPath, currentBinaryPath string) error { + // Try direct replacement first + err := os.Rename(newBinaryPath, currentBinaryPath) + if err == nil { + return nil + } + + // If direct rename fails (cross-device or permissions), try copy + err = copyFile(newBinaryPath, currentBinaryPath) + if err == nil { + return nil + } + + // On unix, try with sudo if permission denied + if runtime.GOOS != "windows" { + log.Printf("direct replacement failed (%v), trying with sudo", err) + runArgs := exec.NewRunArgs("sudo", "cp", newBinaryPath, currentBinaryPath) + runArgs = runArgs.WithInteractive(true) + result, sudoErr := m.commandRunner.Run(ctx, runArgs) + if sudoErr != nil { + return newUpdateError(CodeElevationFailed, sudoErr) + } + if result.ExitCode != 0 { + return newUpdateErrorf(CodeElevationFailed, + "sudo copy failed with exit code %d", result.ExitCode) + } + return nil + } + + return fmt.Errorf("failed to replace binary: %w", err) +} + +// currentExePath returns the resolved path of the currently running azd binary. +func currentExePath() (string, error) { + exePath, err := os.Executable() + if err != nil { + return "", err + } + return filepath.EvalSymlinks(exePath) +} + +func copyFile(src, dst string) error { + in, err := os.Open(src) + if err != nil { + return err + } + defer in.Close() + + srcInfo, err := in.Stat() + if err != nil { + return err + } + + out, err := os.Create(dst) + if err != nil { + // On Linux, overwriting a running executable fails with ETXTBSY ("text file busy"). + // Unlink the old file first — the running process retains its fd via the inode — + // then create a new file at the same path. + if removeErr := os.Remove(dst); removeErr == nil { + out, err = os.Create(dst) + } else if os.IsPermission(removeErr) { + // Can't remove the old file due to directory permissions. + // Return the permission error so callers can handle elevation. + err = removeErr + } + } + if err != nil { + return err + } + defer out.Close() + + if _, err = io.Copy(out, in); err != nil { + return err + } + + if err := out.Sync(); err != nil { + return err + } + + // Preserve source file permissions. After remove-then-create, the new file gets + // default 0666 permissions instead of the original executable permissions. + return os.Chmod(dst, srcInfo.Mode().Perm()) +} + +// extractBinary extracts the azd binary from the archive to destPath. +func extractBinary(archivePath, binaryName, destPath string) error { + if strings.HasSuffix(archivePath, ".tar.gz") { + return extractFromTarGz(archivePath, binaryName, destPath) + } + return extractFromZip(archivePath, binaryName, destPath) +} + +func extractFromTarGz(archivePath, binaryName, destPath string) error { + f, err := os.Open(archivePath) + if err != nil { + return err + } + defer f.Close() + + gz, err := gzip.NewReader(f) + if err != nil { + return err + } + defer gz.Close() + + tr := tar.NewReader(gz) + for { + header, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + return err + } + + name := filepath.Base(header.Name) + if name == binaryName || strings.HasPrefix(name, "azd-") { + out, err := os.Create(destPath) + if err != nil { + return err + } + defer out.Close() + + //nolint:gosec + _, err = io.Copy(out, tr) + return err + } + } + + return fmt.Errorf("binary %q not found in archive", binaryName) +} + +func extractFromZip(archivePath, binaryName, destPath string) error { + r, err := zip.OpenReader(archivePath) + if err != nil { + return err + } + defer r.Close() + + for _, f := range r.File { + name := filepath.Base(f.Name) + if name == binaryName || strings.HasPrefix(name, "azd-") { + rc, err := f.Open() + if err != nil { + return err + } + defer rc.Close() + + out, err := os.Create(destPath) + if err != nil { + return err + } + defer out.Close() + + //nolint:gosec + _, err = io.Copy(out, rc) + return err + } + } + + return fmt.Errorf("binary %q not found in archive", binaryName) +} + +// IsPackageManagerInstall returns true if azd was installed via a package manager. +func IsPackageManagerInstall() bool { + switch installer.InstalledBy() { + case installer.InstallTypeBrew, installer.InstallTypeWinget, installer.InstallTypeChoco: + return true + default: + return false + } +} + +// PackageManagerUninstallCmd returns the uninstall command for the detected package manager. +func PackageManagerUninstallCmd(installedBy installer.InstallType) string { + switch installedBy { + case installer.InstallTypeBrew: + return "brew uninstall azd" + case installer.InstallTypeWinget: + return "winget uninstall Microsoft.Azd" + case installer.InstallTypeChoco: + return "choco uninstall azd" + default: + return "your package manager" + } +} + +// progressReader wraps an io.Reader to report download progress. +type progressReader struct { + reader io.Reader + total int64 + current int64 + writer io.Writer + lastPct int +} + +func (pr *progressReader) Read(p []byte) (int, error) { + n, err := pr.reader.Read(p) + pr.current += int64(n) + pct := int(float64(pr.current) / float64(pr.total) * 100) + if pct != pr.lastPct && pct%10 == 0 { + fmt.Fprintf(pr.writer, "\rDownloading... %d%%", pct) + pr.lastPct = pct + } + return n, err +} + +const stagingDirName = "staging" + +// stagingDir returns the path to ~/.azd/staging/. +func stagingDir() (string, error) { + configDir, err := config.GetUserConfigDir() + if err != nil { + return "", err + } + return filepath.Join(configDir, stagingDirName), nil +} + +// StagedBinaryPath returns the path where a staged binary would be placed. +func StagedBinaryPath() (string, error) { + dir, err := stagingDir() + if err != nil { + return "", err + } + + binaryName := "azd" + if runtime.GOOS == "windows" { + binaryName = "azd.exe" + } + + return filepath.Join(dir, binaryName), nil +} + +// HasStagedUpdate returns true if a staged binary exists and is ready to apply. +func HasStagedUpdate() bool { + path, err := StagedBinaryPath() + if err != nil { + return false + } + + info, err := os.Stat(path) + return err == nil && !info.IsDir() && info.Size() > 0 +} + +// StageUpdate downloads the latest binary to ~/.azd/staging/ for later apply. +// This is intended to run in the background without user interaction. +func (m *Manager) StageUpdate(ctx context.Context, cfg *UpdateConfig) error { + // Only stage for direct binary installs, not package managers + if IsPackageManagerInstall() { + log.Printf("auto-update: package manager install, skipping staging") + return nil + } + + // On Windows, updates are applied via MSI (updateViaMSI); staging a standalone binary + // would be unused and potentially inconsistent with the MSI-based install. + if runtime.GOOS == "windows" { + log.Printf("auto-update: windows MSI-based install, skipping staging") + return nil + } + + downloadURL, err := m.buildDownloadURL(cfg.Channel) + if err != nil { + return err + } + + dir, err := stagingDir() + if err != nil { + return err + } + + if err := os.MkdirAll(dir, 0o755); err != nil { + return fmt.Errorf("failed to create staging directory: %w", err) + } + + // Download archive to staging dir + archiveName := fmt.Sprintf("azd-%s-%s%s", runtime.GOOS, runtime.GOARCH, archiveExtension()) + archivePath := filepath.Join(dir, archiveName) + + if err := m.downloadFile(ctx, downloadURL, archivePath, io.Discard); err != nil { + return fmt.Errorf("auto-update download failed: %w", err) + } + defer os.Remove(archivePath) + + // Extract binary to staging dir + binaryName := "azd" + if runtime.GOOS == "windows" { + binaryName = "azd.exe" + } + + stagedPath, err := StagedBinaryPath() + if err != nil { + return err + } + + if err := extractBinary(archivePath, binaryName, stagedPath); err != nil { + return fmt.Errorf("auto-update extraction failed: %w", err) + } + + if runtime.GOOS != "windows" { + if err := os.Chmod(stagedPath, 0o755); err != nil { + return fmt.Errorf("failed to set permissions on staged binary: %w", err) + } + } + + log.Printf("auto-update: staged new binary to %s", stagedPath) + return nil +} + +// CleanStagedUpdate removes any staged binary, e.g. when auto-update is disabled after staging. +func CleanStagedUpdate() { + path, err := StagedBinaryPath() + if err != nil { + return + } + + if _, err := os.Stat(path); err == nil { + os.Remove(path) + dir, _ := stagingDir() + os.Remove(dir) + log.Printf("auto-update: cleaned up staged binary at %s", path) + } +} + +// ErrNeedsElevation is returned when the staged update can't be applied without elevation. +var ErrNeedsElevation = fmt.Errorf("applying staged update requires elevation") + +// ApplyStagedUpdate replaces the current binary with the staged one and cleans up. +// Returns the path to the new binary if applied, or empty string if no staged update exists. +// Returns ErrNeedsElevation if the install location is not writable (e.g. /opt/microsoft/azd/). +func ApplyStagedUpdate() (string, error) { + stagedPath, err := StagedBinaryPath() + if err != nil { + return "", err + } + + if !HasStagedUpdate() { + return "", nil + } + + // Verify the staged binary is valid before applying. + // A background goroutine may have been interrupted mid-download, leaving a truncated file. + if err := verifyStagedBinary(stagedPath); err != nil { + log.Printf("auto-update: staged binary is invalid, cleaning up: %v", err) + os.Remove(stagedPath) + dir, _ := stagingDir() + os.Remove(dir) + return "", nil + } + + currentPath, err := currentExePath() + if err != nil { + return "", fmt.Errorf("failed to determine current binary: %w", err) + } + + // Check if we can write to the install location without elevation + if err := copyFile(stagedPath, currentPath); err != nil { + if os.IsPermission(err) { + // Keep the staged binary — user can apply via 'azd update' + log.Printf("auto-update: install location %s requires elevation, skipping apply", currentPath) + return "", ErrNeedsElevation + } + + // Non-permission error — clean up to avoid retrying a broken stage + os.Remove(stagedPath) + return "", fmt.Errorf("failed to apply staged update: %w", err) + } + + // Clean up staging directory + os.Remove(stagedPath) + dir, _ := stagingDir() + os.Remove(dir) // remove dir if empty + + log.Printf("auto-update: applied staged binary from %s to %s", stagedPath, currentPath) + return currentPath, nil +} + +// verifyStagedBinary performs basic validation on the staged binary. +// Checks minimum file size (catches truncated downloads and non-binary files). +// On macOS, also runs codesign verification. Unsigned binaries (e.g. dev builds) are allowed, +// but binaries with invalid/corrupted signatures are rejected. +func verifyStagedBinary(path string) error { + // Size sanity check — azd binary is typically 40-65 MB. + // A minimum of 1 MB catches truncated downloads and non-binary files + // that codesign would incorrectly report as "not signed at all". + const minBinarySize = 1024 * 1024 // 1 MB + info, err := os.Stat(path) + if err != nil { + return fmt.Errorf("cannot stat staged binary: %w", err) + } + if info.Size() < minBinarySize { + return fmt.Errorf("staged binary too small (%d bytes), likely corrupted", info.Size()) + } + + if runtime.GOOS == "darwin" { + //nolint:gosec // path is not user-controlled — it's the well-known staging directory + cmd := osexec.Command("codesign", "-v", "--strict", path) + if combinedOut, err := cmd.CombinedOutput(); err != nil { + outStr := string(combinedOut) + // "not signed at all" is OK — dev builds and some installs are unsigned. + // Only reject binaries with invalid/corrupted signatures (e.g. truncated downloads). + if strings.Contains(outStr, "not signed") { + log.Printf("auto-update: staged binary is unsigned, allowing: %s", outStr) + return nil + } + return fmt.Errorf("code signature invalid: %s", outStr) + } + } + + return nil +} + +const appliedMarkerFile = "update-applied.txt" + +// WriteAppliedMarker writes a marker file recording the version before auto-update. +// This is read on the next startup to display an "updated" banner. +func WriteAppliedMarker(fromVersion string) { + configDir, err := config.GetUserConfigDir() + if err != nil { + return + } + + path := filepath.Join(configDir, appliedMarkerFile) + _ = os.WriteFile(path, []byte(fromVersion), osutil.PermissionFile) +} + +// ReadAppliedMarker reads the applied update marker and returns the previous version. +func ReadAppliedMarker() (string, error) { + configDir, err := config.GetUserConfigDir() + if err != nil { + return "", err + } + + data, err := os.ReadFile(filepath.Join(configDir, appliedMarkerFile)) + if err != nil { + return "", err + } + + return strings.TrimSpace(string(data)), nil +} + +// RemoveAppliedMarker deletes the applied update marker file. +func RemoveAppliedMarker() { + configDir, err := config.GetUserConfigDir() + if err != nil { + return + } + + os.Remove(filepath.Join(configDir, appliedMarkerFile)) +} diff --git a/cli/azd/pkg/update/manager_test.go b/cli/azd/pkg/update/manager_test.go new file mode 100644 index 00000000000..e97e6333144 --- /dev/null +++ b/cli/azd/pkg/update/manager_test.go @@ -0,0 +1,662 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package update + +import ( + "archive/tar" + "archive/zip" + "bytes" + "compress/gzip" + "context" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "runtime" + "strings" + "testing" + + "github.com/azure/azure-dev/cli/azd/pkg/exec" + "github.com/azure/azure-dev/cli/azd/pkg/installer" + "github.com/azure/azure-dev/cli/azd/test/mocks/mockexec" + "github.com/stretchr/testify/require" +) + +func TestParseDailyBuildNumber(t *testing.T) { + tests := []struct { + name string + version string + want int + wantErr bool + }{ + {"standard daily", "1.24.0-beta.1-daily.5935787", 5935787, false}, + {"simple daily", "1.0.0-daily.100", 100, false}, + {"large build number", "2.0.0-beta.2-daily.9999999", 9999999, false}, + {"with commit suffix", + "1.4.9-beta.1-daily.5000000 (commit aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa)", 5000000, false}, + {"no daily suffix", "1.23.6", 0, true}, + {"dev version", "0.0.0-dev.0", 0, true}, + {"empty string", "", 0, true}, + {"daily but no number", "1.0.0-daily.", 0, true}, + {"daily with non-numeric", "1.0.0-daily.abc", 0, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseDailyBuildNumber(tt.version) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, tt.want, got) + } + }) + } +} + +func TestBuildDownloadURL(t *testing.T) { + m := NewManager(nil) + + tests := []struct { + name string + channel Channel + wantErr bool + // We check that the URL contains these substrings + contains []string + }{ + { + name: "stable", + channel: ChannelStable, + contains: []string{ + blobBaseURL + "/stable/", + fmt.Sprintf("azd-%s-%s", runtime.GOOS, runtime.GOARCH), + }, + }, + { + name: "daily", + channel: ChannelDaily, + contains: []string{ + blobBaseURL + "/daily/", + fmt.Sprintf("azd-%s-%s", runtime.GOOS, runtime.GOARCH), + }, + }, + { + name: "invalid channel", + channel: Channel("nightly"), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := m.buildDownloadURL(tt.channel) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + for _, s := range tt.contains { + require.Contains(t, got, s) + } + }) + } +} + +func TestArchiveExtension(t *testing.T) { + ext := archiveExtension() + if runtime.GOOS == "linux" { + require.Equal(t, ".tar.gz", ext) + } else { + require.Equal(t, ".zip", ext) + } +} + +func TestPackageManagerUninstallCmd(t *testing.T) { + tests := []struct { + name string + installedBy installer.InstallType + want string + }{ + {"brew", installer.InstallTypeBrew, "brew uninstall azd"}, + {"winget", installer.InstallTypeWinget, "winget uninstall Microsoft.Azd"}, + {"choco", installer.InstallTypeChoco, "choco uninstall azd"}, + {"unknown", installer.InstallTypeUnknown, "your package manager"}, + {"script", installer.InstallTypeSh, "your package manager"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, PackageManagerUninstallCmd(tt.installedBy)) + }) + } +} + +func TestBuildVersionInfoFromCache_Stable(t *testing.T) { + m := NewManager(nil) + + tests := []struct { + name string + version string + hasUpdate bool + }{ + // Dev build (0.0.0-dev.0) is always less than any release + {"newer version available", "999.0.0", true}, + // In semver, 0.0.0 > 0.0.0-dev.0 (pre-release has lower precedence) + // so even 0.0.0 is considered an update from a dev build + {"release beats pre-release", "0.0.0", true}, + // A pre-release that equals the current version is not an update + {"same pre-release version", "0.0.0-dev.0", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cache := &CacheFile{ + Channel: "stable", + Version: tt.version, + } + info, err := m.buildVersionInfoFromCache(cache, ChannelStable) + require.NoError(t, err) + require.Equal(t, tt.hasUpdate, info.HasUpdate) + require.Equal(t, tt.version, info.Version) + require.Equal(t, ChannelStable, info.Channel) + }) + } +} + +func TestBuildVersionInfoFromCache_Daily(t *testing.T) { + m := NewManager(nil) + + // Dev build (0.0.0-dev.0) can't parse a daily build number, + // so it always assumes update available + cache := &CacheFile{ + Channel: "daily", + Version: "1.24.0-beta.1-daily.5935787", + BuildNumber: 5935787, + } + + info, err := m.buildVersionInfoFromCache(cache, ChannelDaily) + require.NoError(t, err) + require.True(t, info.HasUpdate, "dev build should always see daily update available") + require.Equal(t, ChannelDaily, info.Channel) + require.Equal(t, 5935787, info.BuildNumber) +} + +func TestBuildVersionInfoFromCache_InvalidVersion(t *testing.T) { + m := NewManager(nil) + cache := &CacheFile{ + Channel: "stable", + Version: "not-a-version", + } + + _, err := m.buildVersionInfoFromCache(cache, ChannelStable) + require.Error(t, err) + require.Contains(t, err.Error(), "parse") +} + +func TestCheckForUpdate_StableHTTP(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, "999.0.0") + })) + defer server.Close() + + // Override the default client transport to redirect requests to test server + origTransport := http.DefaultTransport + http.DefaultTransport = &urlRewriteTransport{ + base: origTransport, + targetURL: server.URL, + } + defer func() { http.DefaultTransport = origTransport }() + + tempDir := t.TempDir() + t.Setenv("AZD_CONFIG_DIR", tempDir) + + m := NewManager(nil) + cfg := &UpdateConfig{Channel: ChannelStable} + + info, err := m.CheckForUpdate(context.Background(), cfg, true) + require.NoError(t, err) + require.Equal(t, "999.0.0", info.Version) + require.Equal(t, ChannelStable, info.Channel) + require.True(t, info.HasUpdate) +} + +func TestCheckForUpdate_DailyHTTP(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, "1.24.0-beta.1-daily.9999999") + })) + defer server.Close() + + origTransport := http.DefaultTransport + http.DefaultTransport = &urlRewriteTransport{ + base: origTransport, + targetURL: server.URL, + } + defer func() { http.DefaultTransport = origTransport }() + + tempDir := t.TempDir() + t.Setenv("AZD_CONFIG_DIR", tempDir) + + m := NewManager(nil) + cfg := &UpdateConfig{Channel: ChannelDaily} + + info, err := m.CheckForUpdate(context.Background(), cfg, true) + require.NoError(t, err) + require.Equal(t, "1.24.0-beta.1-daily.9999999", info.Version) + require.Equal(t, 9999999, info.BuildNumber) + require.Equal(t, ChannelDaily, info.Channel) + require.True(t, info.HasUpdate) +} + +func TestCheckForUpdate_HTTPError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + origTransport := http.DefaultTransport + http.DefaultTransport = &urlRewriteTransport{ + base: origTransport, + targetURL: server.URL, + } + defer func() { http.DefaultTransport = origTransport }() + + tempDir := t.TempDir() + t.Setenv("AZD_CONFIG_DIR", tempDir) + + m := NewManager(nil) + cfg := &UpdateConfig{Channel: ChannelStable} + + _, err := m.CheckForUpdate(context.Background(), cfg, true) + require.Error(t, err) + require.Contains(t, err.Error(), "500") +} + +func TestCheckForUpdate_UsesCache(t *testing.T) { + tempDir := t.TempDir() + t.Setenv("AZD_CONFIG_DIR", tempDir) + + // Pre-populate cache with a future expiry + cache := &CacheFile{ + Channel: "stable", + Version: "888.0.0", + ExpiresOn: "2099-01-01T00:00:00Z", + } + require.NoError(t, SaveCache(cache)) + + m := NewManager(nil) + cfg := &UpdateConfig{Channel: ChannelStable} + + // ignoreCache=false should use the cache (no HTTP call needed) + info, err := m.CheckForUpdate(context.Background(), cfg, false) + require.NoError(t, err) + require.Equal(t, "888.0.0", info.Version) + require.True(t, info.HasUpdate) +} + +func TestCheckForUpdate_InvalidChannel(t *testing.T) { + tempDir := t.TempDir() + t.Setenv("AZD_CONFIG_DIR", tempDir) + + m := NewManager(nil) + cfg := &UpdateConfig{Channel: Channel("nightly")} + + _, err := m.CheckForUpdate(context.Background(), cfg, true) + require.Error(t, err) + require.Contains(t, err.Error(), "unsupported channel") +} + +func TestUpdateViaPackageManager_Success(t *testing.T) { + mockRunner := mockexec.NewMockCommandRunner() + mockRunner.When(func(args exec.RunArgs, command string) bool { + return strings.Contains(command, "brew upgrade azd") + }).Respond(exec.NewRunResult(0, "Updated azd", "")) + + m := NewManager(mockRunner) + var buf bytes.Buffer + + err := m.updateViaPackageManager(context.Background(), "brew", []string{"upgrade", "azd"}, &buf) + require.NoError(t, err) + require.Contains(t, buf.String(), "Updating azd via brew") +} + +func TestUpdateViaPackageManager_Failure(t *testing.T) { + mockRunner := mockexec.NewMockCommandRunner() + mockRunner.When(func(args exec.RunArgs, command string) bool { + return strings.Contains(command, "brew upgrade azd") + }).Respond(exec.NewRunResult(1, "", "Error: no such formula")) + + m := NewManager(mockRunner) + var buf bytes.Buffer + + err := m.updateViaPackageManager(context.Background(), "brew", []string{"upgrade", "azd"}, &buf) + require.Error(t, err) + + var updateErr *UpdateError + require.ErrorAs(t, err, &updateErr) + require.Equal(t, CodePackageManagerFailed, updateErr.Code) +} + +func TestUpdateViaPackageManager_CommandError(t *testing.T) { + mockRunner := mockexec.NewMockCommandRunner() + mockRunner.When(func(args exec.RunArgs, command string) bool { + return true + }).SetError(fmt.Errorf("command not found: brew")) + + m := NewManager(mockRunner) + var buf bytes.Buffer + + err := m.updateViaPackageManager(context.Background(), "brew", []string{"upgrade", "azd"}, &buf) + require.Error(t, err) + + var updateErr *UpdateError + require.ErrorAs(t, err, &updateErr) + require.Equal(t, CodePackageManagerFailed, updateErr.Code) +} + +func TestVerifyCodeSignature_NilRunner(t *testing.T) { + m := NewManager(nil) + err := m.verifyCodeSignature(context.Background(), "/some/binary", io.Discard) + require.NoError(t, err, "should skip when no command runner") +} + +func TestExtractFromZip(t *testing.T) { + tempDir := t.TempDir() + + // Create a zip archive containing a fake "azd" binary + archivePath := filepath.Join(tempDir, "test.zip") + binaryContent := []byte("#!/bin/sh\necho hello") + + zipFile, err := os.Create(archivePath) + require.NoError(t, err) + + zw := zip.NewWriter(zipFile) + fw, err := zw.Create("azd") + require.NoError(t, err) + _, err = fw.Write(binaryContent) + require.NoError(t, err) + require.NoError(t, zw.Close()) + require.NoError(t, zipFile.Close()) + + // Extract + destPath := filepath.Join(tempDir, "extracted-azd") + err = extractFromZip(archivePath, "azd", destPath) + require.NoError(t, err) + + // Verify content + extracted, err := os.ReadFile(destPath) + require.NoError(t, err) + require.Equal(t, binaryContent, extracted) +} + +func TestExtractFromZip_BinaryNotFound(t *testing.T) { + tempDir := t.TempDir() + + archivePath := filepath.Join(tempDir, "empty.zip") + zipFile, err := os.Create(archivePath) + require.NoError(t, err) + + zw := zip.NewWriter(zipFile) + fw, err := zw.Create("other-file.txt") + require.NoError(t, err) + _, err = fw.Write([]byte("not the binary")) + require.NoError(t, err) + require.NoError(t, zw.Close()) + require.NoError(t, zipFile.Close()) + + destPath := filepath.Join(tempDir, "extracted") + err = extractFromZip(archivePath, "azd", destPath) + require.Error(t, err) + require.Contains(t, err.Error(), "not found in archive") +} + +func TestExtractFromTarGz(t *testing.T) { + tempDir := t.TempDir() + + archivePath := filepath.Join(tempDir, "test.tar.gz") + binaryContent := []byte("#!/bin/sh\necho hello from tar") + + // Create tar.gz + f, err := os.Create(archivePath) + require.NoError(t, err) + + gw := gzip.NewWriter(f) + tw := tar.NewWriter(gw) + + hdr := &tar.Header{ + Name: "azd", + Mode: 0o755, + Size: int64(len(binaryContent)), + } + require.NoError(t, tw.WriteHeader(hdr)) + _, err = tw.Write(binaryContent) + require.NoError(t, err) + + require.NoError(t, tw.Close()) + require.NoError(t, gw.Close()) + require.NoError(t, f.Close()) + + // Extract + destPath := filepath.Join(tempDir, "extracted-azd") + err = extractFromTarGz(archivePath, "azd", destPath) + require.NoError(t, err) + + extracted, err := os.ReadFile(destPath) + require.NoError(t, err) + require.Equal(t, binaryContent, extracted) +} + +func TestExtractBinary_ChoosesFormat(t *testing.T) { + tempDir := t.TempDir() + binaryContent := []byte("binary data") + + // Create a zip + archivePath := filepath.Join(tempDir, "test.zip") + zipFile, err := os.Create(archivePath) + require.NoError(t, err) + zw := zip.NewWriter(zipFile) + fw, err := zw.Create("azd") + require.NoError(t, err) + _, err = fw.Write(binaryContent) + require.NoError(t, err) + require.NoError(t, zw.Close()) + require.NoError(t, zipFile.Close()) + + destPath := filepath.Join(tempDir, "out-azd") + err = extractBinary(archivePath, "azd", destPath) + require.NoError(t, err) + + extracted, err := os.ReadFile(destPath) + require.NoError(t, err) + require.Equal(t, binaryContent, extracted) +} + +func TestStagedBinaryPath(t *testing.T) { + tempDir := t.TempDir() + t.Setenv("AZD_CONFIG_DIR", tempDir) + + path, err := StagedBinaryPath() + require.NoError(t, err) + require.Contains(t, path, "staging") + + binaryName := "azd" + if runtime.GOOS == "windows" { + binaryName = "azd.exe" + } + require.True(t, strings.HasSuffix(path, binaryName)) +} + +func TestHasStagedUpdate_NoFile(t *testing.T) { + tempDir := t.TempDir() + t.Setenv("AZD_CONFIG_DIR", tempDir) + + require.False(t, HasStagedUpdate()) +} + +func TestHasStagedUpdate_WithFile(t *testing.T) { + tempDir := t.TempDir() + t.Setenv("AZD_CONFIG_DIR", tempDir) + + path, err := StagedBinaryPath() + require.NoError(t, err) + require.NoError(t, os.MkdirAll(filepath.Dir(path), 0o755)) + require.NoError(t, os.WriteFile(path, []byte("fake binary"), 0o755)) //nolint:gosec + + require.True(t, HasStagedUpdate()) +} + +func TestCleanStagedUpdate(t *testing.T) { + tempDir := t.TempDir() + t.Setenv("AZD_CONFIG_DIR", tempDir) + + // Stage a fake binary + path, err := StagedBinaryPath() + require.NoError(t, err) + require.NoError(t, os.MkdirAll(filepath.Dir(path), 0o755)) + require.NoError(t, os.WriteFile(path, []byte("fake binary"), 0o755)) //nolint:gosec + require.True(t, HasStagedUpdate()) + + // Clean it + CleanStagedUpdate() + require.False(t, HasStagedUpdate()) +} + +func TestAppliedMarkerLifecycle(t *testing.T) { + tempDir := t.TempDir() + t.Setenv("AZD_CONFIG_DIR", tempDir) + + // No marker initially + _, err := ReadAppliedMarker() + require.Error(t, err) + + // Write marker + WriteAppliedMarker("1.22.0") + + // Read marker + version, err := ReadAppliedMarker() + require.NoError(t, err) + require.Equal(t, "1.22.0", version) + + // Remove marker + RemoveAppliedMarker() + _, err = ReadAppliedMarker() + require.Error(t, err) +} + +func TestProgressReader(t *testing.T) { + content := bytes.Repeat([]byte("x"), 100) + reader := bytes.NewReader(content) + + var output bytes.Buffer + pr := &progressReader{ + reader: reader, + total: 100, + writer: &output, + } + + buf := make([]byte, 10) + totalRead := 0 + for { + n, err := pr.Read(buf) + totalRead += n + if errors.Is(err, io.EOF) { + break + } + require.NoError(t, err) + } + + require.Equal(t, 100, totalRead) + // Should have reported at least some progress percentages + require.Contains(t, output.String(), "%") +} + +func TestCopyFile(t *testing.T) { + tempDir := t.TempDir() + + src := filepath.Join(tempDir, "source") + dst := filepath.Join(tempDir, "dest") + content := []byte("hello world") + + require.NoError(t, os.WriteFile(src, content, 0o600)) + require.NoError(t, copyFile(src, dst)) + + copied, err := os.ReadFile(dst) + require.NoError(t, err) + require.Equal(t, content, copied) +} + +func TestUpdateError(t *testing.T) { + inner := fmt.Errorf("connection refused") + ue := newUpdateError(CodeDownloadFailed, inner) + + require.Equal(t, "connection refused", ue.Error()) + require.Equal(t, CodeDownloadFailed, ue.Code) + require.ErrorIs(t, ue, inner) + + ue2 := newUpdateErrorf(CodeDownloadFailed, "hash mismatch: expected %s", "abc123") + require.Contains(t, ue2.Error(), "hash mismatch") + require.Equal(t, CodeDownloadFailed, ue2.Code) +} + +func TestErrNeedsElevation(t *testing.T) { + require.NotNil(t, ErrNeedsElevation) + require.Contains(t, ErrNeedsElevation.Error(), "elevation") +} + +func TestDownloadFile(t *testing.T) { + content := []byte("downloaded binary content") + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(content))) + w.WriteHeader(http.StatusOK) + w.Write(content) + })) + defer server.Close() + + tempDir := t.TempDir() + destPath := filepath.Join(tempDir, "downloaded") + + m := NewManager(nil) + err := m.downloadFile(context.Background(), server.URL+"/azd.zip", destPath, io.Discard) + require.NoError(t, err) + + got, err := os.ReadFile(destPath) + require.NoError(t, err) + require.Equal(t, content, got) +} + +func TestDownloadFile_HTTPError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + defer server.Close() + + tempDir := t.TempDir() + destPath := filepath.Join(tempDir, "downloaded") + + m := NewManager(nil) + err := m.downloadFile(context.Background(), server.URL+"/missing.zip", destPath, io.Discard) + require.Error(t, err) + require.Contains(t, err.Error(), "404") +} + +// urlRewriteTransport rewrites all outgoing request URLs to point at a test server. +type urlRewriteTransport struct { + base http.RoundTripper + targetURL string +} + +func (t *urlRewriteTransport) RoundTrip(req *http.Request) (*http.Response, error) { + // Rewrite the request URL to the test server, preserving path + newURL := t.targetURL + req.URL.Path + newReq, err := http.NewRequestWithContext(req.Context(), req.Method, newURL, req.Body) + if err != nil { + return nil, err + } + newReq.Header = req.Header + return t.base.RoundTrip(newReq) +} diff --git a/cli/azd/resources/alpha_features.yaml b/cli/azd/resources/alpha_features.yaml index d872e7efa3e..e0e622a7d7d 100644 --- a/cli/azd/resources/alpha_features.yaml +++ b/cli/azd/resources/alpha_features.yaml @@ -14,3 +14,5 @@ description: "Enables the use of LLMs in the CLI with support for intelligent azd init assistance and error handling workflows." - id: language.custom description: "Enables support for services to use custom language." +- id: update + description: "Enables the azd update command for self-updating azd, including channel management and auto-update." diff --git a/eng/pipelines/templates/stages/publish.yml b/eng/pipelines/templates/stages/publish.yml index e9c8e57984c..a75bd85bb79 100644 --- a/eng/pipelines/templates/stages/publish.yml +++ b/eng/pipelines/templates/stages/publish.yml @@ -375,6 +375,29 @@ stages: CreateGitHubRelease: false PublishUploadLocations: release/daily;daily/archive/$(Build.BuildId)-$(Build.SourceVersion) + - pwsh: | + New-Item -ItemType Directory -Path daily-version -Force + Set-Content -Path ./daily-version/version.txt -Value "$(CLI_VERSION)" -NoNewline + Write-Host "Daily version.txt:" + Get-Content ./daily-version/version.txt + displayName: Write daily version.txt + + - task: AzurePowerShell@5 + displayName: Upload daily version.txt to storage + inputs: + azureSubscription: 'Azure SDK Artifacts' + azurePowerShellVersion: LatestVersion + pwsh: true + ScriptType: InlineScript + Inline: | + azcopy copy "daily-version/version.txt" "$(publish-storage-location)/`$web/azd/standalone/release/daily/" --overwrite=true + if ($LASTEXITCODE) { + Write-Error "Upload failed" + exit 1 + } + env: + AZCOPY_AUTO_LOGIN_TYPE: 'PSCRED' + - deployment: Publish_For_PR environment: none condition: >-