1
0
mirror of https://github.com/goharbor/harbor.git synced 2024-12-21 08:07:59 +01:00

Merge pull request from heww/transtaction-middleware

feat(middleware): add transaction middleware for v2 and v2.0 APIs
This commit is contained in:
Wenkai Yin(尹文开) 2020-01-22 16:58:10 +08:00 committed by GitHub
commit d79a0e6030
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 468 additions and 33 deletions

View File

@ -17,6 +17,12 @@ package main
import (
"encoding/gob"
"fmt"
"os"
"os/signal"
"strconv"
"syscall"
"time"
"github.com/astaxie/beego"
_ "github.com/astaxie/beego/session/redis"
"github.com/goharbor/harbor/src/common/dao"
@ -48,15 +54,6 @@ import (
"github.com/goharbor/harbor/src/pkg/version"
"github.com/goharbor/harbor/src/replication"
"github.com/goharbor/harbor/src/server"
"github.com/goharbor/harbor/src/server/middleware/orm"
"github.com/goharbor/harbor/src/server/middleware/requestid"
"net/http"
"os"
"os/signal"
"strconv"
"strings"
"syscall"
"time"
)
const (
@ -292,21 +289,5 @@ func main() {
log.Infof("Version: %s, Git commit: %s", version.ReleaseVersion, version.GitCommit)
middlewares := []beego.MiddleWare{
requestid.Middleware(),
orm.Middleware(legacyAPISkipper),
}
beego.RunWithMiddleWares("", middlewares...)
}
// legacyAPISkipper skip middleware for legacy APIs
func legacyAPISkipper(r *http.Request) bool {
for _, prefix := range []string{"/v2/", "/api/v2.0/"} {
if strings.HasPrefix(r.URL.Path, prefix) {
return false
}
}
return true
beego.RunWithMiddleWares("", middlewares.MiddleWares()...)
}

View File

@ -0,0 +1,56 @@
// Copyright Project Harbor Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package middlewares
import (
"net/http"
"regexp"
"strings"
"github.com/astaxie/beego"
"github.com/docker/distribution/reference"
"github.com/goharbor/harbor/src/server/middleware"
"github.com/goharbor/harbor/src/server/middleware/orm"
"github.com/goharbor/harbor/src/server/middleware/requestid"
"github.com/goharbor/harbor/src/server/middleware/transaction"
)
var (
blobURLRe = regexp.MustCompile("^/v2/(" + reference.NameRegexp.String() + ")/blobs/" + reference.DigestRegexp.String())
// fetchBlobAPISkipper skip transaction middleware for fetch blob API
// because transaction use the ResponseBuffer for the response which will degrade the performance for fetch blob
fetchBlobAPISkipper = middleware.MethodAndPathSkipper(http.MethodGet, blobURLRe)
)
// legacyAPISkipper skip middleware for legacy APIs
func legacyAPISkipper(r *http.Request) bool {
for _, prefix := range []string{"/v2/", "/api/v2.0/"} {
if strings.HasPrefix(r.URL.Path, prefix) {
return false
}
}
return true
}
// MiddleWares returns global middlewares
func MiddleWares() []beego.MiddleWare {
return []beego.MiddleWare{
requestid.Middleware(),
orm.Middleware(legacyAPISkipper),
transaction.Middleware(legacyAPISkipper, fetchBlobAPISkipper),
}
}

View File

@ -0,0 +1,43 @@
// Copyright Project Harbor Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package middlewares
import (
"net/http"
"net/http/httptest"
"testing"
)
func Test_fetchBlobAPISkipper(t *testing.T) {
type args struct {
r *http.Request
}
tests := []struct {
name string
args args
want bool
}{
{"fetch blob", args{httptest.NewRequest(http.MethodGet, "/v2/library/photon/blobs/sha256:6e0447537050cf871f9ab6a3fec5715f9c6fff5212f6666993f1fc46b1f717a3", nil)}, true},
{"delete blob", args{httptest.NewRequest(http.MethodDelete, "/v2/library/photon/blobs/sha256:6e0447537050cf871f9ab6a3fec5715f9c6fff5212f6666993f1fc46b1f717a3", nil)}, false},
{"get manifest", args{httptest.NewRequest(http.MethodDelete, "/v2/library/photon/manifests/sha256:6e0447537050cf871f9ab6a3fec5715f9c6fff5212f6666993f1fc46b1f717a3", nil)}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := fetchBlobAPISkipper(tt.args.r); got != tt.want {
t.Errorf("fetchBlobAPISkipper() = %v, want %v", got, tt.want)
}
})
}
}

View File

@ -16,6 +16,7 @@ package internal
import (
"bytes"
"errors"
"net/http"
)
@ -26,6 +27,7 @@ type ResponseBuffer struct {
header http.Header
buffer bytes.Buffer
wroteHeader bool
flushed bool
}
// NewResponseBuffer creates a ResponseBuffer object
@ -48,7 +50,9 @@ func (r *ResponseBuffer) WriteHeader(statusCode int) {
// Write writes the data into the buffer without writing to the underlying response writer
func (r *ResponseBuffer) Write(data []byte) (int, error) {
r.WriteHeader(http.StatusOK)
if !r.wroteHeader {
r.WriteHeader(http.StatusOK)
}
return r.buffer.Write(data)
}
@ -59,6 +63,8 @@ func (r *ResponseBuffer) Header() http.Header {
// Flush the status code, header and data into the underlying response writer
func (r *ResponseBuffer) Flush() (int, error) {
r.flushed = true
header := r.w.Header()
for k, vs := range r.header {
for _, v := range vs {
@ -73,5 +79,19 @@ func (r *ResponseBuffer) Flush() (int, error) {
// Success checks whether the status code is >= 200 & <= 399
func (r *ResponseBuffer) Success() bool {
return r.code >= 200 && r.code <= 399
return r.code >= http.StatusOK && r.code < http.StatusBadRequest
}
// Reset reset the response buffer
func (r *ResponseBuffer) Reset() error {
if r.flushed {
return errors.New("response flushed")
}
r.code = 0
r.wroteHeader = false
r.header = http.Header{}
r.buffer = bytes.Buffer{}
return nil
}

View File

@ -14,7 +14,9 @@
package middleware
import "net/http"
import (
"net/http"
)
// Middleware receives a handler and returns another handler.
// The returned handler can do some customized task according to
@ -30,10 +32,6 @@ func WithMiddlewares(handler http.Handler, middlewares ...Middleware) http.Handl
return handler
}
// Skipper defines a function to skip middleware.
// Returning true skips processing the middleware.
type Skipper func(*http.Request) bool
// New make a middleware from fn which type is func(w http.ResponseWriter, r *http.Request, next http.Handler)
func New(fn func(http.ResponseWriter, *http.Request, http.Handler), skippers ...Skipper) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {

View File

@ -0,0 +1,37 @@
// Copyright Project Harbor Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package middleware
import (
"net/http"
"regexp"
)
// Skipper defines a function to skip middleware.
// Returning true skips processing the middleware.
type Skipper func(*http.Request) bool
// MethodAndPathSkipper returns skipper which
// will skip the middleware when r.Method equals the method and r.URL.Path matches the re
// when method is "*" it equals all http method
func MethodAndPathSkipper(method string, re *regexp.Regexp) func(r *http.Request) bool {
return func(r *http.Request) bool {
if (method == "*" || r.Method == method) && re.MatchString(r.URL.Path) {
return true
}
return false
}
}

View File

@ -0,0 +1,48 @@
// Copyright Project Harbor Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package middleware
import (
"net/http"
"net/http/httptest"
"reflect"
"regexp"
"testing"
)
func TestMethodAndPathSkipper(t *testing.T) {
type args struct {
method string
re *regexp.Regexp
r *http.Request
}
tests := []struct {
name string
args args
want bool
}{
{"match method and path", args{http.MethodGet, regexp.MustCompile(`/req`), httptest.NewRequest(http.MethodGet, "/req", nil)}, true},
{"match method only", args{http.MethodGet, regexp.MustCompile(`/req`), httptest.NewRequest(http.MethodGet, "/path", nil)}, false},
{"match path only", args{http.MethodGet, regexp.MustCompile(`/req`), httptest.NewRequest(http.MethodPost, "/req", nil)}, false},
{"match all methods", args{"*", regexp.MustCompile(`/req`), httptest.NewRequest(http.MethodPost, "/req", nil)}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := MethodAndPathSkipper(tt.args.method, tt.args.re)(tt.args.r); !reflect.DeepEqual(got, tt.want) {
t.Errorf("MethodAndPathSkipper()() = %v, want %v", got, tt.want)
}
})
}
}

View File

@ -0,0 +1,91 @@
// Copyright Project Harbor Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package transaction
import (
"context"
"errors"
"fmt"
"net/http"
"github.com/goharbor/harbor/src/common/utils/log"
"github.com/goharbor/harbor/src/internal"
"github.com/goharbor/harbor/src/internal/orm"
"github.com/goharbor/harbor/src/server/middleware"
)
var (
errNonSuccess = errors.New("non success status code")
)
type committableContext struct {
context.Context
committed bool
}
func (ctx *committableContext) Commit() {
ctx.committed = true
}
type committable interface {
Commit()
}
// MustCommit mark http.Request as committed so that transaction
// middleware ignore the status code of the response and commit transaction for this request
func MustCommit(r *http.Request) error {
c, ok := r.Context().(committable)
if !ok {
return fmt.Errorf("%s URL %s is not committable, please enable transaction middleware for it", r.Method, r.URL.Path)
}
c.Commit()
return nil
}
// Middleware middleware which add transaction for the http request with default config
func Middleware(skippers ...middleware.Skipper) func(http.Handler) http.Handler {
return middleware.New(func(w http.ResponseWriter, r *http.Request, next http.Handler) {
res, ok := w.(*internal.ResponseBuffer)
if !ok {
res = internal.NewResponseBuffer(w)
defer res.Flush()
}
h := func(ctx context.Context) error {
cc := &committableContext{Context: ctx}
next.ServeHTTP(res, r.WithContext(cc))
if !cc.committed && !res.Success() {
return errNonSuccess
}
return nil
}
if err := orm.WithTransaction(h)(r.Context()); err != nil && err != errNonSuccess {
log.Errorf("deal with %s request in transaction failed: %v", r.URL.Path, err)
// begin, commit or rollback transaction db error happened,
// reset the response and set status code to 500
if err := res.Reset(); err != nil {
log.Errorf("reset the response failed: %v", err)
return
}
res.WriteHeader(http.StatusInternalServerError)
}
}, skippers...)
}

View File

@ -0,0 +1,161 @@
// Copyright Project Harbor Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package transaction
import (
"context"
"errors"
"io"
"net/http"
"net/http/httptest"
"testing"
o "github.com/astaxie/beego/orm"
"github.com/goharbor/harbor/src/internal/orm"
"github.com/stretchr/testify/assert"
)
type mockOrmer struct {
o.Ormer
records []interface{}
beginErr error
commitErr error
}
func (m *mockOrmer) Insert(i interface{}) (int64, error) {
m.records = append(m.records, i)
return int64(len(m.records)), nil
}
func (m *mockOrmer) Begin() error {
return m.beginErr
}
func (m *mockOrmer) Commit() error {
return m.commitErr
}
func (m *mockOrmer) Rollback() error {
m.ResetRecords()
return nil
}
func (m *mockOrmer) ResetRecords() {
m.records = nil
}
func (m *mockOrmer) Reset() {
m.ResetRecords()
m.beginErr = nil
m.commitErr = nil
}
func TestTransaction(t *testing.T) {
assert := assert.New(t)
mo := &mockOrmer{}
newRequest := func(method, target string, body io.Reader) *http.Request {
req := httptest.NewRequest(http.MethodGet, "/req1", nil)
return req.WithContext(orm.NewContext(req.Context(), mo))
}
next := func(status int) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mo.Insert("record1")
w.WriteHeader(status)
})
}
// test response status code accepted
req1 := newRequest(http.MethodGet, "/req", nil)
rec1 := httptest.NewRecorder()
Middleware()(next(http.StatusOK)).ServeHTTP(rec1, req1)
assert.Equal(http.StatusOK, rec1.Code)
assert.NotEmpty(mo.records)
mo.ResetRecords()
assert.Empty(mo.records)
// test response status code not accepted
req2 := newRequest(http.MethodGet, "/req", nil)
rec2 := httptest.NewRecorder()
Middleware()(next(http.StatusBadRequest)).ServeHTTP(rec2, req2)
assert.Equal(http.StatusBadRequest, rec2.Code)
assert.Empty(mo.records)
// test begin transaction failed
mo.beginErr = errors.New("begin tx failed")
req3 := newRequest(http.MethodGet, "/req", nil)
rec3 := httptest.NewRecorder()
Middleware()(next(http.StatusBadRequest)).ServeHTTP(rec3, req3)
assert.Equal(http.StatusInternalServerError, rec3.Code)
assert.Empty(mo.records)
// test commit transaction failed
mo.beginErr = nil
mo.commitErr = errors.New("commit tx failed")
req4 := newRequest(http.MethodGet, "/req", nil)
rec4 := httptest.NewRecorder()
Middleware()(next(http.StatusOK)).ServeHTTP(rec4, req4)
assert.Equal(http.StatusInternalServerError, rec4.Code)
// test MustCommit
mo.Reset()
assert.Empty(mo.records)
txMustCommit := func(status int) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer MustCommit(r)
mo.Insert("record1")
w.WriteHeader(status)
})
}
req5 := newRequest(http.MethodGet, "/req", nil)
rec5 := httptest.NewRecorder()
Middleware()(txMustCommit(http.StatusBadRequest)).ServeHTTP(rec5, req5)
assert.Equal(http.StatusBadRequest, rec2.Code)
assert.NotEmpty(mo.records)
}
func TestMustCommit(t *testing.T) {
newRequest := func(ctx context.Context) *http.Request {
req := httptest.NewRequest(http.MethodGet, "/req", nil)
return req.WithContext(ctx)
}
type args struct {
r *http.Request
}
tests := []struct {
name string
args args
wantErr bool
}{
{"request committable", args{newRequest(&committableContext{Context: context.Background()})}, false},
{"request not committable", args{newRequest(context.Background())}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := MustCommit(tt.args.r); (err != nil) != tt.wantErr {
t.Errorf("MustCommit() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}