fix: update preheat api handler and DAO (#17079)

1. fix preheat dao Get method
2. update preheat tasks and getLog api

Signed-off-by: chlins <chenyuzh@vmware.com>
This commit is contained in:
Chenyu Zhang 2022-06-28 19:01:08 +08:00 committed by GitHub
parent ff4eb7f27c
commit 1c3eb6974c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 42 additions and 33 deletions

View File

@ -63,10 +63,13 @@ func (d *dao) Get(ctx context.Context, id int64) (*provider.Instance, error) {
}
di := provider.Instance{ID: id}
err = o.Read(&di, "ID")
if err == beego_orm.ErrNoRows {
return nil, nil
if err = o.Read(&di, "ID"); err != nil {
if e := orm.AsNotFoundError(err, "instance %d not found", id); e != nil {
err = e
}
return nil, err
}
return &di, err
}

View File

@ -62,6 +62,7 @@ func (is *instanceSuite) TestGet() {
// not exist
i, err = is.dao.Get(is.ctx, 0)
assert.Nil(t, i)
assert.True(t, errors.IsNotFoundErr(err))
}
// TestCreate tests create instance.

View File

@ -258,6 +258,12 @@ func (api *preheatAPI) CreatePolicy(ctx context.Context, params operation.Create
// override project ID
policy.ProjectID = project.ProjectID
// validate provider whether exist
_, err = api.preheatCtl.GetInstance(ctx, policy.ProviderID)
if err != nil {
return api.SendError(ctx, err)
}
_, err = api.preheatCtl.CreatePolicy(ctx, policy)
if err != nil {
return api.SendError(ctx, err)
@ -289,6 +295,12 @@ func (api *preheatAPI) UpdatePolicy(ctx context.Context, params operation.Update
return api.SendError(ctx, err)
}
// validate provider whether exist
_, err = api.preheatCtl.GetInstance(ctx, policy.ProviderID)
if err != nil {
return api.SendError(ctx, err)
}
err = api.preheatCtl.UpdatePolicy(ctx, policy)
if err != nil {
return api.SendError(ctx, err)
@ -716,6 +728,11 @@ func (api *preheatAPI) ListTasks(ctx context.Context, params operation.ListTasks
if err := api.RequireProjectAccess(ctx, params.ProjectName, rbac.ActionList, rbac.ResourcePreatPolicy); err != nil {
return api.SendError(ctx, err)
}
if err := api.requireExecutionInProject(ctx, params.ProjectName, params.PreheatPolicyName, params.ExecutionID); err != nil {
return api.SendError(ctx, err)
}
query, err := api.BuildQuery(ctx, params.Q, params.Sort, params.Page, params.PageSize)
if err != nil {
return api.SendError(ctx, err)
@ -754,7 +771,7 @@ func (api *preheatAPI) GetPreheatLog(ctx context.Context, params operation.GetPr
return api.SendError(ctx, err)
}
if err := api.requireTaskInProject(ctx, params.ProjectName, params.PreheatPolicyName, params.TaskID); err != nil {
if err := api.requireTaskInProject(ctx, params.ProjectName, params.PreheatPolicyName, params.ExecutionID, params.TaskID); err != nil {
return api.SendError(ctx, err)
}
@ -766,39 +783,26 @@ func (api *preheatAPI) GetPreheatLog(ctx context.Context, params operation.GetPr
return operation.NewGetPreheatLogOK().WithPayload(string(l))
}
func (api *preheatAPI) requireTaskInProject(ctx context.Context, projectNameOrID interface{}, policyName string, taskID int64) error {
func (api *preheatAPI) requireTaskInProject(ctx context.Context, projectNameOrID interface{}, policyName string, executionID, taskID int64) error {
projectID, err := getProjectID(ctx, projectNameOrID)
notFoundErr := fmt.Errorf("project id %d, task id %d not found", projectID, taskID)
if err != nil {
return err
}
plc, err := api.preheatCtl.GetPolicyByName(ctx, projectID, policyName)
// require execution before require task
if err := api.requireExecutionInProject(ctx, projectID, policyName, executionID); err != nil {
return err
}
task, err := api.taskCtl.Get(ctx, taskID)
if err != nil {
return err
}
execs, err := api.executionCtl.List(ctx, q.New(q.KeyWords{"VendorType": job.P2PPreheat, "VendorID": plc.ID}))
if err != nil {
return err
}
if len(execs) == 0 {
return errors.NotFoundError(notFoundErr)
}
var execIds []interface{}
for _, item := range execs {
execIds = append(execIds, item.ID)
}
tasks, err := api.taskCtl.List(ctx, q.New(q.KeyWords{"ExecutionID": q.NewOrList(execIds)}))
if err != nil {
return err
}
if len(tasks) == 0 {
return errors.NotFoundError(notFoundErr)
}
for _, t := range tasks {
if t.ID == taskID {
return nil
}
if task != nil && task.ExecutionID == executionID {
return nil
}
return errors.NotFoundError(notFoundErr)
}
@ -812,15 +816,16 @@ func (api *preheatAPI) requireExecutionInProject(ctx context.Context, projectNam
if err != nil {
return err
}
execs, err := api.executionCtl.List(ctx, q.New(q.KeyWords{"VendorType": job.P2PPreheat, "VendorID": plc.ID}))
exec, err := api.executionCtl.Get(ctx, executionID)
if err != nil {
return err
}
for _, e := range execs {
if e.ID == executionID {
return nil
}
if exec != nil && exec.VendorType == job.P2PPreheat && exec.VendorID == plc.ID {
return nil
}
return errors.NotFoundError(notFoundErr)
}