Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion internal/cmd/flags_core.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ func registerCoreFlags(cmd *cobra.Command) {
cmd.Flags().BoolVar(&validateEnv, "validate-env", false, "Validate execution environment (Docker, env vars) before starting")
cmd.Flags().CountVarP(&verbosity, "verbose", "v", "Increase verbosity level (use -v for info, -vv for debug, -vvv for trace)")

// Mark mutually exclusive flags
// Flag validation groups
cmd.MarkFlagsMutuallyExclusive("routed", "unified")
cmd.MarkFlagsOneRequired("config", "config-stdin")
}
54 changes: 35 additions & 19 deletions internal/cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"os/signal"
"strings"
"syscall"
"time"

"github.com/github/gh-aw-mcpg/internal/config"
"github.com/github/gh-aw-mcpg/internal/logger"
Expand Down Expand Up @@ -48,12 +49,17 @@ It provides routing, aggregation, and management of multiple MCP backend servers
SilenceUsage: true, // Don't show help on runtime errors
PersistentPreRunE: preRun,
RunE: run,
PersistentPostRun: postRun,
}

func init() {
// Set custom error prefix for better branding
rootCmd.SetErrPrefix("MCPG Error:")

// Set custom version template with enhanced formatting
rootCmd.SetVersionTemplate(`MCPG Gateway {{.Version}}
`)

// Register all flags from feature modules (flags_*.go files)
registerAllFlags(rootCmd)

Expand Down Expand Up @@ -91,15 +97,18 @@ func registerFlagCompletions(cmd *cobra.Command) {
cmd.RegisterFlagCompletionFunc("env", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
return []string{"env"}, cobra.ShellCompDirectiveFilterFileExt
})

// Add ActiveHelp for --config and --config-stdin flags
cmd.ValidArgsFunction = func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
// Provide helpful tips when completing the command
return cobra.AppendActiveHelp(nil,
"Tip: Use --config <file> for file-based config or --config-stdin for piped JSON config"),
cobra.ShellCompDirectiveNoFileComp
}
}

