From 913993c0bf05885f7637c5249371ec3975bdfd3a Mon Sep 17 00:00:00 2001 From: Naveen Gogineni Date: Thu, 30 Nov 2023 09:13:47 -0500 Subject: [PATCH 1/2] Fix:(issue_1826) Make IsSet work with persistent flags --- command.go | 32 ++++++++++++++++++++++++++------ command_test.go | 40 ++++++++++++++++++++++++++++++++++++++++ flag_impl.go | 1 + 3 files changed, 67 insertions(+), 6 deletions(-) diff --git a/command.go b/command.go index f56ab607da..d9beb9f6b8 100644 --- a/command.go +++ b/command.go @@ -459,7 +459,7 @@ func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) { tracef("running flag actions (cmd=%[1]q)", cmd.Name) - if err := runFlagActions(ctx, cmd, cmd.appliedFlags); err != nil { + if err := cmd.runFlagActions(ctx); err != nil { return err } @@ -1014,19 +1014,39 @@ func hasCommand(commands []*Command, command *Command) bool { return false } -func runFlagActions(ctx context.Context, cmd *Command, flags []Flag) error { - for _, fl := range flags { +func (cmd *Command) runFlagActions(ctx context.Context) error { + if cmd.flagSet == nil { + return nil + } + + for _, fl := range cmd.appliedFlags { isSet := false + // check only local flagset for running local flag actions for _, name := range fl.Names() { - if cmd.IsSet(name) { - isSet = true + cmd.flagSet.Visit(func(f *flag.Flag) { + if f.Name == name { + isSet = true + } + }) + if isSet { break } } + // If the flag hasnt been set on cmd line then we need to further + // check if it has been set via other means. If however it has + // been set by other means but it is persistent(and not set via current cmd) + // do not run the flag action if !isSet { - continue + if !fl.IsSet() { + continue + } + if fl.IsSet() { + if pf, ok := fl.(PersistentFlag); ok && pf.IsPersistent() { + continue + } + } } if af, ok := fl.(ActionableFlag); ok { diff --git a/command_test.go b/command_test.go index 4c1a4e8631..05b83f41cc 100644 --- a/command_test.go +++ b/command_test.go @@ -2588,6 +2588,8 @@ func TestWhenExitSubCommandWithCodeThenCommandQuitUnexpectedly(t *testing.T) { } func buildMinimalTestCommand() *Command { + // reset the help flag because tests may have set it + HelpFlag.(*BoolFlag).hasBeenSet = false return &Command{Writer: io.Discard} } @@ -3055,6 +3057,44 @@ func TestPersistentFlag(t *testing.T) { } } +func TestPersistentFlagIsSet(t *testing.T) { + + result := "" + resultIsSet := false + + app := &Command{ + Name: "root", + Flags: []Flag{ + &StringFlag{ + Name: "result", + Persistent: true, + }, + }, + Commands: []*Command{ + { + Name: "sub", + Action: func(_ context.Context, cmd *Command) error { + result = cmd.String("result") + resultIsSet = cmd.IsSet("result") + return nil + }, + }, + }, + } + + r := require.New(t) + + err := app.Run(context.Background(), []string{"root", "--result", "before", "sub"}) + r.NoError(err) + r.Equal("before", result) + r.True(resultIsSet) + + err = app.Run(context.Background(), []string{"root", "sub", "--result", "after"}) + r.NoError(err) + r.Equal("after", result) + r.True(resultIsSet) +} + func TestFlagDuplicates(t *testing.T) { tests := []struct { name string diff --git a/flag_impl.go b/flag_impl.go index e4930195b9..014a2a97dc 100644 --- a/flag_impl.go +++ b/flag_impl.go @@ -180,6 +180,7 @@ func (f *FlagBase[T, C, V]) Apply(set *flag.FlagSet) error { if err := f.value.Set(val); err != nil { return err } + f.hasBeenSet = true if f.Validator != nil { if v, ok := f.value.Get().(T); !ok { return &typeError[T]{ From cafdd41afbb096de97358c580f8d43e68f656a42 Mon Sep 17 00:00:00 2001 From: Naveen Gogineni Date: Fri, 1 Dec 2023 08:43:42 -0500 Subject: [PATCH 2/2] Remove redundant checks --- command.go | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/command.go b/command.go index d9beb9f6b8..5b41bf9d8c 100644 --- a/command.go +++ b/command.go @@ -1015,9 +1015,6 @@ func hasCommand(commands []*Command, command *Command) bool { } func (cmd *Command) runFlagActions(ctx context.Context) error { - if cmd.flagSet == nil { - return nil - } for _, fl := range cmd.appliedFlags { isSet := false @@ -1042,10 +1039,8 @@ func (cmd *Command) runFlagActions(ctx context.Context) error { if !fl.IsSet() { continue } - if fl.IsSet() { - if pf, ok := fl.(PersistentFlag); ok && pf.IsPersistent() { - continue - } + if pf, ok := fl.(PersistentFlag); ok && pf.IsPersistent() { + continue } }