fix: add checkpoint when enqueue scan tasks for scan all (#18680)

Fix the scanAll cannot be stopped in case of large number of artifacts,
add the checkpoint before submit scan tasks, mark the scanAll stopped
flag in the redis.

Fixes: #18044

Signed-off-by: chlins <chenyuzh@vmware.com>
This commit is contained in:
Chlins Zhang 2023-06-05 15:12:54 +08:00 committed by GitHub
parent 9d28d1f43f
commit fbeeaa7537
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 127 additions and 19 deletions

View File

@ -40,7 +40,13 @@ func Iterator(ctx context.Context, chunkSize int, query *q.Query, option *Option
} }
for _, artifact := range artifacts { for _, artifact := range artifacts {
ch <- artifact select {
case <-ctx.Done():
log.G(ctx).Errorf("context done, list artifacts exited, error: %v", ctx.Err())
return
case ch <- artifact:
continue
}
} }
if len(artifacts) < chunkSize { if len(artifacts) < chunkSize {

View File

@ -21,6 +21,7 @@ import (
"reflect" "reflect"
"strings" "strings"
"sync" "sync"
"time"
"github.com/google/uuid" "github.com/google/uuid"
@ -30,6 +31,7 @@ import (
sc "github.com/goharbor/harbor/src/controller/scanner" sc "github.com/goharbor/harbor/src/controller/scanner"
"github.com/goharbor/harbor/src/controller/tag" "github.com/goharbor/harbor/src/controller/tag"
"github.com/goharbor/harbor/src/jobservice/job" "github.com/goharbor/harbor/src/jobservice/job"
"github.com/goharbor/harbor/src/lib/cache"
"github.com/goharbor/harbor/src/lib/config" "github.com/goharbor/harbor/src/lib/config"
"github.com/goharbor/harbor/src/lib/errors" "github.com/goharbor/harbor/src/lib/errors"
"github.com/goharbor/harbor/src/lib/log" "github.com/goharbor/harbor/src/lib/log"
@ -50,8 +52,12 @@ import (
"github.com/goharbor/harbor/src/pkg/task" "github.com/goharbor/harbor/src/pkg/task"
) )
// DefaultController is a default singleton scan API controller. var (
var DefaultController = NewController() // DefaultController is a default singleton scan API controller.
DefaultController = NewController()
errScanAllStopped = errors.New("scanAll stopped")
)
// const definitions // const definitions
const ( const (
@ -74,6 +80,9 @@ type uuidGenerator func() (string, error)
// utility methods. // utility methods.
type configGetter func(cfg string) (string, error) type configGetter func(cfg string) (string, error)
// cacheGetter returns cache
type cacheGetter func() cache.Cache
// launchScanJobParam is a param to launch scan job. // launchScanJobParam is a param to launch scan job.
type launchScanJobParam struct { type launchScanJobParam struct {
ExecutionID int64 ExecutionID int64
@ -109,6 +118,8 @@ type basicController struct {
taskMgr task.Manager taskMgr task.Manager
// Converter for V1 report to V2 report // Converter for V1 report to V2 report
reportConverter postprocessors.NativeScanReportConverter reportConverter postprocessors.NativeScanReportConverter
// cache stores the stop scan all marks
cache cacheGetter
} }
// NewController news a scan API controller // NewController news a scan API controller
@ -154,6 +165,9 @@ func NewController() Controller {
taskMgr: task.Mgr, taskMgr: task.Mgr,
// Get the scan V1 to V2 report converters // Get the scan V1 to V2 report converters
reportConverter: postprocessors.Converter, reportConverter: postprocessors.Converter,
cache: func() cache.Cache {
return cache.Default()
},
} }
} }
@ -368,6 +382,44 @@ func (bc *basicController) ScanAll(ctx context.Context, trigger string, async bo
return executionID, nil return executionID, nil
} }
func (bc *basicController) StopScanAll(ctx context.Context, executionID int64, async bool) error {
stopScanAll := func(ctx context.Context, executionID int64) error {
// mark scan all stopped
if err := bc.markScanAllStopped(ctx, executionID); err != nil {
return err
}
// stop the execution and sub tasks
return bc.execMgr.Stop(ctx, executionID)
}
if async {
go func() {
if err := stopScanAll(ctx, executionID); err != nil {
log.Errorf("failed to stop scan all, error: %v", err)
}
}()
return nil
}
return stopScanAll(ctx, executionID)
}
func scanAllStoppedKey(execID int64) string {
return fmt.Sprintf("scan_all:execution_id:%d:stopped", execID)
}
func (bc *basicController) markScanAllStopped(ctx context.Context, execID int64) error {
// set the expire time to 2 hours, the duration should be large enough
// for controller to capture the stop flag, leverage the key recycled
// by redis TTL, no need to clean by scan controller as the new scan all
// will have a new unique execution id, the old key has no effects to anything.
return bc.cache().Save(ctx, scanAllStoppedKey(execID), "", 2*time.Hour)
}
func (bc *basicController) isScanAllStopped(ctx context.Context, execID int64) bool {
return bc.cache().Contains(ctx, scanAllStoppedKey(execID))
}
func (bc *basicController) startScanAll(ctx context.Context, executionID int64) error { func (bc *basicController) startScanAll(ctx context.Context, executionID int64) error {
batchSize := 50 batchSize := 50
@ -379,8 +431,15 @@ func (bc *basicController) startScanAll(ctx context.Context, executionID int64)
UnsupportCount int `json:"unsupport_count"` UnsupportCount int `json:"unsupport_count"`
UnknowCount int `json:"unknow_count"` UnknowCount int `json:"unknow_count"`
}{} }{}
// with cancel function to signal downstream worker
ctx, cancel := context.WithCancel(ctx)
defer cancel()
for artifact := range ar.Iterator(ctx, batchSize, nil, nil) { for artifact := range ar.Iterator(ctx, batchSize, nil, nil) {
if bc.isScanAllStopped(ctx, executionID) {
return errScanAllStopped
}
summary.TotalCount++ summary.TotalCount++
scan := func(ctx context.Context) error { scan := func(ctx context.Context) error {

View File

@ -30,6 +30,7 @@ import (
"github.com/goharbor/harbor/src/common/rbac" "github.com/goharbor/harbor/src/common/rbac"
"github.com/goharbor/harbor/src/controller/artifact" "github.com/goharbor/harbor/src/controller/artifact"
"github.com/goharbor/harbor/src/controller/robot" "github.com/goharbor/harbor/src/controller/robot"
"github.com/goharbor/harbor/src/lib/cache"
"github.com/goharbor/harbor/src/lib/config" "github.com/goharbor/harbor/src/lib/config"
"github.com/goharbor/harbor/src/lib/orm" "github.com/goharbor/harbor/src/lib/orm"
"github.com/goharbor/harbor/src/lib/q" "github.com/goharbor/harbor/src/lib/q"
@ -49,6 +50,7 @@ import (
robottesting "github.com/goharbor/harbor/src/testing/controller/robot" robottesting "github.com/goharbor/harbor/src/testing/controller/robot"
scannertesting "github.com/goharbor/harbor/src/testing/controller/scanner" scannertesting "github.com/goharbor/harbor/src/testing/controller/scanner"
tagtesting "github.com/goharbor/harbor/src/testing/controller/tag" tagtesting "github.com/goharbor/harbor/src/testing/controller/tag"
mockcache "github.com/goharbor/harbor/src/testing/lib/cache"
ormtesting "github.com/goharbor/harbor/src/testing/lib/orm" ormtesting "github.com/goharbor/harbor/src/testing/lib/orm"
"github.com/goharbor/harbor/src/testing/mock" "github.com/goharbor/harbor/src/testing/mock"
accessorytesting "github.com/goharbor/harbor/src/testing/pkg/accessory" accessorytesting "github.com/goharbor/harbor/src/testing/pkg/accessory"
@ -77,6 +79,7 @@ type ControllerTestSuite struct {
ar artifact.Controller ar artifact.Controller
c Controller c Controller
reportConverter *postprocessorstesting.ScanReportV1ToV2Converter reportConverter *postprocessorstesting.ScanReportV1ToV2Converter
cache *mockcache.Cache
} }
// TestController is the entry point of ControllerTestSuite. // TestController is the entry point of ControllerTestSuite.
@ -271,6 +274,8 @@ func (suite *ControllerTestSuite) SetupSuite() {
suite.taskMgr = &tasktesting.Manager{} suite.taskMgr = &tasktesting.Manager{}
suite.cache = &mockcache.Cache{}
suite.c = &basicController{ suite.c = &basicController{
manager: mgr, manager: mgr,
ar: suite.ar, ar: suite.ar,
@ -298,6 +303,7 @@ func (suite *ControllerTestSuite) SetupSuite() {
execMgr: suite.execMgr, execMgr: suite.execMgr,
taskMgr: suite.taskMgr, taskMgr: suite.taskMgr,
reportConverter: &postprocessorstesting.ScanReportV1ToV2Converter{}, reportConverter: &postprocessorstesting.ScanReportV1ToV2Converter{},
cache: func() cache.Cache { return suite.cache },
} }
} }
@ -522,25 +528,25 @@ func (suite *ControllerTestSuite) TestScanControllerGetMultiScanLog() {
func (suite *ControllerTestSuite) TestScanAll() { func (suite *ControllerTestSuite) TestScanAll() {
{ {
// no artifacts found when scan all // no artifacts found when scan all
ctx := context.TODO()
executionID := int64(1) executionID := int64(1)
suite.execMgr.On( suite.execMgr.On(
"Create", ctx, "SCAN_ALL", int64(0), "SCHEDULE", "Create", mock.Anything, "SCAN_ALL", int64(0), "SCHEDULE",
).Return(executionID, nil).Once() ).Return(executionID, nil).Once()
mock.OnAnything(suite.accessoryMgr, "List").Return([]accessoryModel.Accessory{}, nil).Once() mock.OnAnything(suite.accessoryMgr, "List").Return([]accessoryModel.Accessory{}, nil).Once()
mock.OnAnything(suite.artifactCtl, "List").Return([]*artifact.Artifact{}, nil).Once() mock.OnAnything(suite.artifactCtl, "List").Return([]*artifact.Artifact{}, nil).Once()
suite.taskMgr.On("Count", ctx, q.New(q.KeyWords{"execution_id": executionID})).Return(int64(0), nil).Once() suite.taskMgr.On("Count", mock.Anything, q.New(q.KeyWords{"execution_id": executionID})).Return(int64(0), nil).Once()
mock.OnAnything(suite.execMgr, "UpdateExtraAttrs").Return(nil).Once() mock.OnAnything(suite.execMgr, "UpdateExtraAttrs").Return(nil).Once()
suite.execMgr.On("MarkDone", ctx, executionID, mock.Anything).Return(nil).Once() suite.execMgr.On("MarkDone", mock.Anything, executionID, mock.Anything).Return(nil).Once()
_, err := suite.c.ScanAll(ctx, "SCHEDULE", false) suite.cache.On("Contains", mock.Anything, scanAllStoppedKey(1)).Return(false).Once()
_, err := suite.c.ScanAll(context.TODO(), "SCHEDULE", false)
suite.NoError(err) suite.NoError(err)
} }
@ -551,7 +557,7 @@ func (suite *ControllerTestSuite) TestScanAll() {
executionID := int64(1) executionID := int64(1)
suite.execMgr.On( suite.execMgr.On(
"Create", ctx, "SCAN_ALL", int64(0), "SCHEDULE", "Create", mock.Anything, "SCAN_ALL", int64(0), "SCHEDULE",
).Return(executionID, nil).Once() ).Return(executionID, nil).Once()
mock.OnAnything(suite.accessoryMgr, "List").Return([]accessoryModel.Accessory{}, nil).Once() mock.OnAnything(suite.accessoryMgr, "List").Return([]accessoryModel.Accessory{}, nil).Once()
@ -568,13 +574,28 @@ func (suite *ControllerTestSuite) TestScanAll() {
mock.OnAnything(suite.reportMgr, "Create").Return("uuid", nil).Once() mock.OnAnything(suite.reportMgr, "Create").Return("uuid", nil).Once()
mock.OnAnything(suite.taskMgr, "Create").Return(int64(0), fmt.Errorf("failed")).Once() mock.OnAnything(suite.taskMgr, "Create").Return(int64(0), fmt.Errorf("failed")).Once()
mock.OnAnything(suite.execMgr, "UpdateExtraAttrs").Return(nil).Once() mock.OnAnything(suite.execMgr, "UpdateExtraAttrs").Return(nil).Once()
suite.execMgr.On("MarkError", ctx, executionID, mock.Anything).Return(nil).Once() suite.execMgr.On("MarkError", mock.Anything, executionID, mock.Anything).Return(nil).Once()
_, err := suite.c.ScanAll(ctx, "SCHEDULE", false) _, err := suite.c.ScanAll(ctx, "SCHEDULE", false)
suite.NoError(err) suite.NoError(err)
} }
} }
func (suite *ControllerTestSuite) TestStopScanAll() {
mockExecID := int64(100)
// mock error case
mockErr := fmt.Errorf("stop scan all error")
suite.cache.On("Save", mock.Anything, scanAllStoppedKey(mockExecID), mock.Anything, mock.Anything).Return(mockErr).Once()
err := suite.c.StopScanAll(context.TODO(), mockExecID, false)
suite.EqualError(err, mockErr.Error())
// mock normal case
suite.cache.On("Save", mock.Anything, scanAllStoppedKey(mockExecID), mock.Anything, mock.Anything).Return(nil).Once()
suite.execMgr.On("Stop", mock.Anything, mockExecID).Return(nil).Once()
err = suite.c.StopScanAll(context.TODO(), mockExecID, false)
suite.NoError(err)
}
func (suite *ControllerTestSuite) TestDeleteReports() { func (suite *ControllerTestSuite) TestDeleteReports() {
suite.reportMgr.On("DeleteByDigests", context.TODO(), "digest").Return(nil).Once() suite.reportMgr.On("DeleteByDigests", context.TODO(), "digest").Return(nil).Once()

View File

@ -157,7 +157,7 @@ func (suite *CallbackTestSuite) TestScanAllCallback() {
mock.OnAnything(suite.execMgr, "UpdateExtraAttrs").Return(nil).Once() mock.OnAnything(suite.execMgr, "UpdateExtraAttrs").Return(nil).Once()
suite.execMgr.On("MarkDone", context.TODO(), executionID, mock.Anything).Return(nil).Once() suite.execMgr.On("MarkDone", mock.Anything, executionID, mock.Anything).Return(nil).Once()
suite.NoError(scanAllCallback(context.TODO(), "")) suite.NoError(scanAllCallback(context.TODO(), ""))
} }

View File

@ -115,6 +115,16 @@ type Controller interface {
// error : non nil error if any errors occurred // error : non nil error if any errors occurred
ScanAll(ctx context.Context, trigger string, async bool) (int64, error) ScanAll(ctx context.Context, trigger string, async bool) (int64, error)
// StopScanAll stops the scanAll
//
// Arguments:
// ctx context.Context : the context for this method
// executionID int64 : the id of scan all execution
// async bool : stop scan all in background
// Returns:
// error : non nil error if any errors occurred
StopScanAll(ctx context.Context, executionID int64, async bool) error
// GetVulnerable returns the vulnerable of the artifact for the allowlist // GetVulnerable returns the vulnerable of the artifact for the allowlist
// //
// Arguments: // Arguments:

View File

@ -28,7 +28,6 @@ import (
"github.com/goharbor/harbor/src/controller/scanner" "github.com/goharbor/harbor/src/controller/scanner"
"github.com/goharbor/harbor/src/jobservice/job" "github.com/goharbor/harbor/src/jobservice/job"
"github.com/goharbor/harbor/src/lib/errors" "github.com/goharbor/harbor/src/lib/errors"
"github.com/goharbor/harbor/src/lib/log"
"github.com/goharbor/harbor/src/lib/orm" "github.com/goharbor/harbor/src/lib/orm"
"github.com/goharbor/harbor/src/lib/q" "github.com/goharbor/harbor/src/lib/q"
"github.com/goharbor/harbor/src/pkg/scheduler" "github.com/goharbor/harbor/src/pkg/scheduler"
@ -74,12 +73,10 @@ func (s *scanAllAPI) StopScanAll(ctx context.Context, params operation.StopScanA
if execution == nil { if execution == nil {
return s.SendError(ctx, errors.BadRequestError(nil).WithMessage("no scan all job is found currently")) return s.SendError(ctx, errors.BadRequestError(nil).WithMessage("no scan all job is found currently"))
} }
go func(ctx context.Context, eid int64) {
err := s.execMgr.Stop(ctx, eid) if err = s.scanCtl.StopScanAll(s.makeCtx(), execution.ID, true); err != nil {
if err != nil { return s.SendError(ctx, err)
log.Errorf("failed to stop the execution of executionID=%+v", execution.ID) }
}
}(s.makeCtx(), execution.ID)
return operation.NewStopScanAllAccepted() return operation.NewStopScanAllAccepted()
} }

View File

@ -247,6 +247,7 @@ func (suite *ScanAllTestSuite) TestStopScanAll() {
times := 3 times := 3
suite.Security.On("IsAuthenticated").Return(true).Times(times) suite.Security.On("IsAuthenticated").Return(true).Times(times)
suite.Security.On("Can", mock.Anything, mock.Anything, mock.Anything).Return(true).Times(times) suite.Security.On("Can", mock.Anything, mock.Anything, mock.Anything).Return(true).Times(times)
mock.OnAnything(suite.scanCtl, "StopScanAll").Return(nil).Times(times)
mock.OnAnything(suite.scannerCtl, "ListRegistrations").Return([]*scanner.Registration{{ID: int64(1)}}, nil).Times(times) mock.OnAnything(suite.scannerCtl, "ListRegistrations").Return([]*scanner.Registration{{ID: int64(1)}}, nil).Times(times)
{ {

View File

@ -205,6 +205,20 @@ func (_m *Controller) Stop(ctx context.Context, _a1 *artifact.Artifact) error {
return r0 return r0
} }
// StopScanAll provides a mock function with given fields: ctx, executionID, async
func (_m *Controller) StopScanAll(ctx context.Context, executionID int64, async bool) error {
ret := _m.Called(ctx, executionID, async)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, int64, bool) error); ok {
r0 = rf(ctx, executionID, async)
} else {
r0 = ret.Error(0)
}
return r0
}
type mockConstructorTestingTNewController interface { type mockConstructorTestingTNewController interface {
mock.TestingT mock.TestingT
Cleanup(func()) Cleanup(func())