From b4715fde6241ecee3cb0f3561cc126d6c43f42c0 Mon Sep 17 00:00:00 2001 From: Kimmo Lehto Date: Mon, 3 Feb 2025 15:45:14 +0200 Subject: [PATCH] Abort Each/ParallelEach early on context errors Signed-off-by: Kimmo Lehto --- .../k0sctl.k0sproject.io/v1beta1/cluster/hosts.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/pkg/apis/k0sctl.k0sproject.io/v1beta1/cluster/hosts.go b/pkg/apis/k0sctl.k0sproject.io/v1beta1/cluster/hosts.go index 0a7b6c6d..b027d9f3 100644 --- a/pkg/apis/k0sctl.k0sproject.io/v1beta1/cluster/hosts.go +++ b/pkg/apis/k0sctl.k0sproject.io/v1beta1/cluster/hosts.go @@ -101,6 +101,9 @@ func (hosts Hosts) Workers() Hosts { func (hosts Hosts) Each(ctx context.Context, filters ...func(context.Context, *Host) error) error { for _, filter := range filters { for _, h := range hosts { + if err := ctx.Err(); err != nil { + return fmt.Errorf("error from context: %w", err) + } if err := filter(ctx, h); err != nil { return err } @@ -122,6 +125,12 @@ func (hosts Hosts) ParallelEach(ctx context.Context, filters ...func(context.Con wg.Add(1) go func(h *Host) { defer wg.Done() + if err := ctx.Err(); err != nil { + mu.Lock() + errors = append(errors, fmt.Sprintf("error from context: %v", err)) + mu.Unlock() + return + } if err := filter(ctx, h); err != nil { mu.Lock() errors = append(errors, fmt.Sprintf("%s: %s", h.String(), err.Error())) @@ -146,6 +155,9 @@ func (hosts Hosts) BatchedParallelEach(ctx context.Context, batchSize int, filte if end > len(hosts) { end = len(hosts) } + if err := ctx.Err(); err != nil { + return fmt.Errorf("error from context: %w", err) + } if err := hosts[i:end].ParallelEach(ctx, filter...); err != nil { return err }