Merge pull request #15011 from reasonerjt/merge-slash-middleware

Add merge slash middleware
This commit is contained in:
Daniel Jiang 2021-05-31 13:09:39 +08:00 committed by GitHub
commit 486554caa1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 73 additions and 0 deletions

View File

@ -24,6 +24,7 @@ import (
"github.com/goharbor/harbor/src/server/middleware/artifactinfo"
"github.com/goharbor/harbor/src/server/middleware/csrf"
"github.com/goharbor/harbor/src/server/middleware/log"
"github.com/goharbor/harbor/src/server/middleware/mergeslash"
"github.com/goharbor/harbor/src/server/middleware/metric"
"github.com/goharbor/harbor/src/server/middleware/notification"
"github.com/goharbor/harbor/src/server/middleware/orm"
@ -77,6 +78,7 @@ var (
// MiddleWares returns global middlewares
func MiddleWares() []beego.MiddleWare {
return []beego.MiddleWare{
mergeslash.Middleware(),
metric.Middleware(),
requestid.Middleware(),
log.Middleware(),

View File

@ -0,0 +1,19 @@
package mergeslash
import (
"net/http"
"regexp"
"github.com/goharbor/harbor/src/server/middleware"
)
var multiSlash = regexp.MustCompile(`(/+)`)
// Middleware creates the middleware to merge slashes in the URL path of the request
func Middleware(skippers ...middleware.Skipper) func(http.Handler) http.Handler {
return middleware.New(func(w http.ResponseWriter, r *http.Request, next http.Handler) {
path := multiSlash.ReplaceAll([]byte(r.URL.Path), []byte("/"))
r.URL.Path = string(path)
next.ServeHTTP(w, r)
}, skippers...)
}

View File

@ -0,0 +1,52 @@
package mergeslash
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
type handler struct {
path string
}
func (h *handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
h.path = req.URL.Path
w.WriteHeader(200)
}
func TestMergeSlash(t *testing.T) {
next := &handler{}
rec := httptest.NewRecorder()
req1, _ := http.NewRequest(http.MethodGet, "https://test.local/api/v2.0/systeminfo/", nil)
req2, _ := http.NewRequest(http.MethodGet, "https://test.local/v2//////_catalog", nil)
req3, _ := http.NewRequest(http.MethodPost, "https://test.local/v2/library///////ubuntu//blobs/uploads///////", nil)
req4, _ := http.NewRequest(http.MethodGet, "https://test.local//api/v2.0///////artifacts?scan_overview=false", nil)
cases := []struct {
req *http.Request
expectedPath string
}{
{
req: req1,
expectedPath: "/api/v2.0/systeminfo/",
},
{
req: req2,
expectedPath: "/v2/_catalog",
},
{
req: req3,
expectedPath: "/v2/library/ubuntu/blobs/uploads/",
},
{
req: req4,
expectedPath: "/api/v2.0/artifacts",
},
}
for _, tt := range cases {
Middleware()(next).ServeHTTP(rec, tt.req)
assert.Equal(t, tt.expectedPath, next.path)
}
}