diff --git a/pkg/orchestrator/scanestimationwatcher/watcher.go b/pkg/orchestrator/scanestimationwatcher/watcher.go index f6d7c7bca..4198c0dd8 100644 --- a/pkg/orchestrator/scanestimationwatcher/watcher.go +++ b/pkg/orchestrator/scanestimationwatcher/watcher.go @@ -19,6 +19,7 @@ import ( "context" "errors" "fmt" + "math" "sync" "time" @@ -496,6 +497,10 @@ func (w *Watcher) reconcileInProgress(ctx context.Context, scanEstimation *model *assetScanEstimations.Count-failedAssetScanEstimations, failedAssetScanEstimations, *assetScanEstimations.Count)) scanEstimation.EndTime = utils.PointerTo(time.Now()) + + if err := updateTotalScanTimeWithParallelScans(scanEstimation); err != nil { + return fmt.Errorf("failed to update scan time from paraller scans: %w", err) + } } scanEstimationPatch := &models.ScanEstimation{ @@ -512,6 +517,41 @@ func (w *Watcher) reconcileInProgress(ctx context.Context, scanEstimation *model return nil } +func updateTotalScanTimeWithParallelScans(scanEstimation *models.ScanEstimation) error { + if scanEstimation == nil { + return fmt.Errorf("empty scan estimation") + } + + if scanEstimation.ScanTemplate == nil { + return fmt.Errorf("empty scan template") + } + + if scanEstimation.Summary == nil { + return fmt.Errorf("empty summary") + } + + if scanEstimation.Summary.JobsCompleted == nil { + return fmt.Errorf("jobsCompleted is not set") + } + + if *scanEstimation.Summary.JobsCompleted == 0 { + return fmt.Errorf("0 completed jobs in summary") + } + + maxParallelScanners := utils.ValueOrZero(scanEstimation.ScanTemplate.MaxParallelScanners) + + if maxParallelScanners > 1 { + numberOfJobs := *scanEstimation.Summary.JobsCompleted + + actualParallelScanners := int(math.Min(float64(maxParallelScanners), float64(numberOfJobs))) + + // Note: This is a rough estimation, as we don't know which jobs will be running in parallel. + *scanEstimation.Summary.TotalScanTime = *scanEstimation.Summary.TotalScanTime / actualParallelScanners + } + + return nil +} + // nolint:cyclop func (w *Watcher) reconcileAborted(ctx context.Context, scanEstimation *models.ScanEstimation) error { logger := log.GetLoggerFromContextOrDiscard(ctx) diff --git a/pkg/orchestrator/scanestimationwatcher/watcher_test.go b/pkg/orchestrator/scanestimationwatcher/watcher_test.go index 95bdb9698..9293ba9e0 100644 --- a/pkg/orchestrator/scanestimationwatcher/watcher_test.go +++ b/pkg/orchestrator/scanestimationwatcher/watcher_test.go @@ -62,3 +62,104 @@ func Test_updateScanEstimationSummaryFromAssetScanEstimation(t *testing.T) { }) } } + +func Test_updateTotalScanTimeWithParallelScans(t *testing.T) { + type args struct { + scanEstimation *models.ScanEstimation + } + tests := []struct { + name string + args args + wantErr bool + wantTotalScanTime int + }{ + { + name: "max parallel scanners == nil", + args: args{ + scanEstimation: &models.ScanEstimation{ + Summary: &models.ScanEstimationSummary{ + JobsCompleted: utils.PointerTo(5), + TotalScanTime: utils.PointerTo(30), + }, + ScanTemplate: &models.ScanTemplate{}, + }, + }, + wantTotalScanTime: 30, + wantErr: false, + }, + { + name: "max parallel scanners == 1", + args: args{ + scanEstimation: &models.ScanEstimation{ + Summary: &models.ScanEstimationSummary{ + JobsCompleted: utils.PointerTo(5), + TotalScanTime: utils.PointerTo(30), + }, + ScanTemplate: &models.ScanTemplate{ + MaxParallelScanners: utils.PointerTo(1), + }, + }, + }, + wantTotalScanTime: 30, + wantErr: false, + }, + { + name: "max parallel scanners == number of jobs", + args: args{ + scanEstimation: &models.ScanEstimation{ + Summary: &models.ScanEstimationSummary{ + JobsCompleted: utils.PointerTo(5), + TotalScanTime: utils.PointerTo(30), + }, + ScanTemplate: &models.ScanTemplate{ + MaxParallelScanners: utils.PointerTo(5), + }, + }, + }, + wantTotalScanTime: 6, + wantErr: false, + }, + { + name: "max parallel scanners < number of jobs", + args: args{ + scanEstimation: &models.ScanEstimation{ + Summary: &models.ScanEstimationSummary{ + JobsCompleted: utils.PointerTo(3), + TotalScanTime: utils.PointerTo(30), + }, + ScanTemplate: &models.ScanTemplate{ + MaxParallelScanners: utils.PointerTo(2), + }, + }, + }, + wantTotalScanTime: 15, + wantErr: false, + }, + { + name: "max parallel scanners > number of jobs", + args: args{ + scanEstimation: &models.ScanEstimation{ + Summary: &models.ScanEstimationSummary{ + JobsCompleted: utils.PointerTo(2), + TotalScanTime: utils.PointerTo(30), + }, + ScanTemplate: &models.ScanTemplate{ + MaxParallelScanners: utils.PointerTo(3), + }, + }, + }, + wantTotalScanTime: 15, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := updateTotalScanTimeWithParallelScans(tt.args.scanEstimation); (err != nil) != tt.wantErr { + t.Errorf("updateTotalScanTimeWithParallelScans() error = %v, wantErr %v", err, tt.wantErr) + } + if *tt.args.scanEstimation.Summary.TotalScanTime != tt.wantTotalScanTime { + t.Errorf("updateTotalScanTimeWithParallelScans() failed. wantTotalScanTime = %v, gotTotalScanTime = %v", tt.wantTotalScanTime, *tt.args.scanEstimation.Summary.TotalScanTime) + } + }) + } +}