diff --git a/src/pkg/p2p/preheat/dao/instance/dao.go b/src/pkg/p2p/preheat/dao/instance/dao.go index 5de519ff5..07095d59f 100644 --- a/src/pkg/p2p/preheat/dao/instance/dao.go +++ b/src/pkg/p2p/preheat/dao/instance/dao.go @@ -63,10 +63,13 @@ func (d *dao) Get(ctx context.Context, id int64) (*provider.Instance, error) { } di := provider.Instance{ID: id} - err = o.Read(&di, "ID") - if err == beego_orm.ErrNoRows { - return nil, nil + if err = o.Read(&di, "ID"); err != nil { + if e := orm.AsNotFoundError(err, "instance %d not found", id); e != nil { + err = e + } + return nil, err } + return &di, err } diff --git a/src/pkg/p2p/preheat/dao/instance/dao_test.go b/src/pkg/p2p/preheat/dao/instance/dao_test.go index 9205dd219..67f2446da 100644 --- a/src/pkg/p2p/preheat/dao/instance/dao_test.go +++ b/src/pkg/p2p/preheat/dao/instance/dao_test.go @@ -62,6 +62,7 @@ func (is *instanceSuite) TestGet() { // not exist i, err = is.dao.Get(is.ctx, 0) assert.Nil(t, i) + assert.True(t, errors.IsNotFoundErr(err)) } // TestCreate tests create instance. diff --git a/src/server/v2.0/handler/preheat.go b/src/server/v2.0/handler/preheat.go index 2c200f3db..abe4c666f 100644 --- a/src/server/v2.0/handler/preheat.go +++ b/src/server/v2.0/handler/preheat.go @@ -258,6 +258,12 @@ func (api *preheatAPI) CreatePolicy(ctx context.Context, params operation.Create // override project ID policy.ProjectID = project.ProjectID + // validate provider whether exist + _, err = api.preheatCtl.GetInstance(ctx, policy.ProviderID) + if err != nil { + return api.SendError(ctx, err) + } + _, err = api.preheatCtl.CreatePolicy(ctx, policy) if err != nil { return api.SendError(ctx, err) @@ -289,6 +295,12 @@ func (api *preheatAPI) UpdatePolicy(ctx context.Context, params operation.Update return api.SendError(ctx, err) } + // validate provider whether exist + _, err = api.preheatCtl.GetInstance(ctx, policy.ProviderID) + if err != nil { + return api.SendError(ctx, err) + } + err = api.preheatCtl.UpdatePolicy(ctx, policy) if err != nil { return api.SendError(ctx, err) @@ -716,6 +728,11 @@ func (api *preheatAPI) ListTasks(ctx context.Context, params operation.ListTasks if err := api.RequireProjectAccess(ctx, params.ProjectName, rbac.ActionList, rbac.ResourcePreatPolicy); err != nil { return api.SendError(ctx, err) } + + if err := api.requireExecutionInProject(ctx, params.ProjectName, params.PreheatPolicyName, params.ExecutionID); err != nil { + return api.SendError(ctx, err) + } + query, err := api.BuildQuery(ctx, params.Q, params.Sort, params.Page, params.PageSize) if err != nil { return api.SendError(ctx, err) @@ -754,7 +771,7 @@ func (api *preheatAPI) GetPreheatLog(ctx context.Context, params operation.GetPr return api.SendError(ctx, err) } - if err := api.requireTaskInProject(ctx, params.ProjectName, params.PreheatPolicyName, params.TaskID); err != nil { + if err := api.requireTaskInProject(ctx, params.ProjectName, params.PreheatPolicyName, params.ExecutionID, params.TaskID); err != nil { return api.SendError(ctx, err) } @@ -766,39 +783,26 @@ func (api *preheatAPI) GetPreheatLog(ctx context.Context, params operation.GetPr return operation.NewGetPreheatLogOK().WithPayload(string(l)) } -func (api *preheatAPI) requireTaskInProject(ctx context.Context, projectNameOrID interface{}, policyName string, taskID int64) error { +func (api *preheatAPI) requireTaskInProject(ctx context.Context, projectNameOrID interface{}, policyName string, executionID, taskID int64) error { projectID, err := getProjectID(ctx, projectNameOrID) notFoundErr := fmt.Errorf("project id %d, task id %d not found", projectID, taskID) if err != nil { return err } - plc, err := api.preheatCtl.GetPolicyByName(ctx, projectID, policyName) + // require execution before require task + if err := api.requireExecutionInProject(ctx, projectID, policyName, executionID); err != nil { + return err + } + + task, err := api.taskCtl.Get(ctx, taskID) if err != nil { return err } - execs, err := api.executionCtl.List(ctx, q.New(q.KeyWords{"VendorType": job.P2PPreheat, "VendorID": plc.ID})) - if err != nil { - return err - } - if len(execs) == 0 { - return errors.NotFoundError(notFoundErr) - } - var execIds []interface{} - for _, item := range execs { - execIds = append(execIds, item.ID) - } - tasks, err := api.taskCtl.List(ctx, q.New(q.KeyWords{"ExecutionID": q.NewOrList(execIds)})) - if err != nil { - return err - } - if len(tasks) == 0 { - return errors.NotFoundError(notFoundErr) - } - for _, t := range tasks { - if t.ID == taskID { - return nil - } + + if task != nil && task.ExecutionID == executionID { + return nil } + return errors.NotFoundError(notFoundErr) } @@ -812,15 +816,16 @@ func (api *preheatAPI) requireExecutionInProject(ctx context.Context, projectNam if err != nil { return err } - execs, err := api.executionCtl.List(ctx, q.New(q.KeyWords{"VendorType": job.P2PPreheat, "VendorID": plc.ID})) + + exec, err := api.executionCtl.Get(ctx, executionID) if err != nil { return err } - for _, e := range execs { - if e.ID == executionID { - return nil - } + + if exec != nil && exec.VendorType == job.P2PPreheat && exec.VendorID == plc.ID { + return nil } + return errors.NotFoundError(notFoundErr) }