Skip to content
Open
15 changes: 15 additions & 0 deletions e2e/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,3 +245,18 @@ var CachedVMSizeSupportsNVMe = cachedFunc(func(ctx context.Context, req VMSizeSK
var CachedIsVMSizeGen2Only = cachedFunc(func(ctx context.Context, req VMSizeSKURequest) (bool, error) {
return config.Azure.IsVMSizeGen2Only(ctx, req.Location, req.VMSize)
})

// GetLatestExtensionVersionRequest is the cache key for VM extension version lookups.
type GetLatestExtensionVersionRequest struct {
Location string
ExtType string
Publisher string
}

// CachedGetLatestVMExtensionImageVersion caches the result of querying the Azure API
// for the latest VM extension image version.
var CachedGetLatestVMExtensionImageVersion = cachedFunc(
func(ctx context.Context, req GetLatestExtensionVersionRequest) (string, error) {
return config.Azure.GetLatestVMExtensionImageVersion(ctx, req.Location, req.ExtType, req.Publisher)
},
)
123 changes: 123 additions & 0 deletions e2e/cache_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package e2e

import (
"context"
"fmt"
"sync/atomic"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func Test_cachedFunc_returns_consistent_results(t *testing.T) {
var callCount atomic.Int32
fn := cachedFunc(func(ctx context.Context, key string) (string, error) {
callCount.Add(1)
return "result-" + key, nil
})

ctx := context.Background()
first, err := fn(ctx, "a")
require.NoError(t, err)

second, err := fn(ctx, "a")
require.NoError(t, err)

assert.Equal(t, first, second, "cached function should return the same result on repeated calls")
assert.Equal(t, int32(1), callCount.Load(), "underlying function should only be called once for the same key")
}

func Test_cachedFunc_warm_call_is_faster_than_cold(t *testing.T) {
fn := cachedFunc(func(ctx context.Context, key string) (string, error) {
// simulate a slow operation like a network call
time.Sleep(10 * time.Millisecond)
return "result", nil
})

ctx := context.Background()

start := time.Now()
_, err := fn(ctx, "key")
coldDuration := time.Since(start)
require.NoError(t, err)

start = time.Now()
_, err = fn(ctx, "key")
warmDuration := time.Since(start)
require.NoError(t, err)

assert.Less(t, warmDuration, coldDuration, "warm (cached) call should be faster than cold call")
}

func Test_cachedFunc_different_keys_produce_different_cache_entries(t *testing.T) {
var callCount atomic.Int32
fn := cachedFunc(func(ctx context.Context, key string) (string, error) {
callCount.Add(1)
return "result-" + key, nil
})

ctx := context.Background()

resultA, err := fn(ctx, "a")
require.NoError(t, err)

resultB, err := fn(ctx, "b")
require.NoError(t, err)

assert.Equal(t, "result-a", resultA)
assert.Equal(t, "result-b", resultB)
assert.NotEqual(t, resultA, resultB, "different keys should produce different results")
assert.Equal(t, int32(2), callCount.Load(), "underlying function should be called once per unique key")
}

func Test_cachedFunc_caches_errors(t *testing.T) {
var callCount atomic.Int32
expectedErr := fmt.Errorf("something went wrong")
fn := cachedFunc(func(ctx context.Context, key string) (string, error) {
callCount.Add(1)
return "", expectedErr
})

ctx := context.Background()

_, err1 := fn(ctx, "a")
require.ErrorIs(t, err1, expectedErr)

_, err2 := fn(ctx, "a")
require.ErrorIs(t, err2, expectedErr)

assert.Equal(t, int32(1), callCount.Load(), "underlying function should only be called once even when it returns an error")
}

func Test_cachedFunc_with_struct_key(t *testing.T) {
type request struct {
Location string
Type string
}

var callCount atomic.Int32
fn := cachedFunc(func(ctx context.Context, req request) (string, error) {
callCount.Add(1)
return req.Location + "-" + req.Type, nil
})

ctx := context.Background()

r1, err := fn(ctx, request{Location: "eastus", Type: "ext1"})
require.NoError(t, err)
assert.Equal(t, "eastus-ext1", r1)

// same key should return cached result
r2, err := fn(ctx, request{Location: "eastus", Type: "ext1"})
require.NoError(t, err)
assert.Equal(t, r1, r2)

// different key should call the function again
r3, err := fn(ctx, request{Location: "westus", Type: "ext1"})
require.NoError(t, err)
assert.Equal(t, "westus-ext1", r3)

assert.Equal(t, int32(2), callCount.Load(), "underlying function should be called once per unique struct key")
}
79 changes: 52 additions & 27 deletions e2e/config/azure.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
package config

import (
"cmp"
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"os"
"sort"
"slices"
"strconv"
"strings"
"time"
Expand Down Expand Up @@ -745,75 +746,99 @@ func (a *AzureClient) DeleteSnapshot(ctx context.Context, resourceGroupName, sna
return nil
}

// vmExtensionImageVersionLister abstracts the ListVersions method of the VM extension images client for testability.
type vmExtensionImageVersionLister interface {
ListVersions(ctx context.Context, location string, publisherName string, typeParam string,
options *armcompute.VirtualMachineExtensionImagesClientListVersionsOptions,
) (armcompute.VirtualMachineExtensionImagesClientListVersionsResponse, error)
}

// GetLatestVMExtensionImageVersion lists VM extension images for a given extension name and returns the latest version.
// This is equivalent to: az vm extension image list -n Compute.AKS.Linux.AKSNode --latest
func (a *AzureClient) GetLatestVMExtensionImageVersion(ctx context.Context, location, extType, extPublisher string) (string, error) {
return getLatestVMExtensionImageVersion(ctx, a.VMExtensionImages, location, extType, extPublisher)
}

// getLatestVMExtensionImageVersion lists VM extension images using the provided lister and returns the latest version.
func getLatestVMExtensionImageVersion(ctx context.Context, lister vmExtensionImageVersionLister, location, extType, extPublisher string) (string, error) {
// List extension versions
resp, err := a.VMExtensionImages.ListVersions(ctx, location, extPublisher, extType, &armcompute.VirtualMachineExtensionImagesClientListVersionsOptions{})
resp, err := lister.ListVersions(ctx, location, extPublisher, extType, &armcompute.VirtualMachineExtensionImagesClientListVersionsOptions{})
if err != nil {
return "", fmt.Errorf("listing extension versions: %w", err)
}

if len(resp.VirtualMachineExtensionImageArray) == 0 {
return "", fmt.Errorf("no extension versions found")
}

version := make([]VMExtenstionVersion, len(resp.VirtualMachineExtensionImageArray))
versions := make([]vmExtensionVersion, len(resp.VirtualMachineExtensionImageArray))
for i, ext := range resp.VirtualMachineExtensionImageArray {
version[i] = parseVersion(ext)
versions[i] = parseVersion(ctx, ext)
}

sort.Slice(version, func(i, j int) bool {
return version[i].Less(version[j])
latest := slices.MaxFunc(versions, func(a, b vmExtensionVersion) int {
return a.cmp(b)
})

return *version[len(version)-1].Original.Name, nil
if latest.original.Name == nil {
return "", fmt.Errorf("latest extension version has nil name")
}
return *latest.original.Name, nil
}

// VMExtenstionVersion represents a parsed version of a VM extension image.
type VMExtenstionVersion struct {
Original *armcompute.VirtualMachineExtensionImage
Major int
Minor int
Patch int
// vmExtensionVersion represents a parsed version of a VM extension image.
type vmExtensionVersion struct {
original *armcompute.VirtualMachineExtensionImage
major int
minor int
patch int
}

// parseVersion parses the version from a VM extension image name, which can be in the format 1.151, 1.0.1, etc.
// You can find all the versions of a specific VM extension by running:
// az vm extension image list -n Compute.AKS.Linux.AKSNode
func parseVersion(v *armcompute.VirtualMachineExtensionImage) VMExtenstionVersion {
func parseVersion(ctx context.Context, v *armcompute.VirtualMachineExtensionImage) vmExtensionVersion {
version := vmExtensionVersion{original: v}
if v.Name == nil {
toolkit.Logf(ctx, "warning: VM extension image has nil name, skipping version parse")
return version
}

// Split by dots
parts := strings.Split(*v.Name, ".")

version := VMExtenstionVersion{Original: v}

if len(parts) >= 1 {
if major, err := strconv.Atoi(parts[0]); err == nil {
version.Major = major
version.major = major
} else {
toolkit.Logf(ctx, "warning: failed to parse major version from %q: %v", *v.Name, err)
}
}
if len(parts) >= 2 {
if minor, err := strconv.Atoi(parts[1]); err == nil {
version.Minor = minor
version.minor = minor
} else {
toolkit.Logf(ctx, "warning: failed to parse minor version from %q: %v", *v.Name, err)
}
}
if len(parts) >= 3 {
if patch, err := strconv.Atoi(parts[2]); err == nil {
version.Patch = patch
version.patch = patch
} else {
toolkit.Logf(ctx, "warning: failed to parse patch version from %q: %v", *v.Name, err)
}
}

return version
}

func (v VMExtenstionVersion) Less(other VMExtenstionVersion) bool {
if v.Major != other.Major {
return v.Major < other.Major
// cmp compares two versions, returning -1, 0, or 1.
func (v vmExtensionVersion) cmp(other vmExtensionVersion) int {
if c := cmp.Compare(v.major, other.major); c != 0 {
return c
}
if v.Minor != other.Minor {
return v.Minor < other.Minor
if c := cmp.Compare(v.minor, other.minor); c != 0 {
return c
}
return v.Patch < other.Patch
return cmp.Compare(v.patch, other.patch)
}

// getResourceSKU queries the Azure Resource SKUs API to find the SKU for the given VM size and location.
Expand Down
Loading
Loading