diff --git a/pkg/cli/actionlint_test.go b/pkg/cli/actionlint_test.go index ed777b2e007..9548d6ded0d 100644 --- a/pkg/cli/actionlint_test.go +++ b/pkg/cli/actionlint_test.go @@ -3,10 +3,11 @@ package cli import ( - "bytes" - "os" - "strings" "testing" + + "github.com/github/gh-aw/pkg/testutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestParseAndDisplayActionlintOutput(t *testing.T) { @@ -58,109 +59,6 @@ func TestParseAndDisplayActionlintOutput(t *testing.T) { stdout: `{invalid json}`, expectError: true, }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Capture stderr output - originalStderr := os.Stderr - r, w, _ := os.Pipe() - os.Stderr = w - - count, kinds, err := parseAndDisplayActionlintOutput(tt.stdout, tt.verbose) - - // Restore stderr and get output - w.Close() - os.Stderr = originalStderr - var buf bytes.Buffer - buf.ReadFrom(r) - output := buf.String() - - // Check error expectation - if tt.expectError && err == nil { - t.Errorf("Expected error but got none") - } - if !tt.expectError && err != nil { - t.Errorf("Unexpected error: %v", err) - } - - // Check count - if count != tt.expectedCount { - t.Errorf("Expected count %d, got %d", tt.expectedCount, count) - } - - // Check kinds map - if !tt.expectError && tt.expectedKinds != nil { - if len(kinds) != len(tt.expectedKinds) { - t.Errorf("Expected %d kinds, got %d", len(tt.expectedKinds), len(kinds)) - } - for kind, expectedCount := range tt.expectedKinds { - if kinds[kind] != expectedCount { - t.Errorf("Expected %d errors of kind %s, got %d", expectedCount, kind, kinds[kind]) - } - } - } - - // Check expected output strings are present - for _, expected := range tt.expectedOutput { - if !strings.Contains(output, expected) { - t.Errorf("Expected output to contain %q, but it didn't.\nGot: %s", expected, output) - } - } - }) - } -} - -func TestGetActionlintVersion(t *testing.T) { - // Reset the cached version before test - originalVersion := actionlintVersion - defer func() { actionlintVersion = originalVersion }() - - tests := []struct { - name string - presetVersion string - expectCached bool - }{ - { - name: "first call fetches version", - presetVersion: "", - expectCached: false, - }, - { - name: "second call returns cached version", - presetVersion: "1.7.9", - expectCached: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - actionlintVersion = tt.presetVersion - - // If we preset a version, this should return immediately - if tt.expectCached { - version, err := getActionlintVersion() - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - if version != tt.presetVersion { - t.Errorf("Expected cached version %q, got %q", tt.presetVersion, version) - } - } - }) - } -} - -func TestParseAndDisplayActionlintOutputMultiFile(t *testing.T) { - tests := []struct { - name string - stdout string - verbose bool - expectedOutput []string - expectError bool - expectedCount int - expectedKinds map[string]int - }{ { name: "multiple errors from multiple files", stdout: `[ @@ -195,55 +93,41 @@ func TestParseAndDisplayActionlintOutputMultiFile(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - // Capture stderr output - originalStderr := os.Stderr - r, w, _ := os.Pipe() - os.Stderr = w - - count, kinds, err := parseAndDisplayActionlintOutput(tt.stdout, tt.verbose) - - // Restore stderr and get output - w.Close() - os.Stderr = originalStderr - var buf bytes.Buffer - buf.ReadFrom(r) - output := buf.String() - - // Check error expectation - if tt.expectError && err == nil { - t.Errorf("Expected error but got none") - } - if !tt.expectError && err != nil { - t.Errorf("Unexpected error: %v", err) - } - - // Check count - if count != tt.expectedCount { - t.Errorf("Expected count %d, got %d", tt.expectedCount, count) - } - - // Check kinds map - if !tt.expectError && tt.expectedKinds != nil { - if len(kinds) != len(tt.expectedKinds) { - t.Errorf("Expected %d kinds, got %d", len(tt.expectedKinds), len(kinds)) + var count int + var kinds map[string]int + var err error + + output := testutil.CaptureStderr(t, func() { + count, kinds, err = parseAndDisplayActionlintOutput(tt.stdout, tt.verbose) + }) + + if tt.expectError { + require.Error(t, err, "should return error for invalid input") + } else { + require.NoError(t, err, "should not return error for valid input") + assert.Equal(t, tt.expectedCount, count, "error count should match expected") + if tt.expectedKinds != nil { + assert.Equal(t, tt.expectedKinds, kinds, "error kinds should match expected") } - for kind, expectedCount := range tt.expectedKinds { - if kinds[kind] != expectedCount { - t.Errorf("Expected %d errors of kind %s, got %d", expectedCount, kind, kinds[kind]) - } - } - } - - // Check expected output strings are present - for _, expected := range tt.expectedOutput { - if !strings.Contains(output, expected) { - t.Errorf("Expected output to contain %q, but it didn't.\nGot: %s", expected, output) + for _, expected := range tt.expectedOutput { + assert.Contains(t, output, expected, + "output should contain %q", expected) } } }) } } +func TestGetActionlintVersion(t *testing.T) { + original := actionlintVersion + defer func() { actionlintVersion = original }() + + actionlintVersion = "1.7.9" + version, err := getActionlintVersion() + require.NoError(t, err, "should not error when version is cached") + assert.Equal(t, "1.7.9", version, "should return cached version") +} + func TestDisplayActionlintSummary(t *testing.T) { tests := []struct { name string @@ -313,64 +197,32 @@ func TestDisplayActionlintSummary(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - // Save original stats and restore after test originalStats := actionlintStats defer func() { actionlintStats = originalStats }() - - // Set up test stats actionlintStats = tt.stats - // Capture stderr output - originalStderr := os.Stderr - r, w, _ := os.Pipe() - os.Stderr = w - - displayActionlintSummary() + output := testutil.CaptureStderr(t, displayActionlintSummary) - // Restore stderr and get output - w.Close() - os.Stderr = originalStderr - var buf bytes.Buffer - buf.ReadFrom(r) - output := buf.String() - - // Check expected strings are present for _, expected := range tt.expectedContains { - if !strings.Contains(output, expected) { - t.Errorf("Expected output to contain %q, but it didn't.\nGot: %s", expected, output) - } + assert.Contains(t, output, expected, + "output should contain %q", expected) } }) } } func TestInitActionlintStats(t *testing.T) { - // Save original stats and restore after test originalStats := actionlintStats defer func() { actionlintStats = originalStats }() - // Initialize stats initActionlintStats() - // Check that stats were initialized - if actionlintStats == nil { - t.Fatal("actionlintStats should not be nil after initialization") - } - if actionlintStats.TotalWorkflows != 0 { - t.Errorf("TotalWorkflows should be 0, got %d", actionlintStats.TotalWorkflows) - } - if actionlintStats.TotalErrors != 0 { - t.Errorf("TotalErrors should be 0, got %d", actionlintStats.TotalErrors) - } - if actionlintStats.TotalWarnings != 0 { - t.Errorf("TotalWarnings should be 0, got %d", actionlintStats.TotalWarnings) - } - if actionlintStats.ErrorsByKind == nil { - t.Error("ErrorsByKind should not be nil after initialization") - } - if len(actionlintStats.ErrorsByKind) != 0 { - t.Errorf("ErrorsByKind should be empty, got %d entries", len(actionlintStats.ErrorsByKind)) - } + require.NotNil(t, actionlintStats, "actionlintStats should not be nil after initialization") + assert.Zero(t, actionlintStats.TotalWorkflows, "TotalWorkflows should start at 0") + assert.Zero(t, actionlintStats.TotalErrors, "TotalErrors should start at 0") + assert.Zero(t, actionlintStats.TotalWarnings, "TotalWarnings should start at 0") + assert.NotNil(t, actionlintStats.ErrorsByKind, "ErrorsByKind map should be initialized") + assert.Empty(t, actionlintStats.ErrorsByKind, "ErrorsByKind should start empty") } func TestGetActionlintDocsURL(t *testing.T) { @@ -419,9 +271,7 @@ func TestGetActionlintDocsURL(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := getActionlintDocsURL(tt.kind) - if result != tt.expected { - t.Errorf("getActionlintDocsURL(%q) = %q, want %q", tt.kind, result, tt.expected) - } + assert.Equal(t, tt.expected, result, "docs URL should match expected for kind %q", tt.kind) }) } } diff --git a/pkg/testutil/tempdir.go b/pkg/testutil/tempdir.go index 0962e621e04..5fa1d9ac1ce 100644 --- a/pkg/testutil/tempdir.go +++ b/pkg/testutil/tempdir.go @@ -1,6 +1,7 @@ package testutil import ( + "bytes" "fmt" "os" "path/filepath" @@ -68,6 +69,29 @@ func TempDir(t *testing.T, pattern string) string { return tempDir } +// CaptureStderr runs fn and returns everything written to os.Stderr during its execution. +// It restores os.Stderr automatically via t.Cleanup. +func CaptureStderr(t *testing.T, fn func()) string { + t.Helper() + r, w, err := os.Pipe() + if err != nil { + t.Fatalf("CaptureStderr: failed to create pipe: %v", err) + } + + origStderr := os.Stderr + os.Stderr = w + t.Cleanup(func() { os.Stderr = origStderr }) + + fn() + + w.Close() + var buf bytes.Buffer + if _, err = buf.ReadFrom(r); err != nil { + t.Fatalf("CaptureStderr: failed to read pipe: %v", err) + } + return buf.String() +} + // StripYAMLCommentHeader removes the comment header from generated YAML files // and returns only the non-comment YAML content. This is useful for tests that // need to verify content without matching strings in the comment header.