diff --git a/utils/registry/auth/authorizer.go b/utils/registry/auth/authorizer.go index cea731246..26ea177a5 100644 --- a/utils/registry/auth/authorizer.go +++ b/utils/registry/auth/authorizer.go @@ -46,8 +46,8 @@ func NewRequestAuthorizer(handlers []Handler, challenges []au.Challenge) *Reques // ModifyRequest adds authorization to the request func (r *RequestAuthorizer) ModifyRequest(req *http.Request) error { - for _, handler := range r.handlers { - for _, challenge := range r.challenges { + for _, challenge := range r.challenges { + for _, handler := range r.handlers { if handler.Scheme() == challenge.Scheme { if err := handler.AuthorizeRequest(req, challenge.Parameters); err != nil { return err diff --git a/utils/registry/auth/tokenhandler.go b/utils/registry/auth/tokenhandler.go index f546bac0c..9d075b25d 100644 --- a/utils/registry/auth/tokenhandler.go +++ b/utils/registry/auth/tokenhandler.go @@ -22,6 +22,7 @@ import ( "net/http" "net/url" "strings" + "sync" "time" token_util "github.com/vmware/harbor/service/token" @@ -48,6 +49,7 @@ type tokenHandler struct { cache string // cached token expiresIn int // The duration in seconds since the token was issued that it will remain valid issuedAt *time.Time // The RFC3339-serialized UTC standard time at which a given token was issued + sync.Mutex } // Scheme returns the scheme that the handler can handle @@ -77,8 +79,10 @@ func (t *tokenHandler) AuthorizeRequest(req *http.Request, params map[string]str expired := true - if t.expiresIn != 0 && t.issuedAt != nil { - expired = t.issuedAt.Add(time.Duration(t.expiresIn) * time.Second).Before(time.Now().UTC()) + cachedToken, cachedExpiredIn, cachedIssuedAt := t.getCachedToken() + + if len(cachedToken) != 0 && cachedExpiredIn != 0 && cachedIssuedAt != nil { + expired = cachedIssuedAt.Add(time.Duration(cachedExpiredIn) * time.Second).Before(time.Now().UTC()) } if expired || hasFrom { @@ -93,13 +97,11 @@ func (t *tokenHandler) AuthorizeRequest(req *http.Request, params map[string]str token = to if !hasFrom { - t.cache = token - t.expiresIn = expiresIn - t.issuedAt = issuedAt + t.updateCachedToken(to, expiresIn, issuedAt) log.Debug("add token to cache") } } else { - token = t.cache + token = cachedToken log.Debug("get token from cache") } @@ -109,6 +111,20 @@ func (t *tokenHandler) AuthorizeRequest(req *http.Request, params map[string]str return nil } +func (t *tokenHandler) getCachedToken() (string, int, *time.Time) { + t.Lock() + defer t.Unlock() + return t.cache, t.expiresIn, t.issuedAt +} + +func (t *tokenHandler) updateCachedToken(token string, expiresIn int, issuedAt *time.Time) { + t.Lock() + defer t.Unlock() + t.cache = token + t.expiresIn = expiresIn + t.issuedAt = issuedAt +} + // Implements interface Handler type standardTokenHandler struct { tokenHandler @@ -168,6 +184,7 @@ func (s *standardTokenHandler) generateToken(realm, service string, scopes []str if resp.StatusCode != http.StatusOK { err = registry_errors.Error{ StatusCode: resp.StatusCode, + StatusText: resp.Status, Message: string(b), } return diff --git a/utils/registry/errors/error.go b/utils/registry/errors/error.go index 60b8d6ce5..7a1311b00 100644 --- a/utils/registry/errors/error.go +++ b/utils/registry/errors/error.go @@ -23,12 +23,13 @@ import ( // an Error instance will be returned type Error struct { StatusCode int + StatusText string Message string } // Error ... func (e Error) Error() string { - return fmt.Sprintf("%d %s", e.StatusCode, e.Message) + return fmt.Sprintf("%d %s %s", e.StatusCode, e.StatusText, e.Message) } // ParseError parses err, if err is type Error, convert it to Error diff --git a/utils/registry/registry.go b/utils/registry/registry.go index 1ee01892e..baaf70e91 100644 --- a/utils/registry/registry.go +++ b/utils/registry/registry.go @@ -89,6 +89,10 @@ func (r *Registry) Catalog() ([]string, error) { resp, err := r.client.Do(req) if err != nil { + ok, e := isUnauthorizedError(err) + if ok { + return repos, e + } return repos, err } @@ -115,6 +119,7 @@ func (r *Registry) Catalog() ([]string, error) { return repos, errors.Error{ StatusCode: resp.StatusCode, + StatusText: resp.Status, Message: string(b), } } diff --git a/utils/registry/repository.go b/utils/registry/repository.go index 507634415..ac49e04f8 100644 --- a/utils/registry/repository.go +++ b/utils/registry/repository.go @@ -72,6 +72,9 @@ func NewRepositoryWithCredential(name, endpoint string, credential auth.Credenti } client, err := newClient(endpoint, "", credential, "repository", name, "pull", "push") + if err != nil { + return nil, err + } repository := &Repository{ Name: name, @@ -108,6 +111,17 @@ func NewRepositoryWithUsername(name, endpoint, username string) (*Repository, er return repository, nil } +// try to convert err to errors.Error if it is +func isUnauthorizedError(err error) (bool, error) { + if strings.Contains(err.Error(), http.StatusText(http.StatusUnauthorized)) { + return true, errors.Error{ + StatusCode: http.StatusUnauthorized, + StatusText: http.StatusText(http.StatusUnauthorized), + } + } + return false, err +} + // ListTag ... func (r *Repository) ListTag() ([]string, error) { tags := []string{} @@ -118,6 +132,10 @@ func (r *Repository) ListTag() ([]string, error) { resp, err := r.client.Do(req) if err != nil { + ok, e := isUnauthorizedError(err) + if ok { + return tags, e + } return tags, err } @@ -141,9 +159,9 @@ func (r *Repository) ListTag() ([]string, error) { return tags, nil } - return tags, errors.Error{ StatusCode: resp.StatusCode, + StatusText: resp.Status, Message: string(b), } @@ -161,6 +179,11 @@ func (r *Repository) ManifestExist(reference string) (digest string, exist bool, resp, err := r.client.Do(req) if err != nil { + ok, e := isUnauthorizedError(err) + if ok { + err = e + return + } return } @@ -183,6 +206,7 @@ func (r *Repository) ManifestExist(reference string) (digest string, exist bool, err = errors.Error{ StatusCode: resp.StatusCode, + StatusText: resp.Status, Message: string(b), } return @@ -201,6 +225,11 @@ func (r *Repository) PullManifest(reference string, acceptMediaTypes []string) ( resp, err := r.client.Do(req) if err != nil { + ok, e := isUnauthorizedError(err) + if ok { + err = e + return + } return } @@ -219,6 +248,7 @@ func (r *Repository) PullManifest(reference string, acceptMediaTypes []string) ( err = errors.Error{ StatusCode: resp.StatusCode, + StatusText: resp.Status, Message: string(b), } @@ -236,6 +266,11 @@ func (r *Repository) PushManifest(reference, mediaType string, payload []byte) ( resp, err := r.client.Do(req) if err != nil { + ok, e := isUnauthorizedError(err) + if ok { + err = e + return + } return } @@ -253,6 +288,7 @@ func (r *Repository) PushManifest(reference, mediaType string, payload []byte) ( err = errors.Error{ StatusCode: resp.StatusCode, + StatusText: resp.Status, Message: string(b), } @@ -268,6 +304,10 @@ func (r *Repository) DeleteManifest(digest string) error { resp, err := r.client.Do(req) if err != nil { + ok, e := isUnauthorizedError(err) + if ok { + return e + } return err } @@ -284,6 +324,7 @@ func (r *Repository) DeleteManifest(digest string) error { return errors.Error{ StatusCode: resp.StatusCode, + StatusText: resp.Status, Message: string(b), } } @@ -298,6 +339,7 @@ func (r *Repository) DeleteTag(tag string) error { if !exist { return errors.Error{ StatusCode: http.StatusNotFound, + StatusText: http.StatusText(http.StatusNotFound), } } @@ -313,6 +355,10 @@ func (r *Repository) BlobExist(digest string) (bool, error) { resp, err := r.client.Do(req) if err != nil { + ok, e := isUnauthorizedError(err) + if ok { + return false, e + } return false, err } @@ -333,6 +379,7 @@ func (r *Repository) BlobExist(digest string) (bool, error) { return false, errors.Error{ StatusCode: resp.StatusCode, + StatusText: resp.Status, Message: string(b), } } @@ -346,6 +393,11 @@ func (r *Repository) PullBlob(digest string) (size int64, data []byte, err error resp, err := r.client.Do(req) if err != nil { + ok, e := isUnauthorizedError(err) + if ok { + err = e + return + } return } @@ -367,6 +419,7 @@ func (r *Repository) PullBlob(digest string) (size int64, data []byte, err error err = errors.Error{ StatusCode: resp.StatusCode, + StatusText: resp.Status, Message: string(b), } @@ -379,6 +432,11 @@ func (r *Repository) initiateBlobUpload(name string) (location, uploadUUID strin resp, err := r.client.Do(req) if err != nil { + ok, e := isUnauthorizedError(err) + if ok { + err = e + return + } return } @@ -397,6 +455,7 @@ func (r *Repository) initiateBlobUpload(name string) (location, uploadUUID strin err = errors.Error{ StatusCode: resp.StatusCode, + StatusText: resp.Status, Message: string(b), } @@ -411,6 +470,10 @@ func (r *Repository) monolithicBlobUpload(location, digest string, size int64, d resp, err := r.client.Do(req) if err != nil { + ok, e := isUnauthorizedError(err) + if ok { + return e + } return err } @@ -427,6 +490,7 @@ func (r *Repository) monolithicBlobUpload(location, digest string, size int64, d return errors.Error{ StatusCode: resp.StatusCode, + StatusText: resp.Status, Message: string(b), } } @@ -460,6 +524,10 @@ func (r *Repository) DeleteBlob(digest string) error { resp, err := r.client.Do(req) if err != nil { + ok, e := isUnauthorizedError(err) + if ok { + return e + } return err } @@ -476,6 +544,7 @@ func (r *Repository) DeleteBlob(digest string) error { return errors.Error{ StatusCode: resp.StatusCode, + StatusText: resp.Status, Message: string(b), } } diff --git a/utils/registry/repository_test.go b/utils/registry/repository_test.go new file mode 100644 index 000000000..07b46237b --- /dev/null +++ b/utils/registry/repository_test.go @@ -0,0 +1,176 @@ +/* + Copyright (c) 2016 VMware, Inc. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package registry + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + "time" + + //"github.com/vmware/harbor/utils/log" + "github.com/vmware/harbor/utils/registry/auth" + "github.com/vmware/harbor/utils/registry/errors" +) + +var ( + username = "user" + password = "P@ssw0rd" + repo = "samalba/my-app" + tags = tagResp{Tags: []string{"1.0", "2.0", "3.0"}} + validToken = "valid_token" + invalidToken = "invalid_token" + credential auth.Credential + registryServer *httptest.Server + tokenServer *httptest.Server + repositoryClient *Repository +) + +type tagResp struct { + Tags []string `json:"tags"` +} + +func TestMain(m *testing.M) { + //log.SetLevel(log.DebugLevel) + credential = auth.NewBasicAuthCredential(username, password) + + tokenServer = initTokenServer() + defer tokenServer.Close() + + registryServer = initRegistryServer() + defer registryServer.Close() + + os.Exit(m.Run()) +} + +func initRegistryServer() *httptest.Server { + mux := http.NewServeMux() + mux.HandleFunc("/v2/", servePing) + mux.HandleFunc(fmt.Sprintf("/v2/%s/tags/list", repo), serveTaglisting) + + return httptest.NewServer(mux) +} + +//response ping request: http://registry/v2 +func servePing(w http.ResponseWriter, r *http.Request) { + if !isTokenValid(r) { + challenge(w) + return + } +} + +func serveTaglisting(w http.ResponseWriter, r *http.Request) { + if !isTokenValid(r) { + challenge(w) + return + } + + if err := json.NewEncoder(w).Encode(tags); err != nil { + w.Write([]byte(err.Error())) + w.WriteHeader(http.StatusInternalServerError) + return + } + +} + +func isTokenValid(r *http.Request) bool { + valid := false + auth := r.Header.Get(http.CanonicalHeaderKey("Authorization")) + if len(auth) != 0 { + auth = strings.TrimSpace(auth) + index := strings.Index(auth, "Bearer") + token := auth[index+6:] + token = strings.TrimSpace(token) + if token == validToken { + valid = true + } + } + return valid +} + +func challenge(w http.ResponseWriter) { + challenge := "Bearer realm=\"" + tokenServer.URL + "/service/token\",service=\"token-service\"" + w.Header().Set("Www-Authenticate", challenge) + w.WriteHeader(http.StatusUnauthorized) + return +} + +func initTokenServer() *httptest.Server { + mux := http.NewServeMux() + mux.HandleFunc("/service/token", serveToken) + + return httptest.NewServer(mux) +} + +func serveToken(w http.ResponseWriter, r *http.Request) { + u, p, ok := r.BasicAuth() + if !ok || u != username || p != password { + w.WriteHeader(http.StatusUnauthorized) + return + } + + result := make(map[string]interface{}) + result["token"] = validToken + result["expires_in"] = 300 + result["issued_at"] = time.Now().Format(time.RFC3339) + + encoder := json.NewEncoder(w) + if err := encoder.Encode(result); err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(err.Error())) + return + } +} + +func TestListTag(t *testing.T) { + client, err := NewRepositoryWithCredential(repo, registryServer.URL, credential) + if err != nil { + t.Error(err) + } + + list, err := client.ListTag() + if err != nil { + t.Error(err) + return + } + if len(list) != len(tags.Tags) { + t.Errorf("expected length: %d, actual length: %d", len(tags.Tags), len(list)) + return + } + +} + +func TestListTagWithInvalidCredential(t *testing.T) { + credential := auth.NewBasicAuthCredential(username, "wrong_password") + client, err := NewRepositoryWithCredential(repo, registryServer.URL, credential) + if err != nil { + t.Error(err) + } + + _, err = client.ListTag() + if err != nil { + e, ok := errors.ParseError(err) + if ok && e.StatusCode == http.StatusUnauthorized { + return + } + t.Error(err) + return + } +}