diff --git a/src/core/main.go b/src/core/main.go index ff06b9195..cd389549e 100755 --- a/src/core/main.go +++ b/src/core/main.go @@ -17,6 +17,12 @@ package main import ( "encoding/gob" "fmt" + "os" + "os/signal" + "strconv" + "syscall" + "time" + "github.com/astaxie/beego" _ "github.com/astaxie/beego/session/redis" "github.com/goharbor/harbor/src/common/dao" @@ -48,15 +54,6 @@ import ( "github.com/goharbor/harbor/src/pkg/version" "github.com/goharbor/harbor/src/replication" "github.com/goharbor/harbor/src/server" - "github.com/goharbor/harbor/src/server/middleware/orm" - "github.com/goharbor/harbor/src/server/middleware/requestid" - "net/http" - "os" - "os/signal" - "strconv" - "strings" - "syscall" - "time" ) const ( @@ -292,21 +289,5 @@ func main() { log.Infof("Version: %s, Git commit: %s", version.ReleaseVersion, version.GitCommit) - middlewares := []beego.MiddleWare{ - requestid.Middleware(), - orm.Middleware(legacyAPISkipper), - } - beego.RunWithMiddleWares("", middlewares...) - -} - -// legacyAPISkipper skip middleware for legacy APIs -func legacyAPISkipper(r *http.Request) bool { - for _, prefix := range []string{"/v2/", "/api/v2.0/"} { - if strings.HasPrefix(r.URL.Path, prefix) { - return false - } - } - - return true + beego.RunWithMiddleWares("", middlewares.MiddleWares()...) } diff --git a/src/core/middlewares/middlewares.go b/src/core/middlewares/middlewares.go new file mode 100644 index 000000000..e5fbc3eae --- /dev/null +++ b/src/core/middlewares/middlewares.go @@ -0,0 +1,56 @@ +// Copyright Project Harbor Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package middlewares + +import ( + "net/http" + "regexp" + "strings" + + "github.com/astaxie/beego" + "github.com/docker/distribution/reference" + "github.com/goharbor/harbor/src/server/middleware" + "github.com/goharbor/harbor/src/server/middleware/orm" + "github.com/goharbor/harbor/src/server/middleware/requestid" + "github.com/goharbor/harbor/src/server/middleware/transaction" +) + +var ( + blobURLRe = regexp.MustCompile("^/v2/(" + reference.NameRegexp.String() + ")/blobs/" + reference.DigestRegexp.String()) + + // fetchBlobAPISkipper skip transaction middleware for fetch blob API + // because transaction use the ResponseBuffer for the response which will degrade the performance for fetch blob + fetchBlobAPISkipper = middleware.MethodAndPathSkipper(http.MethodGet, blobURLRe) +) + +// legacyAPISkipper skip middleware for legacy APIs +func legacyAPISkipper(r *http.Request) bool { + for _, prefix := range []string{"/v2/", "/api/v2.0/"} { + if strings.HasPrefix(r.URL.Path, prefix) { + return false + } + } + + return true +} + +// MiddleWares returns global middlewares +func MiddleWares() []beego.MiddleWare { + return []beego.MiddleWare{ + requestid.Middleware(), + orm.Middleware(legacyAPISkipper), + transaction.Middleware(legacyAPISkipper, fetchBlobAPISkipper), + } +} diff --git a/src/core/middlewares/middlewares_test.go b/src/core/middlewares/middlewares_test.go new file mode 100644 index 000000000..60b7d5995 --- /dev/null +++ b/src/core/middlewares/middlewares_test.go @@ -0,0 +1,43 @@ +// Copyright Project Harbor Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package middlewares + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func Test_fetchBlobAPISkipper(t *testing.T) { + type args struct { + r *http.Request + } + tests := []struct { + name string + args args + want bool + }{ + {"fetch blob", args{httptest.NewRequest(http.MethodGet, "/v2/library/photon/blobs/sha256:6e0447537050cf871f9ab6a3fec5715f9c6fff5212f6666993f1fc46b1f717a3", nil)}, true}, + {"delete blob", args{httptest.NewRequest(http.MethodDelete, "/v2/library/photon/blobs/sha256:6e0447537050cf871f9ab6a3fec5715f9c6fff5212f6666993f1fc46b1f717a3", nil)}, false}, + {"get manifest", args{httptest.NewRequest(http.MethodDelete, "/v2/library/photon/manifests/sha256:6e0447537050cf871f9ab6a3fec5715f9c6fff5212f6666993f1fc46b1f717a3", nil)}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := fetchBlobAPISkipper(tt.args.r); got != tt.want { + t.Errorf("fetchBlobAPISkipper() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/src/internal/response_buffer.go b/src/internal/response_buffer.go index 79669a41b..dba8dfc4b 100644 --- a/src/internal/response_buffer.go +++ b/src/internal/response_buffer.go @@ -16,6 +16,7 @@ package internal import ( "bytes" + "errors" "net/http" ) @@ -26,6 +27,7 @@ type ResponseBuffer struct { header http.Header buffer bytes.Buffer wroteHeader bool + flushed bool } // NewResponseBuffer creates a ResponseBuffer object @@ -48,7 +50,9 @@ func (r *ResponseBuffer) WriteHeader(statusCode int) { // Write writes the data into the buffer without writing to the underlying response writer func (r *ResponseBuffer) Write(data []byte) (int, error) { - r.WriteHeader(http.StatusOK) + if !r.wroteHeader { + r.WriteHeader(http.StatusOK) + } return r.buffer.Write(data) } @@ -59,6 +63,8 @@ func (r *ResponseBuffer) Header() http.Header { // Flush the status code, header and data into the underlying response writer func (r *ResponseBuffer) Flush() (int, error) { + r.flushed = true + header := r.w.Header() for k, vs := range r.header { for _, v := range vs { @@ -73,5 +79,19 @@ func (r *ResponseBuffer) Flush() (int, error) { // Success checks whether the status code is >= 200 & <= 399 func (r *ResponseBuffer) Success() bool { - return r.code >= 200 && r.code <= 399 + return r.code >= http.StatusOK && r.code < http.StatusBadRequest +} + +// Reset reset the response buffer +func (r *ResponseBuffer) Reset() error { + if r.flushed { + return errors.New("response flushed") + } + + r.code = 0 + r.wroteHeader = false + r.header = http.Header{} + r.buffer = bytes.Buffer{} + + return nil } diff --git a/src/server/middleware/middleware.go b/src/server/middleware/middleware.go index eeae2a356..ccac9b07c 100644 --- a/src/server/middleware/middleware.go +++ b/src/server/middleware/middleware.go @@ -14,7 +14,9 @@ package middleware -import "net/http" +import ( + "net/http" +) // Middleware receives a handler and returns another handler. // The returned handler can do some customized task according to @@ -30,10 +32,6 @@ func WithMiddlewares(handler http.Handler, middlewares ...Middleware) http.Handl return handler } -// Skipper defines a function to skip middleware. -// Returning true skips processing the middleware. -type Skipper func(*http.Request) bool - // New make a middleware from fn which type is func(w http.ResponseWriter, r *http.Request, next http.Handler) func New(fn func(http.ResponseWriter, *http.Request, http.Handler), skippers ...Skipper) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { diff --git a/src/server/middleware/skipper.go b/src/server/middleware/skipper.go new file mode 100644 index 000000000..b99d4807d --- /dev/null +++ b/src/server/middleware/skipper.go @@ -0,0 +1,37 @@ +// Copyright Project Harbor Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package middleware + +import ( + "net/http" + "regexp" +) + +// Skipper defines a function to skip middleware. +// Returning true skips processing the middleware. +type Skipper func(*http.Request) bool + +// MethodAndPathSkipper returns skipper which +// will skip the middleware when r.Method equals the method and r.URL.Path matches the re +// when method is "*" it equals all http method +func MethodAndPathSkipper(method string, re *regexp.Regexp) func(r *http.Request) bool { + return func(r *http.Request) bool { + if (method == "*" || r.Method == method) && re.MatchString(r.URL.Path) { + return true + } + + return false + } +} diff --git a/src/server/middleware/skipper_test.go b/src/server/middleware/skipper_test.go new file mode 100644 index 000000000..dd45de525 --- /dev/null +++ b/src/server/middleware/skipper_test.go @@ -0,0 +1,48 @@ +// Copyright Project Harbor Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package middleware + +import ( + "net/http" + "net/http/httptest" + "reflect" + "regexp" + "testing" +) + +func TestMethodAndPathSkipper(t *testing.T) { + type args struct { + method string + re *regexp.Regexp + r *http.Request + } + tests := []struct { + name string + args args + want bool + }{ + {"match method and path", args{http.MethodGet, regexp.MustCompile(`/req`), httptest.NewRequest(http.MethodGet, "/req", nil)}, true}, + {"match method only", args{http.MethodGet, regexp.MustCompile(`/req`), httptest.NewRequest(http.MethodGet, "/path", nil)}, false}, + {"match path only", args{http.MethodGet, regexp.MustCompile(`/req`), httptest.NewRequest(http.MethodPost, "/req", nil)}, false}, + {"match all methods", args{"*", regexp.MustCompile(`/req`), httptest.NewRequest(http.MethodPost, "/req", nil)}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := MethodAndPathSkipper(tt.args.method, tt.args.re)(tt.args.r); !reflect.DeepEqual(got, tt.want) { + t.Errorf("MethodAndPathSkipper()() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/src/server/middleware/transaction/transaction.go b/src/server/middleware/transaction/transaction.go new file mode 100644 index 000000000..f7dd8f25d --- /dev/null +++ b/src/server/middleware/transaction/transaction.go @@ -0,0 +1,91 @@ +// Copyright Project Harbor Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package transaction + +import ( + "context" + "errors" + "fmt" + "net/http" + + "github.com/goharbor/harbor/src/common/utils/log" + "github.com/goharbor/harbor/src/internal" + "github.com/goharbor/harbor/src/internal/orm" + "github.com/goharbor/harbor/src/server/middleware" +) + +var ( + errNonSuccess = errors.New("non success status code") +) + +type committableContext struct { + context.Context + committed bool +} + +func (ctx *committableContext) Commit() { + ctx.committed = true +} + +type committable interface { + Commit() +} + +// MustCommit mark http.Request as committed so that transaction +// middleware ignore the status code of the response and commit transaction for this request +func MustCommit(r *http.Request) error { + c, ok := r.Context().(committable) + if !ok { + return fmt.Errorf("%s URL %s is not committable, please enable transaction middleware for it", r.Method, r.URL.Path) + } + + c.Commit() + + return nil +} + +// Middleware middleware which add transaction for the http request with default config +func Middleware(skippers ...middleware.Skipper) func(http.Handler) http.Handler { + return middleware.New(func(w http.ResponseWriter, r *http.Request, next http.Handler) { + res, ok := w.(*internal.ResponseBuffer) + if !ok { + res = internal.NewResponseBuffer(w) + defer res.Flush() + } + + h := func(ctx context.Context) error { + cc := &committableContext{Context: ctx} + next.ServeHTTP(res, r.WithContext(cc)) + + if !cc.committed && !res.Success() { + return errNonSuccess + } + + return nil + } + + if err := orm.WithTransaction(h)(r.Context()); err != nil && err != errNonSuccess { + log.Errorf("deal with %s request in transaction failed: %v", r.URL.Path, err) + + // begin, commit or rollback transaction db error happened, + // reset the response and set status code to 500 + if err := res.Reset(); err != nil { + log.Errorf("reset the response failed: %v", err) + return + } + res.WriteHeader(http.StatusInternalServerError) + } + }, skippers...) +} diff --git a/src/server/middleware/transaction/transaction_test.go b/src/server/middleware/transaction/transaction_test.go new file mode 100644 index 000000000..0c5e179a7 --- /dev/null +++ b/src/server/middleware/transaction/transaction_test.go @@ -0,0 +1,161 @@ +// Copyright Project Harbor Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package transaction + +import ( + "context" + "errors" + "io" + "net/http" + "net/http/httptest" + "testing" + + o "github.com/astaxie/beego/orm" + "github.com/goharbor/harbor/src/internal/orm" + "github.com/stretchr/testify/assert" +) + +type mockOrmer struct { + o.Ormer + records []interface{} + beginErr error + commitErr error +} + +func (m *mockOrmer) Insert(i interface{}) (int64, error) { + m.records = append(m.records, i) + + return int64(len(m.records)), nil +} + +func (m *mockOrmer) Begin() error { + return m.beginErr +} + +func (m *mockOrmer) Commit() error { + return m.commitErr +} + +func (m *mockOrmer) Rollback() error { + m.ResetRecords() + + return nil +} + +func (m *mockOrmer) ResetRecords() { + m.records = nil +} + +func (m *mockOrmer) Reset() { + m.ResetRecords() + + m.beginErr = nil + m.commitErr = nil +} + +func TestTransaction(t *testing.T) { + assert := assert.New(t) + + mo := &mockOrmer{} + + newRequest := func(method, target string, body io.Reader) *http.Request { + req := httptest.NewRequest(http.MethodGet, "/req1", nil) + return req.WithContext(orm.NewContext(req.Context(), mo)) + } + + next := func(status int) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mo.Insert("record1") + w.WriteHeader(status) + }) + } + + // test response status code accepted + req1 := newRequest(http.MethodGet, "/req", nil) + rec1 := httptest.NewRecorder() + Middleware()(next(http.StatusOK)).ServeHTTP(rec1, req1) + assert.Equal(http.StatusOK, rec1.Code) + assert.NotEmpty(mo.records) + + mo.ResetRecords() + assert.Empty(mo.records) + + // test response status code not accepted + req2 := newRequest(http.MethodGet, "/req", nil) + rec2 := httptest.NewRecorder() + Middleware()(next(http.StatusBadRequest)).ServeHTTP(rec2, req2) + assert.Equal(http.StatusBadRequest, rec2.Code) + assert.Empty(mo.records) + + // test begin transaction failed + mo.beginErr = errors.New("begin tx failed") + req3 := newRequest(http.MethodGet, "/req", nil) + rec3 := httptest.NewRecorder() + Middleware()(next(http.StatusBadRequest)).ServeHTTP(rec3, req3) + assert.Equal(http.StatusInternalServerError, rec3.Code) + assert.Empty(mo.records) + + // test commit transaction failed + mo.beginErr = nil + mo.commitErr = errors.New("commit tx failed") + req4 := newRequest(http.MethodGet, "/req", nil) + rec4 := httptest.NewRecorder() + Middleware()(next(http.StatusOK)).ServeHTTP(rec4, req4) + assert.Equal(http.StatusInternalServerError, rec4.Code) + + // test MustCommit + mo.Reset() + assert.Empty(mo.records) + + txMustCommit := func(status int) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer MustCommit(r) + mo.Insert("record1") + w.WriteHeader(status) + }) + } + + req5 := newRequest(http.MethodGet, "/req", nil) + rec5 := httptest.NewRecorder() + Middleware()(txMustCommit(http.StatusBadRequest)).ServeHTTP(rec5, req5) + assert.Equal(http.StatusBadRequest, rec2.Code) + assert.NotEmpty(mo.records) +} + +func TestMustCommit(t *testing.T) { + newRequest := func(ctx context.Context) *http.Request { + req := httptest.NewRequest(http.MethodGet, "/req", nil) + return req.WithContext(ctx) + } + + type args struct { + r *http.Request + } + tests := []struct { + name string + args args + wantErr bool + }{ + {"request committable", args{newRequest(&committableContext{Context: context.Background()})}, false}, + {"request not committable", args{newRequest(context.Background())}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := MustCommit(tt.args.r); (err != nil) != tt.wantErr { + t.Errorf("MustCommit() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +}