feat(orm): add orm support with context (#10337)

1. Get and set orm from context.
2. Add WithTransaction decorator make func run in transaction.
3. Support nested transaction by Savepoint.

Signed-off-by: He Weiwei <hweiwei@vmware.com>
This commit is contained in:
He Weiwei 2019-12-31 18:30:52 +08:00 committed by GitHub
parent 803e676ee7
commit 7a4cb17450
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 491 additions and 0 deletions

30
src/core/filter/orm.go Normal file
View File

@ -0,0 +1,30 @@
// Copyright Project Harbor Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package filter
import (
"github.com/astaxie/beego/context"
o "github.com/astaxie/beego/orm"
"github.com/goharbor/harbor/src/internal/orm"
)
// OrmFilter set orm.Ormer instance to the context of the http.Request
func OrmFilter(ctx *context.Context) {
if ctx == nil || ctx.Request == nil {
return
}
ctx.Request = ctx.Request.WithContext(orm.NewContext(ctx.Request.Context(), o.NewOrm()))
}

View File

@ -247,6 +247,7 @@ func main() {
filter.Init()
beego.InsertFilter("/api/*", beego.BeforeStatic, filter.SessionCheck)
beego.InsertFilter("/*", beego.BeforeRouter, filter.OrmFilter)
beego.InsertFilter("/*", beego.BeforeRouter, filter.SecurityFilter)
beego.InsertFilter("/*", beego.BeforeRouter, filter.ReadonlyFilter)

68
src/internal/orm/orm.go Normal file
View File

@ -0,0 +1,68 @@
// Copyright Project Harbor Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"context"
"fmt"
"github.com/astaxie/beego/orm"
"github.com/goharbor/harbor/src/common/utils/log"
)
type ormKey struct{}
// FromContext returns orm from context
func FromContext(ctx context.Context) (orm.Ormer, bool) {
o, ok := ctx.Value(ormKey{}).(orm.Ormer)
return o, ok
}
// NewContext returns new context with orm
func NewContext(ctx context.Context, o orm.Ormer) context.Context {
return context.WithValue(ctx, ormKey{}, o)
}
// WithTransaction a decorator which make f run in transaction
func WithTransaction(f func(ctx context.Context) error) func(ctx context.Context) error {
return func(ctx context.Context) error {
o, ok := FromContext(ctx)
if !ok {
return fmt.Errorf("ormer value not found in context")
}
tx := ormerTx{Ormer: o}
if err := tx.Begin(); err != nil {
log.Errorf("begin transaction failed: %v", err)
return err
}
if err := f(ctx); err != nil {
if e := tx.Rollback(); e != nil {
log.Errorf("rollback transaction failed: %v", e)
return e
}
return err
}
if err := tx.Commit(); err != nil {
log.Errorf("commit transaction failed: %v", err)
return err
}
return nil
}
}

View File

@ -0,0 +1,315 @@
// Copyright Project Harbor Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"context"
"errors"
"testing"
"github.com/astaxie/beego/orm"
"github.com/goharbor/harbor/src/common/dao"
"github.com/goharbor/harbor/src/common/models"
"github.com/stretchr/testify/suite"
)
func addProject(ctx context.Context, project models.Project) (int64, error) {
o, ok := FromContext(ctx)
if !ok {
return 0, errors.New("orm not found in context")
}
return o.Insert(&project)
}
func readProject(ctx context.Context, id int64) (*models.Project, error) {
o, ok := FromContext(ctx)
if !ok {
return nil, errors.New("orm not found in context")
}
project := &models.Project{
ProjectID: id,
}
if err := o.Read(project, "project_id"); err != nil {
return nil, err
}
return project, nil
}
func deleteProject(ctx context.Context, id int64) error {
o, ok := FromContext(ctx)
if !ok {
return errors.New("orm not found in context")
}
project := &models.Project{
ProjectID: id,
}
_, err := o.Delete(project, "project_id")
return err
}
func existProject(ctx context.Context, id int64) bool {
o, ok := FromContext(ctx)
if !ok {
return false
}
project := &models.Project{
ProjectID: id,
}
if err := o.Read(project, "project_id"); err != nil {
return false
}
return true
}
// Suite ...
type OrmSuite struct {
suite.Suite
}
// SetupSuite ...
func (suite *OrmSuite) SetupSuite() {
dao.PrepareTestForPostgresSQL()
}
func (suite *OrmSuite) TestContext() {
ctx := context.TODO()
o, ok := FromContext(ctx)
suite.False(ok)
suite.Nil(o)
o, ok = FromContext(NewContext(ctx, orm.NewOrm()))
suite.True(ok)
suite.NotNil(o)
}
func (suite *OrmSuite) TestWithTransaction() {
ctx := NewContext(context.TODO(), orm.NewOrm())
var id int64
t1 := WithTransaction(func(ctx context.Context) (err error) {
id, err = addProject(ctx, models.Project{Name: "t1", OwnerID: 1})
return err
})
suite.Nil(t1(ctx))
suite.True(existProject(ctx, id))
suite.Nil(deleteProject(ctx, id))
}
func (suite *OrmSuite) TestSequentialTransactions() {
ctx := NewContext(context.TODO(), orm.NewOrm())
var id1, id2 int64
t1 := func(ctx context.Context, retErr error) error {
return WithTransaction(func(ctx context.Context) (err error) {
id1, err = addProject(ctx, models.Project{Name: "t1", OwnerID: 1})
if err != nil {
return err
}
// Ensure t1 created success
suite.True(existProject(ctx, id1))
return retErr
})(ctx)
}
t2 := func(ctx context.Context, retErr error) error {
return WithTransaction(func(ctx context.Context) (err error) {
id2, _ = addProject(ctx, models.Project{Name: "t2", OwnerID: 1})
if err != nil {
return err
}
// Ensure t2 created success
suite.True(existProject(ctx, id2))
return retErr
})(ctx)
}
if suite.Nil(t1(ctx, nil)) {
suite.True(existProject(ctx, id1))
}
if suite.Nil(t2(ctx, nil)) {
suite.True(existProject(ctx, id2))
}
// delete project t1 and t2 in db
suite.Nil(deleteProject(ctx, id1))
suite.Nil(deleteProject(ctx, id2))
if suite.Error(t1(ctx, errors.New("oops"))) {
suite.False(existProject(ctx, id1))
}
if suite.Nil(t2(ctx, nil)) {
suite.True(existProject(ctx, id2))
suite.Nil(deleteProject(ctx, id2))
}
}
func (suite *OrmSuite) TestNestedTransaction() {
ctx := NewContext(context.TODO(), orm.NewOrm())
var id1, id2 int64
nt1 := WithTransaction(func(ctx context.Context) (err error) {
id1, err = addProject(ctx, models.Project{Name: "nt1", OwnerID: 1})
return err
})
nt2 := WithTransaction(func(ctx context.Context) (err error) {
id2, err = addProject(ctx, models.Project{Name: "nt2", OwnerID: 1})
return err
})
nt := func(ctx context.Context, retErr error) error {
return WithTransaction(func(ctx context.Context) error {
if err := nt1(ctx); err != nil {
return err
}
if err := nt2(ctx); err != nil {
return err
}
// Ensure nt1 and nt2 created success
suite.True(existProject(ctx, id1))
suite.True(existProject(ctx, id2))
return retErr
})(ctx)
}
if suite.Nil(nt(ctx, nil)) {
suite.True(existProject(ctx, id1))
suite.True(existProject(ctx, id2))
// delete project nt1 and nt2 in db
suite.Nil(deleteProject(ctx, id1))
suite.Nil(deleteProject(ctx, id2))
suite.False(existProject(ctx, id1))
suite.False(existProject(ctx, id2))
}
if suite.Error(nt(ctx, errors.New("oops"))) {
suite.False(existProject(ctx, id1))
suite.False(existProject(ctx, id2))
}
// test nt1 failed but we skip it and nt2 success
suite.Nil(nt1(ctx))
suite.True(existProject(ctx, id1))
// delete nt1 here because id1 will overwrite in the following transaction
defer func(id int64) {
suite.Nil(deleteProject(ctx, id))
}(id1)
t := WithTransaction(func(ctx context.Context) error {
suite.Error(nt1(ctx))
if err := nt2(ctx); err != nil {
return err
}
// Ensure t2 created success
suite.True(existProject(ctx, id2))
return nil
})
if suite.Nil(t(ctx)) {
suite.True(existProject(ctx, id2))
// delete project t2 in db
suite.Nil(deleteProject(ctx, id2))
}
}
func (suite *OrmSuite) TestNestedSavepoint() {
ctx := NewContext(context.TODO(), orm.NewOrm())
var id1, id2 int64
ns1 := WithTransaction(func(ctx context.Context) (err error) {
id1, err = addProject(ctx, models.Project{Name: "ns1", OwnerID: 1})
return err
})
ns2 := WithTransaction(func(ctx context.Context) (err error) {
id2, err = addProject(ctx, models.Project{Name: "ns2", OwnerID: 1})
return err
})
ns := func(ctx context.Context, retErr error) error {
return WithTransaction(func(ctx context.Context) error {
if err := ns1(ctx); err != nil {
return err
}
if err := ns2(ctx); err != nil {
return err
}
// Ensure nt1 and nt2 created success
suite.True(existProject(ctx, id1))
suite.True(existProject(ctx, id2))
return retErr
})(ctx)
}
t := func(ctx context.Context, tErr, pErr error) error {
return WithTransaction(func(c context.Context) error {
ns(c, pErr)
return tErr
})(ctx)
}
// transaction commit and s1s2 commit
suite.Nil(t(ctx, nil, nil))
// Ensure nt1 and nt2 created success
suite.True(existProject(ctx, id1))
suite.True(existProject(ctx, id2))
// delete project nt1 and nt2 in db
suite.Nil(deleteProject(ctx, id1))
suite.Nil(deleteProject(ctx, id2))
suite.False(existProject(ctx, id1))
suite.False(existProject(ctx, id2))
// transaction commit and s1s2 rollback
suite.Nil(t(ctx, nil, errors.New("oops")))
// Ensure nt1 and nt2 created failed
suite.False(existProject(ctx, id1))
suite.False(existProject(ctx, id2))
// transaction rollback and s1s2 commit
suite.Error(t(ctx, errors.New("oops"), nil))
// Ensure nt1 and nt2 created failed
suite.False(existProject(ctx, id1))
suite.False(existProject(ctx, id2))
}
func TestRunOrmSuite(t *testing.T) {
suite.Run(t, new(OrmSuite))
}

