[cherry-pick]add url raw query check middleware (#17073)

add url raw query check middleware

The middleware can give a uniform url validation and raised error early.

Signed-off-by: Wang Yan <wangyan@vmware.com>
This commit is contained in:
Wang Yan 2022-06-27 11:32:37 +08:00 committed by GitHub
parent 86056cab75
commit a4cb1a481f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 62 additions and 0 deletions

View File

@ -34,6 +34,7 @@ import (
"github.com/goharbor/harbor/src/server/middleware/session"
"github.com/goharbor/harbor/src/server/middleware/trace"
"github.com/goharbor/harbor/src/server/middleware/transaction"
"github.com/goharbor/harbor/src/server/middleware/url"
)
var (
@ -79,6 +80,7 @@ var (
// MiddleWares returns global middlewares
func MiddleWares() []beego.MiddleWare {
return []beego.MiddleWare{
url.Middleware(),
mergeslash.Middleware(),
trace.Middleware(),
metric.Middleware(),

View File

@ -0,0 +1,24 @@
package url
import (
"net/http"
"net/url"
"github.com/goharbor/harbor/src/lib/errors"
lib_http "github.com/goharbor/harbor/src/lib/http"
"github.com/goharbor/harbor/src/server/middleware"
)
// Middleware middleware which validates the raw query, especially for the invalid semicolon separator.
func Middleware(skippers ...middleware.Skipper) func(http.Handler) http.Handler {
return middleware.New(func(w http.ResponseWriter, r *http.Request, next http.Handler) {
if r.URL != nil && r.URL.RawQuery != "" {
_, err := url.ParseQuery(r.URL.RawQuery)
if err != nil {
lib_http.SendError(w, errors.New(err).WithCode(errors.BadRequestCode))
return
}
}
next.ServeHTTP(w, r)
}, skippers...)
}

View File

@ -0,0 +1,36 @@
package url
import (
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"testing"
)
func TestURL(t *testing.T) {
assert := assert.New(t)
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
req1 := httptest.NewRequest(http.MethodPost, "/req1?mount=sha256&from=test", nil)
rec1 := httptest.NewRecorder()
Middleware()(next).ServeHTTP(rec1, req1)
assert.Equal(http.StatusOK, rec1.Code)
req2 := httptest.NewRequest(http.MethodPost, "/req2?mount=sha256&from=test;", nil)
rec2 := httptest.NewRecorder()
Middleware()(next).ServeHTTP(rec2, req2)
assert.Equal(http.StatusBadRequest, rec2.Code)
req3 := httptest.NewRequest(http.MethodGet, "/req3?foo=bar?", nil)
rec3 := httptest.NewRecorder()
Middleware()(next).ServeHTTP(rec3, req3)
assert.Equal(http.StatusOK, rec3.Code)
req4 := httptest.NewRequest(http.MethodGet, "/req4", nil)
rec4 := httptest.NewRecorder()
Middleware()(next).ServeHTTP(rec4, req4)
assert.Equal(http.StatusOK, rec4.Code)
}