diff --git a/src/core/middlewares/middlewares.go b/src/core/middlewares/middlewares.go index ce2ea5981..84efc9942 100644 --- a/src/core/middlewares/middlewares.go +++ b/src/core/middlewares/middlewares.go @@ -24,6 +24,7 @@ import ( "github.com/goharbor/harbor/src/server/middleware/artifactinfo" "github.com/goharbor/harbor/src/server/middleware/csrf" "github.com/goharbor/harbor/src/server/middleware/log" + "github.com/goharbor/harbor/src/server/middleware/mergeslash" "github.com/goharbor/harbor/src/server/middleware/metric" "github.com/goharbor/harbor/src/server/middleware/notification" "github.com/goharbor/harbor/src/server/middleware/orm" @@ -77,6 +78,7 @@ var ( // MiddleWares returns global middlewares func MiddleWares() []beego.MiddleWare { return []beego.MiddleWare{ + mergeslash.Middleware(), metric.Middleware(), requestid.Middleware(), log.Middleware(), diff --git a/src/server/middleware/mergeslash/mergeslash.go b/src/server/middleware/mergeslash/mergeslash.go new file mode 100644 index 000000000..622546808 --- /dev/null +++ b/src/server/middleware/mergeslash/mergeslash.go @@ -0,0 +1,19 @@ +package mergeslash + +import ( + "net/http" + "regexp" + + "github.com/goharbor/harbor/src/server/middleware" +) + +var multiSlash = regexp.MustCompile(`(/+)`) + +// Middleware creates the middleware to merge slashes in the URL path of the request +func Middleware(skippers ...middleware.Skipper) func(http.Handler) http.Handler { + return middleware.New(func(w http.ResponseWriter, r *http.Request, next http.Handler) { + path := multiSlash.ReplaceAll([]byte(r.URL.Path), []byte("/")) + r.URL.Path = string(path) + next.ServeHTTP(w, r) + }, skippers...) +} diff --git a/src/server/middleware/mergeslash/mergeslash_test.go b/src/server/middleware/mergeslash/mergeslash_test.go new file mode 100644 index 000000000..9397ccff2 --- /dev/null +++ b/src/server/middleware/mergeslash/mergeslash_test.go @@ -0,0 +1,52 @@ +package mergeslash + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +type handler struct { + path string +} + +func (h *handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + h.path = req.URL.Path + w.WriteHeader(200) +} + +func TestMergeSlash(t *testing.T) { + next := &handler{} + rec := httptest.NewRecorder() + req1, _ := http.NewRequest(http.MethodGet, "https://test.local/api/v2.0/systeminfo/", nil) + req2, _ := http.NewRequest(http.MethodGet, "https://test.local/v2//////_catalog", nil) + req3, _ := http.NewRequest(http.MethodPost, "https://test.local/v2/library///////ubuntu//blobs/uploads///////", nil) + req4, _ := http.NewRequest(http.MethodGet, "https://test.local//api/v2.0///////artifacts?scan_overview=false", nil) + cases := []struct { + req *http.Request + expectedPath string + }{ + { + req: req1, + expectedPath: "/api/v2.0/systeminfo/", + }, + { + req: req2, + expectedPath: "/v2/_catalog", + }, + { + req: req3, + expectedPath: "/v2/library/ubuntu/blobs/uploads/", + }, + { + req: req4, + expectedPath: "/api/v2.0/artifacts", + }, + } + for _, tt := range cases { + Middleware()(next).ServeHTTP(rec, tt.req) + assert.Equal(t, tt.expectedPath, next.path) + } +}