From 7a4cb174509acbad7bec94fc224fbea8a344ec84 Mon Sep 17 00:00:00 2001 From: He Weiwei Date: Tue, 31 Dec 2019 18:30:52 +0800 Subject: [PATCH] feat(orm): add orm support with context (#10337) 1. Get and set orm from context. 2. Add WithTransaction decorator make func run in transaction. 3. Support nested transaction by Savepoint. Signed-off-by: He Weiwei --- src/core/filter/orm.go | 30 ++++ src/core/main.go | 1 + src/internal/orm/orm.go | 68 ++++++++ src/internal/orm/orm_test.go | 315 +++++++++++++++++++++++++++++++++++ src/internal/orm/tx.go | 77 +++++++++ 5 files changed, 491 insertions(+) create mode 100644 src/core/filter/orm.go create mode 100644 src/internal/orm/orm.go create mode 100644 src/internal/orm/orm_test.go create mode 100644 src/internal/orm/tx.go diff --git a/src/core/filter/orm.go b/src/core/filter/orm.go new file mode 100644 index 000000000..1310db398 --- /dev/null +++ b/src/core/filter/orm.go @@ -0,0 +1,30 @@ +// Copyright Project Harbor Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package filter + +import ( + "github.com/astaxie/beego/context" + o "github.com/astaxie/beego/orm" + "github.com/goharbor/harbor/src/internal/orm" +) + +// OrmFilter set orm.Ormer instance to the context of the http.Request +func OrmFilter(ctx *context.Context) { + if ctx == nil || ctx.Request == nil { + return + } + + ctx.Request = ctx.Request.WithContext(orm.NewContext(ctx.Request.Context(), o.NewOrm())) +} diff --git a/src/core/main.go b/src/core/main.go index 53a9a7f65..50a3a34e8 100755 --- a/src/core/main.go +++ b/src/core/main.go @@ -247,6 +247,7 @@ func main() { filter.Init() beego.InsertFilter("/api/*", beego.BeforeStatic, filter.SessionCheck) + beego.InsertFilter("/*", beego.BeforeRouter, filter.OrmFilter) beego.InsertFilter("/*", beego.BeforeRouter, filter.SecurityFilter) beego.InsertFilter("/*", beego.BeforeRouter, filter.ReadonlyFilter) diff --git a/src/internal/orm/orm.go b/src/internal/orm/orm.go new file mode 100644 index 000000000..8748128f2 --- /dev/null +++ b/src/internal/orm/orm.go @@ -0,0 +1,68 @@ +// Copyright Project Harbor Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "context" + "fmt" + + "github.com/astaxie/beego/orm" + "github.com/goharbor/harbor/src/common/utils/log" +) + +type ormKey struct{} + +// FromContext returns orm from context +func FromContext(ctx context.Context) (orm.Ormer, bool) { + o, ok := ctx.Value(ormKey{}).(orm.Ormer) + return o, ok +} + +// NewContext returns new context with orm +func NewContext(ctx context.Context, o orm.Ormer) context.Context { + return context.WithValue(ctx, ormKey{}, o) +} + +// WithTransaction a decorator which make f run in transaction +func WithTransaction(f func(ctx context.Context) error) func(ctx context.Context) error { + return func(ctx context.Context) error { + o, ok := FromContext(ctx) + if !ok { + return fmt.Errorf("ormer value not found in context") + } + + tx := ormerTx{Ormer: o} + if err := tx.Begin(); err != nil { + log.Errorf("begin transaction failed: %v", err) + return err + } + + if err := f(ctx); err != nil { + if e := tx.Rollback(); e != nil { + log.Errorf("rollback transaction failed: %v", e) + return e + } + + return err + } + + if err := tx.Commit(); err != nil { + log.Errorf("commit transaction failed: %v", err) + return err + } + + return nil + } +} diff --git a/src/internal/orm/orm_test.go b/src/internal/orm/orm_test.go new file mode 100644 index 000000000..bc07178f6 --- /dev/null +++ b/src/internal/orm/orm_test.go @@ -0,0 +1,315 @@ +// Copyright Project Harbor Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "context" + "errors" + "testing" + + "github.com/astaxie/beego/orm" + "github.com/goharbor/harbor/src/common/dao" + "github.com/goharbor/harbor/src/common/models" + "github.com/stretchr/testify/suite" +) + +func addProject(ctx context.Context, project models.Project) (int64, error) { + o, ok := FromContext(ctx) + if !ok { + return 0, errors.New("orm not found in context") + } + + return o.Insert(&project) +} + +func readProject(ctx context.Context, id int64) (*models.Project, error) { + o, ok := FromContext(ctx) + if !ok { + return nil, errors.New("orm not found in context") + } + + project := &models.Project{ + ProjectID: id, + } + + if err := o.Read(project, "project_id"); err != nil { + return nil, err + } + + return project, nil +} + +func deleteProject(ctx context.Context, id int64) error { + o, ok := FromContext(ctx) + if !ok { + return errors.New("orm not found in context") + } + + project := &models.Project{ + ProjectID: id, + } + + _, err := o.Delete(project, "project_id") + return err +} + +func existProject(ctx context.Context, id int64) bool { + o, ok := FromContext(ctx) + if !ok { + return false + } + + project := &models.Project{ + ProjectID: id, + } + + if err := o.Read(project, "project_id"); err != nil { + return false + } + + return true +} + +// Suite ... +type OrmSuite struct { + suite.Suite +} + +// SetupSuite ... +func (suite *OrmSuite) SetupSuite() { + dao.PrepareTestForPostgresSQL() +} + +func (suite *OrmSuite) TestContext() { + ctx := context.TODO() + + o, ok := FromContext(ctx) + suite.False(ok) + suite.Nil(o) + + o, ok = FromContext(NewContext(ctx, orm.NewOrm())) + suite.True(ok) + suite.NotNil(o) +} + +func (suite *OrmSuite) TestWithTransaction() { + ctx := NewContext(context.TODO(), orm.NewOrm()) + + var id int64 + t1 := WithTransaction(func(ctx context.Context) (err error) { + id, err = addProject(ctx, models.Project{Name: "t1", OwnerID: 1}) + return err + }) + + suite.Nil(t1(ctx)) + suite.True(existProject(ctx, id)) + suite.Nil(deleteProject(ctx, id)) +} + +func (suite *OrmSuite) TestSequentialTransactions() { + ctx := NewContext(context.TODO(), orm.NewOrm()) + + var id1, id2 int64 + t1 := func(ctx context.Context, retErr error) error { + return WithTransaction(func(ctx context.Context) (err error) { + id1, err = addProject(ctx, models.Project{Name: "t1", OwnerID: 1}) + if err != nil { + return err + } + + // Ensure t1 created success + suite.True(existProject(ctx, id1)) + + return retErr + })(ctx) + } + t2 := func(ctx context.Context, retErr error) error { + return WithTransaction(func(ctx context.Context) (err error) { + id2, _ = addProject(ctx, models.Project{Name: "t2", OwnerID: 1}) + if err != nil { + return err + } + + // Ensure t2 created success + suite.True(existProject(ctx, id2)) + + return retErr + })(ctx) + } + + if suite.Nil(t1(ctx, nil)) { + suite.True(existProject(ctx, id1)) + } + + if suite.Nil(t2(ctx, nil)) { + suite.True(existProject(ctx, id2)) + } + + // delete project t1 and t2 in db + suite.Nil(deleteProject(ctx, id1)) + suite.Nil(deleteProject(ctx, id2)) + + if suite.Error(t1(ctx, errors.New("oops"))) { + suite.False(existProject(ctx, id1)) + } + + if suite.Nil(t2(ctx, nil)) { + suite.True(existProject(ctx, id2)) + suite.Nil(deleteProject(ctx, id2)) + } +} + +func (suite *OrmSuite) TestNestedTransaction() { + ctx := NewContext(context.TODO(), orm.NewOrm()) + + var id1, id2 int64 + nt1 := WithTransaction(func(ctx context.Context) (err error) { + id1, err = addProject(ctx, models.Project{Name: "nt1", OwnerID: 1}) + return err + }) + nt2 := WithTransaction(func(ctx context.Context) (err error) { + id2, err = addProject(ctx, models.Project{Name: "nt2", OwnerID: 1}) + return err + }) + + nt := func(ctx context.Context, retErr error) error { + return WithTransaction(func(ctx context.Context) error { + if err := nt1(ctx); err != nil { + return err + } + + if err := nt2(ctx); err != nil { + return err + } + + // Ensure nt1 and nt2 created success + suite.True(existProject(ctx, id1)) + suite.True(existProject(ctx, id2)) + + return retErr + })(ctx) + } + + if suite.Nil(nt(ctx, nil)) { + suite.True(existProject(ctx, id1)) + suite.True(existProject(ctx, id2)) + + // delete project nt1 and nt2 in db + suite.Nil(deleteProject(ctx, id1)) + suite.Nil(deleteProject(ctx, id2)) + suite.False(existProject(ctx, id1)) + suite.False(existProject(ctx, id2)) + } + + if suite.Error(nt(ctx, errors.New("oops"))) { + suite.False(existProject(ctx, id1)) + suite.False(existProject(ctx, id2)) + } + + // test nt1 failed but we skip it and nt2 success + suite.Nil(nt1(ctx)) + suite.True(existProject(ctx, id1)) + + // delete nt1 here because id1 will overwrite in the following transaction + defer func(id int64) { + suite.Nil(deleteProject(ctx, id)) + }(id1) + + t := WithTransaction(func(ctx context.Context) error { + suite.Error(nt1(ctx)) + + if err := nt2(ctx); err != nil { + return err + } + + // Ensure t2 created success + suite.True(existProject(ctx, id2)) + + return nil + }) + + if suite.Nil(t(ctx)) { + suite.True(existProject(ctx, id2)) + + // delete project t2 in db + suite.Nil(deleteProject(ctx, id2)) + } +} + +func (suite *OrmSuite) TestNestedSavepoint() { + ctx := NewContext(context.TODO(), orm.NewOrm()) + + var id1, id2 int64 + ns1 := WithTransaction(func(ctx context.Context) (err error) { + id1, err = addProject(ctx, models.Project{Name: "ns1", OwnerID: 1}) + return err + }) + ns2 := WithTransaction(func(ctx context.Context) (err error) { + id2, err = addProject(ctx, models.Project{Name: "ns2", OwnerID: 1}) + return err + }) + + ns := func(ctx context.Context, retErr error) error { + return WithTransaction(func(ctx context.Context) error { + if err := ns1(ctx); err != nil { + return err + } + + if err := ns2(ctx); err != nil { + return err + } + + // Ensure nt1 and nt2 created success + suite.True(existProject(ctx, id1)) + suite.True(existProject(ctx, id2)) + + return retErr + })(ctx) + } + + t := func(ctx context.Context, tErr, pErr error) error { + return WithTransaction(func(c context.Context) error { + ns(c, pErr) + return tErr + })(ctx) + } + + // transaction commit and s1s2 commit + suite.Nil(t(ctx, nil, nil)) + // Ensure nt1 and nt2 created success + suite.True(existProject(ctx, id1)) + suite.True(existProject(ctx, id2)) + // delete project nt1 and nt2 in db + suite.Nil(deleteProject(ctx, id1)) + suite.Nil(deleteProject(ctx, id2)) + suite.False(existProject(ctx, id1)) + suite.False(existProject(ctx, id2)) + + // transaction commit and s1s2 rollback + suite.Nil(t(ctx, nil, errors.New("oops"))) + // Ensure nt1 and nt2 created failed + suite.False(existProject(ctx, id1)) + suite.False(existProject(ctx, id2)) + + // transaction rollback and s1s2 commit + suite.Error(t(ctx, errors.New("oops"), nil)) + // Ensure nt1 and nt2 created failed + suite.False(existProject(ctx, id1)) + suite.False(existProject(ctx, id2)) +} + +func TestRunOrmSuite(t *testing.T) { + suite.Run(t, new(OrmSuite)) +} diff --git a/src/internal/orm/tx.go b/src/internal/orm/tx.go new file mode 100644 index 000000000..9627c482f --- /dev/null +++ b/src/internal/orm/tx.go @@ -0,0 +1,77 @@ +// Copyright Project Harbor Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "encoding/hex" + "fmt" + + "github.com/astaxie/beego/orm" + "github.com/google/uuid" +) + +// ormerTx transaction which support savepoint +type ormerTx struct { + orm.Ormer + savepoint string +} + +func (o *ormerTx) savepointMode() bool { + return o.savepoint != "" +} + +func (o *ormerTx) createSavepoint() error { + val := uuid.New() + o.savepoint = fmt.Sprintf("p%s", hex.EncodeToString(val[:])) + + _, err := o.Raw(fmt.Sprintf("SAVEPOINT %s", o.savepoint)).Exec() + return err +} + +func (o *ormerTx) releaseSavepoint() error { + _, err := o.Raw(fmt.Sprintf("RELEASE SAVEPOINT %s", o.savepoint)).Exec() + return err +} + +func (o *ormerTx) rollbackToSavepoint() error { + _, err := o.Raw(fmt.Sprintf("ROLLBACK TO SAVEPOINT %s", o.savepoint)).Exec() + return err +} + +func (o *ormerTx) Begin() error { + err := o.Ormer.Begin() + if err == orm.ErrTxHasBegan { + // transaction has began for the ormer, so begin nested transaction by savepoint + return o.createSavepoint() + } + + return err +} + +func (o *ormerTx) Commit() error { + if o.savepointMode() { + return o.releaseSavepoint() + } + + return o.Ormer.Commit() +} + +func (o *ormerTx) Rollback() error { + if o.savepointMode() { + return o.rollbackToSavepoint() + } + + return o.Ormer.Rollback() +}