From 805a36e7f0756cc543d42a1179ebe599f7ffe85d Mon Sep 17 00:00:00 2001 From: Vincent Ni Date: Fri, 30 Sep 2022 00:07:47 -0700 Subject: [PATCH] Fix Replication to Cross-account AWS ECR (#17583) Replication to Cross-account AWS ECR --- src/pkg/reg/adapter/awsecr/adapter.go | 16 ++++++++-------- src/pkg/reg/adapter/awsecr/adapter_test.go | 12 ++++++------ src/pkg/reg/adapter/awsecr/auth.go | 12 +++++++++--- 3 files changed, 23 insertions(+), 17 deletions(-) diff --git a/src/pkg/reg/adapter/awsecr/adapter.go b/src/pkg/reg/adapter/awsecr/adapter.go index aa120056e..ca8de422e 100644 --- a/src/pkg/reg/adapter/awsecr/adapter.go +++ b/src/pkg/reg/adapter/awsecr/adapter.go @@ -29,11 +29,11 @@ import ( ) const ( - regionPattern = "https://(?:api|\\d+\\.dkr)\\.ecr\\.([\\w\\-]+)\\.amazonaws\\.com" + ecrPattern = "https://(?:api|(\\d+)\\.dkr)\\.ecr\\.([\\w\\-]+)\\.amazonaws\\.com" ) var ( - regionRegexp = regexp.MustCompile(regionPattern) + ecrRegexp = regexp.MustCompile(ecrPattern) ) func init() { @@ -45,12 +45,12 @@ func init() { } func newAdapter(registry *model.Registry) (*adapter, error) { - region, err := parseRegion(registry.URL) + _, region, err := parseAccountRegion(registry.URL) if err != nil { return nil, err } svc, err := getAwsSvc( - region, registry.Credential.AccessKey, registry.Credential.AccessSecret, registry.Insecure, ®istry.URL) + region, registry.Credential.AccessKey, registry.Credential.AccessSecret, registry.Insecure, nil) if err != nil { return nil, err } @@ -62,12 +62,12 @@ func newAdapter(registry *model.Registry) (*adapter, error) { }, nil } -func parseRegion(url string) (string, error) { - rs := regionRegexp.FindStringSubmatch(url) +func parseAccountRegion(url string) (string, string, error) { + rs := ecrRegexp.FindStringSubmatch(url) if rs == nil { - return "", errors.New("bad aws url") + return "", "", errors.New("bad aws url") } - return rs[1], nil + return rs[1], rs[2], nil } type factory struct { diff --git a/src/pkg/reg/adapter/awsecr/adapter_test.go b/src/pkg/reg/adapter/awsecr/adapter_test.go index 076b73fda..3089ca6cd 100644 --- a/src/pkg/reg/adapter/awsecr/adapter_test.go +++ b/src/pkg/reg/adapter/awsecr/adapter_test.go @@ -289,18 +289,18 @@ var urlForBenchmark = []string{ "https://test-region.amazonaws.com", } -func compileRegexpEveryTime(url string) (string, error) { - rs := regexp.MustCompile(regionPattern).FindStringSubmatch(url) +func compileRegexpEveryTime(url string) (string, string, error) { + rs := regexp.MustCompile(ecrPattern).FindStringSubmatch(url) if rs == nil { - return "", errors.New("bad aws url") + return "", "", errors.New("bad aws url") } - return rs[1], nil + return rs[1], rs[2], nil } -func BenchmarkGetRegion(b *testing.B) { +func BenchmarkGetAccountRegion(b *testing.B) { for i := 0; i < b.N; i++ { for _, url := range urlForBenchmark { - parseRegion(url) + parseAccountRegion(url) } } } diff --git a/src/pkg/reg/adapter/awsecr/auth.go b/src/pkg/reg/adapter/awsecr/auth.go index bbfce5e7f..30099e8f1 100644 --- a/src/pkg/reg/adapter/awsecr/auth.go +++ b/src/pkg/reg/adapter/awsecr/auth.go @@ -63,7 +63,7 @@ func (a *awsAuthCredential) Modify(req *http.Request) error { return nil } if !a.isTokenValid() { - endpoint, user, pass, expiresAt, err := a.getAuthorization() + endpoint, user, pass, expiresAt, err := a.getAuthorization(req.URL.String()) if err != nil { return err @@ -121,9 +121,15 @@ func getAwsSvc(region, accessKey, accessSecret string, insecure bool, forceEndpo return svc, nil } -func (a *awsAuthCredential) getAuthorization() (string, string, string, *time.Time, error) { +func (a *awsAuthCredential) getAuthorization(url string) (string, string, string, *time.Time, error) { + id, _, err := parseAccountRegion(url) + if err != nil { + return "", "", "", nil, err + } + + regIds := []*string{&id} svc := a.awssvc - result, err := svc.GetAuthorizationToken(nil) + result, err := svc.GetAuthorizationToken(&awsecrapi.GetAuthorizationTokenInput{RegistryIds: regIds}) if err != nil { if aerr, ok := err.(awserr.Error); ok { return "", "", "", nil, fmt.Errorf("%s: %s", aerr.Code(), aerr.Error())