Get the ormer from context instead of creating a new one

This commit updates the dao methods with the ormer got from context

Signed-off-by: Wenkai Yin <yinw@vmware.com>
This commit is contained in:
Wenkai Yin 2020-01-14 16:35:36 +08:00
parent a3e84380fa
commit ff1a03cccc
12 changed files with 221 additions and 113 deletions

View File

@ -25,6 +25,8 @@ func OrmFilter(ctx *context.Context) {
if ctx == nil || ctx.Request == nil { if ctx == nil || ctx.Request == nil {
return return
} }
// This is a temp workaround for beego bug: https://github.com/goharbor/harbor/issues/10446
ctx.Request = ctx.Request.WithContext(orm.NewContext(ctx.Request.Context(), o.NewOrm())) // After we upgrading beego to the latest version and moving the filter to middleware,
// this workaround can be removed
*(ctx.Request) = *(ctx.Request.WithContext(orm.NewContext(ctx.Request.Context(), o.NewOrm())))
} }

View File

@ -18,7 +18,10 @@ func SessionCheck(ctx *beegoctx.Context) {
req := ctx.Request req := ctx.Request
_, err := req.Cookie(config.SessionCookieName) _, err := req.Cookie(config.SessionCookieName)
if err == nil { if err == nil {
ctx.Request = req.WithContext(context.WithValue(req.Context(), SessionReqKey, true)) // This is a temp workaround for beego bug: https://github.com/goharbor/harbor/issues/10446
// After we upgrading beego to the latest version and moving the filter to middleware,
// this workaround can be removed
*(ctx.Request) = *(req.WithContext(context.WithValue(req.Context(), SessionReqKey, true)))
log.Debug("Mark the request as no-session") log.Debug("Mark the request as no-session")
} }
} }

View File

