Fix Replication to Cross-account AWS ECR (#17583)

Replication to Cross-account AWS ECR
This commit is contained in:
Vincent Ni 2022-09-30 00:07:47 -07:00 committed by GitHub
parent cf5197246a
commit 805a36e7f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 23 additions and 17 deletions

View File

@ -29,11 +29,11 @@ import (
)
const (
regionPattern = "https://(?:api|\\d+\\.dkr)\\.ecr\\.([\\w\\-]+)\\.amazonaws\\.com"
ecrPattern = "https://(?:api|(\\d+)\\.dkr)\\.ecr\\.([\\w\\-]+)\\.amazonaws\\.com"
)
var (
regionRegexp = regexp.MustCompile(regionPattern)
ecrRegexp = regexp.MustCompile(ecrPattern)
)
func init() {
@ -45,12 +45,12 @@ func init() {
}
func newAdapter(registry *model.Registry) (*adapter, error) {
region, err := parseRegion(registry.URL)
_, region, err := parseAccountRegion(registry.URL)
if err != nil {
return nil, err
}
svc, err := getAwsSvc(
region, registry.Credential.AccessKey, registry.Credential.AccessSecret, registry.Insecure, &registry.URL)
region, registry.Credential.AccessKey, registry.Credential.AccessSecret, registry.Insecure, nil)
if err != nil {
return nil, err
}
@ -62,12 +62,12 @@ func newAdapter(registry *model.Registry) (*adapter, error) {
}, nil
}
func parseRegion(url string) (string, error) {
rs := regionRegexp.FindStringSubmatch(url)
func parseAccountRegion(url string) (string, string, error) {
rs := ecrRegexp.FindStringSubmatch(url)
if rs == nil {
return "", errors.New("bad aws url")
return "", "", errors.New("bad aws url")
}
return rs[1], nil
return rs[1], rs[2], nil
}
type factory struct {

View File

@ -289,18 +289,18 @@ var urlForBenchmark = []string{
"https://test-region.amazonaws.com",
}
func compileRegexpEveryTime(url string) (string, error) {
rs := regexp.MustCompile(regionPattern).FindStringSubmatch(url)
func compileRegexpEveryTime(url string) (string, string, error) {
rs := regexp.MustCompile(ecrPattern).FindStringSubmatch(url)
if rs == nil {
return "", errors.New("bad aws url")
return "", "", errors.New("bad aws url")
}
return rs[1], nil
return rs[1], rs[2], nil
}
func BenchmarkGetRegion(b *testing.B) {
func BenchmarkGetAccountRegion(b *testing.B) {
for i := 0; i < b.N; i++ {
for _, url := range urlForBenchmark {
parseRegion(url)
parseAccountRegion(url)
}
}
}

View File

@ -63,7 +63,7 @@ func (a *awsAuthCredential) Modify(req *http.Request) error {
return nil
}
if !a.isTokenValid() {
endpoint, user, pass, expiresAt, err := a.getAuthorization()
endpoint, user, pass, expiresAt, err := a.getAuthorization(req.URL.String())
if err != nil {
return err
@ -121,9 +121,15 @@ func getAwsSvc(region, accessKey, accessSecret string, insecure bool, forceEndpo
return svc, nil
}
func (a *awsAuthCredential) getAuthorization() (string, string, string, *time.Time, error) {
func (a *awsAuthCredential) getAuthorization(url string) (string, string, string, *time.Time, error) {
id, _, err := parseAccountRegion(url)
if err != nil {
return "", "", "", nil, err
}
regIds := []*string{&id}
svc := a.awssvc
result, err := svc.GetAuthorizationToken(nil)
result, err := svc.GetAuthorizationToken(&awsecrapi.GetAuthorizationTokenInput{RegistryIds: regIds})
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
return "", "", "", nil, fmt.Errorf("%s: %s", aerr.Code(), aerr.Error())