From 8bf1a58bb82fbfc034d6da9805643d5f4cd81051 Mon Sep 17 00:00:00 2001 From: Suraj Deshmukh Date: Tue, 10 Mar 2026 14:48:04 -0700 Subject: [PATCH 1/9] feat(e2e): dynamically fetch latest VM extension version with caching Replace hardcoded VM extension version "1.406" with a cached Azure API call that fetches the latest version at runtime. Add a caching layer via CachedGetLatestVMExtensionImageVersion to avoid redundant API calls across tests. Update all callers to pass context for proper propagation. Signed-off-by: Suraj Deshmukh --- e2e/cache.go | 15 ++++++ e2e/config/azure.go | 52 ++++++++++++--------- e2e/scenario_gpu_managed_experience_test.go | 10 ++-- e2e/scenario_test.go | 2 +- e2e/test_helpers.go | 27 +++++------ 5 files changed, 64 insertions(+), 42 deletions(-) diff --git a/e2e/cache.go b/e2e/cache.go index cfa45122af7..1b07d383815 100644 --- a/e2e/cache.go +++ b/e2e/cache.go @@ -245,3 +245,18 @@ var CachedVMSizeSupportsNVMe = cachedFunc(func(ctx context.Context, req VMSizeSK var CachedIsVMSizeGen2Only = cachedFunc(func(ctx context.Context, req VMSizeSKURequest) (bool, error) { return config.Azure.IsVMSizeGen2Only(ctx, req.Location, req.VMSize) }) + +// GetLatestExtensionVersionRequest is the cache key for VM extension version lookups. +type GetLatestExtensionVersionRequest struct { + Location string + ExtType string + Publisher string +} + +// CachedGetLatestVMExtensionImageVersion caches the result of querying the Azure API +// for the latest VM extension image version. +var CachedGetLatestVMExtensionImageVersion = cachedFunc( + func(ctx context.Context, req GetLatestExtensionVersionRequest) (string, error) { + return config.Azure.GetLatestVMExtensionImageVersion(ctx, req.Location, req.ExtType, req.Publisher) + }, +) diff --git a/e2e/config/azure.go b/e2e/config/azure.go index e003a87754c..3fdfd8d643a 100644 --- a/e2e/config/azure.go +++ b/e2e/config/azure.go @@ -753,67 +753,73 @@ func (a *AzureClient) GetLatestVMExtensionImageVersion(ctx context.Context, loca if err != nil { return "", fmt.Errorf("listing extension versions: %w", err) } - if len(resp.VirtualMachineExtensionImageArray) == 0 { return "", fmt.Errorf("no extension versions found") } - version := make([]VMExtenstionVersion, len(resp.VirtualMachineExtensionImageArray)) + versions := make([]vmExtensionVersion, len(resp.VirtualMachineExtensionImageArray)) for i, ext := range resp.VirtualMachineExtensionImageArray { - version[i] = parseVersion(ext) + versions[i] = parseVersion(ctx, ext) } - sort.Slice(version, func(i, j int) bool { - return version[i].Less(version[j]) + sort.Slice(versions, func(i, j int) bool { + return versions[i].less(versions[j]) }) - return *version[len(version)-1].Original.Name, nil + return *versions[len(versions)-1].original.Name, nil } -// VMExtenstionVersion represents a parsed version of a VM extension image. -type VMExtenstionVersion struct { - Original *armcompute.VirtualMachineExtensionImage - Major int - Minor int - Patch int +// vmExtensionVersion represents a parsed version of a VM extension image. +type vmExtensionVersion struct { + original *armcompute.VirtualMachineExtensionImage + major int + minor int + patch int } // parseVersion parses the version from a VM extension image name, which can be in the format 1.151, 1.0.1, etc. // You can find all the versions of a specific VM extension by running: // az vm extension image list -n Compute.AKS.Linux.AKSNode -func parseVersion(v *armcompute.VirtualMachineExtensionImage) VMExtenstionVersion { +func parseVersion(ctx context.Context, v *armcompute.VirtualMachineExtensionImage) vmExtensionVersion { // Split by dots parts := strings.Split(*v.Name, ".") - version := VMExtenstionVersion{Original: v} + version := vmExtensionVersion{original: v} if len(parts) >= 1 { if major, err := strconv.Atoi(parts[0]); err == nil { - version.Major = major + version.major = major + } else { + toolkit.Logf(ctx, "warning: failed to parse major version from %q: %v", *v.Name, err) } } if len(parts) >= 2 { if minor, err := strconv.Atoi(parts[1]); err == nil { - version.Minor = minor + version.minor = minor + } else { + toolkit.Logf(ctx, "warning: failed to parse minor version from %q: %v", *v.Name, err) } } if len(parts) >= 3 { if patch, err := strconv.Atoi(parts[2]); err == nil { - version.Patch = patch + version.patch = patch + } else { + toolkit.Logf(ctx, "warning: failed to parse patch version from %q: %v", *v.Name, err) } } return version } -func (v VMExtenstionVersion) Less(other VMExtenstionVersion) bool { - if v.Major != other.Major { - return v.Major < other.Major +// less returns true if v is a lower version than other. +func (v vmExtensionVersion) less(other vmExtensionVersion) bool { + if v.major != other.major { + return v.major < other.major } - if v.Minor != other.Minor { - return v.Minor < other.Minor + if v.minor != other.minor { + return v.minor < other.minor } - return v.Patch < other.Patch + return v.patch < other.patch } // getResourceSKU queries the Azure Resource SKUs API to find the SKU for the given VM size and location. diff --git a/e2e/scenario_gpu_managed_experience_test.go b/e2e/scenario_gpu_managed_experience_test.go index 80cefed3e51..b39f9c20cc7 100644 --- a/e2e/scenario_gpu_managed_experience_test.go +++ b/e2e/scenario_gpu_managed_experience_test.go @@ -259,7 +259,7 @@ func Test_Ubuntu2404_NvidiaDevicePluginRunning(t *testing.T) { vmss.Tags["EnableManagedGPUExperience"] = to.Ptr("true") // Enable the AKS VM extension for GPU nodes - extension, err := createVMExtensionLinuxAKSNode(vmss.Location) + extension, err := createVMExtensionLinuxAKSNode(t.Context(), vmss.Location) require.NoError(t, err, "creating AKS VM extension") vmss.Properties = addVMExtensionToVMSS(vmss.Properties, extension) }, @@ -334,7 +334,7 @@ func Test_Ubuntu2204_NvidiaDevicePluginRunning(t *testing.T) { vmss.Tags["EnableManagedGPUExperience"] = to.Ptr("true") // Enable the AKS VM extension for GPU nodes - extension, err := createVMExtensionLinuxAKSNode(vmss.Location) + extension, err := createVMExtensionLinuxAKSNode(t.Context(), vmss.Location) require.NoError(t, err, "creating AKS VM extension") vmss.Properties = addVMExtensionToVMSS(vmss.Properties, extension) }, @@ -408,7 +408,7 @@ func Test_AzureLinux3_NvidiaDevicePluginRunning(t *testing.T) { vmss.Tags["EnableManagedGPUExperience"] = to.Ptr("true") // Enable the AKS VM extension for GPU nodes - extension, err := createVMExtensionLinuxAKSNode(vmss.Location) + extension, err := createVMExtensionLinuxAKSNode(t.Context(), vmss.Location) require.NoError(t, err, "creating AKS VM extension") vmss.Properties = addVMExtensionToVMSS(vmss.Properties, extension) }, @@ -478,7 +478,7 @@ func Test_Ubuntu2404_NvidiaDevicePluginRunning_MIG(t *testing.T) { vmss.SKU.Name = to.Ptr("Standard_NC24ads_A100_v4") // Enable the AKS VM extension for GPU nodes - extension, err := createVMExtensionLinuxAKSNode(vmss.Location) + extension, err := createVMExtensionLinuxAKSNode(t.Context(), vmss.Location) require.NoError(t, err, "creating AKS VM extension") vmss.Properties = addVMExtensionToVMSS(vmss.Properties, extension) }, @@ -555,7 +555,7 @@ func Test_Ubuntu2204_NvidiaDevicePluginRunning_WithoutVMSSTag(t *testing.T) { // to test that NBC EnableManagedGPU field works independently // Enable the AKS VM extension for GPU nodes - extension, err := createVMExtensionLinuxAKSNode(vmss.Location) + extension, err := createVMExtensionLinuxAKSNode(t.Context(), vmss.Location) require.NoError(t, err, "creating AKS VM extension") vmss.Properties = addVMExtensionToVMSS(vmss.Properties, extension) }, diff --git a/e2e/scenario_test.go b/e2e/scenario_test.go index 4dd5605388f..8fefd29e755 100644 --- a/e2e/scenario_test.go +++ b/e2e/scenario_test.go @@ -2138,7 +2138,7 @@ func Test_Ubuntu2404_NPD_Basic(t *testing.T) { BootstrapConfigMutator: func(nbc *datamodel.NodeBootstrappingConfiguration) { }, VMConfigMutator: func(vmss *armcompute.VirtualMachineScaleSet) { - extension, err := createVMExtensionLinuxAKSNode(vmss.Location) + extension, err := createVMExtensionLinuxAKSNode(t.Context(), vmss.Location) require.NoError(t, err, "creating AKS VM extension") vmss.Properties = addVMExtensionToVMSS(vmss.Properties, extension) }, diff --git a/e2e/test_helpers.go b/e2e/test_helpers.go index 6b05a5d5924..7322aa3f365 100644 --- a/e2e/test_helpers.go +++ b/e2e/test_helpers.go @@ -522,21 +522,22 @@ func addTrustedLaunchToVMSS(properties *armcompute.VirtualMachineScaleSetPropert return properties } -func createVMExtensionLinuxAKSNode(_ *string) (*armcompute.VirtualMachineScaleSetExtension, error) { - // Default to "westus" if location is nil. - // region := "westus" - // if location != nil { - // region = *location - // } +func createVMExtensionLinuxAKSNode(ctx context.Context, location *string) (*armcompute.VirtualMachineScaleSetExtension, error) { + region := config.Config.DefaultLocation + if location != nil { + region = *location + } extensionName := "Compute.AKS.Linux.AKSNode" publisher := "Microsoft.AKS" - extensionVersion := "1.406" - // NOTE (@surajssd): If this is gonna be called multiple times, then find a way to cache the latest version. - // extensionVersion, err := config.Azure.GetLatestVMExtensionImageVersion(context.TODO(), region, extensionName, publisher) - // if err != nil { - // return nil, fmt.Errorf("getting latest VM extension image version: %v", err) - // } + extensionVersion, err := CachedGetLatestVMExtensionImageVersion(ctx, GetLatestExtensionVersionRequest{ + Location: region, + ExtType: extensionName, + Publisher: publisher, + }) + if err != nil { + return nil, fmt.Errorf("getting latest VM extension image version: %w", err) + } return &armcompute.VirtualMachineScaleSetExtension{ Name: to.Ptr(extensionName), @@ -795,7 +796,7 @@ func runScenarioGPUNPD(t *testing.T, vmSize, location, k8sSystemPoolSKU string) VMConfigMutator: func(vmss *armcompute.VirtualMachineScaleSet) { vmss.SKU.Name = to.Ptr(vmSize) - extension, err := createVMExtensionLinuxAKSNode(vmss.Location) + extension, err := createVMExtensionLinuxAKSNode(t.Context(), vmss.Location) require.NoError(t, err, "creating AKS VM extension") vmss.Properties = addVMExtensionToVMSS(vmss.Properties, extension) From e2446384bd028ed43084dafe6baf2333e13ed2c4 Mon Sep 17 00:00:00 2001 From: Suraj Deshmukh Date: Tue, 10 Mar 2026 16:20:54 -0700 Subject: [PATCH 2/9] refactor(e2e): use slices.MaxFunc instead of sort.Slice for version lookup Replace `sort.Slice` + last-element pattern with `slices.MaxFunc` in `GetLatestVMExtensionImageVersion` for `O(n)` instead of `O(n log n)`. Rename the less method to `cmp` using `cmp.Compare` from stdlib, and add a nil name guard before accessing the result. Move the nil name check in `parseVersion` to an early return before splitting. Signed-off-by: Suraj Deshmukh --- e2e/config/azure.go | 35 +++++++++++++++++++++-------------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/e2e/config/azure.go b/e2e/config/azure.go index 3fdfd8d643a..7044e2efa3a 100644 --- a/e2e/config/azure.go +++ b/e2e/config/azure.go @@ -1,6 +1,7 @@ package config import ( + "cmp" "context" "crypto/tls" "errors" @@ -8,7 +9,7 @@ import ( "net" "net/http" "os" - "sort" + "slices" "strconv" "strings" "time" @@ -762,11 +763,13 @@ func (a *AzureClient) GetLatestVMExtensionImageVersion(ctx context.Context, loca versions[i] = parseVersion(ctx, ext) } - sort.Slice(versions, func(i, j int) bool { - return versions[i].less(versions[j]) + latest := slices.MaxFunc(versions, func(a, b vmExtensionVersion) int { + return a.cmp(b) }) - - return *versions[len(versions)-1].original.Name, nil + if latest.original.Name == nil { + return "", fmt.Errorf("latest extension version has nil name") + } + return *latest.original.Name, nil } // vmExtensionVersion represents a parsed version of a VM extension image. @@ -781,11 +784,15 @@ type vmExtensionVersion struct { // You can find all the versions of a specific VM extension by running: // az vm extension image list -n Compute.AKS.Linux.AKSNode func parseVersion(ctx context.Context, v *armcompute.VirtualMachineExtensionImage) vmExtensionVersion { + version := vmExtensionVersion{original: v} + if v.Name == nil { + toolkit.Logf(ctx, "warning: VM extension image has nil name, skipping version parse") + return version + } + // Split by dots parts := strings.Split(*v.Name, ".") - version := vmExtensionVersion{original: v} - if len(parts) >= 1 { if major, err := strconv.Atoi(parts[0]); err == nil { version.major = major @@ -811,15 +818,15 @@ func parseVersion(ctx context.Context, v *armcompute.VirtualMachineExtensionImag return version } -// less returns true if v is a lower version than other. -func (v vmExtensionVersion) less(other vmExtensionVersion) bool { - if v.major != other.major { - return v.major < other.major +// cmp compares two versions, returning -1, 0, or 1. +func (v vmExtensionVersion) cmp(other vmExtensionVersion) int { + if c := cmp.Compare(v.major, other.major); c != 0 { + return c } - if v.minor != other.minor { - return v.minor < other.minor + if c := cmp.Compare(v.minor, other.minor); c != 0 { + return c } - return v.patch < other.patch + return cmp.Compare(v.patch, other.patch) } // getResourceSKU queries the Azure Resource SKUs API to find the SKU for the given VM size and location. From f2c91ba4b9908ff93ce30a5b4cb570f3d8eef4df Mon Sep 17 00:00:00 2001 From: Suraj Deshmukh Date: Tue, 10 Mar 2026 21:53:12 -0700 Subject: [PATCH 3/9] test(e2e): add unit tests for cachedFunc memoization utility Replace the observation-only Test_CreateVMExtensionLinuxAKSNode_Timing from the GPU scenario file with proper unit tests in cache_test.go that assert caching correctness, warm call performance, per-key isolation, error caching, and struct key support. Signed-off-by: Suraj Deshmukh --- e2e/cache_test.go | 123 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 123 insertions(+) create mode 100644 e2e/cache_test.go diff --git a/e2e/cache_test.go b/e2e/cache_test.go new file mode 100644 index 00000000000..deeb83aee7f --- /dev/null +++ b/e2e/cache_test.go @@ -0,0 +1,123 @@ +package e2e + +import ( + "context" + "fmt" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_cachedFunc_returns_consistent_results(t *testing.T) { + var callCount atomic.Int32 + fn := cachedFunc(func(ctx context.Context, key string) (string, error) { + callCount.Add(1) + return "result-" + key, nil + }) + + ctx := context.Background() + first, err := fn(ctx, "a") + require.NoError(t, err) + + second, err := fn(ctx, "a") + require.NoError(t, err) + + assert.Equal(t, first, second, "cached function should return the same result on repeated calls") + assert.Equal(t, int32(1), callCount.Load(), "underlying function should only be called once for the same key") +} + +func Test_cachedFunc_warm_call_is_faster_than_cold(t *testing.T) { + fn := cachedFunc(func(ctx context.Context, key string) (string, error) { + // simulate a slow operation like a network call + time.Sleep(10 * time.Millisecond) + return "result", nil + }) + + ctx := context.Background() + + start := time.Now() + _, err := fn(ctx, "key") + coldDuration := time.Since(start) + require.NoError(t, err) + + start = time.Now() + _, err = fn(ctx, "key") + warmDuration := time.Since(start) + require.NoError(t, err) + + assert.Less(t, warmDuration, coldDuration, "warm (cached) call should be faster than cold call") +} + +func Test_cachedFunc_different_keys_produce_different_cache_entries(t *testing.T) { + var callCount atomic.Int32 + fn := cachedFunc(func(ctx context.Context, key string) (string, error) { + callCount.Add(1) + return "result-" + key, nil + }) + + ctx := context.Background() + + resultA, err := fn(ctx, "a") + require.NoError(t, err) + + resultB, err := fn(ctx, "b") + require.NoError(t, err) + + assert.Equal(t, "result-a", resultA) + assert.Equal(t, "result-b", resultB) + assert.NotEqual(t, resultA, resultB, "different keys should produce different results") + assert.Equal(t, int32(2), callCount.Load(), "underlying function should be called once per unique key") +} + +func Test_cachedFunc_caches_errors(t *testing.T) { + var callCount atomic.Int32 + expectedErr := fmt.Errorf("something went wrong") + fn := cachedFunc(func(ctx context.Context, key string) (string, error) { + callCount.Add(1) + return "", expectedErr + }) + + ctx := context.Background() + + _, err1 := fn(ctx, "a") + require.ErrorIs(t, err1, expectedErr) + + _, err2 := fn(ctx, "a") + require.ErrorIs(t, err2, expectedErr) + + assert.Equal(t, int32(1), callCount.Load(), "underlying function should only be called once even when it returns an error") +} + +func Test_cachedFunc_with_struct_key(t *testing.T) { + type request struct { + Location string + Type string + } + + var callCount atomic.Int32 + fn := cachedFunc(func(ctx context.Context, req request) (string, error) { + callCount.Add(1) + return req.Location + "-" + req.Type, nil + }) + + ctx := context.Background() + + r1, err := fn(ctx, request{Location: "eastus", Type: "ext1"}) + require.NoError(t, err) + assert.Equal(t, "eastus-ext1", r1) + + // same key should return cached result + r2, err := fn(ctx, request{Location: "eastus", Type: "ext1"}) + require.NoError(t, err) + assert.Equal(t, r1, r2) + + // different key should call the function again + r3, err := fn(ctx, request{Location: "westus", Type: "ext1"}) + require.NoError(t, err) + assert.Equal(t, "westus-ext1", r3) + + assert.Equal(t, int32(2), callCount.Load(), "underlying function should be called once per unique struct key") +} From ea69622d84736a41c9d64cd11c32c75fdce51a99 Mon Sep 17 00:00:00 2001 From: Suraj Deshmukh Date: Tue, 10 Mar 2026 22:12:42 -0700 Subject: [PATCH 4/9] fix(e2e): add 30s context timeout to createVMExtensionLinuxAKSNode Prevent the ListVersions Azure API call from blocking indefinitely when the API hangs, which would cause the global 90-minute test timeout to panic and kill all parallel tests. Signed-off-by: Suraj Deshmukh --- e2e/test_helpers.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/e2e/test_helpers.go b/e2e/test_helpers.go index 7322aa3f365..d2a89852663 100644 --- a/e2e/test_helpers.go +++ b/e2e/test_helpers.go @@ -523,6 +523,8 @@ func addTrustedLaunchToVMSS(properties *armcompute.VirtualMachineScaleSetPropert } func createVMExtensionLinuxAKSNode(ctx context.Context, location *string) (*armcompute.VirtualMachineScaleSetExtension, error) { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() region := config.Config.DefaultLocation if location != nil { region = *location From 26f118324c01a293e5633e823c2dd3873d519ccf Mon Sep 17 00:00:00 2001 From: Suraj Deshmukh Date: Tue, 10 Mar 2026 22:22:07 -0700 Subject: [PATCH 5/9] fix(e2e): fall back to hardcoded VM extension version on discovery failure Instead of failing the test when the Azure API call to discover the latest VM extension version fails, fall back to a known-good hardcoded version (1.406) and log a warning. This prevents transient API issues like network errors or throttling from causing test failures. Signed-off-by: Suraj Deshmukh --- e2e/test_helpers.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/e2e/test_helpers.go b/e2e/test_helpers.go index d2a89852663..e557eea0bb3 100644 --- a/e2e/test_helpers.go +++ b/e2e/test_helpers.go @@ -530,6 +530,7 @@ func createVMExtensionLinuxAKSNode(ctx context.Context, location *string) (*armc region = *location } + const fallbackExtensionVersion = "1.406" extensionName := "Compute.AKS.Linux.AKSNode" publisher := "Microsoft.AKS" extensionVersion, err := CachedGetLatestVMExtensionImageVersion(ctx, GetLatestExtensionVersionRequest{ @@ -538,7 +539,8 @@ func createVMExtensionLinuxAKSNode(ctx context.Context, location *string) (*armc Publisher: publisher, }) if err != nil { - return nil, fmt.Errorf("getting latest VM extension image version: %w", err) + toolkit.Logf(ctx, "warning: failed to get latest VM extension version, falling back to %s: %v", fallbackExtensionVersion, err) + extensionVersion = fallbackExtensionVersion } return &armcompute.VirtualMachineScaleSetExtension{ From 1a9cf30e40124481516f86467f94ad170e80b739 Mon Sep 17 00:00:00 2001 From: Suraj Deshmukh Date: Tue, 10 Mar 2026 22:32:27 -0700 Subject: [PATCH 6/9] test(e2e): add unit tests for VM extension version parsing and comparison Extract GetLatestVMExtensionImageVersion logic into a testable unexported function accepting a vmExtensionImageVersionLister interface. Add 26 table-driven tests covering parseVersion, vmExtensionVersion.cmp, and getLatestVMExtensionImageVersion including edge cases for nil names, malformed versions, API errors, and empty responses. Signed-off-by: Suraj Deshmukh --- e2e/config/azure.go | 14 +- e2e/config/azure_vmext_test.go | 304 +++++++++++++++++++++++++++++++++ 2 files changed, 317 insertions(+), 1 deletion(-) create mode 100644 e2e/config/azure_vmext_test.go diff --git a/e2e/config/azure.go b/e2e/config/azure.go index 7044e2efa3a..723dd4518d0 100644 --- a/e2e/config/azure.go +++ b/e2e/config/azure.go @@ -746,11 +746,23 @@ func (a *AzureClient) DeleteSnapshot(ctx context.Context, resourceGroupName, sna return nil } +// vmExtensionImageVersionLister abstracts the ListVersions method of the VM extension images client for testability. +type vmExtensionImageVersionLister interface { + ListVersions(ctx context.Context, location string, publisherName string, typeParam string, + options *armcompute.VirtualMachineExtensionImagesClientListVersionsOptions, + ) (armcompute.VirtualMachineExtensionImagesClientListVersionsResponse, error) +} + // GetLatestVMExtensionImageVersion lists VM extension images for a given extension name and returns the latest version. // This is equivalent to: az vm extension image list -n Compute.AKS.Linux.AKSNode --latest func (a *AzureClient) GetLatestVMExtensionImageVersion(ctx context.Context, location, extType, extPublisher string) (string, error) { + return getLatestVMExtensionImageVersion(ctx, a.VMExtensionImages, location, extType, extPublisher) +} + +// getLatestVMExtensionImageVersion lists VM extension images using the provided lister and returns the latest version. +func getLatestVMExtensionImageVersion(ctx context.Context, lister vmExtensionImageVersionLister, location, extType, extPublisher string) (string, error) { // List extension versions - resp, err := a.VMExtensionImages.ListVersions(ctx, location, extPublisher, extType, &armcompute.VirtualMachineExtensionImagesClientListVersionsOptions{}) + resp, err := lister.ListVersions(ctx, location, extPublisher, extType, &armcompute.VirtualMachineExtensionImagesClientListVersionsOptions{}) if err != nil { return "", fmt.Errorf("listing extension versions: %w", err) } diff --git a/e2e/config/azure_vmext_test.go b/e2e/config/azure_vmext_test.go new file mode 100644 index 00000000000..44e4570dcf1 --- /dev/null +++ b/e2e/config/azure_vmext_test.go @@ -0,0 +1,304 @@ +package config + +import ( + "context" + "fmt" + "strings" + "testing" + + "github.com/Azure/agentbaker/e2e/toolkit" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v7" +) + +// mockVMExtensionImageVersionLister implements vmExtensionImageVersionLister for testing. +type mockVMExtensionImageVersionLister struct { + resp armcompute.VirtualMachineExtensionImagesClientListVersionsResponse + err error +} + +func (m *mockVMExtensionImageVersionLister) ListVersions( + ctx context.Context, + location string, + publisherName string, + typeParam string, + options *armcompute.VirtualMachineExtensionImagesClientListVersionsOptions, +) (armcompute.VirtualMachineExtensionImagesClientListVersionsResponse, error) { + return m.resp, m.err +} + +// makeVersionResponse builds a ListVersionsResponse from a list of version name pointers. +// Pass nil to represent an image with a nil Name field. +func makeVersionResponse(versions ...*string) armcompute.VirtualMachineExtensionImagesClientListVersionsResponse { + images := make([]*armcompute.VirtualMachineExtensionImage, len(versions)) + for i, v := range versions { + images[i] = &armcompute.VirtualMachineExtensionImage{Name: v} + } + return armcompute.VirtualMachineExtensionImagesClientListVersionsResponse{ + VirtualMachineExtensionImageArray: images, + } +} + +func Test_parseVersion(t *testing.T) { + tests := []struct { + name string + inputName *string + expectedMajor int + expectedMinor int + expectedPatch int + }{ + { + name: "three-part version", + inputName: to.Ptr("1.0.1"), + expectedMajor: 1, + expectedMinor: 0, + expectedPatch: 1, + }, + { + name: "two-part version", + inputName: to.Ptr("1.151"), + expectedMajor: 1, + expectedMinor: 151, + expectedPatch: 0, + }, + { + name: "single-part version", + inputName: to.Ptr("5"), + expectedMajor: 5, + expectedMinor: 0, + expectedPatch: 0, + }, + { + name: "nil name", + inputName: nil, + expectedMajor: 0, + expectedMinor: 0, + expectedPatch: 0, + }, + { + name: "non-numeric parts", + inputName: to.Ptr("abc.def.ghi"), + expectedMajor: 0, + expectedMinor: 0, + expectedPatch: 0, + }, + { + name: "partially numeric", + inputName: to.Ptr("2.abc.3"), + expectedMajor: 2, + expectedMinor: 0, + expectedPatch: 3, + }, + { + name: "empty string", + inputName: to.Ptr(""), + expectedMajor: 0, + expectedMinor: 0, + expectedPatch: 0, + }, + { + name: "extra parts ignored", + inputName: to.Ptr("1.2.3.4"), + expectedMajor: 1, + expectedMinor: 2, + expectedPatch: 3, + }, + { + name: "large numbers", + inputName: to.Ptr("100.200.300"), + expectedMajor: 100, + expectedMinor: 200, + expectedPatch: 300, + }, + { + name: "leading zeros", + inputName: to.Ptr("01.02.03"), + expectedMajor: 1, + expectedMinor: 2, + expectedPatch: 3, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := toolkit.ContextWithT(context.Background(), t) + img := &armcompute.VirtualMachineExtensionImage{Name: tt.inputName} + result := parseVersion(ctx, img) + + if result.major != tt.expectedMajor { + t.Errorf("major: got %d, want %d", result.major, tt.expectedMajor) + } + if result.minor != tt.expectedMinor { + t.Errorf("minor: got %d, want %d", result.minor, tt.expectedMinor) + } + if result.patch != tt.expectedPatch { + t.Errorf("patch: got %d, want %d", result.patch, tt.expectedPatch) + } + if result.original != img { + t.Errorf("original: got %p, want %p", result.original, img) + } + }) + } +} + +func Test_vmExtensionVersion_cmp(t *testing.T) { + tests := []struct { + name string + a vmExtensionVersion + b vmExtensionVersion + expected int + }{ + { + name: "equal", + a: vmExtensionVersion{major: 1, minor: 2, patch: 3}, + b: vmExtensionVersion{major: 1, minor: 2, patch: 3}, + expected: 0, + }, + { + name: "a higher major", + a: vmExtensionVersion{major: 2, minor: 0, patch: 0}, + b: vmExtensionVersion{major: 1, minor: 9, patch: 9}, + expected: 1, + }, + { + name: "a lower major", + a: vmExtensionVersion{major: 1, minor: 9, patch: 9}, + b: vmExtensionVersion{major: 2, minor: 0, patch: 0}, + expected: -1, + }, + { + name: "same major, a higher minor", + a: vmExtensionVersion{major: 1, minor: 5, patch: 0}, + b: vmExtensionVersion{major: 1, minor: 3, patch: 9}, + expected: 1, + }, + { + name: "same major, a lower minor", + a: vmExtensionVersion{major: 1, minor: 3, patch: 9}, + b: vmExtensionVersion{major: 1, minor: 5, patch: 0}, + expected: -1, + }, + { + name: "same major+minor, a higher patch", + a: vmExtensionVersion{major: 1, minor: 2, patch: 5}, + b: vmExtensionVersion{major: 1, minor: 2, patch: 3}, + expected: 1, + }, + { + name: "same major+minor, a lower patch", + a: vmExtensionVersion{major: 1, minor: 2, patch: 3}, + b: vmExtensionVersion{major: 1, minor: 2, patch: 5}, + expected: -1, + }, + { + name: "both zero", + a: vmExtensionVersion{major: 0, minor: 0, patch: 0}, + b: vmExtensionVersion{major: 0, minor: 0, patch: 0}, + expected: 0, + }, + { + name: "zero vs non-zero", + a: vmExtensionVersion{major: 0, minor: 0, patch: 0}, + b: vmExtensionVersion{major: 0, minor: 0, patch: 1}, + expected: -1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.a.cmp(tt.b) + if got != tt.expected { + t.Errorf("(%v).cmp(%v) = %d, want %d", tt.a, tt.b, got, tt.expected) + } + }) + } +} + +func Test_getLatestVMExtensionImageVersion(t *testing.T) { + tests := []struct { + name string + mock *mockVMExtensionImageVersionLister + expected string + errContains string + }{ + { + name: "multiple versions, returns latest", + mock: &mockVMExtensionImageVersionLister{ + resp: makeVersionResponse(to.Ptr("1.0.0"), to.Ptr("2.1.0"), to.Ptr("1.5.3")), + }, + expected: "2.1.0", + }, + { + name: "single version", + mock: &mockVMExtensionImageVersionLister{ + resp: makeVersionResponse(to.Ptr("3.2.1")), + }, + expected: "3.2.1", + }, + { + name: "two-part versions", + mock: &mockVMExtensionImageVersionLister{ + resp: makeVersionResponse(to.Ptr("1.100"), to.Ptr("1.151"), to.Ptr("1.50")), + }, + expected: "1.151", + }, + { + name: "API error propagated", + mock: &mockVMExtensionImageVersionLister{ + err: fmt.Errorf("network failure"), + }, + errContains: "listing extension versions", + }, + { + name: "empty list", + mock: &mockVMExtensionImageVersionLister{ + resp: makeVersionResponse(), + }, + errContains: "no extension versions found", + }, + { + name: "all nil names", + mock: &mockVMExtensionImageVersionLister{ + resp: makeVersionResponse(nil), + }, + errContains: "latest extension version has nil name", + }, + { + name: "mix valid and malformed", + mock: &mockVMExtensionImageVersionLister{ + resp: makeVersionResponse(to.Ptr("abc"), to.Ptr("1.2.3"), to.Ptr("xyz")), + }, + expected: "1.2.3", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := toolkit.ContextWithT(context.Background(), t) + got, err := getLatestVMExtensionImageVersion( + ctx, + tt.mock, + "eastus", + "TestExtension", + "TestPublisher", + ) + + if tt.errContains != "" { + if err == nil { + t.Fatalf("expected error containing %q, got nil", tt.errContains) + } + if !strings.Contains(err.Error(), tt.errContains) { + t.Errorf("error %q does not contain %q", err.Error(), tt.errContains) + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != tt.expected { + t.Errorf("got %q, want %q", got, tt.expected) + } + }) + } +} From 28a981ad852a8bd546056803e0281387990b0b0e Mon Sep 17 00:00:00 2001 From: Suraj Deshmukh Date: Tue, 10 Mar 2026 22:43:27 -0700 Subject: [PATCH 7/9] fix(e2e): use %v instead of %w in fmt.Sprintf for panic message fmt.Sprintf does not support the error-wrapping directive %w, which caused a build failure. The %w verb is only valid in fmt.Errorf. Signed-off-by: Suraj Deshmukh --- e2e/config/config.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/e2e/config/config.go b/e2e/config/config.go index 02cf274cdc9..57d616a72a1 100644 --- a/e2e/config/config.go +++ b/e2e/config/config.go @@ -179,7 +179,7 @@ func mustGetNewRSAKeyPair() ([]byte, []byte, string) { privateKeyFileName, err := writePrivateKeyToTempFile(privatePEMBytes) if err != nil { - panic(fmt.Sprintf("failed to write private key to temp file: %w", err)) + panic(fmt.Sprintf("failed to write private key to temp file: %v", err)) } return privatePEMBytes, publicKeyBytes, privateKeyFileName From 11558eb056631242bea1f35a98095596e9d006bd Mon Sep 17 00:00:00 2001 From: Suraj Deshmukh Date: Wed, 11 Mar 2026 11:38:00 -0700 Subject: [PATCH 8/9] test(e2e): fix flaky VM extension cache timing test Replace fragile wall-clock timing assertions with deterministic checks for the createVMExtensionLinuxAKSNode cache test. The previous 10s/1s thresholds tested CI/network performance rather than caching logic. - Add testing.Short() skip guard for integration test - Add nil/empty checks on returned extensions and properties - Replace assert with require for fail-fast behavior - Keep cache consistency check and duration logging for diagnostics - Remove unused assert import Signed-off-by: Suraj Deshmukh --- e2e/scenario_gpu_managed_experience_test.go | 40 +++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/e2e/scenario_gpu_managed_experience_test.go b/e2e/scenario_gpu_managed_experience_test.go index b39f9c20cc7..42693167cef 100644 --- a/e2e/scenario_gpu_managed_experience_test.go +++ b/e2e/scenario_gpu_managed_experience_test.go @@ -604,3 +604,43 @@ func Test_Ubuntu2204_NvidiaDevicePluginRunning_WithoutVMSSTag(t *testing.T) { }, }) } + +func Test_CreateVMExtensionLinuxAKSNode_Timing(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + // First call — may hit the Azure API or cache + start := time.Now() + ext, err := createVMExtensionLinuxAKSNode(t.Context(), nil) + firstDuration := time.Since(start) + require.NoError(t, err, "first call to createVMExtensionLinuxAKSNode failed") + require.NotNil(t, ext, "first call returned nil extension") + t.Logf("First call duration: %s", firstDuration) + + // Second call — should be served from cache + start = time.Now() + ext2, err := createVMExtensionLinuxAKSNode(t.Context(), nil) + secondDuration := time.Since(start) + require.NoError(t, err, "second call to createVMExtensionLinuxAKSNode failed") + require.NotNil(t, ext2, "second call returned nil extension") + t.Logf("Second call duration: %s", secondDuration) + + // Both calls should return a valid, consistent TypeHandlerVersion + require.NotNil(t, ext.Properties, "first extension has nil Properties") + require.NotNil(t, ext2.Properties, "second extension has nil Properties") + require.NotNil(t, ext.Properties.TypeHandlerVersion, "first TypeHandlerVersion is nil") + require.NotNil(t, ext2.Properties.TypeHandlerVersion, "second TypeHandlerVersion is nil") + require.NotEmpty(t, *ext.Properties.TypeHandlerVersion, "first TypeHandlerVersion is empty") + require.NotEmpty(t, *ext2.Properties.TypeHandlerVersion, "second TypeHandlerVersion is empty") + + // // TODO: @surajssd, uncomment this when you update the aks vm extension + // // version that is different than 1.406. + // // Ensure we actually hit Azure and didn't just get the fallback version + // require.NotEqual(t, "1.406", *ext.Properties.TypeHandlerVersion, + // "extension version is the hardcoded fallback — Azure API may not have been reached") + + // Cache consistency: both calls should return the same version + require.Equal(t, *ext.Properties.TypeHandlerVersion, *ext2.Properties.TypeHandlerVersion, + "both calls should return the same extension version") +} From 7872d35c237fd58877110c4c196ec93d36b75c26 Mon Sep 17 00:00:00 2001 From: Suraj Deshmukh Date: Wed, 11 Mar 2026 11:49:37 -0700 Subject: [PATCH 9/9] fix(e2e): add missing return in toolkit.Logf nil guard The nil guard in Logf fell through to t.Helper() and t.Logf() when the context had no testing.TB, causing a nil-pointer dereference panic. Add the missing return to match the sibling Log function's pattern. Signed-off-by: Suraj Deshmukh --- e2e/toolkit/log.go | 1 + 1 file changed, 1 insertion(+) diff --git a/e2e/toolkit/log.go b/e2e/toolkit/log.go index dbbbb7a19cc..5d11d13dbc2 100644 --- a/e2e/toolkit/log.go +++ b/e2e/toolkit/log.go @@ -23,6 +23,7 @@ func Logf(ctx context.Context, format string, args ...any) { t, ok := ctx.Value(testLoggerKey{}).(testing.TB) if !ok || t == nil { log.Printf(format+"WARNING: No *testing.T in Context, this function should only be called from ", args...) + return } t.Helper() t.Logf(format, args...)