diff --git a/src/core/filter/orm.go b/src/core/filter/orm.go index 1310db398..ba034d877 100644 --- a/src/core/filter/orm.go +++ b/src/core/filter/orm.go @@ -25,6 +25,8 @@ func OrmFilter(ctx *context.Context) { if ctx == nil || ctx.Request == nil { return } - - ctx.Request = ctx.Request.WithContext(orm.NewContext(ctx.Request.Context(), o.NewOrm())) + // This is a temp workaround for beego bug: https://github.com/goharbor/harbor/issues/10446 + // After we upgrading beego to the latest version and moving the filter to middleware, + // this workaround can be removed + *(ctx.Request) = *(ctx.Request.WithContext(orm.NewContext(ctx.Request.Context(), o.NewOrm()))) } diff --git a/src/core/filter/sessionchecker.go b/src/core/filter/sessionchecker.go index 8c15beb6c..fe22b6ebb 100644 --- a/src/core/filter/sessionchecker.go +++ b/src/core/filter/sessionchecker.go @@ -18,7 +18,10 @@ func SessionCheck(ctx *beegoctx.Context) { req := ctx.Request _, err := req.Cookie(config.SessionCookieName) if err == nil { - ctx.Request = req.WithContext(context.WithValue(req.Context(), SessionReqKey, true)) + // This is a temp workaround for beego bug: https://github.com/goharbor/harbor/issues/10446 + // After we upgrading beego to the latest version and moving the filter to middleware, + // this workaround can be removed + *(ctx.Request) = *(req.WithContext(context.WithValue(req.Context(), SessionReqKey, true))) log.Debug("Mark the request as no-session") } } diff --git a/src/internal/orm/orm.go b/src/internal/orm/orm.go index 8748128f2..06668fb9c 100644 --- a/src/internal/orm/orm.go +++ b/src/internal/orm/orm.go @@ -16,8 +16,7 @@ package orm import ( "context" - "fmt" - + "errors" "github.com/astaxie/beego/orm" "github.com/goharbor/harbor/src/common/utils/log" ) @@ -25,22 +24,28 @@ import ( type ormKey struct{} // FromContext returns orm from context -func FromContext(ctx context.Context) (orm.Ormer, bool) { +func FromContext(ctx context.Context) (orm.Ormer, error) { o, ok := ctx.Value(ormKey{}).(orm.Ormer) - return o, ok + if !ok { + return nil, errors.New("cannot get the ORM from context") + } + return o, nil } // NewContext returns new context with orm func NewContext(ctx context.Context, o orm.Ormer) context.Context { + if ctx == nil { + ctx = context.Background() + } return context.WithValue(ctx, ormKey{}, o) } // WithTransaction a decorator which make f run in transaction func WithTransaction(f func(ctx context.Context) error) func(ctx context.Context) error { return func(ctx context.Context) error { - o, ok := FromContext(ctx) - if !ok { - return fmt.Errorf("ormer value not found in context") + o, err := FromContext(ctx) + if err != nil { + return err } tx := ormerTx{Ormer: o} diff --git a/src/internal/orm/orm_test.go b/src/internal/orm/orm_test.go index bc07178f6..0eb7b9da3 100644 --- a/src/internal/orm/orm_test.go +++ b/src/internal/orm/orm_test.go @@ -26,18 +26,18 @@ import ( ) func addProject(ctx context.Context, project models.Project) (int64, error) { - o, ok := FromContext(ctx) - if !ok { - return 0, errors.New("orm not found in context") + o, err := FromContext(ctx) + if err != nil { + return 0, err } return o.Insert(&project) } func readProject(ctx context.Context, id int64) (*models.Project, error) { - o, ok := FromContext(ctx) - if !ok { - return nil, errors.New("orm not found in context") + o, err := FromContext(ctx) + if err != nil { + return nil, err } project := &models.Project{ @@ -52,22 +52,21 @@ func readProject(ctx context.Context, id int64) (*models.Project, error) { } func deleteProject(ctx context.Context, id int64) error { - o, ok := FromContext(ctx) - if !ok { - return errors.New("orm not found in context") + o, err := FromContext(ctx) + if err != nil { + return err } - project := &models.Project{ ProjectID: id, } - _, err := o.Delete(project, "project_id") + _, err = o.Delete(project, "project_id") return err } func existProject(ctx context.Context, id int64) bool { - o, ok := FromContext(ctx) - if !ok { + o, err := FromContext(ctx) + if err != nil { return false } @@ -95,12 +94,11 @@ func (suite *OrmSuite) SetupSuite() { func (suite *OrmSuite) TestContext() { ctx := context.TODO() - o, ok := FromContext(ctx) - suite.False(ok) - suite.Nil(o) + o, err := FromContext(ctx) + suite.NotNil(err) - o, ok = FromContext(NewContext(ctx, orm.NewOrm())) - suite.True(ok) + o, err = FromContext(NewContext(ctx, orm.NewOrm())) + suite.Nil(err) suite.NotNil(o) } diff --git a/src/internal/orm/query.go b/src/internal/orm/query.go index f98148ecc..3cad0b9a7 100644 --- a/src/internal/orm/query.go +++ b/src/internal/orm/query.go @@ -17,15 +17,18 @@ package orm import ( "context" "github.com/astaxie/beego/orm" - "github.com/goharbor/harbor/src/common/dao" "github.com/goharbor/harbor/src/pkg/q" ) // QuerySetter generates the query setter according to the query -func QuerySetter(ctx context.Context, model interface{}, query *q.Query) orm.QuerySeter { - qs := GetOrmer(ctx).QueryTable(model) +func QuerySetter(ctx context.Context, model interface{}, query *q.Query) (orm.QuerySeter, error) { + ormer, err := FromContext(ctx) + if err != nil { + return nil, err + } + qs := ormer.QueryTable(model) if query == nil { - return qs + return qs, nil } for k, v := range query.Keywords { qs = qs.Filter(k, v) @@ -36,11 +39,5 @@ func QuerySetter(ctx context.Context, model interface{}, query *q.Query) orm.Que qs = qs.Offset(query.PageSize * (query.PageNumber - 1)) } } - return qs -} - -// GetOrmer returns an ormer -// TODO remove it after weiwei's PR merged -func GetOrmer(ctx context.Context) orm.Ormer { - return dao.GetOrmer() + return qs, nil } diff --git a/src/pkg/artifact/dao/dao.go b/src/pkg/artifact/dao/dao.go index 3900a4a3d..4b321eb7a 100644 --- a/src/pkg/artifact/dao/dao.go +++ b/src/pkg/artifact/dao/dao.go @@ -57,11 +57,19 @@ func (d *dao) Count(ctx context.Context, query *q.Query) (int64, error) { Keywords: query.Keywords, } } - return orm.QuerySetter(ctx, &Artifact{}, query).Count() + qs, err := orm.QuerySetter(ctx, &Artifact{}, query) + if err != nil { + return 0, err + } + return qs.Count() } func (d *dao) List(ctx context.Context, query *q.Query) ([]*Artifact, error) { artifacts := []*Artifact{} - if _, err := orm.QuerySetter(ctx, &Artifact{}, query).All(&artifacts); err != nil { + qs, err := orm.QuerySetter(ctx, &Artifact{}, query) + if err != nil { + return nil, err + } + if _, err = qs.All(&artifacts); err != nil { return nil, err } return artifacts, nil @@ -70,7 +78,11 @@ func (d *dao) Get(ctx context.Context, id int64) (*Artifact, error) { artifact := &Artifact{ ID: id, } - if err := orm.GetOrmer(ctx).Read(artifact); err != nil { + ormer, err := orm.FromContext(ctx) + if err != nil { + return nil, err + } + if err = ormer.Read(artifact); err != nil { if e, ok := orm.IsNotFoundError(err, "artifact %d not found", id); ok { err = e } @@ -79,15 +91,25 @@ func (d *dao) Get(ctx context.Context, id int64) (*Artifact, error) { return artifact, nil } func (d *dao) Create(ctx context.Context, artifact *Artifact) (int64, error) { - id, err := orm.GetOrmer(ctx).Insert(artifact) - if e, ok := orm.IsConflictError(err, "artifact %s already exists under the repository %d", - artifact.Digest, artifact.RepositoryID); ok { - err = e + ormer, err := orm.FromContext(ctx) + if err != nil { + return 0, err + } + id, err := ormer.Insert(artifact) + if err != nil { + if e, ok := orm.IsConflictError(err, "artifact %s already exists under the repository %d", + artifact.Digest, artifact.RepositoryID); ok { + err = e + } } return id, err } func (d *dao) Delete(ctx context.Context, id int64) error { - n, err := orm.GetOrmer(ctx).Delete(&Artifact{ + ormer, err := orm.FromContext(ctx) + if err != nil { + return err + } + n, err := ormer.Delete(&Artifact{ ID: id, }) if err != nil { @@ -99,7 +121,11 @@ func (d *dao) Delete(ctx context.Context, id int64) error { return nil } func (d *dao) Update(ctx context.Context, artifact *Artifact, props ...string) error { - n, err := orm.GetOrmer(ctx).Update(artifact, props...) + ormer, err := orm.FromContext(ctx) + if err != nil { + return err + } + n, err := ormer.Update(artifact, props...) if err != nil { return err } @@ -109,7 +135,11 @@ func (d *dao) Update(ctx context.Context, artifact *Artifact, props ...string) e return nil } func (d *dao) CreateReference(ctx context.Context, reference *ArtifactReference) (int64, error) { - id, err := orm.GetOrmer(ctx).Insert(reference) + ormer, err := orm.FromContext(ctx) + if err != nil { + return 0, err + } + id, err := ormer.Insert(reference) if e, ok := orm.IsConflictError(err, "reference already exists, parent artifact ID: %d, child artifact ID: %d", reference.ParentID, reference.ChildID); ok { err = e @@ -118,7 +148,11 @@ func (d *dao) CreateReference(ctx context.Context, reference *ArtifactReference) } func (d *dao) ListReferences(ctx context.Context, query *q.Query) ([]*ArtifactReference, error) { references := []*ArtifactReference{} - if _, err := orm.QuerySetter(ctx, &ArtifactReference{}, query).All(&references); err != nil { + qs, err := orm.QuerySetter(ctx, &ArtifactReference{}, query) + if err != nil { + return nil, err + } + if _, err = qs.All(&references); err != nil { return nil, err } return references, nil @@ -129,10 +163,14 @@ func (d *dao) DeleteReferences(ctx context.Context, parentID int64) error { if err != nil { return err } - _, err = orm.QuerySetter(ctx, &ArtifactReference{}, &q.Query{ + qs, err := orm.QuerySetter(ctx, &ArtifactReference{}, &q.Query{ Keywords: map[string]interface{}{ "parent_id": parentID, }, - }).Delete() + }) + if err != nil { + return err + } + _, err = qs.Delete() return err } diff --git a/src/pkg/artifact/dao/dao_test.go b/src/pkg/artifact/dao/dao_test.go index f7204148f..ee3fcec77 100644 --- a/src/pkg/artifact/dao/dao_test.go +++ b/src/pkg/artifact/dao/dao_test.go @@ -15,9 +15,12 @@ package dao import ( + "context" "errors" + beegoorm "github.com/astaxie/beego/orm" common_dao "github.com/goharbor/harbor/src/common/dao" ierror "github.com/goharbor/harbor/src/internal/error" + "github.com/goharbor/harbor/src/internal/orm" "github.com/goharbor/harbor/src/pkg/q" "github.com/stretchr/testify/suite" "testing" @@ -37,11 +40,13 @@ type daoTestSuite struct { suite.Suite dao DAO artifactID int64 + ctx context.Context } func (d *daoTestSuite) SetupSuite() { d.dao = New() common_dao.PrepareTestForPostgresSQL() + d.ctx = orm.NewContext(nil, beegoorm.NewOrm()) } func (d *daoTestSuite) SetupTest() { @@ -58,24 +63,24 @@ func (d *daoTestSuite) SetupTest() { ExtraAttrs: `{"attr1":"value1"}`, Annotations: `{"anno1":"value1"}`, } - id, err := d.dao.Create(nil, artifact) + id, err := d.dao.Create(d.ctx, artifact) d.Require().Nil(err) d.artifactID = id } func (d *daoTestSuite) TearDownTest() { - err := d.dao.Delete(nil, d.artifactID) + err := d.dao.Delete(d.ctx, d.artifactID) d.Require().Nil(err) } func (d *daoTestSuite) TestCount() { // nil query - total, err := d.dao.Count(nil, nil) + total, err := d.dao.Count(d.ctx, nil) d.Require().Nil(err) d.True(total > 0) // query by repository ID and digest - total, err = d.dao.Count(nil, &q.Query{ + total, err = d.dao.Count(d.ctx, &q.Query{ Keywords: map[string]interface{}{ "repository_id": repositoryID, "digest": digest, @@ -85,7 +90,7 @@ func (d *daoTestSuite) TestCount() { d.Equal(int64(1), total) // query by repository ID and digest - total, err = d.dao.Count(nil, &q.Query{ + total, err = d.dao.Count(d.ctx, &q.Query{ Keywords: map[string]interface{}{ "repository_id": repositoryID, "digest": digest, @@ -95,7 +100,7 @@ func (d *daoTestSuite) TestCount() { d.Equal(int64(1), total) // populate more data - id, err := d.dao.Create(nil, &Artifact{ + id, err := d.dao.Create(d.ctx, &Artifact{ Type: typee, MediaType: mediaType, ManifestMediaType: manifestMediaType, @@ -105,11 +110,11 @@ func (d *daoTestSuite) TestCount() { }) d.Require().Nil(err) defer func() { - err = d.dao.Delete(nil, id) + err = d.dao.Delete(d.ctx, id) d.Require().Nil(err) }() // set pagination in query - total, err = d.dao.Count(nil, &q.Query{ + total, err = d.dao.Count(d.ctx, &q.Query{ PageNumber: 1, PageSize: 1, }) @@ -119,7 +124,7 @@ func (d *daoTestSuite) TestCount() { func (d *daoTestSuite) TestList() { // nil query - artifacts, err := d.dao.List(nil, nil) + artifacts, err := d.dao.List(d.ctx, nil) d.Require().Nil(err) found := false for _, artifact := range artifacts { @@ -131,7 +136,7 @@ func (d *daoTestSuite) TestList() { d.True(found) // query by repository ID and digest - artifacts, err = d.dao.List(nil, &q.Query{ + artifacts, err = d.dao.List(d.ctx, &q.Query{ Keywords: map[string]interface{}{ "repository_id": repositoryID, "digest": digest, @@ -144,12 +149,12 @@ func (d *daoTestSuite) TestList() { func (d *daoTestSuite) TestGet() { // get the non-exist artifact - _, err := d.dao.Get(nil, 10000) + _, err := d.dao.Get(d.ctx, 10000) d.Require().NotNil(err) d.True(ierror.IsErr(err, ierror.NotFoundCode)) // get the exist artifact - artifact, err := d.dao.Get(nil, d.artifactID) + artifact, err := d.dao.Get(d.ctx, d.artifactID) d.Require().Nil(err) d.Require().NotNil(artifact) d.Equal(d.artifactID, artifact.ID) @@ -172,7 +177,7 @@ func (d *daoTestSuite) TestCreate() { ExtraAttrs: `{"attr1":"value1"}`, Annotations: `{"anno1":"value1"}`, } - _, err := d.dao.Create(nil, artifact) + _, err := d.dao.Create(d.ctx, artifact) d.Require().NotNil(err) d.True(ierror.IsErr(err, ierror.ConflictCode)) } @@ -181,7 +186,7 @@ func (d *daoTestSuite) TestDelete() { // the happy pass case is covered in TearDown // not exist - err := d.dao.Delete(nil, 100021) + err := d.dao.Delete(d.ctx, 100021) d.Require().NotNil(err) var e *ierror.Error d.Require().True(errors.As(err, &e)) @@ -191,19 +196,19 @@ func (d *daoTestSuite) TestDelete() { func (d *daoTestSuite) TestUpdate() { // pass now := time.Now() - err := d.dao.Update(nil, &Artifact{ + err := d.dao.Update(d.ctx, &Artifact{ ID: d.artifactID, PushTime: now, }, "PushTime") d.Require().Nil(err) - artifact, err := d.dao.Get(nil, d.artifactID) + artifact, err := d.dao.Get(d.ctx, d.artifactID) d.Require().Nil(err) d.Require().NotNil(artifact) d.Equal(now.Unix(), artifact.PullTime.Unix()) // not exist - err = d.dao.Update(nil, &Artifact{ + err = d.dao.Update(d.ctx, &Artifact{ ID: 10000, }) d.Require().NotNil(err) @@ -214,14 +219,14 @@ func (d *daoTestSuite) TestUpdate() { func (d *daoTestSuite) TestReference() { // create reference - id, err := d.dao.CreateReference(nil, &ArtifactReference{ + id, err := d.dao.CreateReference(d.ctx, &ArtifactReference{ ParentID: d.artifactID, ChildID: 10000, }) d.Require().Nil(err) // conflict - _, err = d.dao.CreateReference(nil, &ArtifactReference{ + _, err = d.dao.CreateReference(d.ctx, &ArtifactReference{ ParentID: d.artifactID, ChildID: 10000, }) @@ -229,7 +234,7 @@ func (d *daoTestSuite) TestReference() { d.True(ierror.IsErr(err, ierror.ConflictCode)) // list reference - references, err := d.dao.ListReferences(nil, &q.Query{ + references, err := d.dao.ListReferences(d.ctx, &q.Query{ Keywords: map[string]interface{}{ "parent_id": d.artifactID, }, @@ -238,11 +243,11 @@ func (d *daoTestSuite) TestReference() { d.Equal(id, references[0].ID) // delete reference - err = d.dao.DeleteReferences(nil, d.artifactID) + err = d.dao.DeleteReferences(d.ctx, d.artifactID) d.Require().Nil(err) // parent artifact not exist - err = d.dao.DeleteReferences(nil, 10000) + err = d.dao.DeleteReferences(d.ctx, 10000) d.Require().NotNil(err) var e *ierror.Error d.Require().True(errors.As(err, &e)) diff --git a/src/pkg/repository/dao/dao.go b/src/pkg/repository/dao/dao.go index 63fabd7df..a6b5d52eb 100644 --- a/src/pkg/repository/dao/dao.go +++ b/src/pkg/repository/dao/dao.go @@ -52,11 +52,19 @@ func (d *dao) Count(ctx context.Context, query *q.Query) (int64, error) { Keywords: query.Keywords, } } - return orm.QuerySetter(ctx, &models.RepoRecord{}, query).Count() + qs, err := orm.QuerySetter(ctx, &models.RepoRecord{}, query) + if err != nil { + return 0, err + } + return qs.Count() } func (d *dao) List(ctx context.Context, query *q.Query) ([]*models.RepoRecord, error) { repositories := []*models.RepoRecord{} - if _, err := orm.QuerySetter(ctx, &models.RepoRecord{}, query).All(&repositories); err != nil { + qs, err := orm.QuerySetter(ctx, &models.RepoRecord{}, query) + if err != nil { + return nil, err + } + if _, err = qs.All(&repositories); err != nil { return nil, err } return repositories, nil @@ -66,7 +74,11 @@ func (d *dao) Get(ctx context.Context, id int64) (*models.RepoRecord, error) { repository := &models.RepoRecord{ RepositoryID: id, } - if err := orm.GetOrmer(ctx).Read(repository); err != nil { + ormer, err := orm.FromContext(ctx) + if err != nil { + return nil, err + } + if err := ormer.Read(repository); err != nil { if e, ok := orm.IsNotFoundError(err, "repository %d not found", id); ok { err = e } @@ -76,7 +88,11 @@ func (d *dao) Get(ctx context.Context, id int64) (*models.RepoRecord, error) { } func (d *dao) Create(ctx context.Context, repository *models.RepoRecord) (int64, error) { - id, err := orm.GetOrmer(ctx).Insert(repository) + ormer, err := orm.FromContext(ctx) + if err != nil { + return 0, err + } + id, err := ormer.Insert(repository) if e, ok := orm.IsConflictError(err, "repository %s already exists", repository.Name); ok { err = e } @@ -84,7 +100,11 @@ func (d *dao) Create(ctx context.Context, repository *models.RepoRecord) (int64, } func (d *dao) Delete(ctx context.Context, id int64) error { - n, err := orm.GetOrmer(ctx).Delete(&models.RepoRecord{ + ormer, err := orm.FromContext(ctx) + if err != nil { + return err + } + n, err := ormer.Delete(&models.RepoRecord{ RepositoryID: id, }) if err != nil { @@ -97,7 +117,11 @@ func (d *dao) Delete(ctx context.Context, id int64) error { } func (d *dao) Update(ctx context.Context, repository *models.RepoRecord, props ...string) error { - n, err := orm.GetOrmer(ctx).Update(repository, props...) + ormer, err := orm.FromContext(ctx) + if err != nil { + return err + } + n, err := ormer.Update(repository, props...) if err != nil { return err } diff --git a/src/pkg/repository/dao/dao_test.go b/src/pkg/repository/dao/dao_test.go index 9d73a33e8..4d242ca46 100644 --- a/src/pkg/repository/dao/dao_test.go +++ b/src/pkg/repository/dao/dao_test.go @@ -15,11 +15,14 @@ package dao import ( + "context" "errors" "fmt" + beegoorm "github.com/astaxie/beego/orm" common_dao "github.com/goharbor/harbor/src/common/dao" "github.com/goharbor/harbor/src/common/models" ierror "github.com/goharbor/harbor/src/internal/error" + "github.com/goharbor/harbor/src/internal/orm" "github.com/goharbor/harbor/src/pkg/q" "github.com/stretchr/testify/suite" "testing" @@ -34,11 +37,13 @@ type daoTestSuite struct { suite.Suite dao DAO id int64 + ctx context.Context } func (d *daoTestSuite) SetupSuite() { d.dao = New() common_dao.PrepareTestForPostgresSQL() + d.ctx = orm.NewContext(nil, beegoorm.NewOrm()) } func (d *daoTestSuite) SetupTest() { @@ -47,24 +52,24 @@ func (d *daoTestSuite) SetupTest() { ProjectID: 1, Description: "", } - id, err := d.dao.Create(nil, repository) + id, err := d.dao.Create(d.ctx, repository) d.Require().Nil(err) d.id = id } func (d *daoTestSuite) TearDownTest() { - err := d.dao.Delete(nil, d.id) + err := d.dao.Delete(d.ctx, d.id) d.Require().Nil(err) } func (d *daoTestSuite) TestCount() { // nil query - total, err := d.dao.Count(nil, nil) + total, err := d.dao.Count(d.ctx, nil) d.Require().Nil(err) d.True(total > 0) // query by name - total, err = d.dao.Count(nil, &q.Query{ + total, err = d.dao.Count(d.ctx, &q.Query{ Keywords: map[string]interface{}{ "name": repository, }, @@ -75,7 +80,7 @@ func (d *daoTestSuite) TestCount() { func (d *daoTestSuite) TestList() { // nil query - repositories, err := d.dao.List(nil, nil) + repositories, err := d.dao.List(d.ctx, nil) d.Require().Nil(err) found := false for _, repository := range repositories { @@ -87,7 +92,7 @@ func (d *daoTestSuite) TestList() { d.True(found) // query by name - repositories, err = d.dao.List(nil, &q.Query{ + repositories, err = d.dao.List(d.ctx, &q.Query{ Keywords: map[string]interface{}{ "name": repository, }, @@ -99,12 +104,12 @@ func (d *daoTestSuite) TestList() { func (d *daoTestSuite) TestGet() { // get the non-exist repository - _, err := d.dao.Get(nil, 10000) + _, err := d.dao.Get(d.ctx, 10000) d.Require().NotNil(err) d.True(ierror.IsErr(err, ierror.NotFoundCode)) // get the exist repository - repository, err := d.dao.Get(nil, d.id) + repository, err := d.dao.Get(d.ctx, d.id) d.Require().Nil(err) d.Require().NotNil(repository) d.Equal(d.id, repository.RepositoryID) @@ -118,7 +123,7 @@ func (d *daoTestSuite) TestCreate() { Name: repository, ProjectID: 1, } - _, err := d.dao.Create(nil, repository) + _, err := d.dao.Create(d.ctx, repository) d.Require().NotNil(err) d.True(ierror.IsErr(err, ierror.ConflictCode)) } @@ -127,7 +132,7 @@ func (d *daoTestSuite) TestDelete() { // the happy pass case is covered in TearDown // not exist - err := d.dao.Delete(nil, 100021) + err := d.dao.Delete(d.ctx, 100021) d.Require().NotNil(err) var e *ierror.Error d.Require().True(errors.As(err, &e)) @@ -136,19 +141,19 @@ func (d *daoTestSuite) TestDelete() { func (d *daoTestSuite) TestUpdate() { // pass - err := d.dao.Update(nil, &models.RepoRecord{ + err := d.dao.Update(d.ctx, &models.RepoRecord{ RepositoryID: d.id, PullCount: 1, }, "PullCount") d.Require().Nil(err) - repository, err := d.dao.Get(nil, d.id) + repository, err := d.dao.Get(d.ctx, d.id) d.Require().Nil(err) d.Require().NotNil(repository) d.Equal(int64(1), repository.PullCount) // not exist - err = d.dao.Update(nil, &models.RepoRecord{ + err = d.dao.Update(d.ctx, &models.RepoRecord{ RepositoryID: 10000, }) d.Require().NotNil(err) diff --git a/src/pkg/retention/launcher.go b/src/pkg/retention/launcher.go index 05f7abaf8..2faec790d 100644 --- a/src/pkg/retention/launcher.go +++ b/src/pkg/retention/launcher.go @@ -16,6 +16,8 @@ package retention import ( "fmt" + beegoorm "github.com/astaxie/beego/orm" + "github.com/goharbor/harbor/src/internal/orm" "time" "github.com/goharbor/harbor/src/jobservice/job" @@ -348,7 +350,7 @@ func getRepositories(projectMgr project.Manager, repositoryMgr repository.Manage */ // get image repositories // TODO set the context which contains the ORM - _, imageRepositories, err := repositoryMgr.List(nil, &pq.Query{ + _, imageRepositories, err := repositoryMgr.List(orm.NewContext(nil, beegoorm.NewOrm()), &pq.Query{ Keywords: map[string]interface{}{ "ProjectID": projectID, }, diff --git a/src/pkg/tag/dao/dao.go b/src/pkg/tag/dao/dao.go index c743ccb1e..73d03609f 100644 --- a/src/pkg/tag/dao/dao.go +++ b/src/pkg/tag/dao/dao.go @@ -57,11 +57,19 @@ func (d *dao) Count(ctx context.Context, query *q.Query) (int64, error) { Keywords: query.Keywords, } } - return orm.QuerySetter(ctx, &tag.Tag{}, query).Count() + qs, err := orm.QuerySetter(ctx, &tag.Tag{}, query) + if err != nil { + return 0, err + } + return qs.Count() } func (d *dao) List(ctx context.Context, query *q.Query) ([]*tag.Tag, error) { tags := []*tag.Tag{} - if _, err := orm.QuerySetter(ctx, &tag.Tag{}, query).All(&tags); err != nil { + qs, err := orm.QuerySetter(ctx, &tag.Tag{}, query) + if err != nil { + return nil, err + } + if _, err = qs.All(&tags); err != nil { return nil, err } return tags, nil @@ -70,7 +78,11 @@ func (d *dao) Get(ctx context.Context, id int64) (*tag.Tag, error) { tag := &tag.Tag{ ID: id, } - if err := orm.GetOrmer(ctx).Read(tag); err != nil { + ormer, err := orm.FromContext(ctx) + if err != nil { + return nil, err + } + if err := ormer.Read(tag); err != nil { if e, ok := orm.IsNotFoundError(err, "tag %d not found", id); ok { err = e } @@ -79,7 +91,11 @@ func (d *dao) Get(ctx context.Context, id int64) (*tag.Tag, error) { return tag, nil } func (d *dao) Create(ctx context.Context, tag *tag.Tag) (int64, error) { - id, err := orm.GetOrmer(ctx).Insert(tag) + ormer, err := orm.FromContext(ctx) + if err != nil { + return 0, err + } + id, err := ormer.Insert(tag) if e, ok := orm.IsConflictError(err, "tag %s already exists under the repository %d", tag.Name, tag.RepositoryID); ok { err = e @@ -87,7 +103,11 @@ func (d *dao) Create(ctx context.Context, tag *tag.Tag) (int64, error) { return id, err } func (d *dao) Update(ctx context.Context, tag *tag.Tag, props ...string) error { - n, err := orm.GetOrmer(ctx).Update(tag, props...) + ormer, err := orm.FromContext(ctx) + if err != nil { + return err + } + n, err := ormer.Update(tag, props...) if err != nil { return err } @@ -97,7 +117,11 @@ func (d *dao) Update(ctx context.Context, tag *tag.Tag, props ...string) error { return nil } func (d *dao) Delete(ctx context.Context, id int64) error { - n, err := orm.GetOrmer(ctx).Delete(&tag.Tag{ + ormer, err := orm.FromContext(ctx) + if err != nil { + return err + } + n, err := ormer.Delete(&tag.Tag{ ID: id, }) if err != nil { diff --git a/src/pkg/tag/dao/dao_test.go b/src/pkg/tag/dao/dao_test.go index 80695bb40..b419ada30 100644 --- a/src/pkg/tag/dao/dao_test.go +++ b/src/pkg/tag/dao/dao_test.go @@ -15,9 +15,12 @@ package dao import ( + "context" "errors" + beegoorm "github.com/astaxie/beego/orm" common_dao "github.com/goharbor/harbor/src/common/dao" ierror "github.com/goharbor/harbor/src/internal/error" + "github.com/goharbor/harbor/src/internal/orm" "github.com/goharbor/harbor/src/pkg/q" "github.com/goharbor/harbor/src/pkg/tag/model/tag" "github.com/stretchr/testify/suite" @@ -35,11 +38,13 @@ type daoTestSuite struct { suite.Suite dao DAO tagID int64 + ctx context.Context } func (d *daoTestSuite) SetupSuite() { d.dao = New() common_dao.PrepareTestForPostgresSQL() + d.ctx = orm.NewContext(nil, beegoorm.NewOrm()) } func (d *daoTestSuite) SetupTest() { @@ -50,23 +55,23 @@ func (d *daoTestSuite) SetupTest() { PushTime: time.Time{}, PullTime: time.Time{}, } - id, err := d.dao.Create(nil, tag) + id, err := d.dao.Create(d.ctx, tag) d.Require().Nil(err) d.tagID = id } func (d *daoTestSuite) TearDownTest() { - err := d.dao.Delete(nil, d.tagID) + err := d.dao.Delete(d.ctx, d.tagID) d.Require().Nil(err) } func (d *daoTestSuite) TestCount() { // nil query - total, err := d.dao.Count(nil, nil) + total, err := d.dao.Count(d.ctx, nil) d.Require().Nil(err) d.True(total > 0) // query by repository ID and name - total, err = d.dao.Count(nil, &q.Query{ + total, err = d.dao.Count(d.ctx, &q.Query{ Keywords: map[string]interface{}{ "repository_id": repositoryID, "name": name, @@ -78,7 +83,7 @@ func (d *daoTestSuite) TestCount() { func (d *daoTestSuite) TestList() { // nil query - tags, err := d.dao.List(nil, nil) + tags, err := d.dao.List(d.ctx, nil) d.Require().Nil(err) found := false for _, tag := range tags { @@ -90,7 +95,7 @@ func (d *daoTestSuite) TestList() { d.True(found) // query by repository ID and name - tags, err = d.dao.List(nil, &q.Query{ + tags, err = d.dao.List(d.ctx, &q.Query{ Keywords: map[string]interface{}{ "repository_id": repositoryID, "name": name, @@ -103,12 +108,12 @@ func (d *daoTestSuite) TestList() { func (d *daoTestSuite) TestGet() { // get the non-exist tag - _, err := d.dao.Get(nil, 10000) + _, err := d.dao.Get(d.ctx, 10000) d.Require().NotNil(err) d.True(ierror.IsErr(err, ierror.NotFoundCode)) // get the exist tag - tag, err := d.dao.Get(nil, d.tagID) + tag, err := d.dao.Get(d.ctx, d.tagID) d.Require().Nil(err) d.Require().NotNil(tag) d.Equal(d.tagID, tag.ID) @@ -125,7 +130,7 @@ func (d *daoTestSuite) TestCreate() { PushTime: time.Time{}, PullTime: time.Time{}, } - _, err := d.dao.Create(nil, tag) + _, err := d.dao.Create(d.ctx, tag) d.Require().NotNil(err) d.True(ierror.IsErr(err, ierror.ConflictCode)) } @@ -134,7 +139,7 @@ func (d *daoTestSuite) TestDelete() { // happy pass is covered in TearDown // not exist - err := d.dao.Delete(nil, 10000) + err := d.dao.Delete(d.ctx, 10000) d.Require().NotNil(err) var e *ierror.Error d.Require().True(errors.As(err, &e)) @@ -143,19 +148,19 @@ func (d *daoTestSuite) TestDelete() { func (d *daoTestSuite) TestUpdate() { // pass - err := d.dao.Update(nil, &tag.Tag{ + err := d.dao.Update(d.ctx, &tag.Tag{ ID: d.tagID, ArtifactID: 2, }, "ArtifactID") d.Require().Nil(err) - tg, err := d.dao.Get(nil, d.tagID) + tg, err := d.dao.Get(d.ctx, d.tagID) d.Require().Nil(err) d.Require().NotNil(tg) d.Equal(int64(2), tg.ArtifactID) // not exist - err = d.dao.Update(nil, &tag.Tag{ + err = d.dao.Update(d.ctx, &tag.Tag{ ID: 10000, }) d.Require().NotNil(err)