diff --git a/api/base.go b/api/base.go index f2529b61e..7a7ff80df 100644 --- a/api/base.go +++ b/api/base.go @@ -17,8 +17,10 @@ package api import ( "encoding/json" + "fmt" "net/http" + "github.com/astaxie/beego/validation" "github.com/vmware/harbor/auth" "github.com/vmware/harbor/dao" "github.com/vmware/harbor/models" @@ -51,6 +53,30 @@ func (b *BaseAPI) DecodeJSONReq(v interface{}) { } } +// Validate validates v if it implements interface validation.ValidFormer +func (b *BaseAPI) Validate(v interface{}) { + validator := validation.Validation{} + isValid, err := validator.Valid(v) + if err != nil { + log.Errorf("failed to validate: %v", err) + b.CustomAbort(http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) + } + + if !isValid { + message := "" + for _, e := range validator.Errors { + message += fmt.Sprintf("%s %s \n", e.Field, e.Message) + } + b.CustomAbort(http.StatusBadRequest, message) + } +} + +// DecodeJSONReqAndValidate does both decoding and validation +func (b *BaseAPI) DecodeJSONReqAndValidate(v interface{}) { + b.DecodeJSONReq(v) + b.Validate(v) +} + // ValidateUser checks if the request triggered by a valid user func (b *BaseAPI) ValidateUser() int { diff --git a/api/replication_policy.go b/api/replication_policy.go index 10906ba39..57a7c91a2 100644 --- a/api/replication_policy.go +++ b/api/replication_policy.go @@ -69,9 +69,40 @@ func (pa *RepPolicyAPI) Get() { // Post creates a policy, and if it is enbled, the replication will be triggered right now. func (pa *RepPolicyAPI) Post() { - policy := models.RepPolicy{} - pa.DecodeJSONReq(&policy) - pid, err := dao.AddRepPolicy(policy) + policy := &models.RepPolicy{} + pa.DecodeJSONReqAndValidate(policy) + + po, err := dao.GetRepPolicyByName(policy.Name) + if err != nil { + log.Errorf("failed to get policy %s: %v", policy.Name, err) + pa.CustomAbort(http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) + } + + if po != nil { + pa.CustomAbort(http.StatusConflict, "name is already used") + } + + project, err := dao.GetProjectByID(policy.ProjectID) + if err != nil { + log.Errorf("failed to get project %d: %v", policy.ProjectID, err) + pa.CustomAbort(http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) + } + + if project == nil { + pa.CustomAbort(http.StatusBadRequest, fmt.Sprintf("project %d does not exist", policy.ProjectID)) + } + + target, err := dao.GetRepTarget(policy.TargetID) + if err != nil { + log.Errorf("failed to get target %d: %v", policy.TargetID, err) + pa.CustomAbort(http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) + } + + if target == nil { + pa.CustomAbort(http.StatusBadRequest, fmt.Sprintf("target %d does not exist", policy.TargetID)) + } + + pid, err := dao.AddRepPolicy(*policy) if err != nil { log.Errorf("Failed to add policy to DB, error: %v", err) pa.RenderError(http.StatusInternalServerError, "Internal Error") diff --git a/api/target.go b/api/target.go index e159b5779..9a9366cf1 100644 --- a/api/target.go +++ b/api/target.go @@ -164,10 +164,16 @@ func (t *TargetAPI) Get() { // Post ... func (t *TargetAPI) Post() { target := &models.RepTarget{} - t.DecodeJSONReq(target) + t.DecodeJSONReqAndValidate(target) - if len(target.Name) == 0 || len(target.URL) == 0 { - t.CustomAbort(http.StatusBadRequest, "name or URL is nil") + ta, err := dao.GetRepTargetByName(target.Name) + if err != nil { + log.Errorf("failed to get target %s: %v", target.Name, err) + t.CustomAbort(http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) + } + + if ta != nil { + t.CustomAbort(http.StatusConflict, "name is already used") } if len(target.Password) != 0 { @@ -187,16 +193,32 @@ func (t *TargetAPI) Post() { func (t *TargetAPI) Put() { id := t.getIDFromURL() if id == 0 { - t.CustomAbort(http.StatusBadRequest, http.StatusText(http.StatusBadRequest)) + t.CustomAbort(http.StatusBadRequest, "id can not be empty or 0") } target := &models.RepTarget{} - t.DecodeJSONReq(target) + t.DecodeJSONReqAndValidate(target) - if target.ID == 0 { - target.ID = id + originTarget, err := dao.GetRepTarget(id) + if err != nil { + log.Errorf("failed to get target %d: %v", id, err) + t.CustomAbort(http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) } + if target.Name != originTarget.Name { + ta, err := dao.GetRepTargetByName(target.Name) + if err != nil { + log.Errorf("failed to get target %s: %v", target.Name, err) + t.CustomAbort(http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) + } + + if ta != nil { + t.CustomAbort(http.StatusConflict, "name is already used") + } + } + + target.ID = id + if len(target.Password) != 0 { target.Password = utils.ReversibleEncrypt(target.Password) } diff --git a/dao/dao_test.go b/dao/dao_test.go index 25b3f941c..8655a179c 100644 --- a/dao/dao_test.go +++ b/dao/dao_test.go @@ -766,6 +766,78 @@ func TestAddRepTarget(t *testing.T) { } } +func TestGetRepTargetByName(t *testing.T) { + target, err := GetRepTarget(targetID) + if err != nil { + t.Fatalf("failed to get target %d: %v", targetID, err) + } + + target2, err := GetRepTargetByName(target.Name) + if err != nil { + t.Fatalf("failed to get target %s: %v", target.Name, err) + } + + if target.Name != target2.Name { + t.Errorf("unexpected target name: %s, expected: %s", target2.Name, target.Name) + } +} + +func TestUpdateRepTarget(t *testing.T) { + target := &models.RepTarget{ + Name: "name", + URL: "http://url", + Username: "username", + Password: "password", + } + + id, err := AddRepTarget(*target) + if err != nil { + t.Fatalf("failed to add target: %v", err) + } + defer func() { + if err := DeleteRepTarget(id); err != nil { + t.Logf("failed to delete target %d: %v", id, err) + } + }() + + target.ID = id + target.Name = "new_name" + target.URL = "http://new_url" + target.Username = "new_username" + target.Password = "new_password" + + if err = UpdateRepTarget(*target); err != nil { + t.Fatalf("failed to update target: %v", err) + } + + target, err = GetRepTarget(id) + if err != nil { + t.Fatalf("failed to get target %d: %v", id, err) + } + + if target.Name != "new_name" { + t.Errorf("unexpected name: %s, expected: %s", target.Name, "new_name") + } + + if target.URL != "http://new_url" { + t.Errorf("unexpected url: %s, expected: %s", target.URL, "http://new_url") + } + + if target.Username != "new_username" { + t.Errorf("unexpected username: %s, expected: %s", target.Username, "new_username") + } + + if target.Password != "new_password" { + t.Errorf("unexpected password: %s, expected: %s", target.Password, "new_password") + } +} + +func TestGetAllRepTargets(t *testing.T) { + if _, err := GetAllRepTargets(); err != nil { + t.Fatalf("failed to get all targets: %v", err) + } +} + func TestAddRepPolicy(t *testing.T) { policy := models.RepPolicy{ ProjectID: 1, @@ -800,6 +872,23 @@ func TestAddRepPolicy(t *testing.T) { } +func TestGetRepPolicyByName(t *testing.T) { + policy, err := GetRepPolicy(policyID) + if err != nil { + t.Fatalf("failed to get policy %d: %v", policyID, err) + } + + policy2, err := GetRepPolicyByName(policy.Name) + if err != nil { + t.Fatalf("failed to get policy %s: %v", policy.Name, err) + } + + if policy.Name != policy2.Name { + t.Errorf("unexpected name: %s, expected: %s", policy2.Name, policy.Name) + } + +} + func TestDisableRepPolicy(t *testing.T) { err := DisableRepPolicy(policyID) if err != nil { diff --git a/dao/replication_job.go b/dao/replication_job.go index 58d5669a5..dd2b29702 100644 --- a/dao/replication_job.go +++ b/dao/replication_job.go @@ -11,13 +11,13 @@ import ( // AddRepTarget ... func AddRepTarget(target models.RepTarget) (int64, error) { - o := orm.NewOrm() + o := GetOrmer() return o.Insert(&target) } // GetRepTarget ... func GetRepTarget(id int64) (*models.RepTarget, error) { - o := orm.NewOrm() + o := GetOrmer() t := models.RepTarget{ID: id} err := o.Read(&t) if err == orm.ErrNoRows { @@ -26,28 +26,34 @@ func GetRepTarget(id int64) (*models.RepTarget, error) { return &t, err } +// GetRepTargetByName ... +func GetRepTargetByName(name string) (*models.RepTarget, error) { + o := GetOrmer() + t := models.RepTarget{Name: name} + err := o.Read(&t, "Name") + if err == orm.ErrNoRows { + return nil, nil + } + return &t, err +} + // DeleteRepTarget ... func DeleteRepTarget(id int64) error { - o := orm.NewOrm() + o := GetOrmer() _, err := o.Delete(&models.RepTarget{ID: id}) return err } // UpdateRepTarget ... func UpdateRepTarget(target models.RepTarget) error { - o := orm.NewOrm() - if len(target.Password) != 0 { - _, err := o.Update(&target) - return err - } - - _, err := o.Update(&target, "URL", "Name", "Username") + o := GetOrmer() + _, err := o.Update(&target, "URL", "Name", "Username", "Password") return err } // GetAllRepTargets ... func GetAllRepTargets() ([]*models.RepTarget, error) { - o := orm.NewOrm() + o := GetOrmer() qs := o.QueryTable(&models.RepTarget{}) var targets []*models.RepTarget _, err := qs.All(&targets) @@ -56,7 +62,7 @@ func GetAllRepTargets() ([]*models.RepTarget, error) { // AddRepPolicy ... func AddRepPolicy(policy models.RepPolicy) (int64, error) { - o := orm.NewOrm() + o := GetOrmer() sqlTpl := `insert into replication_policy (name, project_id, target_id, enabled, description, cron_str, start_time, creation_time, update_time ) values (?, ?, ?, ?, ?, ?, %s, NOW(), NOW())` var sql string if policy.Enabled == 1 { @@ -78,7 +84,7 @@ func AddRepPolicy(policy models.RepPolicy) (int64, error) { // GetRepPolicy ... func GetRepPolicy(id int64) (*models.RepPolicy, error) { - o := orm.NewOrm() + o := GetOrmer() p := models.RepPolicy{ID: id} err := o.Read(&p) if err == orm.ErrNoRows { @@ -87,24 +93,35 @@ func GetRepPolicy(id int64) (*models.RepPolicy, error) { return &p, err } +// GetRepPolicyByName ... +func GetRepPolicyByName(name string) (*models.RepPolicy, error) { + o := GetOrmer() + p := models.RepPolicy{Name: name} + err := o.Read(&p, "Name") + if err == orm.ErrNoRows { + return nil, nil + } + return &p, err +} + // GetRepPolicyByProject ... func GetRepPolicyByProject(projectID int64) ([]*models.RepPolicy, error) { var res []*models.RepPolicy - o := orm.NewOrm() + o := GetOrmer() _, err := o.QueryTable("replication_policy").Filter("project_id", projectID).All(&res) return res, err } // DeleteRepPolicy ... func DeleteRepPolicy(id int64) error { - o := orm.NewOrm() + o := GetOrmer() _, err := o.Delete(&models.RepPolicy{ID: id}) return err } // UpdateRepPolicyEnablement ... func UpdateRepPolicyEnablement(id int64, enabled int) error { - o := orm.NewOrm() + o := GetOrmer() p := models.RepPolicy{ ID: id, Enabled: enabled} @@ -125,7 +142,7 @@ func DisableRepPolicy(id int64) error { // AddRepJob ... func AddRepJob(job models.RepJob) (int64, error) { - o := orm.NewOrm() + o := GetOrmer() if len(job.Status) == 0 { job.Status = models.JobPending } @@ -137,7 +154,7 @@ func AddRepJob(job models.RepJob) (int64, error) { // GetRepJob ... func GetRepJob(id int64) (*models.RepJob, error) { - o := orm.NewOrm() + o := GetOrmer() j := models.RepJob{ID: id} err := o.Read(&j) if err == orm.ErrNoRows { @@ -164,20 +181,20 @@ func GetRepJobToStop(policyID int64) ([]*models.RepJob, error) { } func repJobPolicyIDQs(policyID int64) orm.QuerySeter { - o := orm.NewOrm() + o := GetOrmer() return o.QueryTable("replication_job").Filter("policy_id", policyID) } // DeleteRepJob ... func DeleteRepJob(id int64) error { - o := orm.NewOrm() + o := GetOrmer() _, err := o.Delete(&models.RepJob{ID: id}) return err } // UpdateRepJobStatus ... func UpdateRepJobStatus(id int64, status string) error { - o := orm.NewOrm() + o := GetOrmer() j := models.RepJob{ ID: id, Status: status, diff --git a/models/replication_job.go b/models/replication_job.go index 97fc8443a..7f3f94082 100644 --- a/models/replication_job.go +++ b/models/replication_job.go @@ -2,6 +2,8 @@ package models import ( "time" + + "github.com/astaxie/beego/validation" ) const ( @@ -42,6 +44,33 @@ type RepPolicy struct { UpdateTime time.Time `orm:"column(update_time);auto_now" json:"update_time"` } +// Valid ... +func (r *RepPolicy) Valid(v *validation.Validation) { + if len(r.Name) == 0 { + v.SetError("name", "can not be empty") + } + + if len(r.Name) > 256 { + v.SetError("name", "max length is 256") + } + + if r.ProjectID <= 0 { + v.SetError("project_id", "invalid") + } + + if r.TargetID <= 0 { + v.SetError("target_id", "invalid") + } + + if r.Enabled != 0 && r.Enabled != 1 { + v.SetError("enabled", "must be 0 or 1") + } + + if len(r.CronStr) > 256 { + v.SetError("cron_str", "max length is 256") + } +} + // RepJob is the model for a replication job, which is the execution unit on job service, currently it is used to transfer/remove // a repository to/from a remote registry instance. type RepJob struct { @@ -68,17 +97,42 @@ type RepTarget struct { UpdateTime time.Time `orm:"column(update_time);auto_now" json:"update_time"` } +// Valid ... +func (r *RepTarget) Valid(v *validation.Validation) { + if len(r.Name) == 0 { + v.SetError("name", "can not be empty") + } + + if len(r.Name) > 64 { + v.SetError("name", "max length is 64") + } + + if len(r.URL) == 0 { + v.SetError("endpoint", "can not be empty") + } + + if len(r.URL) > 64 { + v.SetError("endpoint", "max length is 64") + } + + // password is encoded using base64, the length of this field + // in DB is 64, so the max length in request is 48 + if len(r.Password) > 48 { + v.SetError("password", "max length is 48") + } +} + //TableName is required by by beego orm to map RepTarget to table replication_target -func (rt *RepTarget) TableName() string { +func (r *RepTarget) TableName() string { return "replication_target" } //TableName is required by by beego orm to map RepJob to table replication_job -func (rj *RepJob) TableName() string { +func (r *RepJob) TableName() string { return "replication_job" } //TableName is required by by beego orm to map RepPolicy to table replication_policy -func (rp *RepPolicy) TableName() string { +func (r *RepPolicy) TableName() string { return "replication_policy" }