mirror of
https://github.com/goharbor/harbor.git
synced 2025-01-04 23:17:45 +01:00
Merge pull request #10550 from heww/transtaction-middleware
feat(middleware): add transaction middleware for v2 and v2.0 APIs
This commit is contained in:
commit
d79a0e6030
@ -17,6 +17,12 @@ package main
|
|||||||
import (
|
import (
|
||||||
"encoding/gob"
|
"encoding/gob"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"strconv"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/astaxie/beego"
|
"github.com/astaxie/beego"
|
||||||
_ "github.com/astaxie/beego/session/redis"
|
_ "github.com/astaxie/beego/session/redis"
|
||||||
"github.com/goharbor/harbor/src/common/dao"
|
"github.com/goharbor/harbor/src/common/dao"
|
||||||
@ -48,15 +54,6 @@ import (
|
|||||||
"github.com/goharbor/harbor/src/pkg/version"
|
"github.com/goharbor/harbor/src/pkg/version"
|
||||||
"github.com/goharbor/harbor/src/replication"
|
"github.com/goharbor/harbor/src/replication"
|
||||||
"github.com/goharbor/harbor/src/server"
|
"github.com/goharbor/harbor/src/server"
|
||||||
"github.com/goharbor/harbor/src/server/middleware/orm"
|
|
||||||
"github.com/goharbor/harbor/src/server/middleware/requestid"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"os/signal"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"syscall"
|
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -292,21 +289,5 @@ func main() {
|
|||||||
|
|
||||||
log.Infof("Version: %s, Git commit: %s", version.ReleaseVersion, version.GitCommit)
|
log.Infof("Version: %s, Git commit: %s", version.ReleaseVersion, version.GitCommit)
|
||||||
|
|
||||||
middlewares := []beego.MiddleWare{
|
beego.RunWithMiddleWares("", middlewares.MiddleWares()...)
|
||||||
requestid.Middleware(),
|
|
||||||
orm.Middleware(legacyAPISkipper),
|
|
||||||
}
|
|
||||||
beego.RunWithMiddleWares("", middlewares...)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// legacyAPISkipper skip middleware for legacy APIs
|
|
||||||
func legacyAPISkipper(r *http.Request) bool {
|
|
||||||
for _, prefix := range []string{"/v2/", "/api/v2.0/"} {
|
|
||||||
if strings.HasPrefix(r.URL.Path, prefix) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
|
56
src/core/middlewares/middlewares.go
Normal file
56
src/core/middlewares/middlewares.go
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
// Copyright Project Harbor Authors
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package middlewares
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/astaxie/beego"
|
||||||
|
"github.com/docker/distribution/reference"
|
||||||
|
"github.com/goharbor/harbor/src/server/middleware"
|
||||||
|
"github.com/goharbor/harbor/src/server/middleware/orm"
|
||||||
|
"github.com/goharbor/harbor/src/server/middleware/requestid"
|
||||||
|
"github.com/goharbor/harbor/src/server/middleware/transaction"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
blobURLRe = regexp.MustCompile("^/v2/(" + reference.NameRegexp.String() + ")/blobs/" + reference.DigestRegexp.String())
|
||||||
|
|
||||||
|
// fetchBlobAPISkipper skip transaction middleware for fetch blob API
|
||||||
|
// because transaction use the ResponseBuffer for the response which will degrade the performance for fetch blob
|
||||||
|
fetchBlobAPISkipper = middleware.MethodAndPathSkipper(http.MethodGet, blobURLRe)
|
||||||
|
)
|
||||||
|
|
||||||
|
// legacyAPISkipper skip middleware for legacy APIs
|
||||||
|
func legacyAPISkipper(r *http.Request) bool {
|
||||||
|
for _, prefix := range []string{"/v2/", "/api/v2.0/"} {
|
||||||
|
if strings.HasPrefix(r.URL.Path, prefix) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// MiddleWares returns global middlewares
|
||||||
|
func MiddleWares() []beego.MiddleWare {
|
||||||
|
return []beego.MiddleWare{
|
||||||
|
requestid.Middleware(),
|
||||||
|
orm.Middleware(legacyAPISkipper),
|
||||||
|
transaction.Middleware(legacyAPISkipper, fetchBlobAPISkipper),
|
||||||
|
}
|
||||||
|
}
|
43
src/core/middlewares/middlewares_test.go
Normal file
43
src/core/middlewares/middlewares_test.go
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
// Copyright Project Harbor Authors
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package middlewares
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_fetchBlobAPISkipper(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
r *http.Request
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"fetch blob", args{httptest.NewRequest(http.MethodGet, "/v2/library/photon/blobs/sha256:6e0447537050cf871f9ab6a3fec5715f9c6fff5212f6666993f1fc46b1f717a3", nil)}, true},
|
||||||
|
{"delete blob", args{httptest.NewRequest(http.MethodDelete, "/v2/library/photon/blobs/sha256:6e0447537050cf871f9ab6a3fec5715f9c6fff5212f6666993f1fc46b1f717a3", nil)}, false},
|
||||||
|
{"get manifest", args{httptest.NewRequest(http.MethodDelete, "/v2/library/photon/manifests/sha256:6e0447537050cf871f9ab6a3fec5715f9c6fff5212f6666993f1fc46b1f717a3", nil)}, false},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := fetchBlobAPISkipper(tt.args.r); got != tt.want {
|
||||||
|
t.Errorf("fetchBlobAPISkipper() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
@ -16,6 +16,7 @@ package internal
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -26,6 +27,7 @@ type ResponseBuffer struct {
|
|||||||
header http.Header
|
header http.Header
|
||||||
buffer bytes.Buffer
|
buffer bytes.Buffer
|
||||||
wroteHeader bool
|
wroteHeader bool
|
||||||
|
flushed bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewResponseBuffer creates a ResponseBuffer object
|
// NewResponseBuffer creates a ResponseBuffer object
|
||||||
@ -48,7 +50,9 @@ func (r *ResponseBuffer) WriteHeader(statusCode int) {
|
|||||||
|
|
||||||
// Write writes the data into the buffer without writing to the underlying response writer
|
// Write writes the data into the buffer without writing to the underlying response writer
|
||||||
func (r *ResponseBuffer) Write(data []byte) (int, error) {
|
func (r *ResponseBuffer) Write(data []byte) (int, error) {
|
||||||
r.WriteHeader(http.StatusOK)
|
if !r.wroteHeader {
|
||||||
|
r.WriteHeader(http.StatusOK)
|
||||||
|
}
|
||||||
return r.buffer.Write(data)
|
return r.buffer.Write(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -59,6 +63,8 @@ func (r *ResponseBuffer) Header() http.Header {
|
|||||||
|
|
||||||
// Flush the status code, header and data into the underlying response writer
|
// Flush the status code, header and data into the underlying response writer
|
||||||
func (r *ResponseBuffer) Flush() (int, error) {
|
func (r *ResponseBuffer) Flush() (int, error) {
|
||||||
|
r.flushed = true
|
||||||
|
|
||||||
header := r.w.Header()
|
header := r.w.Header()
|
||||||
for k, vs := range r.header {
|
for k, vs := range r.header {
|
||||||
for _, v := range vs {
|
for _, v := range vs {
|
||||||
@ -73,5 +79,19 @@ func (r *ResponseBuffer) Flush() (int, error) {
|
|||||||
|
|
||||||
// Success checks whether the status code is >= 200 & <= 399
|
// Success checks whether the status code is >= 200 & <= 399
|
||||||
func (r *ResponseBuffer) Success() bool {
|
func (r *ResponseBuffer) Success() bool {
|
||||||
return r.code >= 200 && r.code <= 399
|
return r.code >= http.StatusOK && r.code < http.StatusBadRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset reset the response buffer
|
||||||
|
func (r *ResponseBuffer) Reset() error {
|
||||||
|
if r.flushed {
|
||||||
|
return errors.New("response flushed")
|
||||||
|
}
|
||||||
|
|
||||||
|
r.code = 0
|
||||||
|
r.wroteHeader = false
|
||||||
|
r.header = http.Header{}
|
||||||
|
r.buffer = bytes.Buffer{}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -14,7 +14,9 @@
|
|||||||
|
|
||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import "net/http"
|
import (
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
// Middleware receives a handler and returns another handler.
|
// Middleware receives a handler and returns another handler.
|
||||||
// The returned handler can do some customized task according to
|
// The returned handler can do some customized task according to
|
||||||
@ -30,10 +32,6 @@ func WithMiddlewares(handler http.Handler, middlewares ...Middleware) http.Handl
|
|||||||
return handler
|
return handler
|
||||||
}
|
}
|
||||||
|
|
||||||
// Skipper defines a function to skip middleware.
|
|
||||||
// Returning true skips processing the middleware.
|
|
||||||
type Skipper func(*http.Request) bool
|
|
||||||
|
|
||||||
// New make a middleware from fn which type is func(w http.ResponseWriter, r *http.Request, next http.Handler)
|
// New make a middleware from fn which type is func(w http.ResponseWriter, r *http.Request, next http.Handler)
|
||||||
func New(fn func(http.ResponseWriter, *http.Request, http.Handler), skippers ...Skipper) func(http.Handler) http.Handler {
|
func New(fn func(http.ResponseWriter, *http.Request, http.Handler), skippers ...Skipper) func(http.Handler) http.Handler {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
|
37
src/server/middleware/skipper.go
Normal file
37
src/server/middleware/skipper.go
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
// Copyright Project Harbor Authors
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"regexp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Skipper defines a function to skip middleware.
|
||||||
|
// Returning true skips processing the middleware.
|
||||||
|
type Skipper func(*http.Request) bool
|
||||||
|
|
||||||
|
// MethodAndPathSkipper returns skipper which
|
||||||
|
// will skip the middleware when r.Method equals the method and r.URL.Path matches the re
|
||||||
|
// when method is "*" it equals all http method
|
||||||
|
func MethodAndPathSkipper(method string, re *regexp.Regexp) func(r *http.Request) bool {
|
||||||
|
return func(r *http.Request) bool {
|
||||||
|
if (method == "*" || r.Method == method) && re.MatchString(r.URL.Path) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
48
src/server/middleware/skipper_test.go
Normal file
48
src/server/middleware/skipper_test.go
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
// Copyright Project Harbor Authors
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"reflect"
|
||||||
|
"regexp"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMethodAndPathSkipper(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
method string
|
||||||
|
re *regexp.Regexp
|
||||||
|
r *http.Request
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"match method and path", args{http.MethodGet, regexp.MustCompile(`/req`), httptest.NewRequest(http.MethodGet, "/req", nil)}, true},
|
||||||
|
{"match method only", args{http.MethodGet, regexp.MustCompile(`/req`), httptest.NewRequest(http.MethodGet, "/path", nil)}, false},
|
||||||
|
{"match path only", args{http.MethodGet, regexp.MustCompile(`/req`), httptest.NewRequest(http.MethodPost, "/req", nil)}, false},
|
||||||
|
{"match all methods", args{"*", regexp.MustCompile(`/req`), httptest.NewRequest(http.MethodPost, "/req", nil)}, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := MethodAndPathSkipper(tt.args.method, tt.args.re)(tt.args.r); !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("MethodAndPathSkipper()() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
91
src/server/middleware/transaction/transaction.go
Normal file
91
src/server/middleware/transaction/transaction.go
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
// Copyright Project Harbor Authors
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package transaction
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/goharbor/harbor/src/common/utils/log"
|
||||||
|
"github.com/goharbor/harbor/src/internal"
|
||||||
|
"github.com/goharbor/harbor/src/internal/orm"
|
||||||
|
"github.com/goharbor/harbor/src/server/middleware"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errNonSuccess = errors.New("non success status code")
|
||||||
|
)
|
||||||
|
|
||||||
|
type committableContext struct {
|
||||||
|
context.Context
|
||||||
|
committed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ctx *committableContext) Commit() {
|
||||||
|
ctx.committed = true
|
||||||
|
}
|
||||||
|
|
||||||
|
type committable interface {
|
||||||
|
Commit()
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustCommit mark http.Request as committed so that transaction
|
||||||
|
// middleware ignore the status code of the response and commit transaction for this request
|
||||||
|
func MustCommit(r *http.Request) error {
|
||||||
|
c, ok := r.Context().(committable)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("%s URL %s is not committable, please enable transaction middleware for it", r.Method, r.URL.Path)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Commit()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Middleware middleware which add transaction for the http request with default config
|
||||||
|
func Middleware(skippers ...middleware.Skipper) func(http.Handler) http.Handler {
|
||||||
|
return middleware.New(func(w http.ResponseWriter, r *http.Request, next http.Handler) {
|
||||||
|
res, ok := w.(*internal.ResponseBuffer)
|
||||||
|
if !ok {
|
||||||
|
res = internal.NewResponseBuffer(w)
|
||||||
|
defer res.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
h := func(ctx context.Context) error {
|
||||||
|
cc := &committableContext{Context: ctx}
|
||||||
|
next.ServeHTTP(res, r.WithContext(cc))
|
||||||
|
|
||||||
|
if !cc.committed && !res.Success() {
|
||||||
|
return errNonSuccess
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := orm.WithTransaction(h)(r.Context()); err != nil && err != errNonSuccess {
|
||||||
|
log.Errorf("deal with %s request in transaction failed: %v", r.URL.Path, err)
|
||||||
|
|
||||||
|
// begin, commit or rollback transaction db error happened,
|
||||||
|
// reset the response and set status code to 500
|
||||||
|
if err := res.Reset(); err != nil {
|
||||||
|
log.Errorf("reset the response failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
res.WriteHeader(http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
}, skippers...)
|
||||||
|
}
|
161
src/server/middleware/transaction/transaction_test.go
Normal file
161
src/server/middleware/transaction/transaction_test.go
Normal file
@ -0,0 +1,161 @@
|
|||||||
|
// Copyright Project Harbor Authors
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package transaction
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
o "github.com/astaxie/beego/orm"
|
||||||
|
"github.com/goharbor/harbor/src/internal/orm"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockOrmer struct {
|
||||||
|
o.Ormer
|
||||||
|
records []interface{}
|
||||||
|
beginErr error
|
||||||
|
commitErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockOrmer) Insert(i interface{}) (int64, error) {
|
||||||
|
m.records = append(m.records, i)
|
||||||
|
|
||||||
|
return int64(len(m.records)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockOrmer) Begin() error {
|
||||||
|
return m.beginErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockOrmer) Commit() error {
|
||||||
|
return m.commitErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockOrmer) Rollback() error {
|
||||||
|
m.ResetRecords()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockOrmer) ResetRecords() {
|
||||||
|
m.records = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockOrmer) Reset() {
|
||||||
|
m.ResetRecords()
|
||||||
|
|
||||||
|
m.beginErr = nil
|
||||||
|
m.commitErr = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransaction(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
|
||||||
|
mo := &mockOrmer{}
|
||||||
|
|
||||||
|
newRequest := func(method, target string, body io.Reader) *http.Request {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/req1", nil)
|
||||||
|
return req.WithContext(orm.NewContext(req.Context(), mo))
|
||||||
|
}
|
||||||
|
|
||||||
|
next := func(status int) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
mo.Insert("record1")
|
||||||
|
w.WriteHeader(status)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// test response status code accepted
|
||||||
|
req1 := newRequest(http.MethodGet, "/req", nil)
|
||||||
|
rec1 := httptest.NewRecorder()
|
||||||
|
Middleware()(next(http.StatusOK)).ServeHTTP(rec1, req1)
|
||||||
|
assert.Equal(http.StatusOK, rec1.Code)
|
||||||
|
assert.NotEmpty(mo.records)
|
||||||
|
|
||||||
|
mo.ResetRecords()
|
||||||
|
assert.Empty(mo.records)
|
||||||
|
|
||||||
|
// test response status code not accepted
|
||||||
|
req2 := newRequest(http.MethodGet, "/req", nil)
|
||||||
|
rec2 := httptest.NewRecorder()
|
||||||
|
Middleware()(next(http.StatusBadRequest)).ServeHTTP(rec2, req2)
|
||||||
|
assert.Equal(http.StatusBadRequest, rec2.Code)
|
||||||
|
assert.Empty(mo.records)
|
||||||
|
|
||||||
|
// test begin transaction failed
|
||||||
|
mo.beginErr = errors.New("begin tx failed")
|
||||||
|
req3 := newRequest(http.MethodGet, "/req", nil)
|
||||||
|
rec3 := httptest.NewRecorder()
|
||||||
|
Middleware()(next(http.StatusBadRequest)).ServeHTTP(rec3, req3)
|
||||||
|
assert.Equal(http.StatusInternalServerError, rec3.Code)
|
||||||
|
assert.Empty(mo.records)
|
||||||
|
|
||||||
|
// test commit transaction failed
|
||||||
|
mo.beginErr = nil
|
||||||
|
mo.commitErr = errors.New("commit tx failed")
|
||||||
|
req4 := newRequest(http.MethodGet, "/req", nil)
|
||||||
|
rec4 := httptest.NewRecorder()
|
||||||
|
Middleware()(next(http.StatusOK)).ServeHTTP(rec4, req4)
|
||||||
|
assert.Equal(http.StatusInternalServerError, rec4.Code)
|
||||||
|
|
||||||
|
// test MustCommit
|
||||||
|
mo.Reset()
|
||||||
|
assert.Empty(mo.records)
|
||||||
|
|
||||||
|
txMustCommit := func(status int) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
defer MustCommit(r)
|
||||||
|
mo.Insert("record1")
|
||||||
|
w.WriteHeader(status)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
req5 := newRequest(http.MethodGet, "/req", nil)
|
||||||
|
rec5 := httptest.NewRecorder()
|
||||||
|
Middleware()(txMustCommit(http.StatusBadRequest)).ServeHTTP(rec5, req5)
|
||||||
|
assert.Equal(http.StatusBadRequest, rec2.Code)
|
||||||
|
assert.NotEmpty(mo.records)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMustCommit(t *testing.T) {
|
||||||
|
newRequest := func(ctx context.Context) *http.Request {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/req", nil)
|
||||||
|
return req.WithContext(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
r *http.Request
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"request committable", args{newRequest(&committableContext{Context: context.Background()})}, false},
|
||||||
|
{"request not committable", args{newRequest(context.Background())}, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if err := MustCommit(tt.args.r); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("MustCommit() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user