Consider the default port when comparing the hosts

This commit cover the cases when the port is set in one of the Host of
request or the core URL, to make sure the comparison works as expected
when the default port (80, 443) is added in only one of them.

Signed-off-by: Daniel Jiang <jiangd@vmware.com>
This commit is contained in:
Daniel Jiang 2020-11-30 20:07:08 +08:00
parent dec12308a1
commit 413dad98a8
2 changed files with 84 additions and 7 deletions

View File

@ -110,18 +110,33 @@ func getChallenge(req *http.Request, accessList []access) string {
} }
func tokenSvcEndpoint(req *http.Request) (string, error) { func tokenSvcEndpoint(req *http.Request) (string, error) {
logger := log.G(req.Context())
rawCoreURL := config.InternalCoreURL() rawCoreURL := config.InternalCoreURL()
if coreURL, err := url.Parse(rawCoreURL); err == nil { if match(req.Context(), req.Host, rawCoreURL) {
if req.Host == coreURL.Host { return rawCoreURL, nil
return rawCoreURL, nil
}
} else {
logger.Errorf("Failed to parse core url, error: %v, fallback to external endpoint", err)
} }
return config.ExtEndpoint() return config.ExtEndpoint()
} }
func match(ctx context.Context, reqHost, rawURL string) bool {
logger := log.G(ctx)
cfgURL, err := url.Parse(rawURL)
if err != nil {
logger.Errorf("Failed to parse url: %s, error: %v", rawURL, err)
return false
}
if cfgURL.Scheme == "http" && cfgURL.Port() == "80" ||
cfgURL.Scheme == "https" && cfgURL.Port() == "443" {
cfgURL.Host = cfgURL.Hostname()
}
if cfgURL.Scheme == "http" && strings.HasSuffix(reqHost, ":80") {
reqHost = strings.TrimSuffix(reqHost, ":80")
}
if cfgURL.Scheme == "https" && strings.HasSuffix(reqHost, ":443") {
reqHost = strings.TrimSuffix(reqHost, ":443")
}
return reqHost == cfgURL.Host
}
var ( var (
once sync.Once once sync.Once
checker reqChecker checker reqChecker

View File

@ -266,3 +266,65 @@ func TestGetChallenge(t *testing.T) {
} }
} }
func TestMatch(t *testing.T) {
cases := []struct {
reqHost string
rawURL string
expect bool
}{
{
"abc.com",
"http://abc.com",
true,
},
{
"abc.com",
"https://abc.com",
true,
},
{
"abc.com:80",
"http://abc.com",
true,
},
{
"abc.com:80",
"https://abc.com",
false,
},
{
"abc.com:443",
"http://abc.com",
false,
},
{
"abc.com:443",
"https://abc.com",
true,
},
{
"abcd.com:443",
"https://abc.com",
false,
},
{
"abc.com:8443",
"https://abc.com:8443",
true,
},
{
"abc.com",
"https://abc.com:443",
true,
},
{
"abc.com",
"http://abc.com:443",
false,
},
}
for _, c := range cases {
assert.Equal(t, c.expect, match(context.Background(), c.reqHost, c.rawURL))
}
}