diff --git a/src/server/middleware/csrf/csrf.go b/src/server/middleware/csrf/csrf.go index b5ebb87da..8ec8ff5a8 100644 --- a/src/server/middleware/csrf/csrf.go +++ b/src/server/middleware/csrf/csrf.go @@ -1,6 +1,11 @@ package csrf import ( + "net/http" + "os" + "strings" + "sync" + "github.com/goharbor/harbor/src/common/utils" "github.com/goharbor/harbor/src/common/utils/log" "github.com/goharbor/harbor/src/core/config" @@ -8,10 +13,6 @@ import ( serror "github.com/goharbor/harbor/src/server/error" "github.com/goharbor/harbor/src/server/middleware" "github.com/gorilla/csrf" - "net/http" - "os" - "strings" - "sync" ) const ( @@ -21,8 +22,9 @@ const ( ) var ( - once sync.Once - protect func(handler http.Handler) http.Handler + once sync.Once + secureFlag = true + protect func(handler http.Handler) http.Handler ) // attachToken makes sure if csrf generate a new token it will be included in the response header @@ -30,7 +32,7 @@ func attachToken(w http.ResponseWriter, r *http.Request) { if t := csrf.Token(r); len(t) > 0 { http.SetCookie(w, &http.Cookie{ Name: tokenCookie, - Secure: true, + Secure: secureFlag, Value: t, Path: "/", SameSite: http.SameSiteStrictMode, @@ -60,9 +62,10 @@ func Middleware() func(handler http.Handler) http.Handler { if len(key) != 32 { log.Warningf("Invalid CSRF key from environment: %s, generating random key...", key) key = utils.GenerateRandomString() - } + secureFlag = secureCookie() protect = csrf.Protect([]byte(key), csrf.RequestHeader(tokenHeader), + csrf.Secure(secureFlag), csrf.ErrorHandler(http.HandlerFunc(handleError)), csrf.SameSite(csrf.SameSiteStrictMode), csrf.Path("/")) @@ -87,3 +90,12 @@ func csrfSkipper(req *http.Request) bool { } return false } + +func secureCookie() bool { + ep, err := config.ExtEndpoint() + if err != nil { + log.Warningf("Failed to get external endpoint: %v, set cookie secure flag to true", err) + return true + } + return !strings.HasPrefix(strings.ToLower(ep), "http://") +} diff --git a/src/server/middleware/csrf/csrf_test.go b/src/server/middleware/csrf/csrf_test.go index 94b000e01..14eaf86e6 100644 --- a/src/server/middleware/csrf/csrf_test.go +++ b/src/server/middleware/csrf/csrf_test.go @@ -1,12 +1,24 @@ package csrf import ( + "github.com/goharbor/harbor/src/common" + "github.com/goharbor/harbor/src/core/config" "github.com/stretchr/testify/assert" "net/http" "net/http/httptest" + "os" "testing" ) +func TestMain(m *testing.M) { + conf := map[string]interface{}{} + config.InitWithSettings(conf) + result := m.Run() + if result != 0 { + os.Exit(result) + } +} + type handler struct { } @@ -58,3 +70,15 @@ func hasCookie(resp *http.Response, name string) bool { } return false } + +func TestSecureCookie(t *testing.T) { + assert.True(t, secureCookie()) + conf := map[string]interface{}{ + common.ExtEndpoint: "http://harbor.test", + } + config.InitWithSettings(conf) + + assert.False(t, secureCookie()) + conf = map[string]interface{}{} + config.InitWithSettings(conf) +}