diff --git a/internal/cmd/flags_core.go b/internal/cmd/flags_core.go index 80414cfe..c8d17f96 100644 --- a/internal/cmd/flags_core.go +++ b/internal/cmd/flags_core.go @@ -41,6 +41,7 @@ func registerCoreFlags(cmd *cobra.Command) { cmd.Flags().BoolVar(&validateEnv, "validate-env", false, "Validate execution environment (Docker, env vars) before starting") cmd.Flags().CountVarP(&verbosity, "verbose", "v", "Increase verbosity level (use -v for info, -vv for debug, -vvv for trace)") - // Mark mutually exclusive flags + // Flag validation groups cmd.MarkFlagsMutuallyExclusive("routed", "unified") + cmd.MarkFlagsOneRequired("config", "config-stdin") } diff --git a/internal/cmd/root.go b/internal/cmd/root.go index 4d660e0f..268b51c3 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -13,6 +13,7 @@ import ( "os/signal" "strings" "syscall" + "time" "github.com/github/gh-aw-mcpg/internal/config" "github.com/github/gh-aw-mcpg/internal/logger" @@ -48,12 +49,17 @@ It provides routing, aggregation, and management of multiple MCP backend servers SilenceUsage: true, // Don't show help on runtime errors PersistentPreRunE: preRun, RunE: run, + PersistentPostRun: postRun, } func init() { // Set custom error prefix for better branding rootCmd.SetErrPrefix("MCPG Error:") + // Set custom version template with enhanced formatting + rootCmd.SetVersionTemplate(`MCPG Gateway {{.Version}} +`) + // Register all flags from feature modules (flags_*.go files) registerAllFlags(rootCmd) @@ -91,15 +97,18 @@ func registerFlagCompletions(cmd *cobra.Command) { cmd.RegisterFlagCompletionFunc("env", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { return []string{"env"}, cobra.ShellCompDirectiveFilterFileExt }) + + // Add ActiveHelp for --config and --config-stdin flags + cmd.ValidArgsFunction = func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + // Provide helpful tips when completing the command + return cobra.AppendActiveHelp(nil, + "Tip: Use --config for file-based config or --config-stdin for piped JSON config"), + cobra.ShellCompDirectiveNoFileComp + } } // preRun performs validation before command execution func preRun(cmd *cobra.Command, args []string) error { - // Validate that either --config or --config-stdin is provided - if !configStdin && configFile == "" { - return fmt.Errorf("configuration source required: specify either --config or --config-stdin") - } - // Apply verbosity level to logging (if DEBUG is not already set) // -v (1): info level, -vv (2): debug level, -vvv (3): trace level debugEnv := os.Getenv(logger.EnvDebug) @@ -128,33 +137,39 @@ func preRun(cmd *cobra.Command, args []string) error { return nil } +// postRun performs cleanup after command execution +func postRun(cmd *cobra.Command, args []string) { + // Close all loggers + logger.CloseMarkdownLogger() + logger.CloseJSONLLogger() + logger.CloseServerFileLogger() + logger.CloseGlobalLogger() +} + func run(cmd *cobra.Command, args []string) error { - ctx, cancel := context.WithCancel(context.Background()) + // Use signal.NotifyContext for proper cancellation on SIGINT/SIGTERM + ctx, cancel := signal.NotifyContext(cmd.Context(), os.Interrupt, syscall.SIGTERM) defer cancel() // Initialize file logger early if err := logger.InitFileLogger(logDir, "mcp-gateway.log"); err != nil { log.Printf("Warning: Failed to initialize file logger: %v", err) } - defer logger.CloseGlobalLogger() // Initialize per-serverID logger if err := logger.InitServerFileLogger(logDir); err != nil { log.Printf("Warning: Failed to initialize server file logger: %v", err) } - defer logger.CloseServerFileLogger() // Initialize markdown logger for GitHub workflow preview if err := logger.InitMarkdownLogger(logDir, "gateway.md"); err != nil { log.Printf("Warning: Failed to initialize markdown logger: %v", err) } - defer logger.CloseMarkdownLogger() // Initialize JSONL logger for RPC message logging if err := logger.InitJSONLLogger(logDir, "rpc-messages.jsonl"); err != nil { log.Printf("Warning: Failed to initialize JSONL logger: %v", err) } - defer logger.CloseJSONLLogger() logger.LogInfoMd("startup", "MCPG Gateway version: %s", cliVersion) @@ -270,19 +285,12 @@ func run(cmd *cobra.Command, args []string) error { } defer unifiedServer.Close() - // Handle graceful shutdown - sigChan := make(chan os.Signal, 1) - signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) - + // Handle graceful shutdown via context cancellation go func() { - <-sigChan + <-ctx.Done() logger.LogInfoMd("shutdown", "Shutting down gateway...") log.Println("Shutting down...") - cancel() unifiedServer.Close() - logger.CloseMarkdownLogger() - logger.CloseGlobalLogger() - os.Exit(0) }() // Create HTTP server based on mode @@ -329,6 +337,14 @@ func run(cmd *cobra.Command, args []string) error { // Wait for shutdown signal <-ctx.Done() + + // Gracefully shutdown HTTP server with timeout + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + if err := httpServer.Shutdown(shutdownCtx); err != nil { + log.Printf("HTTP server shutdown error: %v", err) + } + return nil } diff --git a/internal/cmd/root_test.go b/internal/cmd/root_test.go index 6b316ebb..2e536f16 100644 --- a/internal/cmd/root_test.go +++ b/internal/cmd/root_test.go @@ -2,11 +2,14 @@ package cmd import ( "bytes" + "context" "encoding/json" "os" "path/filepath" "testing" + "time" + "github.com/spf13/cobra" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -77,13 +80,9 @@ func TestRunRequiresConfigSource(t *testing.T) { configStdin = origConfigStdin }) - t.Run("no config source provided", func(t *testing.T) { - configFile = "" - configStdin = false - err := preRun(nil, nil) - require.Error(t, err, "Expected error when neither --config nor --config-stdin is provided") - assert.Contains(t, err.Error(), "configuration source required", "Error should mention configuration source required") - }) + // Note: The validation for "one of config or config-stdin is required" is now + // handled by Cobra's MarkFlagsOneRequired, which validates at command execution time, + // not in preRun. Therefore, preRun should pass validation as long as at least one is set. t.Run("config file provided", func(t *testing.T) { configFile = "test.toml" @@ -138,14 +137,8 @@ func TestPreRunValidation(t *testing.T) { assert.NoError(t, err) }) - t.Run("validation fails without config source", func(t *testing.T) { - configFile = "" - configStdin = false - verbosity = 0 - err := preRun(nil, nil) - require.Error(t, err) - assert.Contains(t, err.Error(), "configuration source required") - }) + // Note: validation for "one of config or config-stdin is required" is now + // handled by Cobra's MarkFlagsOneRequired, so preRun doesn't check this anymore t.Run("verbosity level 1 does not set DEBUG", func(t *testing.T) { // Save and clear DEBUG env var @@ -499,3 +492,66 @@ func TestWriteGatewayConfig(t *testing.T) { assert.Contains(t, output, DefaultListenPort) }) } + +// TestContextCancellation tests that context cancellation works properly +func TestContextCancellation(t *testing.T) { + t.Run("context with timeout", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + // Wait for context to be done + <-ctx.Done() + + // Verify context was cancelled due to timeout + assert.Equal(t, context.DeadlineExceeded, ctx.Err()) + }) + + t.Run("context with cancellation", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + // Cancel immediately + cancel() + + // Wait for context to be done + <-ctx.Done() + + // Verify context was cancelled + assert.Equal(t, context.Canceled, ctx.Err()) + }) +} + +// TestFlagValidationGroups tests that flag validation groups work correctly +func TestFlagValidationGroups(t *testing.T) { + // Note: This tests that the flag validation groups are registered correctly. + // Actual validation is performed by Cobra during command execution. + t.Run("mutually exclusive flags registered", func(t *testing.T) { + // Create a new root command to test + cmd := &cobra.Command{ + Use: "test", + } + registerCoreFlags(cmd) + + // Verify flags are registered + assert.NotNil(t, cmd.Flags().Lookup("routed")) + assert.NotNil(t, cmd.Flags().Lookup("unified")) + assert.NotNil(t, cmd.Flags().Lookup("config")) + assert.NotNil(t, cmd.Flags().Lookup("config-stdin")) + }) +} + +// TestVersionTemplate tests that custom version template is set +func TestVersionTemplate(t *testing.T) { + t.Run("version template is set", func(t *testing.T) { + // The version template should be set during init + // We can verify the version command works by checking it's not empty + assert.NotEmpty(t, rootCmd.Version, "Version should be set") + }) +} + +// TestPostRunCleanup tests that postRun cleanup is called +func TestPostRunCleanup(t *testing.T) { + t.Run("postRun is registered", func(t *testing.T) { + // Verify that postRun hook is set + assert.NotNil(t, rootCmd.PersistentPostRun, "PersistentPostRun should be set") + }) +} diff --git a/test/integration/binary_test.go b/test/integration/binary_test.go index e44480ce..4bd5c39c 100644 --- a/test/integration/binary_test.go +++ b/test/integration/binary_test.go @@ -502,14 +502,15 @@ func TestBinaryInvocation_NoConfigRequired(t *testing.T) { require.Error(t, err) outputStr := string(output) - // Should contain the error message about requiring config - if !bytes.Contains(output, []byte("configuration source required")) { - t.Errorf("Expected 'configuration source required' error message, got: %s", outputStr) + // Should contain the error message about requiring at least one of the flags + // Note: Cobra's MarkFlagsOneRequired produces a different error message than manual validation + if !bytes.Contains(output, []byte("at least one of the flags in the group [config config-stdin] is required")) { + t.Errorf("Expected 'at least one of the flags in the group [config config-stdin] is required' error message, got: %s", outputStr) } // Should mention both --config and --config-stdin - if !bytes.Contains(output, []byte("--config")) || !bytes.Contains(output, []byte("--config-stdin")) { - t.Errorf("Expected error message to mention both --config and --config-stdin, got: %s", outputStr) + if !bytes.Contains(output, []byte("config")) || !bytes.Contains(output, []byte("config-stdin")) { + t.Errorf("Expected error message to mention both config and config-stdin, got: %s", outputStr) } t.Logf("✓ Binary correctly requires config source: %s", outputStr)