77
src/internal/orm/tx.go Normal file
View File

@ -0,0 +1,77 @@
// Copyright Project Harbor Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"encoding/hex"
"fmt"
"github.com/astaxie/beego/orm"
"github.com/google/uuid"
)
// ormerTx transaction which support savepoint
type ormerTx struct {
orm.Ormer
savepoint string
}
func (o *ormerTx) savepointMode() bool {
return o.savepoint != ""
}
func (o *ormerTx) createSavepoint() error {
val := uuid.New()
o.savepoint = fmt.Sprintf("p%s", hex.EncodeToString(val[:]))
_, err := o.Raw(fmt.Sprintf("SAVEPOINT %s", o.savepoint)).Exec()
return err
}
func (o *ormerTx) releaseSavepoint() error {
_, err := o.Raw(fmt.Sprintf("RELEASE SAVEPOINT %s", o.savepoint)).Exec()
return err
}
func (o *ormerTx) rollbackToSavepoint() error {
_, err := o.Raw(fmt.Sprintf("ROLLBACK TO SAVEPOINT %s", o.savepoint)).Exec()
return err
}
func (o *ormerTx) Begin() error {
err := o.Ormer.Begin()
if err == orm.ErrTxHasBegan {
// transaction has began for the ormer, so begin nested transaction by savepoint
return o.createSavepoint()
}
return err
}
func (o *ormerTx) Commit() error {
if o.savepointMode() {
return o.releaseSavepoint()
}
return o.Ormer.Commit()
}
func (o *ormerTx) Rollback() error {
if o.savepointMode() {
return o.rollbackToSavepoint()
}
return o.Ormer.Rollback()
}