From 1efb810631204eba16b2fed9d5255b9c247b5dbd Mon Sep 17 00:00:00 2001 From: Naveen Gogineni Date: Mon, 29 Jun 2026 17:04:01 -0400 Subject: [PATCH] feat: add Command.ArgValidator for tree-wide argument validation ArgValidator is an optional hook that runs before Before and Action. It is inherited by subcommands unless they set their own. Returning a non-nil error short-circuits the command. Closes #2327 --- command.go | 5 ++ command_run.go | 17 ++++++ command_test.go | 115 ++++++++++++++++++++++++++++++++++++++++ funcs.go | 7 +++ godoc-current.txt | 12 +++++ testdata/godoc-v3.x.txt | 12 +++++ 6 files changed, 168 insertions(+) diff --git a/command.go b/command.go index 4cd907a558..2e04736be8 100644 --- a/command.go +++ b/command.go @@ -64,6 +64,11 @@ type Command struct { // An action to execute after any subcommands are run, but after the subcommand has finished // It is run even if Action() panics After AfterFunc `json:"-"` + // An action to validate arguments before the command is run. If non-nil, it + // is called before Before and Action. If the current command does not set + // ArgValidator, the nearest ancestor that does is used instead. + // Returning a non-nil error short-circuits the command. + ArgValidator ArgValidatorFunc `json:"-"` // The function to call when this command is invoked Action ActionFunc `json:"-"` // Execute this function if the proper command cannot be found diff --git a/command_run.go b/command_run.go index 8d5907151e..9fca52603e 100644 --- a/command_run.go +++ b/command_run.go @@ -327,6 +327,14 @@ func (cmd *Command) run(ctx context.Context, osArgs []string) (_ context.Context // First, resolve the chain of nested commands up to the parent. cmdChain := commandChain(cmd) + // Run ArgValidator from the nearest ancestor that sets one. + if validator := findArgValidator(cmd); validator != nil { + if err := validator(ctx, cmd); err != nil { + deferErr = cmd.handleExitCoder(ctx, err) + return ctx, deferErr + } + } + // Run Before actions in order. if ctx, err = runBefore(ctx, cmdChain); err != nil { deferErr = err @@ -397,6 +405,15 @@ func commandChain(cmd *Command) []*Command { return cmdChain } +func findArgValidator(cmd *Command) ArgValidatorFunc { + for c := cmd; c != nil; c = c.parent { + if c.ArgValidator != nil { + return c.ArgValidator + } + } + return nil +} + func runBefore(ctx context.Context, cmdChain []*Command) (context.Context, error) { for _, cmd := range cmdChain { if cmd.Before == nil { diff --git a/command_test.go b/command_test.go index cad70b7801..958e0ec146 100644 --- a/command_test.go +++ b/command_test.go @@ -6450,3 +6450,118 @@ func TestCommand_Walk_NilFn(t *testing.T) { cmd := &Command{Name: "foo"} assert.Nil(t, cmd.Walk(nil)) } + +func TestCommand_ArgValidator_RunsBeforeAction(t *testing.T) { + var validated bool + var actionRan bool + + cmd := &Command{ + Name: "test", + ArgValidator: func(_ context.Context, _ *Command) error { + validated = true + return nil + }, + Action: func(_ context.Context, _ *Command) error { + actionRan = true + return nil + }, + } + + err := cmd.Run(buildTestContext(t), []string{"test"}) + require.NoError(t, err) + assert.True(t, validated) + assert.True(t, actionRan) +} + +func TestCommand_ArgValidator_ErrorShortCircuitsAction(t *testing.T) { + var actionRan bool + + cmd := &Command{ + Name: "test", + ArgValidator: func(_ context.Context, _ *Command) error { + return fmt.Errorf("validation failed") + }, + Action: func(_ context.Context, _ *Command) error { + actionRan = true + return nil + }, + } + + err := cmd.Run(buildTestContext(t), []string{"test"}) + assert.ErrorContains(t, err, "validation failed") + assert.False(t, actionRan) +} + +func TestCommand_ArgValidator_InheritsFromParent(t *testing.T) { + var validated bool + + root := &Command{ + Name: "root", + ArgValidator: func(_ context.Context, _ *Command) error { + validated = true + return nil + }, + Commands: []*Command{ + { + Name: "sub", + Action: func(_ context.Context, _ *Command) error { return nil }, + }, + }, + } + + err := root.Run(buildTestContext(t), []string{"root", "sub"}) + require.NoError(t, err) + assert.True(t, validated) +} + +func TestCommand_ArgValidator_SubcommandOverride(t *testing.T) { + var parentValidated bool + var childValidated bool + + root := &Command{ + Name: "root", + ArgValidator: func(_ context.Context, _ *Command) error { + parentValidated = true + return nil + }, + Commands: []*Command{ + { + Name: "sub", + ArgValidator: func(_ context.Context, _ *Command) error { + childValidated = true + return nil + }, + Action: func(_ context.Context, _ *Command) error { return nil }, + }, + }, + } + + err := root.Run(buildTestContext(t), []string{"root", "sub"}) + require.NoError(t, err) + assert.False(t, parentValidated, "should use child's validator, not parent's") + assert.True(t, childValidated) +} + +func TestCommand_ArgValidator_RunsBeforeBefore(t *testing.T) { + var order []string + + cmd := &Command{ + Name: "test", + ArgValidator: func(_ context.Context, _ *Command) error { + order = append(order, "validator") + return nil + }, + Before: func(_ context.Context, _ *Command) (context.Context, error) { + order = append(order, "before") + return nil, nil + }, + Action: func(_ context.Context, _ *Command) error { + order = append(order, "action") + return nil + }, + } + + err := cmd.Run(buildTestContext(t), []string{"test"}) + require.NoError(t, err) + assert.Equal(t, []string{"validator", "before", "action"}, order) +} diff --git a/funcs.go b/funcs.go index fe1224c44e..17a73eca67 100644 --- a/funcs.go +++ b/funcs.go @@ -17,6 +17,13 @@ type AfterFunc func(context.Context, *Command) error // ActionFunc is the action to execute when no subcommands are specified type ActionFunc func(context.Context, *Command) error +// ArgValidatorFunc is an action to validate arguments before the command is run. +// If non-nil, it is called before the command's After and Action functions. +// Returning a non-nil error short-circuits the command and propagates as +// the exit error. If the current command does not set ArgValidator, the +// nearest ancestor that does is used instead. +type ArgValidatorFunc func(context.Context, *Command) error + // CommandNotFoundFunc is executed if the proper command cannot be found type CommandNotFoundFunc func(context.Context, *Command, string) diff --git a/godoc-current.txt b/godoc-current.txt index d680c5f544..93210c6bb6 100644 --- a/godoc-current.txt +++ b/godoc-current.txt @@ -258,6 +258,13 @@ type AfterFunc func(context.Context, *Command) error AfterFunc is an action that executes after any subcommands are run and have finished. The AfterFunc is run even if Action() panics. +type ArgValidatorFunc func(context.Context, *Command) error + ArgValidatorFunc is an action to validate arguments before the command + is run. If non-nil, it is called before the command's After and Action + functions. Returning a non-nil error short-circuits the command and + propagates as the exit error. If the current command does not set + ArgValidator, the nearest ancestor that does is used instead. + type Args interface { // Get returns the nth argument, or else a blank string Get(n int) string @@ -482,6 +489,11 @@ type Command struct { // An action to execute after any subcommands are run, but after the subcommand has finished // It is run even if Action() panics After AfterFunc `json:"-"` + // An action to validate arguments before the command is run. If non-nil, it + // is called before Before and Action. If the current command does not set + // ArgValidator, the nearest ancestor that does is used instead. + // Returning a non-nil error short-circuits the command. + ArgValidator ArgValidatorFunc `json:"-"` // The function to call when this command is invoked Action ActionFunc `json:"-"` // Execute this function if the proper command cannot be found diff --git a/testdata/godoc-v3.x.txt b/testdata/godoc-v3.x.txt index d680c5f544..93210c6bb6 100644 --- a/testdata/godoc-v3.x.txt +++ b/testdata/godoc-v3.x.txt @@ -258,6 +258,13 @@ type AfterFunc func(context.Context, *Command) error AfterFunc is an action that executes after any subcommands are run and have finished. The AfterFunc is run even if Action() panics. +type ArgValidatorFunc func(context.Context, *Command) error + ArgValidatorFunc is an action to validate arguments before the command + is run. If non-nil, it is called before the command's After and Action + functions. Returning a non-nil error short-circuits the command and + propagates as the exit error. If the current command does not set + ArgValidator, the nearest ancestor that does is used instead. + type Args interface { // Get returns the nth argument, or else a blank string Get(n int) string @@ -482,6 +489,11 @@ type Command struct { // An action to execute after any subcommands are run, but after the subcommand has finished // It is run even if Action() panics After AfterFunc `json:"-"` + // An action to validate arguments before the command is run. If non-nil, it + // is called before Before and Action. If the current command does not set + // ArgValidator, the nearest ancestor that does is used instead. + // Returning a non-nil error short-circuits the command. + ArgValidator ArgValidatorFunc `json:"-"` // The function to call when this command is invoked Action ActionFunc `json:"-"` // Execute this function if the proper command cannot be found