fix(transaction): change to use value in the ctx to decide whether commit tx (#11062)

Type assertion not work when the ctx in the request changed in the next
handler, so change to use value in the ctx to decide whether to commit
tx.

Signed-off-by: He Weiwei <hweiwei@vmware.com>
This commit is contained in:
He Weiwei 2020-03-13 15:19:13 +08:00 committed by GitHub
parent 2e7eb8872e
commit 37e6fa5c92
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 20 deletions

View File

@ -18,12 +18,12 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
serror "github.com/goharbor/harbor/src/server/error"
"net/http" "net/http"
"github.com/goharbor/harbor/src/common/utils/log" "github.com/goharbor/harbor/src/common/utils/log"
"github.com/goharbor/harbor/src/internal" "github.com/goharbor/harbor/src/internal"
"github.com/goharbor/harbor/src/internal/orm" "github.com/goharbor/harbor/src/internal/orm"
serror "github.com/goharbor/harbor/src/server/error"
"github.com/goharbor/harbor/src/server/middleware" "github.com/goharbor/harbor/src/server/middleware"
) )
@ -31,28 +31,17 @@ var (
errNonSuccess = errors.New("non success status code") errNonSuccess = errors.New("non success status code")
) )
type committableContext struct { type committedKey struct{}
context.Context
committed bool
}
func (ctx *committableContext) Commit() {
ctx.committed = true
}
type committable interface {
Commit()
}
// MustCommit mark http.Request as committed so that transaction // MustCommit mark http.Request as committed so that transaction
// middleware ignore the status code of the response and commit transaction for this request // middleware ignore the status code of the response and commit transaction for this request
func MustCommit(r *http.Request) error { func MustCommit(r *http.Request) error {
c, ok := r.Context().(committable) committed, ok := r.Context().Value(committedKey{}).(*bool)
if !ok { if !ok {
return fmt.Errorf("%s URL %s is not committable, please enable transaction middleware for it", r.Method, r.URL.Path) return fmt.Errorf("%s URL %s is not committable, please enable transaction middleware for it", r.Method, r.URL.Path)
} }
c.Commit() *committed = true
return nil return nil
} }
@ -67,10 +56,12 @@ func Middleware(skippers ...middleware.Skipper) func(http.Handler) http.Handler
} }
h := func(ctx context.Context) error { h := func(ctx context.Context) error {
cc := &committableContext{Context: ctx} committed := new(bool) // default false, not must commit
cc := context.WithValue(ctx, committedKey{}, committed)
next.ServeHTTP(res, r.WithContext(cc)) next.ServeHTTP(res, r.WithContext(cc))
if !cc.committed && !res.Success() { if !(*committed) && !res.Success() {
return errNonSuccess return errNonSuccess
} }

View File

@ -24,6 +24,7 @@ import (
o "github.com/astaxie/beego/orm" o "github.com/astaxie/beego/orm"
"github.com/goharbor/harbor/src/internal/orm" "github.com/goharbor/harbor/src/internal/orm"
"github.com/goharbor/harbor/src/server/middleware"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -129,7 +130,13 @@ func TestTransaction(t *testing.T) {
req5 := newRequest(http.MethodGet, "/req", nil) req5 := newRequest(http.MethodGet, "/req", nil)
rec5 := httptest.NewRecorder() rec5 := httptest.NewRecorder()
Middleware()(txMustCommit(http.StatusBadRequest)).ServeHTTP(rec5, req5)
m1 := middleware.New(func(w http.ResponseWriter, r *http.Request, next http.Handler) {
type key struct{}
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), key{}, "value")))
})
Middleware()(m1((txMustCommit(http.StatusBadRequest)))).ServeHTTP(rec5, req5)
assert.Equal(http.StatusBadRequest, rec2.Code) assert.Equal(http.StatusBadRequest, rec2.Code)
assert.NotEmpty(mo.records) assert.NotEmpty(mo.records)
} }
@ -140,6 +147,9 @@ func TestMustCommit(t *testing.T) {
return req.WithContext(ctx) return req.WithContext(ctx)
} }
ctx := context.Background()
committableCtx := context.WithValue(ctx, committedKey{}, new(bool))
type args struct { type args struct {
r *http.Request r *http.Request
} }
@ -148,8 +158,8 @@ func TestMustCommit(t *testing.T) {
args args args args
wantErr bool wantErr bool
}{ }{
{"request committable", args{newRequest(&committableContext{Context: context.Background()})}, false}, {"request committable", args{newRequest(committableCtx)}, false},
{"request not committable", args{newRequest(context.Background())}, true}, {"request not committable", args{newRequest(ctx)}, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {