diff --git a/src/controller/replication/execution.go b/src/controller/replication/execution.go index 8eb610365..c5635b3dd 100644 --- a/src/controller/replication/execution.go +++ b/src/controller/replication/execution.go @@ -149,11 +149,8 @@ func (c *controller) Start(ctx context.Context, policy *replicationmodel.Policy, func (c *controller) markError(ctx context.Context, executionID int64, err error) { logger := log.GetLogger(ctx) // try to stop the execution first in case that some tasks are already created - if err := c.execMgr.StopAndWait(ctx, executionID, 10*time.Second); err != nil { - logger.Errorf("failed to stop the execution %d: %v", executionID, err) - } - if err := c.execMgr.MarkError(ctx, executionID, err.Error()); err != nil { - logger.Errorf("failed to mark error for the execution %d: %v", executionID, err) + if e := c.execMgr.StopAndWaitWithError(ctx, executionID, 10*time.Second, err); e != nil { + logger.Errorf("failed to stop the execution %d: %v", executionID, e) } } diff --git a/src/controller/replication/execution_test.go b/src/controller/replication/execution_test.go index 31c5873d4..bbfe8bc95 100644 --- a/src/controller/replication/execution_test.go +++ b/src/controller/replication/execution_test.go @@ -75,8 +75,7 @@ func (r *replicationTestSuite) TestStart() { // got error when running the replication flow r.execMgr.On("Create", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(int64(1), nil) r.execMgr.On("Get", mock.Anything, mock.Anything).Return(&task.Execution{}, nil) - r.execMgr.On("StopAndWait", mock.Anything, mock.Anything, mock.Anything).Return(nil) - r.execMgr.On("MarkError", mock.Anything, mock.Anything, mock.Anything).Return(nil) + r.execMgr.On("StopAndWaitWithError", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) r.flowCtl.On("Start", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(fmt.Errorf("error")) r.ormCreator.On("Create").Return(nil) id, err = r.ctl.Start(context.Background(), &repctlmodel.Policy{Enabled: true}, nil, task.ExecutionTriggerManual) diff --git a/src/controller/retention/controller.go b/src/controller/retention/controller.go index 642d07403..08ac62d01 100644 --- a/src/controller/retention/controller.go +++ b/src/controller/retention/controller.go @@ -233,12 +233,9 @@ func (r *defaultController) TriggerRetentionExec(ctx context.Context, policyID i }, ) if num, err := r.launcher.Launch(ctx, p, id, dryRun); err != nil { - if err1 := r.execMgr.StopAndWait(ctx, id, 10*time.Second); err1 != nil { + if err1 := r.execMgr.StopAndWaitWithError(ctx, id, 10*time.Second, err); err1 != nil { logger.Errorf("failed to stop the retention execution %d: %v", id, err1) } - if err1 := r.execMgr.MarkError(ctx, id, err.Error()); err1 != nil { - logger.Errorf("failed to mark error for the retention execution %d: %v", id, err1) - } return 0, err } else if num == 0 { // no candidates, mark the execution as done directly diff --git a/src/controller/systemartifact/execution.go b/src/controller/systemartifact/execution.go index b7e4dcd6c..9ea6b59da 100644 --- a/src/controller/systemartifact/execution.go +++ b/src/controller/systemartifact/execution.go @@ -105,11 +105,8 @@ func (c *controller) createCleanupTask(ctx context.Context, jobParams job.Parame func (c *controller) markError(ctx context.Context, executionID int64, err error) { // try to stop the execution first in case that some tasks are already created - if err := c.execMgr.StopAndWait(ctx, executionID, 10*time.Second); err != nil { - log.Errorf("failed to stop the execution %d: %v", executionID, err) - } - if err := c.execMgr.MarkError(ctx, executionID, err.Error()); err != nil { - log.Errorf("failed to mark error for the execution %d: %v", executionID, err) + if e := c.execMgr.StopAndWaitWithError(ctx, executionID, 10*time.Second, err); e != nil { + log.Errorf("failed to stop the execution %d: %v", executionID, e) } } diff --git a/src/controller/systemartifact/execution_test.go b/src/controller/systemartifact/execution_test.go index 68ba21191..7656fae37 100644 --- a/src/controller/systemartifact/execution_test.go +++ b/src/controller/systemartifact/execution_test.go @@ -117,7 +117,7 @@ func (suite *SystemArtifactCleanupTestSuite) TestStartCleanupErrorDuringTaskCrea suite.taskMgr.On("Create", ctx, executionID, mock.Anything).Return(taskId, errors.New("test error")).Once() suite.execMgr.On("MarkError", ctx, executionID, mock.Anything).Return(nil).Once() - suite.execMgr.On("StopAndWait", ctx, executionID, mock.Anything).Return(nil).Once() + suite.execMgr.On("StopAndWaitWithError", ctx, executionID, mock.Anything, mock.Anything).Return(nil).Once() err := suite.ctl.Start(ctx, false, "SCHEDULE") suite.Error(err) diff --git a/src/pkg/task/dao/task.go b/src/pkg/task/dao/task.go index de3c71bd2..177738561 100644 --- a/src/pkg/task/dao/task.go +++ b/src/pkg/task/dao/task.go @@ -25,6 +25,8 @@ import ( "github.com/goharbor/harbor/src/lib/log" "github.com/goharbor/harbor/src/lib/orm" "github.com/goharbor/harbor/src/lib/q" + + "github.com/google/uuid" ) // TaskDAO is the data access object interface for task @@ -91,22 +93,33 @@ func (t *taskDAO) List(ctx context.Context, query *q.Query) ([]*Task, error) { return tasks, nil } +func isValidUUID(id string) bool { + if len(id) == 0 { + return false + } + if _, err := uuid.Parse(id); err != nil { + return false + } + return true +} + func (t *taskDAO) ListScanTasksByReportUUID(ctx context.Context, uuid string) ([]*Task, error) { ormer, err := orm.FromContext(ctx) if err != nil { return nil, err } - tasks := []*Task{} - // Due to the limitation of the beego's orm, the SQL cannot be converted by orm framework, - // so we can only execute the query by raw SQL, the SQL filters the task contains the report uuid in the column extra_attrs, - // consider from performance side which can using indexes to speed up queries. - sql := fmt.Sprintf(`SELECT * FROM task WHERE extra_attrs::jsonb->'report_uuids' @> '["%s"]'`, uuid) - _, err = ormer.Raw(sql).QueryRows(&tasks) + if !isValidUUID(uuid) { + return nil, errors.BadRequestError(fmt.Errorf("invalid UUID %v", uuid)) + } + + var tasks []*Task + param := fmt.Sprintf(`{"report_uuids":["%s"]}`, uuid) + sql := `SELECT * FROM task WHERE extra_attrs::jsonb @> cast( ? as jsonb )` + _, err = ormer.Raw(sql, param).QueryRows(&tasks) if err != nil { return nil, err } - return tasks, nil } diff --git a/src/pkg/task/dao/task_test.go b/src/pkg/task/dao/task_test.go index aeca41a1c..58a41e551 100644 --- a/src/pkg/task/dao/task_test.go +++ b/src/pkg/task/dao/task_test.go @@ -113,8 +113,9 @@ func (t *taskDAOTestSuite) TestList() { } func (t *taskDAOTestSuite) TestListScanTasksByReportUUID() { + reportUUID := `7f20b1b9-6117-4a2e-820b-e4cc0401f15e` // should not exist if non set - tasks, err := t.taskDAO.ListScanTasksByReportUUID(t.ctx, "fake-report-uuid") + tasks, err := t.taskDAO.ListScanTasksByReportUUID(t.ctx, reportUUID) t.Require().Nil(err) t.Require().Len(tasks, 0) // create one with report uuid @@ -122,12 +123,12 @@ func (t *taskDAOTestSuite) TestListScanTasksByReportUUID() { ExecutionID: t.executionID, Status: "success", StatusCode: 1, - ExtraAttrs: `{"report_uuids": ["fake-report-uuid"]}`, + ExtraAttrs: fmt.Sprintf(`{"report_uuids": ["%s"]}`, reportUUID), }) t.Require().Nil(err) defer t.taskDAO.Delete(t.ctx, taskID) // should exist as created - tasks, err = t.taskDAO.ListScanTasksByReportUUID(t.ctx, "fake-report-uuid") + tasks, err = t.taskDAO.ListScanTasksByReportUUID(t.ctx, reportUUID) t.Require().Nil(err) t.Require().Len(tasks, 1) t.Equal(taskID, tasks[0].ID) @@ -299,6 +300,29 @@ func (t *taskDAOTestSuite) TestExecutionIDsByVendorAndStatus() { defer t.taskDAO.Delete(t.ctx, tid) } +func TestIsValidUUID(t *testing.T) { + tests := []struct { + name string + uuid string + expected bool + }{ + {"Valid UUID", "7f20b1b9-6117-4a2e-820b-e4cc0401f15f", true}, + {"Invalid UUID - Short", "7f20b1b9-6117-4a2e-820b", false}, + {"Invalid UUID - Long", "7f20b1b9-6117-4a2e-820b-e4cc0401f15f-extra", false}, + {"Invalid UUID - Invalid Characters", "7f20b1b9-6117-4z2e-820b-e4cc0401f15f", false}, + {"Empty String", "", false}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + result := isValidUUID(test.uuid) + if result != test.expected { + t.Errorf("Expected isValidUUID(%s) to be %t, got %t", test.uuid, test.expected, result) + } + }) + } +} + func TestTaskDAOSuite(t *testing.T) { suite.Run(t, &taskDAOTestSuite{}) } diff --git a/src/pkg/task/execution.go b/src/pkg/task/execution.go index 49fee7f96..e2160cd89 100644 --- a/src/pkg/task/execution.go +++ b/src/pkg/task/execution.go @@ -59,6 +59,8 @@ type ExecutionManager interface { // StopAndWait stops all linked tasks of the specified execution and waits until all tasks are stopped // or get an error StopAndWait(ctx context.Context, id int64, timeout time.Duration) (err error) + // StopAndWaitWithError calls the StopAndWait first, if it doesn't return error, then it call MarkError if the origError is not empty + StopAndWaitWithError(ctx context.Context, id int64, timeout time.Duration, origError error) (err error) // Delete the specified execution and its tasks Delete(ctx context.Context, id int64) (err error) // Delete all executions and tasks of the specific vendor. They can be deleted only when all the executions/tasks @@ -250,6 +252,16 @@ func (e *executionManager) StopAndWait(ctx context.Context, id int64, timeout ti } } +func (e *executionManager) StopAndWaitWithError(ctx context.Context, id int64, timeout time.Duration, origError error) error { + if err := e.StopAndWait(ctx, id, timeout); err != nil { + return err + } + if origError != nil { + return e.MarkError(ctx, id, origError.Error()) + } + return nil +} + func (e *executionManager) Delete(ctx context.Context, id int64) error { tasks, err := e.taskDAO.List(ctx, &q.Query{ Keywords: map[string]interface{}{ diff --git a/src/testing/pkg/task/execution_manager.go b/src/testing/pkg/task/execution_manager.go index d09ac5214..e36c01923 100644 --- a/src/testing/pkg/task/execution_manager.go +++ b/src/testing/pkg/task/execution_manager.go @@ -209,6 +209,20 @@ func (_m *ExecutionManager) StopAndWait(ctx context.Context, id int64, timeout t return r0 } +// StopAndWaitWithError provides a mock function with given fields: ctx, id, timeout, origError +func (_m *ExecutionManager) StopAndWaitWithError(ctx context.Context, id int64, timeout time.Duration, origError error) error { + ret := _m.Called(ctx, id, timeout, origError) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, int64, time.Duration, error) error); ok { + r0 = rf(ctx, id, timeout, origError) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // UpdateExtraAttrs provides a mock function with given fields: ctx, id, extraAttrs func (_m *ExecutionManager) UpdateExtraAttrs(ctx context.Context, id int64, extraAttrs map[string]interface{}) error { ret := _m.Called(ctx, id, extraAttrs)