diff --git a/e2e/cache.go b/e2e/cache.go index cfa45122af7..1b07d383815 100644 --- a/e2e/cache.go +++ b/e2e/cache.go @@ -245,3 +245,18 @@ var CachedVMSizeSupportsNVMe = cachedFunc(func(ctx context.Context, req VMSizeSK var CachedIsVMSizeGen2Only = cachedFunc(func(ctx context.Context, req VMSizeSKURequest) (bool, error) { return config.Azure.IsVMSizeGen2Only(ctx, req.Location, req.VMSize) }) + +// GetLatestExtensionVersionRequest is the cache key for VM extension version lookups. +type GetLatestExtensionVersionRequest struct { + Location string + ExtType string + Publisher string +} + +// CachedGetLatestVMExtensionImageVersion caches the result of querying the Azure API +// for the latest VM extension image version. +var CachedGetLatestVMExtensionImageVersion = cachedFunc( + func(ctx context.Context, req GetLatestExtensionVersionRequest) (string, error) { + return config.Azure.GetLatestVMExtensionImageVersion(ctx, req.Location, req.ExtType, req.Publisher) + }, +) diff --git a/e2e/cache_test.go b/e2e/cache_test.go new file mode 100644 index 00000000000..deeb83aee7f --- /dev/null +++ b/e2e/cache_test.go @@ -0,0 +1,123 @@ +package e2e + +import ( + "context" + "fmt" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_cachedFunc_returns_consistent_results(t *testing.T) { + var callCount atomic.Int32 + fn := cachedFunc(func(ctx context.Context, key string) (string, error) { + callCount.Add(1) + return "result-" + key, nil + }) + + ctx := context.Background() + first, err := fn(ctx, "a") + require.NoError(t, err) + + second, err := fn(ctx, "a") + require.NoError(t, err) + + assert.Equal(t, first, second, "cached function should return the same result on repeated calls") + assert.Equal(t, int32(1), callCount.Load(), "underlying function should only be called once for the same key") +} + +func Test_cachedFunc_warm_call_is_faster_than_cold(t *testing.T) { + fn := cachedFunc(func(ctx context.Context, key string) (string, error) { + // simulate a slow operation like a network call + time.Sleep(10 * time.Millisecond) + return "result", nil + }) + + ctx := context.Background() + + start := time.Now() + _, err := fn(ctx, "key") + coldDuration := time.Since(start) + require.NoError(t, err) + + start = time.Now() + _, err = fn(ctx, "key") + warmDuration := time.Since(start) + require.NoError(t, err) + + assert.Less(t, warmDuration, coldDuration, "warm (cached) call should be faster than cold call") +} + +func Test_cachedFunc_different_keys_produce_different_cache_entries(t *testing.T) { + var callCount atomic.Int32 + fn := cachedFunc(func(ctx context.Context, key string) (string, error) { + callCount.Add(1) + return "result-" + key, nil + }) + + ctx := context.Background() + + resultA, err := fn(ctx, "a") + require.NoError(t, err) + + resultB, err := fn(ctx, "b") + require.NoError(t, err) + + assert.Equal(t, "result-a", resultA) + assert.Equal(t, "result-b", resultB) + assert.NotEqual(t, resultA, resultB, "different keys should produce different results") + assert.Equal(t, int32(2), callCount.Load(), "underlying function should be called once per unique key") +} + +func Test_cachedFunc_caches_errors(t *testing.T) { + var callCount atomic.Int32 + expectedErr := fmt.Errorf("something went wrong") + fn := cachedFunc(func(ctx context.Context, key string) (string, error) { + callCount.Add(1) + return "", expectedErr + }) + + ctx := context.Background() + + _, err1 := fn(ctx, "a") + require.ErrorIs(t, err1, expectedErr) + + _, err2 := fn(ctx, "a") + require.ErrorIs(t, err2, expectedErr) + + assert.Equal(t, int32(1), callCount.Load(), "underlying function should only be called once even when it returns an error") +} + +func Test_cachedFunc_with_struct_key(t *testing.T) { + type request struct { + Location string + Type string + } + + var callCount atomic.Int32 + fn := cachedFunc(func(ctx context.Context, req request) (string, error) { + callCount.Add(1) + return req.Location + "-" + req.Type, nil + }) + + ctx := context.Background() + + r1, err := fn(ctx, request{Location: "eastus", Type: "ext1"}) + require.NoError(t, err) + assert.Equal(t, "eastus-ext1", r1) + + // same key should return cached result + r2, err := fn(ctx, request{Location: "eastus", Type: "ext1"}) + require.NoError(t, err) + assert.Equal(t, r1, r2) + + // different key should call the function again + r3, err := fn(ctx, request{Location: "westus", Type: "ext1"}) + require.NoError(t, err) + assert.Equal(t, "westus-ext1", r3) + + assert.Equal(t, int32(2), callCount.Load(), "underlying function should be called once per unique struct key") +} diff --git a/e2e/config/azure.go b/e2e/config/azure.go index e003a87754c..723dd4518d0 100644 --- a/e2e/config/azure.go +++ b/e2e/config/azure.go @@ -1,6 +1,7 @@ package config import ( + "cmp" "context" "crypto/tls" "errors" @@ -8,7 +9,7 @@ import ( "net" "net/http" "os" - "sort" + "slices" "strconv" "strings" "time" @@ -745,75 +746,99 @@ func (a *AzureClient) DeleteSnapshot(ctx context.Context, resourceGroupName, sna return nil } +// vmExtensionImageVersionLister abstracts the ListVersions method of the VM extension images client for testability. +type vmExtensionImageVersionLister interface { + ListVersions(ctx context.Context, location string, publisherName string, typeParam string, + options *armcompute.VirtualMachineExtensionImagesClientListVersionsOptions, + ) (armcompute.VirtualMachineExtensionImagesClientListVersionsResponse, error) +} + // GetLatestVMExtensionImageVersion lists VM extension images for a given extension name and returns the latest version. // This is equivalent to: az vm extension image list -n Compute.AKS.Linux.AKSNode --latest func (a *AzureClient) GetLatestVMExtensionImageVersion(ctx context.Context, location, extType, extPublisher string) (string, error) { + return getLatestVMExtensionImageVersion(ctx, a.VMExtensionImages, location, extType, extPublisher) +} + +// getLatestVMExtensionImageVersion lists VM extension images using the provided lister and returns the latest version. +func getLatestVMExtensionImageVersion(ctx context.Context, lister vmExtensionImageVersionLister, location, extType, extPublisher string) (string, error) { // List extension versions - resp, err := a.VMExtensionImages.ListVersions(ctx, location, extPublisher, extType, &armcompute.VirtualMachineExtensionImagesClientListVersionsOptions{}) + resp, err := lister.ListVersions(ctx, location, extPublisher, extType, &armcompute.VirtualMachineExtensionImagesClientListVersionsOptions{}) if err != nil { return "", fmt.Errorf("listing extension versions: %w", err) } - if len(resp.VirtualMachineExtensionImageArray) == 0 { return "", fmt.Errorf("no extension versions found") } - version := make([]VMExtenstionVersion, len(resp.VirtualMachineExtensionImageArray)) + versions := make([]vmExtensionVersion, len(resp.VirtualMachineExtensionImageArray)) for i, ext := range resp.VirtualMachineExtensionImageArray { - version[i] = parseVersion(ext) + versions[i] = parseVersion(ctx, ext) } - sort.Slice(version, func(i, j int) bool { - return version[i].Less(version[j]) + latest := slices.MaxFunc(versions, func(a, b vmExtensionVersion) int { + return a.cmp(b) }) - - return *version[len(version)-1].Original.Name, nil + if latest.original.Name == nil { + return "", fmt.Errorf("latest extension version has nil name") + } + return *latest.original.Name, nil } -// VMExtenstionVersion represents a parsed version of a VM extension image. -type VMExtenstionVersion struct { - Original *armcompute.VirtualMachineExtensionImage - Major int - Minor int - Patch int +// vmExtensionVersion represents a parsed version of a VM extension image. +type vmExtensionVersion struct { + original *armcompute.VirtualMachineExtensionImage + major int + minor int + patch int } // parseVersion parses the version from a VM extension image name, which can be in the format 1.151, 1.0.1, etc. // You can find all the versions of a specific VM extension by running: // az vm extension image list -n Compute.AKS.Linux.AKSNode -func parseVersion(v *armcompute.VirtualMachineExtensionImage) VMExtenstionVersion { +func parseVersion(ctx context.Context, v *armcompute.VirtualMachineExtensionImage) vmExtensionVersion { + version := vmExtensionVersion{original: v} + if v.Name == nil { + toolkit.Logf(ctx, "warning: VM extension image has nil name, skipping version parse") + return version + } + // Split by dots parts := strings.Split(*v.Name, ".") - version := VMExtenstionVersion{Original: v} - if len(parts) >= 1 { if major, err := strconv.Atoi(parts[0]); err == nil { - version.Major = major + version.major = major + } else { + toolkit.Logf(ctx, "warning: failed to parse major version from %q: %v", *v.Name, err) } } if len(parts) >= 2 { if minor, err := strconv.Atoi(parts[1]); err == nil { - version.Minor = minor + version.minor = minor + } else { + toolkit.Logf(ctx, "warning: failed to parse minor version from %q: %v", *v.Name, err) } } if len(parts) >= 3 { if patch, err := strconv.Atoi(parts[2]); err == nil { - version.Patch = patch + version.patch = patch + } else { + toolkit.Logf(ctx, "warning: failed to parse patch version from %q: %v", *v.Name, err) } } return version } -func (v VMExtenstionVersion) Less(other VMExtenstionVersion) bool { - if v.Major != other.Major { - return v.Major < other.Major +// cmp compares two versions, returning -1, 0, or 1. +func (v vmExtensionVersion) cmp(other vmExtensionVersion) int { + if c := cmp.Compare(v.major, other.major); c != 0 { + return c } - if v.Minor != other.Minor { - return v.Minor < other.Minor + if c := cmp.Compare(v.minor, other.minor); c != 0 { + return c } - return v.Patch < other.Patch + return cmp.Compare(v.patch, other.patch) } // getResourceSKU queries the Azure Resource SKUs API to find the SKU for the given VM size and location. diff --git a/e2e/config/azure_vmext_test.go b/e2e/config/azure_vmext_test.go new file mode 100644 index 00000000000..44e4570dcf1 --- /dev/null +++ b/e2e/config/azure_vmext_test.go @@ -0,0 +1,304 @@ +package config + +import ( + "context" + "fmt" + "strings" + "testing" + + "github.com/Azure/agentbaker/e2e/toolkit" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v7" +) + +// mockVMExtensionImageVersionLister implements vmExtensionImageVersionLister for testing. +type mockVMExtensionImageVersionLister struct { + resp armcompute.VirtualMachineExtensionImagesClientListVersionsResponse + err error +} + +func (m *mockVMExtensionImageVersionLister) ListVersions( + ctx context.Context, + location string, + publisherName string, + typeParam string, + options *armcompute.VirtualMachineExtensionImagesClientListVersionsOptions, +) (armcompute.VirtualMachineExtensionImagesClientListVersionsResponse, error) { + return m.resp, m.err +} + +// makeVersionResponse builds a ListVersionsResponse from a list of version name pointers. +// Pass nil to represent an image with a nil Name field. +func makeVersionResponse(versions ...*string) armcompute.VirtualMachineExtensionImagesClientListVersionsResponse { + images := make([]*armcompute.VirtualMachineExtensionImage, len(versions)) + for i, v := range versions { + images[i] = &armcompute.VirtualMachineExtensionImage{Name: v} + } + return armcompute.VirtualMachineExtensionImagesClientListVersionsResponse{ + VirtualMachineExtensionImageArray: images, + } +} + +func Test_parseVersion(t *testing.T) { + tests := []struct { + name string + inputName *string + expectedMajor int + expectedMinor int + expectedPatch int + }{ + { + name: "three-part version", + inputName: to.Ptr("1.0.1"), + expectedMajor: 1, + expectedMinor: 0, + expectedPatch: 1, + }, + { + name: "two-part version", + inputName: to.Ptr("1.151"), + expectedMajor: 1, + expectedMinor: 151, + expectedPatch: 0, + }, + { + name: "single-part version", + inputName: to.Ptr("5"), + expectedMajor: 5, + expectedMinor: 0, + expectedPatch: 0, + }, + { + name: "nil name", + inputName: nil, + expectedMajor: 0, + expectedMinor: 0, + expectedPatch: 0, + }, + { + name: "non-numeric parts", + inputName: to.Ptr("abc.def.ghi"), + expectedMajor: 0, + expectedMinor: 0, + expectedPatch: 0, + }, + { + name: "partially numeric", + inputName: to.Ptr("2.abc.3"), + expectedMajor: 2, + expectedMinor: 0, + expectedPatch: 3, + }, + { + name: "empty string", + inputName: to.Ptr(""), + expectedMajor: 0, + expectedMinor: 0, + expectedPatch: 0, + }, + { + name: "extra parts ignored", + inputName: to.Ptr("1.2.3.4"), + expectedMajor: 1, + expectedMinor: 2, + expectedPatch: 3, + }, + { + name: "large numbers", + inputName: to.Ptr("100.200.300"), + expectedMajor: 100, + expectedMinor: 200, + expectedPatch: 300, + }, + { + name: "leading zeros", + inputName: to.Ptr("01.02.03"), + expectedMajor: 1, + expectedMinor: 2, + expectedPatch: 3, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := toolkit.ContextWithT(context.Background(), t) + img := &armcompute.VirtualMachineExtensionImage{Name: tt.inputName} + result := parseVersion(ctx, img) + + if result.major != tt.expectedMajor { + t.Errorf("major: got %d, want %d", result.major, tt.expectedMajor) + } + if result.minor != tt.expectedMinor { + t.Errorf("minor: got %d, want %d", result.minor, tt.expectedMinor) + } + if result.patch != tt.expectedPatch { + t.Errorf("patch: got %d, want %d", result.patch, tt.expectedPatch) + } + if result.original != img { + t.Errorf("original: got %p, want %p", result.original, img) + } + }) + } +} + +func Test_vmExtensionVersion_cmp(t *testing.T) { + tests := []struct { + name string + a vmExtensionVersion + b vmExtensionVersion + expected int + }{ + { + name: "equal", + a: vmExtensionVersion{major: 1, minor: 2, patch: 3}, + b: vmExtensionVersion{major: 1, minor: 2, patch: 3}, + expected: 0, + }, + { + name: "a higher major", + a: vmExtensionVersion{major: 2, minor: 0, patch: 0}, + b: vmExtensionVersion{major: 1, minor: 9, patch: 9}, + expected: 1, + }, + { + name: "a lower major", + a: vmExtensionVersion{major: 1, minor: 9, patch: 9}, + b: vmExtensionVersion{major: 2, minor: 0, patch: 0}, + expected: -1, + }, + { + name: "same major, a higher minor", + a: vmExtensionVersion{major: 1, minor: 5, patch: 0}, + b: vmExtensionVersion{major: 1, minor: 3, patch: 9}, + expected: 1, + }, + { + name: "same major, a lower minor", + a: vmExtensionVersion{major: 1, minor: 3, patch: 9}, + b: vmExtensionVersion{major: 1, minor: 5, patch: 0}, + expected: -1, + }, + { + name: "same major+minor, a higher patch", + a: vmExtensionVersion{major: 1, minor: 2, patch: 5}, + b: vmExtensionVersion{major: 1, minor: 2, patch: 3}, + expected: 1, + }, + { + name: "same major+minor, a lower patch", + a: vmExtensionVersion{major: 1, minor: 2, patch: 3}, + b: vmExtensionVersion{major: 1, minor: 2, patch: 5}, + expected: -1, + }, + { + name: "both zero", + a: vmExtensionVersion{major: 0, minor: 0, patch: 0}, + b: vmExtensionVersion{major: 0, minor: 0, patch: 0}, + expected: 0, + }, + { + name: "zero vs non-zero", + a: vmExtensionVersion{major: 0, minor: 0, patch: 0}, + b: vmExtensionVersion{major: 0, minor: 0, patch: 1}, + expected: -1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.a.cmp(tt.b) + if got != tt.expected { + t.Errorf("(%v).cmp(%v) = %d, want %d", tt.a, tt.b, got, tt.expected) + } + }) + } +} + +func Test_getLatestVMExtensionImageVersion(t *testing.T) { + tests := []struct { + name string + mock *mockVMExtensionImageVersionLister + expected string + errContains string + }{ + { + name: "multiple versions, returns latest", + mock: &mockVMExtensionImageVersionLister{ + resp: makeVersionResponse(to.Ptr("1.0.0"), to.Ptr("2.1.0"), to.Ptr("1.5.3")), + }, + expected: "2.1.0", + }, + { + name: "single version", + mock: &mockVMExtensionImageVersionLister{ + resp: makeVersionResponse(to.Ptr("3.2.1")), + }, + expected: "3.2.1", + }, + { + name: "two-part versions", + mock: &mockVMExtensionImageVersionLister{ + resp: makeVersionResponse(to.Ptr("1.100"), to.Ptr("1.151"), to.Ptr("1.50")), + }, + expected: "1.151", + }, + { + name: "API error propagated", + mock: &mockVMExtensionImageVersionLister{ + err: fmt.Errorf("network failure"), + }, + errContains: "listing extension versions", + }, + { + name: "empty list", + mock: &mockVMExtensionImageVersionLister{ + resp: makeVersionResponse(), + }, + errContains: "no extension versions found", + }, + { + name: "all nil names", + mock: &mockVMExtensionImageVersionLister{ + resp: makeVersionResponse(nil), + }, + errContains: "latest extension version has nil name", + }, + { + name: "mix valid and malformed", + mock: &mockVMExtensionImageVersionLister{ + resp: makeVersionResponse(to.Ptr("abc"), to.Ptr("1.2.3"), to.Ptr("xyz")), + }, + expected: "1.2.3", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := toolkit.ContextWithT(context.Background(), t) + got, err := getLatestVMExtensionImageVersion( + ctx, + tt.mock, + "eastus", + "TestExtension", + "TestPublisher", + ) + + if tt.errContains != "" { + if err == nil { + t.Fatalf("expected error containing %q, got nil", tt.errContains) + } + if !strings.Contains(err.Error(), tt.errContains) { + t.Errorf("error %q does not contain %q", err.Error(), tt.errContains) + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != tt.expected { + t.Errorf("got %q, want %q", got, tt.expected) + } + }) + } +} diff --git a/e2e/config/config.go b/e2e/config/config.go index 02cf274cdc9..57d616a72a1 100644 --- a/e2e/config/config.go +++ b/e2e/config/config.go @@ -179,7 +179,7 @@ func mustGetNewRSAKeyPair() ([]byte, []byte, string) { privateKeyFileName, err := writePrivateKeyToTempFile(privatePEMBytes) if err != nil { - panic(fmt.Sprintf("failed to write private key to temp file: %w", err)) + panic(fmt.Sprintf("failed to write private key to temp file: %v", err)) } return privatePEMBytes, publicKeyBytes, privateKeyFileName diff --git a/e2e/scenario_gpu_managed_experience_test.go b/e2e/scenario_gpu_managed_experience_test.go index 80cefed3e51..42693167cef 100644 --- a/e2e/scenario_gpu_managed_experience_test.go +++ b/e2e/scenario_gpu_managed_experience_test.go @@ -259,7 +259,7 @@ func Test_Ubuntu2404_NvidiaDevicePluginRunning(t *testing.T) { vmss.Tags["EnableManagedGPUExperience"] = to.Ptr("true") // Enable the AKS VM extension for GPU nodes - extension, err := createVMExtensionLinuxAKSNode(vmss.Location) + extension, err := createVMExtensionLinuxAKSNode(t.Context(), vmss.Location) require.NoError(t, err, "creating AKS VM extension") vmss.Properties = addVMExtensionToVMSS(vmss.Properties, extension) }, @@ -334,7 +334,7 @@ func Test_Ubuntu2204_NvidiaDevicePluginRunning(t *testing.T) { vmss.Tags["EnableManagedGPUExperience"] = to.Ptr("true") // Enable the AKS VM extension for GPU nodes - extension, err := createVMExtensionLinuxAKSNode(vmss.Location) + extension, err := createVMExtensionLinuxAKSNode(t.Context(), vmss.Location) require.NoError(t, err, "creating AKS VM extension") vmss.Properties = addVMExtensionToVMSS(vmss.Properties, extension) }, @@ -408,7 +408,7 @@ func Test_AzureLinux3_NvidiaDevicePluginRunning(t *testing.T) { vmss.Tags["EnableManagedGPUExperience"] = to.Ptr("true") // Enable the AKS VM extension for GPU nodes - extension, err := createVMExtensionLinuxAKSNode(vmss.Location) + extension, err := createVMExtensionLinuxAKSNode(t.Context(), vmss.Location) require.NoError(t, err, "creating AKS VM extension") vmss.Properties = addVMExtensionToVMSS(vmss.Properties, extension) }, @@ -478,7 +478,7 @@ func Test_Ubuntu2404_NvidiaDevicePluginRunning_MIG(t *testing.T) { vmss.SKU.Name = to.Ptr("Standard_NC24ads_A100_v4") // Enable the AKS VM extension for GPU nodes - extension, err := createVMExtensionLinuxAKSNode(vmss.Location) + extension, err := createVMExtensionLinuxAKSNode(t.Context(), vmss.Location) require.NoError(t, err, "creating AKS VM extension") vmss.Properties = addVMExtensionToVMSS(vmss.Properties, extension) }, @@ -555,7 +555,7 @@ func Test_Ubuntu2204_NvidiaDevicePluginRunning_WithoutVMSSTag(t *testing.T) { // to test that NBC EnableManagedGPU field works independently // Enable the AKS VM extension for GPU nodes - extension, err := createVMExtensionLinuxAKSNode(vmss.Location) + extension, err := createVMExtensionLinuxAKSNode(t.Context(), vmss.Location) require.NoError(t, err, "creating AKS VM extension") vmss.Properties = addVMExtensionToVMSS(vmss.Properties, extension) }, @@ -604,3 +604,43 @@ func Test_Ubuntu2204_NvidiaDevicePluginRunning_WithoutVMSSTag(t *testing.T) { }, }) } + +func Test_CreateVMExtensionLinuxAKSNode_Timing(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + // First call — may hit the Azure API or cache + start := time.Now() + ext, err := createVMExtensionLinuxAKSNode(t.Context(), nil) + firstDuration := time.Since(start) + require.NoError(t, err, "first call to createVMExtensionLinuxAKSNode failed") + require.NotNil(t, ext, "first call returned nil extension") + t.Logf("First call duration: %s", firstDuration) + + // Second call — should be served from cache + start = time.Now() + ext2, err := createVMExtensionLinuxAKSNode(t.Context(), nil) + secondDuration := time.Since(start) + require.NoError(t, err, "second call to createVMExtensionLinuxAKSNode failed") + require.NotNil(t, ext2, "second call returned nil extension") + t.Logf("Second call duration: %s", secondDuration) + + // Both calls should return a valid, consistent TypeHandlerVersion + require.NotNil(t, ext.Properties, "first extension has nil Properties") + require.NotNil(t, ext2.Properties, "second extension has nil Properties") + require.NotNil(t, ext.Properties.TypeHandlerVersion, "first TypeHandlerVersion is nil") + require.NotNil(t, ext2.Properties.TypeHandlerVersion, "second TypeHandlerVersion is nil") + require.NotEmpty(t, *ext.Properties.TypeHandlerVersion, "first TypeHandlerVersion is empty") + require.NotEmpty(t, *ext2.Properties.TypeHandlerVersion, "second TypeHandlerVersion is empty") + + // // TODO: @surajssd, uncomment this when you update the aks vm extension + // // version that is different than 1.406. + // // Ensure we actually hit Azure and didn't just get the fallback version + // require.NotEqual(t, "1.406", *ext.Properties.TypeHandlerVersion, + // "extension version is the hardcoded fallback — Azure API may not have been reached") + + // Cache consistency: both calls should return the same version + require.Equal(t, *ext.Properties.TypeHandlerVersion, *ext2.Properties.TypeHandlerVersion, + "both calls should return the same extension version") +} diff --git a/e2e/scenario_test.go b/e2e/scenario_test.go index 4dd5605388f..8fefd29e755 100644 --- a/e2e/scenario_test.go +++ b/e2e/scenario_test.go @@ -2138,7 +2138,7 @@ func Test_Ubuntu2404_NPD_Basic(t *testing.T) { BootstrapConfigMutator: func(nbc *datamodel.NodeBootstrappingConfiguration) { }, VMConfigMutator: func(vmss *armcompute.VirtualMachineScaleSet) { - extension, err := createVMExtensionLinuxAKSNode(vmss.Location) + extension, err := createVMExtensionLinuxAKSNode(t.Context(), vmss.Location) require.NoError(t, err, "creating AKS VM extension") vmss.Properties = addVMExtensionToVMSS(vmss.Properties, extension) }, diff --git a/e2e/test_helpers.go b/e2e/test_helpers.go index 6b05a5d5924..e557eea0bb3 100644 --- a/e2e/test_helpers.go +++ b/e2e/test_helpers.go @@ -522,21 +522,26 @@ func addTrustedLaunchToVMSS(properties *armcompute.VirtualMachineScaleSetPropert return properties } -func createVMExtensionLinuxAKSNode(_ *string) (*armcompute.VirtualMachineScaleSetExtension, error) { - // Default to "westus" if location is nil. - // region := "westus" - // if location != nil { - // region = *location - // } +func createVMExtensionLinuxAKSNode(ctx context.Context, location *string) (*armcompute.VirtualMachineScaleSetExtension, error) { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + region := config.Config.DefaultLocation + if location != nil { + region = *location + } + const fallbackExtensionVersion = "1.406" extensionName := "Compute.AKS.Linux.AKSNode" publisher := "Microsoft.AKS" - extensionVersion := "1.406" - // NOTE (@surajssd): If this is gonna be called multiple times, then find a way to cache the latest version. - // extensionVersion, err := config.Azure.GetLatestVMExtensionImageVersion(context.TODO(), region, extensionName, publisher) - // if err != nil { - // return nil, fmt.Errorf("getting latest VM extension image version: %v", err) - // } + extensionVersion, err := CachedGetLatestVMExtensionImageVersion(ctx, GetLatestExtensionVersionRequest{ + Location: region, + ExtType: extensionName, + Publisher: publisher, + }) + if err != nil { + toolkit.Logf(ctx, "warning: failed to get latest VM extension version, falling back to %s: %v", fallbackExtensionVersion, err) + extensionVersion = fallbackExtensionVersion + } return &armcompute.VirtualMachineScaleSetExtension{ Name: to.Ptr(extensionName), @@ -795,7 +800,7 @@ func runScenarioGPUNPD(t *testing.T, vmSize, location, k8sSystemPoolSKU string) VMConfigMutator: func(vmss *armcompute.VirtualMachineScaleSet) { vmss.SKU.Name = to.Ptr(vmSize) - extension, err := createVMExtensionLinuxAKSNode(vmss.Location) + extension, err := createVMExtensionLinuxAKSNode(t.Context(), vmss.Location) require.NoError(t, err, "creating AKS VM extension") vmss.Properties = addVMExtensionToVMSS(vmss.Properties, extension) diff --git a/e2e/toolkit/log.go b/e2e/toolkit/log.go index dbbbb7a19cc..5d11d13dbc2 100644 --- a/e2e/toolkit/log.go +++ b/e2e/toolkit/log.go @@ -23,6 +23,7 @@ func Logf(ctx context.Context, format string, args ...any) { t, ok := ctx.Value(testLoggerKey{}).(testing.TB) if !ok || t == nil { log.Printf(format+"WARNING: No *testing.T in Context, this function should only be called from ", args...) + return } t.Helper() t.Logf(format, args...)