@ -16,8 +16,7 @@ package orm
import ( import (
"context" "context"
"fmt" "errors"
"github.com/astaxie/beego/orm" "github.com/astaxie/beego/orm"
"github.com/goharbor/harbor/src/common/utils/log" "github.com/goharbor/harbor/src/common/utils/log"
) )
@ -25,22 +24,28 @@ import (
type ormKey struct{} type ormKey struct{}
// FromContext returns orm from context // FromContext returns orm from context
func FromContext(ctx context.Context) (orm.Ormer, bool) { func FromContext(ctx context.Context) (orm.Ormer, error) {
o, ok := ctx.Value(ormKey{}).(orm.Ormer) o, ok := ctx.Value(ormKey{}).(orm.Ormer)
return o, ok if !ok {
return nil, errors.New("cannot get the ORM from context")
}
return o, nil
} }
// NewContext returns new context with orm // NewContext returns new context with orm
func NewContext(ctx context.Context, o orm.Ormer) context.Context { func NewContext(ctx context.Context, o orm.Ormer) context.Context {
if ctx == nil {
ctx = context.Background()
}
return context.WithValue(ctx, ormKey{}, o) return context.WithValue(ctx, ormKey{}, o)
} }
// WithTransaction a decorator which make f run in transaction // WithTransaction a decorator which make f run in transaction
func WithTransaction(f func(ctx context.Context) error) func(ctx context.Context) error { func WithTransaction(f func(ctx context.Context) error) func(ctx context.Context) error {
return func(ctx context.Context) error { return func(ctx context.Context) error {
o, ok := FromContext(ctx) o, err := FromContext(ctx)
if !ok { if err != nil {
return fmt.Errorf("ormer value not found in context") return err
} }
tx := ormerTx{Ormer: o} tx := ormerTx{Ormer: o}

View File

@ -26,18 +26,18 @@ import (
) )
func addProject(ctx context.Context, project models.Project) (int64, error) { func addProject(ctx context.Context, project models.Project) (int64, error) {
o, ok := FromContext(ctx) o, err := FromContext(ctx)
if !ok { if err != nil {
return 0, errors.New("orm not found in context") return 0, err
} }
return o.Insert(&project) return o.Insert(&project)
} }
func readProject(ctx context.Context, id int64) (*models.Project, error) { func readProject(ctx context.Context, id int64) (*models.Project, error) {
o, ok := FromContext(ctx) o, err := FromContext(ctx)
if !ok { if err != nil {
return nil, errors.New("orm not found in context") return nil, err
} }
project := &models.Project{ project := &models.Project{
@ -52,22 +52,21 @@ func readProject(ctx context.Context, id int64) (*models.Project, error) {
} }
func deleteProject(ctx context.Context, id int64) error { func deleteProject(ctx context.Context, id int64) error {
o, ok := FromContext(ctx) o, err := FromContext(ctx)
if !ok { if err != nil {
return errors.New("orm not found in context") return err
} }
project := &models.Project{ project := &models.Project{
ProjectID: id, ProjectID: id,
} }
_, err := o.Delete(project, "project_id") _, err = o.Delete(project, "project_id")
return err return err
} }
func existProject(ctx context.Context, id int64) bool { func existProject(ctx context.Context, id int64) bool {
o, ok := FromContext(ctx) o, err := FromContext(ctx)
if !ok { if err != nil {
return false return false
} }
@ -95,12 +94,11 @@ func (suite *OrmSuite) SetupSuite() {
func (suite *OrmSuite) TestContext() { func (suite *OrmSuite) TestContext() {
ctx := context.TODO() ctx := context.TODO()
o, ok := FromContext(ctx) o, err := FromContext(ctx)
suite.False(ok) suite.NotNil(err)
suite.Nil(o)
o, ok = FromContext(NewContext(ctx, orm.NewOrm())) o, err = FromContext(NewContext(ctx, orm.NewOrm()))
suite.True(ok) suite.Nil(err)
suite.NotNil(o) suite.NotNil(o)
} }

View File

@ -17,15 +17,18 @@ package orm
import ( import (
"context" "context"
"github.com/astaxie/beego/orm" "github.com/astaxie/beego/orm"
"github.com/goharbor/harbor/src/common/dao"
"github.com/goharbor/harbor/src/pkg/q" "github.com/goharbor/harbor/src/pkg/q"
) )
// QuerySetter generates the query setter according to the query // QuerySetter generates the query setter according to the query
func QuerySetter(ctx context.Context, model interface{}, query *q.Query) orm.QuerySeter { func QuerySetter(ctx context.Context, model interface{}, query *q.Query) (orm.QuerySeter, error) {
qs := GetOrmer(ctx).QueryTable(model) ormer, err := FromContext(ctx)
if err != nil {
return nil, err
}
qs := ormer.QueryTable(model)
if query == nil { if query == nil {
return qs return qs, nil
} }
for k, v := range query.Keywords { for k, v := range query.Keywords {
qs = qs.Filter(k, v) qs = qs.Filter(k, v)
@ -36,11 +39,5 @@ func QuerySetter(ctx context.Context, model interface{}, query *q.Query) orm.Que
qs = qs.Offset(query.PageSize * (query.PageNumber - 1)) qs = qs.Offset(query.PageSize * (query.PageNumber - 1))
} }
} }
return qs return qs, nil
}
// GetOrmer returns an ormer
// TODO remove it after weiwei's PR merged
func GetOrmer(ctx context.Context) orm.Ormer {
return dao.GetOrmer()
} }

View File

@ -57,11 +57,19 @@ func (d *dao) Count(ctx context.Context, query *q.Query) (int64, error) {
Keywords: query.Keywords, Keywords: query.Keywords,
} }
} }
return orm.QuerySetter(ctx, &Artifact{}, query).Count() qs, err := orm.QuerySetter(ctx, &Artifact{}, query)
if err != nil {
return 0, err
}
return qs.Count()
} }
func (d *dao) List(ctx context.Context, query *q.Query) ([]*Artifact, error) { func (d *dao) List(ctx context.Context, query *q.Query) ([]*Artifact, error) {
artifacts := []*Artifact{} artifacts := []*Artifact{}
if _, err := orm.QuerySetter(ctx, &Artifact{}, query).All(&artifacts); err != nil { qs, err := orm.QuerySetter(ctx, &Artifact{}, query)
if err != nil {
return nil, err
}
if _, err = qs.All(&artifacts); err != nil {
return nil, err return nil, err
} }
return artifacts, nil return artifacts, nil
@ -70,7 +78,11 @@ func (d *dao) Get(ctx context.Context, id int64) (*Artifact, error) {
artifact := &Artifact{ artifact := &Artifact{
ID: id, ID: id,
} }
if err := orm.GetOrmer(ctx).Read(artifact); err != nil { ormer, err := orm.FromContext(ctx)
if err != nil {
return nil, err
}
if err = ormer.Read(artifact); err != nil {
if e, ok := orm.IsNotFoundError(err, "artifact %d not found", id); ok { if e, ok := orm.IsNotFoundError(err, "artifact %d not found", id); ok {
err = e err = e
} }
@ -79,15 +91,25 @@ func (d *dao) Get(ctx context.Context, id int64) (*Artifact, error) {
return artifact, nil return artifact, nil
} }
func (d *dao) Create(ctx context.Context, artifact *Artifact) (int64, error) { func (d *dao) Create(ctx context.Context, artifact *Artifact) (int64, error) {
id, err := orm.GetOrmer(ctx).Insert(artifact) ormer, err := orm.FromContext(ctx)
if err != nil {
return 0, err
}
id, err := ormer.Insert(artifact)
if err != nil {
if e, ok := orm.IsConflictError(err, "artifact %s already exists under the repository %d", if e, ok := orm.IsConflictError(err, "artifact %s already exists under the repository %d",
artifact.Digest, artifact.RepositoryID); ok { artifact.Digest, artifact.RepositoryID); ok {
err = e err = e
} }
}
return id, err return id, err
} }
func (d *dao) Delete(ctx context.Context, id int64) error { func (d *dao) Delete(ctx context.Context, id int64) error {
n, err := orm.GetOrmer(ctx).Delete(&Artifact{ ormer, err := orm.FromContext(ctx)
if err != nil {
return err
}
n, err := ormer.Delete(&Artifact{
ID: id, ID: id,
}) })
if err != nil { if err != nil {
@ -99,7 +121,11 @@ func (d *dao) Delete(ctx context.Context, id int64) error {
return nil return nil
} }
func (d *dao) Update(ctx context.Context, artifact *Artifact, props ...string) error { func (d *dao) Update(ctx context.Context, artifact *Artifact, props ...string) error {
n, err := orm.GetOrmer(ctx).Update(artifact, props...) ormer, err := orm.FromContext(ctx)
if err != nil {
return err
}
n, err := ormer.Update(artifact, props...)
if err != nil { if err != nil {
return err return err
} }
@ -109,7 +135,11 @@ func (d *dao) Update(ctx context.Context, artifact *Artifact, props ...string) e
return nil return nil
} }
func (d *dao) CreateReference(ctx context.Context, reference *ArtifactReference) (int64, error) { func (d *dao) CreateReference(ctx context.Context, reference *ArtifactReference) (int64, error) {
id, err := orm.GetOrmer(ctx).Insert(reference) ormer, err := orm.FromContext(ctx)
if err != nil {
return 0, err
}
id, err := ormer.Insert(reference)
if e, ok := orm.IsConflictError(err, "reference already exists, parent artifact ID: %d, child artifact ID: %d", if e, ok := orm.IsConflictError(err, "reference already exists, parent artifact ID: %d, child artifact ID: %d",
reference.ParentID, reference.ChildID); ok { reference.ParentID, reference.ChildID); ok {
err = e err = e
@ -118,7 +148,11 @@ func (d *dao) CreateReference(ctx context.Context, reference *ArtifactReference)
} }
func (d *dao) ListReferences(ctx context.Context, query *q.Query) ([]*ArtifactReference, error) { func (d *dao) ListReferences(ctx context.Context, query *q.Query) ([]*ArtifactReference, error) {
references := []*ArtifactReference{} references := []*ArtifactReference{}
if _, err := orm.QuerySetter(ctx, &ArtifactReference{}, query).All(&references); err != nil { qs, err := orm.QuerySetter(ctx, &ArtifactReference{}, query)
if err != nil {
return nil, err
}
if _, err = qs.All(&references); err != nil {
return nil, err return nil, err
} }
return references, nil return references, nil
@ -129,10 +163,14 @@ func (d *dao) DeleteReferences(ctx context.Context, parentID int64) error {
if err != nil { if err != nil {
return err return err
} }
_, err = orm.QuerySetter(ctx, &ArtifactReference{}, &q.Query{ qs, err := orm.QuerySetter(ctx, &ArtifactReference{}, &q.Query{
Keywords: map[string]interface{}{ Keywords: map[string]interface{}{
"parent_id": parentID, "parent_id": parentID,
}, },
}).Delete() })
if err != nil {
return err
}
_, err = qs.Delete()
return err return err
} }

View File

@ -15,9 +15,12 @@
package dao package dao
import ( import (
"context"
"errors" "errors"
beegoorm "github.com/astaxie/beego/orm"
common_dao "github.com/goharbor/harbor/src/common/dao" common_dao "github.com/goharbor/harbor/src/common/dao"
ierror "github.com/goharbor/harbor/src/internal/error" ierror "github.com/goharbor/harbor/src/internal/error"
"github.com/goharbor/harbor/src/internal/orm"
"github.com/goharbor/harbor/src/pkg/q" "github.com/goharbor/harbor/src/pkg/q"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"testing" "testing"
@ -37,11 +40,13 @@ type daoTestSuite struct {
suite.Suite suite.Suite
dao DAO dao DAO
artifactID int64 artifactID int64
ctx context.Context
} }
func (d *daoTestSuite) SetupSuite() { func (d *daoTestSuite) SetupSuite() {
d.dao = New() d.dao = New()
common_dao.PrepareTestForPostgresSQL() common_dao.PrepareTestForPostgresSQL()
d.ctx = orm.NewContext(nil, beegoorm.NewOrm())
} }
func (d *daoTestSuite) SetupTest() { func (d *daoTestSuite) SetupTest() {
@ -58,24 +63,24 @@ func (d *daoTestSuite) SetupTest() {
ExtraAttrs: `{"attr1":"value1"}`, ExtraAttrs: `{"attr1":"value1"}`,
Annotations: `{"anno1":"value1"}`, Annotations: `{"anno1":"value1"}`,
} }
id, err := d.dao.Create(nil, artifact) id, err := d.dao.Create(d.ctx, artifact)
d.Require().Nil(err) d.Require().Nil(err)
d.artifactID = id d.artifactID = id
} }
func (d *daoTestSuite) TearDownTest() { func (d *daoTestSuite) TearDownTest() {
err := d.dao.Delete(nil, d.artifactID) err := d.dao.Delete(d.ctx, d.artifactID)
d.Require().Nil(err) d.Require().Nil(err)
} }
func (d *daoTestSuite) TestCount() { func (d *daoTestSuite) TestCount() {
// nil query // nil query
total, err := d.dao.Count(nil, nil) total, err := d.dao.Count(d.ctx, nil)
d.Require().Nil(err) d.Require().Nil(err)
d.True(total > 0) d.True(total > 0)
// query by repository ID and digest // query by repository ID and digest
total, err = d.dao.Count(nil, &q.Query{ total, err = d.dao.Count(d.ctx, &q.Query{
Keywords: map[string]interface{}{ Keywords: map[string]interface{}{
"repository_id": repositoryID, "repository_id": repositoryID,
"digest": digest, "digest": digest,
@ -85,7 +90,7 @@ func (d *daoTestSuite) TestCount() {
d.Equal(int64(1), total) d.Equal(int64(1), total)
// query by repository ID and digest // query by repository ID and digest
total, err = d.dao.Count(nil, &q.Query{ total, err = d.dao.Count(d.ctx, &q.Query{
Keywords: map[string]interface{}{ Keywords: map[string]interface{}{
"repository_id": repositoryID, "repository_id": repositoryID,
"digest": digest, "digest": digest,
@ -95,7 +100,7 @@ func (d *daoTestSuite) TestCount() {
d.Equal(int64(1), total) d.Equal(int64(1), total)
// populate more data // populate more data
id, err := d.dao.Create(nil, &Artifact{ id, err := d.dao.Create(d.ctx, &Artifact{
Type: typee, Type: typee,
MediaType: mediaType, MediaType: mediaType,
ManifestMediaType: manifestMediaType, ManifestMediaType: manifestMediaType,
@ -105,11 +110,11 @@ func (d *daoTestSuite) TestCount() {
}) })
d.Require().Nil(err) d.Require().Nil(err)
defer func() { defer func() {
err = d.dao.Delete(nil, id) err = d.dao.Delete(d.ctx, id)
d.Require().Nil(err) d.Require().Nil(err)
}() }()
// set pagination in query // set pagination in query
total, err = d.dao.Count(nil, &q.Query{ total, err = d.dao.Count(d.ctx, &q.Query{
PageNumber: 1, PageNumber: 1,
PageSize: 1, PageSize: 1,
}) })
@ -119,7 +124,7 @@ func (d *daoTestSuite) TestCount() {
func (d *daoTestSuite) TestList() { func (d *daoTestSuite) TestList() {
// nil query // nil query
artifacts, err := d.dao.List(nil, nil) artifacts, err := d.dao.List(d.ctx, nil)
d.Require().Nil(err) d.Require().Nil(err)
found := false found := false
for _, artifact := range artifacts { for _, artifact := range artifacts {
@ -131,7 +136,7 @@ func (d *daoTestSuite) TestList() {
d.True(found) d.True(found)
// query by repository ID and digest // query by repository ID and digest
artifacts, err = d.dao.List(nil, &q.Query{ artifacts, err = d.dao.List(d.ctx, &q.Query{
Keywords: map[string]interface{}{ Keywords: map[string]interface{}{
"repository_id": repositoryID, "repository_id": repositoryID,
"digest": digest, "digest": digest,
@ -144,12 +149,12 @@ func (d *daoTestSuite) TestList() {
func (d *daoTestSuite) TestGet() { func (d *daoTestSuite) TestGet() {
// get the non-exist artifact // get the non-exist artifact
_, err := d.dao.Get(nil, 10000) _, err := d.dao.Get(d.ctx, 10000)
d.Require().NotNil(err) d.Require().NotNil(err)
d.True(ierror.IsErr(err, ierror.NotFoundCode)) d.True(ierror.IsErr(err, ierror.NotFoundCode))
// get the exist artifact // get the exist artifact
artifact, err := d.dao.Get(nil, d.artifactID) artifact, err := d.dao.Get(d.ctx, d.artifactID)
d.Require().Nil(err) d.Require().Nil(err)
d.Require().NotNil(artifact) d.Require().NotNil(artifact)
d.Equal(d.artifactID, artifact.ID) d.Equal(d.artifactID, artifact.ID)
@ -172,7 +177,7 @@ func (d *daoTestSuite) TestCreate() {
ExtraAttrs: `{"attr1":"value1"}`, ExtraAttrs: `{"attr1":"value1"}`,
Annotations: `{"anno1":"value1"}`, Annotations: `{"anno1":"value1"}`,
} }
_, err := d.dao.Create(nil, artifact) _, err := d.dao.Create(d.ctx, artifact)
d.Require().NotNil(err) d.Require().NotNil(err)
d.True(ierror.IsErr(err, ierror.ConflictCode)) d.True(ierror.IsErr(err, ierror.ConflictCode))
} }
@ -181,7 +186,7 @@ func (d *daoTestSuite) TestDelete() {
// the happy pass case is covered in TearDown // the happy pass case is covered in TearDown
// not exist // not exist
err := d.dao.Delete(nil, 100021) err := d.dao.Delete(d.ctx, 100021)
d.Require().NotNil(err) d.Require().NotNil(err)
var e *ierror.Error var e *ierror.Error
d.Require().True(errors.As(err, &e)) d.Require().True(errors.As(err, &e))
@ -191,19 +196,19 @@ func (d *daoTestSuite) TestDelete() {
func (d *daoTestSuite) TestUpdate() { func (d *daoTestSuite) TestUpdate() {
// pass // pass
now := time.Now() now := time.Now()
err := d.dao.Update(nil, &Artifact{ err := d.dao.Update(d.ctx, &Artifact{
ID: d.artifactID, ID: d.artifactID,
PushTime: now, PushTime: now,
}, "PushTime") }, "PushTime")
d.Require().Nil(err) d.Require().Nil(err)
artifact, err := d.dao.Get(nil, d.artifactID) artifact, err := d.dao.Get(d.ctx, d.artifactID)
d.Require().Nil(err) d.Require().Nil(err)
d.Require().NotNil(artifact) d.Require().NotNil(artifact)
d.Equal(now.Unix(), artifact.PullTime.Unix()) d.Equal(now.Unix(), artifact.PullTime.Unix())
// not exist // not exist
err = d.dao.Update(nil, &Artifact{ err = d.dao.Update(d.ctx, &Artifact{
ID: 10000, ID: 10000,
}) })
d.Require().NotNil(err) d.Require().NotNil(err)
@ -214,14 +219,14 @@ func (d *daoTestSuite) TestUpdate() {
func (d *daoTestSuite) TestReference() { func (d *daoTestSuite) TestReference() {
// create reference // create reference
id, err := d.dao.CreateReference(nil, &ArtifactReference{ id, err := d.dao.CreateReference(d.ctx, &ArtifactReference{
ParentID: d.artifactID, ParentID: d.artifactID,
ChildID: 10000, ChildID: 10000,
}) })
d.Require().Nil(err) d.Require().Nil(err)
// conflict // conflict
_, err = d.dao.CreateReference(nil, &ArtifactReference{ _, err = d.dao.CreateReference(d.ctx, &ArtifactReference{
ParentID: d.artifactID, ParentID: d.artifactID,
ChildID: 10000, ChildID: 10000,
}) })
@ -229,7 +234,7 @@ func (d *daoTestSuite) TestReference() {
d.True(ierror.IsErr(err, ierror.ConflictCode)) d.True(ierror.IsErr(err, ierror.ConflictCode))
// list reference // list reference
references, err := d.dao.ListReferences(nil, &q.Query{ references, err := d.dao.ListReferences(d.ctx, &q.Query{
Keywords: map[string]interface{}{ Keywords: map[string]interface{}{
"parent_id": d.artifactID, "parent_id": d.artifactID,
}, },
@ -238,11 +243,11 @@ func (d *daoTestSuite) TestReference() {
d.Equal(id, references[0].ID) d.Equal(id, references[0].ID)
// delete reference // delete reference
err = d.dao.DeleteReferences(nil, d.artifactID) err = d.dao.DeleteReferences(d.ctx, d.artifactID)
d.Require().Nil(err) d.Require().Nil(err)
// parent artifact not exist // parent artifact not exist
err = d.dao.DeleteReferences(nil, 10000) err = d.dao.DeleteReferences(d.ctx, 10000)
d.Require().NotNil(err) d.Require().NotNil(err)
var e *ierror.Error var e *ierror.Error
d.Require().True(errors.As(err, &e)) d.Require().True(errors.As(err, &e))

View File

@ -52,11 +52,19 @@ func (d *dao) Count(ctx context.Context, query *q.Query) (int64, error) {
Keywords: query.Keywords, Keywords: query.Keywords,
} }
} }
return orm.QuerySetter(ctx, &models.RepoRecord{}, query).Count() qs, err := orm.QuerySetter(ctx, &models.RepoRecord{}, query)
if err != nil {
return 0, err
}
return qs.Count()
} }
func (d *dao) List(ctx context.Context, query *q.Query) ([]*models.RepoRecord, error) { func (d *dao) List(ctx context.Context, query *q.Query) ([]*models.RepoRecord, error) {
repositories := []*models.RepoRecord{} repositories := []*models.RepoRecord{}
if _, err := orm.QuerySetter(ctx, &models.RepoRecord{}, query).All(&repositories); err != nil { qs, err := orm.QuerySetter(ctx, &models.RepoRecord{}, query)
if err != nil {
return nil, err
}
if _, err = qs.All(&repositories); err != nil {
return nil, err return nil, err
} }
return repositories, nil return repositories, nil
@ -66,7 +74,11 @@ func (d *dao) Get(ctx context.Context, id int64) (*models.RepoRecord, error) {
repository := &models.RepoRecord{ repository := &models.RepoRecord{
RepositoryID: id, RepositoryID: id,
} }
if err := orm.GetOrmer(ctx).Read(repository); err != nil { ormer, err := orm.FromContext(ctx)
if err != nil {
return nil, err
}
if err := ormer.Read(repository); err != nil {
if e, ok := orm.IsNotFoundError(err, "repository %d not found", id); ok { if e, ok := orm.IsNotFoundError(err, "repository %d not found", id); ok {
err = e err = e
} }
@ -76,7 +88,11 @@ func (d *dao) Get(ctx context.Context, id int64) (*models.RepoRecord, error) {
} }
func (d *dao) Create(ctx context.Context, repository *models.RepoRecord) (int64, error) { func (d *dao) Create(ctx context.Context, repository *models.RepoRecord) (int64, error) {
id, err := orm.GetOrmer(ctx).Insert(repository) ormer, err := orm.FromContext(ctx)
if err != nil {
return 0, err
}
id, err := ormer.Insert(repository)
if e, ok := orm.IsConflictError(err, "repository %s already exists", repository.Name); ok { if e, ok := orm.IsConflictError(err, "repository %s already exists", repository.Name); ok {
err = e err = e
} }
@ -84,7 +100,11 @@ func (d *dao) Create(ctx context.Context, repository *models.RepoRecord) (int64,
} }
func (d *dao) Delete(ctx context.Context, id int64) error { func (d *dao) Delete(ctx context.Context, id int64) error {
n, err := orm.GetOrmer(ctx).Delete(&models.RepoRecord{ ormer, err := orm.FromContext(ctx)
if err != nil {
return err
}
n, err := ormer.Delete(&models.RepoRecord{
RepositoryID: id, RepositoryID: id,
}) })
if err != nil { if err != nil {
@ -97,7 +117,11 @@ func (d *dao) Delete(ctx context.Context, id int64) error {
} }
func (d *dao) Update(ctx context.Context, repository *models.RepoRecord, props ...string) error { func (d *dao) Update(ctx context.Context, repository *models.RepoRecord, props ...string) error {
n, err := orm.GetOrmer(ctx).Update(repository, props...) ormer, err := orm.FromContext(ctx)
if err != nil {
return err
}
n, err := ormer.Update(repository, props...)
if err != nil { if err != nil {
return err return err
} }

View File

@ -15,11 +15,14 @@
package dao package dao
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
beegoorm "github.com/astaxie/beego/orm"
common_dao "github.com/goharbor/harbor/src/common/dao" common_dao "github.com/goharbor/harbor/src/common/dao"
"github.com/goharbor/harbor/src/common/models" "github.com/goharbor/harbor/src/common/models"
ierror "github.com/goharbor/harbor/src/internal/error" ierror "github.com/goharbor/harbor/src/internal/error"
"github.com/goharbor/harbor/src/internal/orm"
"github.com/goharbor/harbor/src/pkg/q" "github.com/goharbor/harbor/src/pkg/q"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"testing" "testing"
@ -34,11 +37,13 @@ type daoTestSuite struct {
suite.Suite suite.Suite
dao DAO dao DAO
id int64 id int64
ctx context.Context
} }
func (d *daoTestSuite) SetupSuite() { func (d *daoTestSuite) SetupSuite() {
d.dao = New() d.dao = New()
common_dao.PrepareTestForPostgresSQL() common_dao.PrepareTestForPostgresSQL()
d.ctx = orm.NewContext(nil, beegoorm.NewOrm())
} }
func (d *daoTestSuite) SetupTest() { func (d *daoTestSuite) SetupTest() {
@ -47,24 +52,24 @@ func (d *daoTestSuite) SetupTest() {
ProjectID: 1, ProjectID: 1,
Description: "", Description: "",
} }
id, err := d.dao.Create(nil, repository) id, err := d.dao.Create(d.ctx, repository)
d.Require().Nil(err) d.Require().Nil(err)
d.id = id d.id = id
} }
func (d *daoTestSuite) TearDownTest() { func (d *daoTestSuite) TearDownTest() {
err := d.dao.Delete(nil, d.id) err := d.dao.Delete(d.ctx, d.id)
d.Require().Nil(err) d.Require().Nil(err)
} }
func (d *daoTestSuite) TestCount() { func (d *daoTestSuite) TestCount() {
// nil query // nil query
total, err := d.dao.Count(nil, nil) total, err := d.dao.Count(d.ctx, nil)
d.Require().Nil(err) d.Require().Nil(err)
d.True(total > 0) d.True(total > 0)
// query by name // query by name
total, err = d.dao.Count(nil, &q.Query{ total, err = d.dao.Count(d.ctx, &q.Query{
Keywords: map[string]interface{}{ Keywords: map[string]interface{}{
"name": repository, "name": repository,
}, },
@ -75,7 +80,7 @@ func (d *daoTestSuite) TestCount() {
func (d *daoTestSuite) TestList() { func (d *daoTestSuite) TestList() {
// nil query // nil query
repositories, err := d.dao.List(nil, nil) repositories, err := d.dao.List(d.ctx, nil)
d.Require().Nil(err) d.Require().Nil(err)
found := false found := false
for _, repository := range repositories { for _, repository := range repositories {
@ -87,7 +92,7 @@ func (d *daoTestSuite) TestList() {
d.True(found) d.True(found)
// query by name // query by name
repositories, err = d.dao.List(nil, &q.Query{ repositories, err = d.dao.List(d.ctx, &q.Query{
Keywords: map[string]interface{}{ Keywords: map[string]interface{}{
"name": repository, "name": repository,
}, },
@ -99,12 +104,12 @@ func (d *daoTestSuite) TestList() {
func (d *daoTestSuite) TestGet() { func (d *daoTestSuite) TestGet() {
// get the non-exist repository // get the non-exist repository
_, err := d.dao.Get(nil, 10000) _, err := d.dao.Get(d.ctx, 10000)
d.Require().NotNil(err) d.Require().NotNil(err)
d.True(ierror.IsErr(err, ierror.NotFoundCode)) d.True(ierror.IsErr(err, ierror.NotFoundCode))
// get the exist repository // get the exist repository
repository, err := d.dao.Get(nil, d.id) repository, err := d.dao.Get(d.ctx, d.id)
d.Require().Nil(err) d.Require().Nil(err)
d.Require().NotNil(repository) d.Require().NotNil(repository)
d.Equal(d.id, repository.RepositoryID) d.Equal(d.id, repository.RepositoryID)
@ -118,7 +123,7 @@ func (d *daoTestSuite) TestCreate() {
Name: repository, Name: repository,
ProjectID: 1, ProjectID: 1,
} }
_, err := d.dao.Create(nil, repository) _, err := d.dao.Create(d.ctx, repository)
d.Require().NotNil(err) d.Require().NotNil(err)
d.True(ierror.IsErr(err, ierror.ConflictCode)) d.True(ierror.IsErr(err, ierror.ConflictCode))
} }
@ -127,7 +132,7 @@ func (d *daoTestSuite) TestDelete() {
// the happy pass case is covered in TearDown // the happy pass case is covered in TearDown
// not exist // not exist
err := d.dao.Delete(nil, 100021) err := d.dao.Delete(d.ctx, 100021)
d.Require().NotNil(err) d.Require().NotNil(err)
var e *ierror.Error var e *ierror.Error
d.Require().True(errors.As(err, &e)) d.Require().True(errors.As(err, &e))
@ -136,19 +141,19 @@ func (d *daoTestSuite) TestDelete() {
func (d *daoTestSuite) TestUpdate() { func (d *daoTestSuite) TestUpdate() {
// pass // pass
err := d.dao.Update(nil, &models.RepoRecord{ err := d.dao.Update(d.ctx, &models.RepoRecord{
RepositoryID: d.id, RepositoryID: d.id,
PullCount: 1, PullCount: 1,
}, "PullCount") }, "PullCount")
d.Require().Nil(err) d.Require().Nil(err)
repository, err := d.dao.Get(nil, d.id) repository, err := d.dao.Get(d.ctx, d.id)
d.Require().Nil(err) d.Require().Nil(err)
d.Require().NotNil(repository) d.Require().NotNil(repository)
d.Equal(int64(1), repository.PullCount) d.Equal(int64(1), repository.PullCount)
// not exist // not exist
err = d.dao.Update(nil, &models.RepoRecord{ err = d.dao.Update(d.ctx, &models.RepoRecord{
RepositoryID: 10000, RepositoryID: 10000,
}) })
d.Require().NotNil(err) d.Require().NotNil(err)

View File

@ -16,6 +16,8 @@ package retention
import ( import (
"fmt" "fmt"
beegoorm "github.com/astaxie/beego/orm"
"github.com/goharbor/harbor/src/internal/orm"
"time" "time"
"github.com/goharbor/harbor/src/jobservice/job" "github.com/goharbor/harbor/src/jobservice/job"
@ -348,7 +350,7 @@ func getRepositories(projectMgr project.Manager, repositoryMgr repository.Manage
*/ */
// get image repositories // get image repositories
// TODO set the context which contains the ORM // TODO set the context which contains the ORM
_, imageRepositories, err := repositoryMgr.List(nil, &pq.Query{ _, imageRepositories, err := repositoryMgr.List(orm.NewContext(nil, beegoorm.NewOrm()), &pq.Query{
Keywords: map[string]interface{}{ Keywords: map[string]interface{}{
"ProjectID": projectID, "ProjectID": projectID,
}, },

View File

@ -57,11 +57,19 @@ func (d *dao) Count(ctx context.Context, query *q.Query) (int64, error) {
Keywords: query.Keywords, Keywords: query.Keywords,
} }
} }
return orm.QuerySetter(ctx, &tag.Tag{}, query).Count() qs, err := orm.QuerySetter(ctx, &tag.Tag{}, query)
if err != nil {
return 0, err
}
return qs.Count()
} }
func (d *dao) List(ctx context.Context, query *q.Query) ([]*tag.Tag, error) { func (d *dao) List(ctx context.Context, query *q.Query) ([]*tag.Tag, error) {
tags := []*tag.Tag{} tags := []*tag.Tag{}
if _, err := orm.QuerySetter(ctx, &tag.Tag{}, query).All(&tags); err != nil { qs, err := orm.QuerySetter(ctx, &tag.Tag{}, query)
if err != nil {
return nil, err
}
if _, err = qs.All(&tags); err != nil {
return nil, err return nil, err
} }
return tags, nil return tags, nil
@ -70,7 +78,11 @@ func (d *dao) Get(ctx context.Context, id int64) (*tag.Tag, error) {
tag := &tag.Tag{ tag := &tag.Tag{
ID: id, ID: id,
} }
if err := orm.GetOrmer(ctx).Read(tag); err != nil { ormer, err := orm.FromContext(ctx)
if err != nil {
return nil, err
}
if err := ormer.Read(tag); err != nil {
if e, ok := orm.IsNotFoundError(err, "tag %d not found", id); ok { if e, ok := orm.IsNotFoundError(err, "tag %d not found", id); ok {
err = e err = e
} }
@ -79,7 +91,11 @@ func (d *dao) Get(ctx context.Context, id int64) (*tag.Tag, error) {
return tag, nil return tag, nil
} }
func (d *dao) Create(ctx context.Context, tag *tag.Tag) (int64, error) { func (d *dao) Create(ctx context.Context, tag *tag.Tag) (int64, error) {
id, err := orm.GetOrmer(ctx).Insert(tag) ormer, err := orm.FromContext(ctx)
if err != nil {
return 0, err
}
id, err := ormer.Insert(tag)
if e, ok := orm.IsConflictError(err, "tag %s already exists under the repository %d", if e, ok := orm.IsConflictError(err, "tag %s already exists under the repository %d",
tag.Name, tag.RepositoryID); ok { tag.Name, tag.RepositoryID); ok {
err = e err = e
@ -87,7 +103,11 @@ func (d *dao) Create(ctx context.Context, tag *tag.Tag) (int64, error) {
return id, err return id, err
} }
func (d *dao) Update(ctx context.Context, tag *tag.Tag, props ...string) error { func (d *dao) Update(ctx context.Context, tag *tag.Tag, props ...string) error {
n, err := orm.GetOrmer(ctx).Update(tag, props...) ormer, err := orm.FromContext(ctx)
if err != nil {
return err
}
n, err := ormer.Update(tag, props...)
if err != nil { if err != nil {
return err return err
} }
@ -97,7 +117,11 @@ func (d *dao) Update(ctx context.Context, tag *tag.Tag, props ...string) error {
return nil return nil
} }
func (d *dao) Delete(ctx context.Context, id int64) error { func (d *dao) Delete(ctx context.Context, id int64) error {
n, err := orm.GetOrmer(ctx).Delete(&tag.Tag{ ormer, err := orm.FromContext(ctx)
if err != nil {
return err
}
n, err := ormer.Delete(&tag.Tag{
ID: id, ID: id,
}) })
if err != nil { if err != nil {

View File

@ -15,9 +15,12 @@
package dao package dao
import ( import (
"context"
"errors" "errors"
beegoorm "github.com/astaxie/beego/orm"
common_dao "github.com/goharbor/harbor/src/common/dao" common_dao "github.com/goharbor/harbor/src/common/dao"
ierror "github.com/goharbor/harbor/src/internal/error" ierror "github.com/goharbor/harbor/src/internal/error"
"github.com/goharbor/harbor/src/internal/orm"
"github.com/goharbor/harbor/src/pkg/q" "github.com/goharbor/harbor/src/pkg/q"
"github.com/goharbor/harbor/src/pkg/tag/model/tag" "github.com/goharbor/harbor/src/pkg/tag/model/tag"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
@ -35,11 +38,13 @@ type daoTestSuite struct {
suite.Suite suite.Suite
dao DAO dao DAO
tagID int64 tagID int64
ctx context.Context
} }
func (d *daoTestSuite) SetupSuite() { func (d *daoTestSuite) SetupSuite() {
d.dao = New() d.dao = New()
common_dao.PrepareTestForPostgresSQL() common_dao.PrepareTestForPostgresSQL()
d.ctx = orm.NewContext(nil, beegoorm.NewOrm())
} }
func (d *daoTestSuite) SetupTest() { func (d *daoTestSuite) SetupTest() {
@ -50,23 +55,23 @@ func (d *daoTestSuite) SetupTest() {
PushTime: time.Time{}, PushTime: time.Time{},
PullTime: time.Time{}, PullTime: time.Time{},
} }
id, err := d.dao.Create(nil, tag) id, err := d.dao.Create(d.ctx, tag)
d.Require().Nil(err) d.Require().Nil(err)
d.tagID = id d.tagID = id
} }
func (d *daoTestSuite) TearDownTest() { func (d *daoTestSuite) TearDownTest() {
err := d.dao.Delete(nil, d.tagID) err := d.dao.Delete(d.ctx, d.tagID)
d.Require().Nil(err) d.Require().Nil(err)
} }
func (d *daoTestSuite) TestCount() { func (d *daoTestSuite) TestCount() {
// nil query // nil query
total, err := d.dao.Count(nil, nil) total, err := d.dao.Count(d.ctx, nil)
d.Require().Nil(err) d.Require().Nil(err)
d.True(total > 0) d.True(total > 0)
// query by repository ID and name // query by repository ID and name
total, err = d.dao.Count(nil, &q.Query{ total, err = d.dao.Count(d.ctx, &q.Query{
Keywords: map[string]interface{}{ Keywords: map[string]interface{}{
"repository_id": repositoryID, "repository_id": repositoryID,
"name": name, "name": name,
@ -78,7 +83,7 @@ func (d *daoTestSuite) TestCount() {
func (d *daoTestSuite) TestList() { func (d *daoTestSuite) TestList() {
// nil query // nil query
tags, err := d.dao.List(nil, nil) tags, err := d.dao.List(d.ctx, nil)
d.Require().Nil(err) d.Require().Nil(err)
found := false found := false
for _, tag := range tags { for _, tag := range tags {
@ -90,7 +95,7 @@ func (d *daoTestSuite) TestList() {
d.True(found) d.True(found)
// query by repository ID and name // query by repository ID and name
tags, err = d.dao.List(nil, &q.Query{ tags, err = d.dao.List(d.ctx, &q.Query{
Keywords: map[string]interface{}{ Keywords: map[string]interface{}{
"repository_id": repositoryID, "repository_id": repositoryID,
"name": name, "name": name,
@ -103,12 +108,12 @@ func (d *daoTestSuite) TestList() {
func (d *daoTestSuite) TestGet() { func (d *daoTestSuite) TestGet() {
// get the non-exist tag // get the non-exist tag
_, err := d.dao.Get(nil, 10000) _, err := d.dao.Get(d.ctx, 10000)
d.Require().NotNil(err) d.Require().NotNil(err)
d.True(ierror.IsErr(err, ierror.NotFoundCode)) d.True(ierror.IsErr(err, ierror.NotFoundCode))
// get the exist tag // get the exist tag
tag, err := d.dao.Get(nil, d.tagID) tag, err := d.dao.Get(d.ctx, d.tagID)
d.Require().Nil(err) d.Require().Nil(err)
d.Require().NotNil(tag) d.Require().NotNil(tag)
d.Equal(d.tagID, tag.ID) d.Equal(d.tagID, tag.ID)
@ -125,7 +130,7 @@ func (d *daoTestSuite) TestCreate() {
PushTime: time.Time{}, PushTime: time.Time{},
PullTime: time.Time{}, PullTime: time.Time{},
} }
_, err := d.dao.Create(nil, tag) _, err := d.dao.Create(d.ctx, tag)
d.Require().NotNil(err) d.Require().NotNil(err)
d.True(ierror.IsErr(err, ierror.ConflictCode)) d.True(ierror.IsErr(err, ierror.ConflictCode))
} }
@ -134,7 +139,7 @@ func (d *daoTestSuite) TestDelete() {
// happy pass is covered in TearDown // happy pass is covered in TearDown
// not exist // not exist
err := d.dao.Delete(nil, 10000) err := d.dao.Delete(d.ctx, 10000)
d.Require().NotNil(err) d.Require().NotNil(err)
var e *ierror.Error var e *ierror.Error
d.Require().True(errors.As(err, &e)) d.Require().True(errors.As(err, &e))
@ -143,19 +148,19 @@ func (d *daoTestSuite) TestDelete() {
func (d *daoTestSuite) TestUpdate() { func (d *daoTestSuite) TestUpdate() {
// pass // pass
err := d.dao.Update(nil, &tag.Tag{ err := d.dao.Update(d.ctx, &tag.Tag{
ID: d.tagID, ID: d.tagID,
ArtifactID: 2, ArtifactID: 2,
}, "ArtifactID") }, "ArtifactID")
d.Require().Nil(err) d.Require().Nil(err)
tg, err := d.dao.Get(nil, d.tagID) tg, err := d.dao.Get(d.ctx, d.tagID)
d.Require().Nil(err) d.Require().Nil(err)
d.Require().NotNil(tg) d.Require().NotNil(tg)
d.Equal(int64(2), tg.ArtifactID) d.Equal(int64(2), tg.ArtifactID)
// not exist // not exist
err = d.dao.Update(nil, &tag.Tag{ err = d.dao.Update(d.ctx, &tag.Tag{
ID: 10000, ID: 10000,
}) })
d.Require().NotNil(err) d.Require().NotNil(err)