// preRun performs validation before command execution
func preRun(cmd *cobra.Command, args []string) error {
// Validate that either --config or --config-stdin is provided
if !configStdin && configFile == "" {
return fmt.Errorf("configuration source required: specify either --config <file> or --config-stdin")
}

// Apply verbosity level to logging (if DEBUG is not already set)
// -v (1): info level, -vv (2): debug level, -vvv (3): trace level
debugEnv := os.Getenv(logger.EnvDebug)
Expand Down Expand Up @@ -128,33 +137,39 @@ func preRun(cmd *cobra.Command, args []string) error {
return nil
}

// postRun performs cleanup after command execution
func postRun(cmd *cobra.Command, args []string) {
// Close all loggers
logger.CloseMarkdownLogger()
logger.CloseJSONLLogger()
logger.CloseServerFileLogger()
logger.CloseGlobalLogger()
}

func run(cmd *cobra.Command, args []string) error {
ctx, cancel := context.WithCancel(context.Background())
// Use signal.NotifyContext for proper cancellation on SIGINT/SIGTERM
ctx, cancel := signal.NotifyContext(cmd.Context(), os.Interrupt, syscall.SIGTERM)
defer cancel()

// Initialize file logger early
if err := logger.InitFileLogger(logDir, "mcp-gateway.log"); err != nil {
log.Printf("Warning: Failed to initialize file logger: %v", err)
}
defer logger.CloseGlobalLogger()

// Initialize per-serverID logger
if err := logger.InitServerFileLogger(logDir); err != nil {
log.Printf("Warning: Failed to initialize server file logger: %v", err)
}
defer logger.CloseServerFileLogger()

// Initialize markdown logger for GitHub workflow preview
if err := logger.InitMarkdownLogger(logDir, "gateway.md"); err != nil {
log.Printf("Warning: Failed to initialize markdown logger: %v", err)
}
defer logger.CloseMarkdownLogger()

// Initialize JSONL logger for RPC message logging
if err := logger.InitJSONLLogger(logDir, "rpc-messages.jsonl"); err != nil {
log.Printf("Warning: Failed to initialize JSONL logger: %v", err)
}
defer logger.CloseJSONLLogger()

logger.LogInfoMd("startup", "MCPG Gateway version: %s", cliVersion)

Expand Down Expand Up @@ -270,19 +285,12 @@ func run(cmd *cobra.Command, args []string) error {
}
defer unifiedServer.Close()

// Handle graceful shutdown
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)

// Handle graceful shutdown via context cancellation
go func() {
<-sigChan
<-ctx.Done()
logger.LogInfoMd("shutdown", "Shutting down gateway...")
log.Println("Shutting down...")
cancel()
unifiedServer.Close()
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unifiedServer.Close() is called both via defer unifiedServer.Close() and again inside the shutdown goroutine. UnifiedServer.Close() delegates to launcher.Close(), which calls sessionPool.Stop() and will block up to 1s on subsequent calls, so this can introduce an unnecessary delay and makes shutdown behavior harder to reason about. Prefer ensuring shutdown happens exactly once (e.g., remove the goroutine call and rely on the deferred Close, or switch to an idempotent shutdown method such as InitiateShutdown() guarded by sync.Once).

Suggested change
unifiedServer.Close()

Copilot uses AI. Check for mistakes.
logger.CloseMarkdownLogger()
logger.CloseGlobalLogger()
os.Exit(0)
}()

// Create HTTP server based on mode
Expand Down Expand Up @@ -329,6 +337,14 @@ func run(cmd *cobra.Command, args []string) error {

// Wait for shutdown signal
<-ctx.Done()

// Gracefully shutdown HTTP server with timeout
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer shutdownCancel()
if err := httpServer.Shutdown(shutdownCtx); err != nil {
log.Printf("HTTP server shutdown error: %v", err)
}

return nil
}

Expand Down
86 changes: 71 additions & 15 deletions internal/cmd/root_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@ package cmd

import (
"bytes"
"context"
"encoding/json"
"os"
"path/filepath"
"testing"
"time"

"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

Expand Down Expand Up @@ -77,13 +80,9 @@ func TestRunRequiresConfigSource(t *testing.T) {
configStdin = origConfigStdin
})

t.Run("no config source provided", func(t *testing.T) {
configFile = ""
configStdin = false
err := preRun(nil, nil)
require.Error(t, err, "Expected error when neither --config nor --config-stdin is provided")
assert.Contains(t, err.Error(), "configuration source required", "Error should mention configuration source required")
})
// Note: The validation for "one of config or config-stdin is required" is now
// handled by Cobra's MarkFlagsOneRequired, which validates at command execution time,
// not in preRun. Therefore, preRun should pass validation as long as at least one is set.

t.Run("config file provided", func(t *testing.T) {
configFile = "test.toml"
Expand Down Expand Up @@ -138,14 +137,8 @@ func TestPreRunValidation(t *testing.T) {
assert.NoError(t, err)
})

t.Run("validation fails without config source", func(t *testing.T) {
configFile = ""
configStdin = false
verbosity = 0
err := preRun(nil, nil)
require.Error(t, err)
assert.Contains(t, err.Error(), "configuration source required")
})
// Note: validation for "one of config or config-stdin is required" is now
// handled by Cobra's MarkFlagsOneRequired, so preRun doesn't check this anymore

t.Run("verbosity level 1 does not set DEBUG", func(t *testing.T) {
// Save and clear DEBUG env var
Expand Down Expand Up @@ -499,3 +492,66 @@ func TestWriteGatewayConfig(t *testing.T) {
assert.Contains(t, output, DefaultListenPort)
})
}

// TestContextCancellation tests that context cancellation works properly
func TestContextCancellation(t *testing.T) {
t.Run("context with timeout", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
Comment on lines +496 to +500
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TestContextCancellation only verifies standard library context behavior (timeouts/cancel), and does not exercise the new CLI wiring (e.g., that run() derives from cmd.Context() / cancellation triggers shutdown). Consider replacing this with a test that sets a cancelable context on a command and asserts run() reacts to cancellation (or remove this test if it doesn't validate project behavior).

This issue also appears in the following locations of the same file:

  • line 523
  • line 542
  • line 551

Copilot uses AI. Check for mistakes.

// Wait for context to be done
<-ctx.Done()

// Verify context was cancelled due to timeout
assert.Equal(t, context.DeadlineExceeded, ctx.Err())
})

t.Run("context with cancellation", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())

// Cancel immediately
cancel()

// Wait for context to be done
<-ctx.Done()

// Verify context was cancelled
assert.Equal(t, context.Canceled, ctx.Err())
})
}

// TestFlagValidationGroups tests that flag validation groups work correctly
func TestFlagValidationGroups(t *testing.T) {
// Note: This tests that the flag validation groups are registered correctly.
// Actual validation is performed by Cobra during command execution.
t.Run("mutually exclusive flags registered", func(t *testing.T) {
// Create a new root command to test
cmd := &cobra.Command{
Use: "test",
}
registerCoreFlags(cmd)

// Verify flags are registered
assert.NotNil(t, cmd.Flags().Lookup("routed"))
assert.NotNil(t, cmd.Flags().Lookup("unified"))
assert.NotNil(t, cmd.Flags().Lookup("config"))
assert.NotNil(t, cmd.Flags().Lookup("config-stdin"))
})
}

// TestVersionTemplate tests that custom version template is set
func TestVersionTemplate(t *testing.T) {
t.Run("version template is set", func(t *testing.T) {
// The version template should be set during init
// We can verify the version command works by checking it's not empty
assert.NotEmpty(t, rootCmd.Version, "Version should be set")
})
}

// TestPostRunCleanup tests that postRun cleanup is called
func TestPostRunCleanup(t *testing.T) {
t.Run("postRun is registered", func(t *testing.T) {
// Verify that postRun hook is set
assert.NotNil(t, rootCmd.PersistentPostRun, "PersistentPostRun should be set")
})
}
11 changes: 6 additions & 5 deletions test/integration/binary_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -502,14 +502,15 @@ func TestBinaryInvocation_NoConfigRequired(t *testing.T) {
require.Error(t, err)

outputStr := string(output)
// Should contain the error message about requiring config
if !bytes.Contains(output, []byte("configuration source required")) {
t.Errorf("Expected 'configuration source required' error message, got: %s", outputStr)
// Should contain the error message about requiring at least one of the flags
// Note: Cobra's MarkFlagsOneRequired produces a different error message than manual validation
if !bytes.Contains(output, []byte("at least one of the flags in the group [config config-stdin] is required")) {
t.Errorf("Expected 'at least one of the flags in the group [config config-stdin] is required' error message, got: %s", outputStr)
}

// Should mention both --config and --config-stdin
if !bytes.Contains(output, []byte("--config")) || !bytes.Contains(output, []byte("--config-stdin")) {
t.Errorf("Expected error message to mention both --config and --config-stdin, got: %s", outputStr)
if !bytes.Contains(output, []byte("config")) || !bytes.Contains(output, []byte("config-stdin")) {
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assertion is ineffective for verifying that both flags are mentioned: "config" is a substring of "config-stdin", so the check will pass as soon as config-stdin appears. If the intent is to ensure both flags are referenced, check for --config and --config-stdin (or use a stricter match that distinguishes config from config-stdin).

Suggested change
if !bytes.Contains(output, []byte("config")) || !bytes.Contains(output, []byte("config-stdin")) {
if !bytes.Contains(output, []byte("--config")) || !bytes.Contains(output, []byte("--config-stdin")) {

Copilot uses AI. Check for mistakes.
t.Errorf("Expected error message to mention both config and config-stdin, got: %s", outputStr)
}

t.Logf("✓ Binary correctly requires config source: %s", outputStr)
Expand Down
Loading