Upgrade aws-sdk-go

Versions of the Go AWS SDK newer than 1.23.13 support OIDC in EKS.
Running Harbor on EKS doesn't require keys in a configmap for the
registry to authenticate to S3 when using the newer library.

Signed-off-by: Phil Fenstermacher <pcfens@wm.edu>
This commit is contained in:
Phil Fenstermacher 2020-06-19 00:30:51 -04:00
parent 58b7242a25
commit 33069e0f98
No known key found for this signature in database
GPG Key ID: 457B18DB9B6B362A
135 changed files with 22284 additions and 3008 deletions

View File

@ -11,7 +11,7 @@ require (
github.com/agl/ed25519 v0.0.0-20170116200512-5312a6153412 // indirect
github.com/aliyun/alibaba-cloud-sdk-go v0.0.0-20190726115642-cd293c93fd97
github.com/astaxie/beego v1.12.1
github.com/aws/aws-sdk-go v1.19.47
github.com/aws/aws-sdk-go v1.32.5
github.com/beego/i18n v0.0.0-20140604031826-e87155e8f0c0
github.com/bmatcuk/doublestar v1.1.1
github.com/bugsnag/bugsnag-go v1.5.2 // indirect
@ -61,7 +61,7 @@ require (
github.com/robfig/cron v1.0.0
github.com/shiena/ansicolor v0.0.0-20151119151921-a422bbe96644 // indirect
github.com/spf13/viper v1.4.0 // indirect
github.com/stretchr/testify v1.4.0
github.com/stretchr/testify v1.5.1
github.com/theupdateframework/notary v0.6.1
golang.org/x/crypto v0.0.0-20200302210943-78000ba7a073
golang.org/x/net v0.0.0-20200202094626-16171245cfb2

View File

@ -100,8 +100,8 @@ github.com/astaxie/beego v1.12.1 h1:dfpuoxpzLVgclveAXe4PyNKqkzgm5zF4tgF2B3kkM2I=
github.com/astaxie/beego v1.12.1/go.mod h1:kPBWpSANNbSdIqOc8SUL9h+1oyBMZhROeYsXQDbidWQ=
github.com/aws/aws-sdk-go v1.15.11/go.mod h1:mFuSZ37Z9YOHbQEwBWztmVzqXrEkub65tZoCYDt7FT0=
github.com/aws/aws-sdk-go v1.17.7/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo=
github.com/aws/aws-sdk-go v1.19.47 h1:ZEze0mpk8Fttrsz6UNLqhH/jRGYbMPfWFA2ILas4AmM=
github.com/aws/aws-sdk-go v1.19.47/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo=
github.com/aws/aws-sdk-go v1.32.5 h1:Sz0C7deIoMu5lFGTVkIN92IEZrUz1AWIDDW+9p6n1Rk=
github.com/aws/aws-sdk-go v1.32.5/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZveU8YkpAk0=
github.com/baiyubin/aliyun-sts-go-sdk v0.0.0-20180326062324-cfa1a18b161f/go.mod h1:AuiFmCCPBSrqvVMvuqFuk0qogytodnVFVSN5CeJB8Gc=
github.com/beego/goyaml2 v0.0.0-20130207012346-5545475820dd/go.mod h1:1b+Y/CofkYwXMUU0OhQqGvsY2Bvgr4j6jfT699wyZKQ=
github.com/beego/i18n v0.0.0-20140604031826-e87155e8f0c0 h1:fQaDnUQvBXHHQdGBu9hz8nPznB4BeiPQokvmQVjmNEw=
@ -484,6 +484,8 @@ github.com/jmespath/go-jmespath v0.0.0-20160202185014-0b12d6b521d8/go.mod h1:Nht
github.com/jmespath/go-jmespath v0.0.0-20160803190731-bd40a432e4c7/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k=
github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af h1:pmfjZENx5imkbgOkpRUYLnmbU7UEFbjtDA2hxJ1ichM=
github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k=
github.com/jmespath/go-jmespath v0.3.0 h1:OS12ieG61fsCg5+qLJ+SsW9NicxNkg3b25OyT2yCeUc=
github.com/jmespath/go-jmespath v0.3.0/go.mod h1:9QtRXoHjLGCJ5IBSaohpXITPlowMeeYCZ7fLUTSywik=
github.com/jmoiron/sqlx v1.2.0/go.mod h1:1FEQNm3xlJgrMD+FBdI9+xvCksHtbpVBBw5dYhBSsks=
github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo=
github.com/json-iterator/go v0.0.0-20180612202835-f2b4162afba3/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
@ -703,6 +705,8 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/syndtr/gocapability v0.0.0-20170704070218-db04d3cc01c8/go.mod h1:hkRG7XYTFWNJGYcbNJQlaLq0fg1yr4J4t/NcTQtrfww=
github.com/syndtr/goleveldb v0.0.0-20181127023241-353a9fca669c/go.mod h1:Z4AUp2Km+PwemOoO/VB5AOx9XSsIItzFjoJlOSiYmn0=
github.com/theupdateframework/notary v0.6.1 h1:7wshjstgS9x9F5LuB1L5mBI2xNMObWqjz+cjWoom6l0=

93
src/vendor/github.com/aws/aws-sdk-go/aws/arn/arn.go generated vendored Normal file
View File

@ -0,0 +1,93 @@
// Package arn provides a parser for interacting with Amazon Resource Names.
package arn
import (
"errors"
"strings"
)
const (
arnDelimiter = ":"
arnSections = 6
arnPrefix = "arn:"
// zero-indexed
sectionPartition = 1
sectionService = 2
sectionRegion = 3
sectionAccountID = 4
sectionResource = 5
// errors
invalidPrefix = "arn: invalid prefix"
invalidSections = "arn: not enough sections"
)
// ARN captures the individual fields of an Amazon Resource Name.
// See http://docs.aws.amazon.com/general/latest/gr/aws-arns-and-namespaces.html for more information.
type ARN struct {
// The partition that the resource is in. For standard AWS regions, the partition is "aws". If you have resources in
// other partitions, the partition is "aws-partitionname". For example, the partition for resources in the China
// (Beijing) region is "aws-cn".
Partition string
// The service namespace that identifies the AWS product (for example, Amazon S3, IAM, or Amazon RDS). For a list of
// namespaces, see
// http://docs.aws.amazon.com/general/latest/gr/aws-arns-and-namespaces.html#genref-aws-service-namespaces.
Service string
// The region the resource resides in. Note that the ARNs for some resources do not require a region, so this
// component might be omitted.
Region string
// The ID of the AWS account that owns the resource, without the hyphens. For example, 123456789012. Note that the
// ARNs for some resources don't require an account number, so this component might be omitted.
AccountID string
// The content of this part of the ARN varies by service. It often includes an indicator of the type of resource —
// for example, an IAM user or Amazon RDS database - followed by a slash (/) or a colon (:), followed by the
// resource name itself. Some services allows paths for resource names, as described in
// http://docs.aws.amazon.com/general/latest/gr/aws-arns-and-namespaces.html#arns-paths.
Resource string
}
// Parse parses an ARN into its constituent parts.
//
// Some example ARNs:
// arn:aws:elasticbeanstalk:us-east-1:123456789012:environment/My App/MyEnvironment
// arn:aws:iam::123456789012:user/David
// arn:aws:rds:eu-west-1:123456789012:db:mysql-db
// arn:aws:s3:::my_corporate_bucket/exampleobject.png
func Parse(arn string) (ARN, error) {
if !strings.HasPrefix(arn, arnPrefix) {
return ARN{}, errors.New(invalidPrefix)
}
sections := strings.SplitN(arn, arnDelimiter, arnSections)
if len(sections) != arnSections {
return ARN{}, errors.New(invalidSections)
}
return ARN{
Partition: sections[sectionPartition],
Service: sections[sectionService],
Region: sections[sectionRegion],
AccountID: sections[sectionAccountID],
Resource: sections[sectionResource],
}, nil
}
// IsARN returns whether the given string is an ARN by looking for
// whether the string starts with "arn:" and contains the correct number
// of sections delimited by colons(:).
func IsARN(arn string) bool {
return strings.HasPrefix(arn, arnPrefix) && strings.Count(arn, ":") >= arnSections-1
}
// String returns the canonical representation of the ARN
func (arn ARN) String() string {
return arnPrefix +
arn.Partition + arnDelimiter +
arn.Service + arnDelimiter +
arn.Region + arnDelimiter +
arn.AccountID + arnDelimiter +
arn.Resource
}

View File

@ -208,7 +208,7 @@ func (e errorList) Error() string {
// How do we want to handle the array size being zero
if size := len(e); size > 0 {
for i := 0; i < size; i++ {
msg += fmt.Sprintf("%s", e[i].Error())
msg += e[i].Error()
// We check the next index to see if it is within the slice.
// If it is, then we append a newline. We do this, because unit tests
// could be broken with the additional '\n'

View File

@ -70,7 +70,7 @@ func rValuesAtPath(v interface{}, path string, createPath, caseSensitive, nilTer
value = value.FieldByNameFunc(func(name string) bool {
if c == name {
return true
} else if !caseSensitive && strings.ToLower(name) == strings.ToLower(c) {
} else if !caseSensitive && strings.EqualFold(name, c) {
return true
}
return false
@ -185,14 +185,13 @@ func ValuesAtPath(i interface{}, path string) ([]interface{}, error) {
// SetValueAtPath sets a value at the case insensitive lexical path inside
// of a structure.
func SetValueAtPath(i interface{}, path string, v interface{}) {
if rvals := rValuesAtPath(i, path, true, false, v == nil); rvals != nil {
rvals := rValuesAtPath(i, path, true, false, v == nil)
for _, rval := range rvals {
if rval.Kind() == reflect.Ptr && rval.IsNil() {
continue
}
setValue(rval, v)
}
}
}
func setValue(dstVal reflect.Value, src interface{}) {

View File

@ -12,6 +12,7 @@ import (
type Config struct {
Config *aws.Config
Handlers request.Handlers
PartitionID string
Endpoint string
SigningRegion string
SigningName string
@ -64,7 +65,7 @@ func New(cfg aws.Config, info metadata.ClientInfo, handlers request.Handlers, op
default:
maxRetries := aws.IntValue(cfg.MaxRetries)
if cfg.MaxRetries == nil || maxRetries == aws.UseServiceDefaultRetries {
maxRetries = 3
maxRetries = DefaultRetryerMaxNumRetries
}
svc.Retryer = DefaultRetryer{NumMaxRetries: maxRetries}
}

View File

@ -1,6 +1,7 @@
package client
import (
"math"
"strconv"
"time"
@ -9,82 +10,142 @@ import (
)
// DefaultRetryer implements basic retry logic using exponential backoff for
// most services. If you want to implement custom retry logic, implement the
// request.Retryer interface or create a structure type that composes this
// struct and override the specific methods. For example, to override only
// the MaxRetries method:
// most services. If you want to implement custom retry logic, you can implement the
// request.Retryer interface.
//
// type retryer struct {
// client.DefaultRetryer
// }
//
// // This implementation always has 100 max retries
// func (d retryer) MaxRetries() int { return 100 }
type DefaultRetryer struct {
// Num max Retries is the number of max retries that will be performed.
// By default, this is zero.
NumMaxRetries int
// MinRetryDelay is the minimum retry delay after which retry will be performed.
// If not set, the value is 0ns.
MinRetryDelay time.Duration
// MinThrottleRetryDelay is the minimum retry delay when throttled.
// If not set, the value is 0ns.
MinThrottleDelay time.Duration
// MaxRetryDelay is the maximum retry delay before which retry must be performed.
// If not set, the value is 0ns.
MaxRetryDelay time.Duration
// MaxThrottleDelay is the maximum retry delay when throttled.
// If not set, the value is 0ns.
MaxThrottleDelay time.Duration
}
const (
// DefaultRetryerMaxNumRetries sets maximum number of retries
DefaultRetryerMaxNumRetries = 3
// DefaultRetryerMinRetryDelay sets minimum retry delay
DefaultRetryerMinRetryDelay = 30 * time.Millisecond
// DefaultRetryerMinThrottleDelay sets minimum delay when throttled
DefaultRetryerMinThrottleDelay = 500 * time.Millisecond
// DefaultRetryerMaxRetryDelay sets maximum retry delay
DefaultRetryerMaxRetryDelay = 300 * time.Second
// DefaultRetryerMaxThrottleDelay sets maximum delay when throttled
DefaultRetryerMaxThrottleDelay = 300 * time.Second
)
// MaxRetries returns the number of maximum returns the service will use to make
// an individual API request.
func (d DefaultRetryer) MaxRetries() int {
return d.NumMaxRetries
}
// setRetryerDefaults sets the default values of the retryer if not set
func (d *DefaultRetryer) setRetryerDefaults() {
if d.MinRetryDelay == 0 {
d.MinRetryDelay = DefaultRetryerMinRetryDelay
}
if d.MaxRetryDelay == 0 {
d.MaxRetryDelay = DefaultRetryerMaxRetryDelay
}
if d.MinThrottleDelay == 0 {
d.MinThrottleDelay = DefaultRetryerMinThrottleDelay
}
if d.MaxThrottleDelay == 0 {
d.MaxThrottleDelay = DefaultRetryerMaxThrottleDelay
}
}
// RetryRules returns the delay duration before retrying this request again
func (d DefaultRetryer) RetryRules(r *request.Request) time.Duration {
// Set the upper limit of delay in retrying at ~five minutes
minTime := 30
throttle := d.shouldThrottle(r)
if throttle {
if delay, ok := getRetryDelay(r); ok {
return delay
// if number of max retries is zero, no retries will be performed.
if d.NumMaxRetries == 0 {
return 0
}
minTime = 500
// Sets default value for retryer members
d.setRetryerDefaults()
// minDelay is the minimum retryer delay
minDelay := d.MinRetryDelay
var initialDelay time.Duration
isThrottle := r.IsErrorThrottle()
if isThrottle {
if delay, ok := getRetryAfterDelay(r); ok {
initialDelay = delay
}
minDelay = d.MinThrottleDelay
}
retryCount := r.RetryCount
if throttle && retryCount > 8 {
retryCount = 8
} else if retryCount > 13 {
retryCount = 13
// maxDelay the maximum retryer delay
maxDelay := d.MaxRetryDelay
if isThrottle {
maxDelay = d.MaxThrottleDelay
}
delay := (1 << uint(retryCount)) * (sdkrand.SeededRand.Intn(minTime) + minTime)
return time.Duration(delay) * time.Millisecond
var delay time.Duration
// Logic to cap the retry count based on the minDelay provided
actualRetryCount := int(math.Log2(float64(minDelay))) + 1
if actualRetryCount < 63-retryCount {
delay = time.Duration(1<<uint64(retryCount)) * getJitterDelay(minDelay)
if delay > maxDelay {
delay = getJitterDelay(maxDelay / 2)
}
} else {
delay = getJitterDelay(maxDelay / 2)
}
return delay + initialDelay
}
// getJitterDelay returns a jittered delay for retry
func getJitterDelay(duration time.Duration) time.Duration {
return time.Duration(sdkrand.SeededRand.Int63n(int64(duration)) + int64(duration))
}
// ShouldRetry returns true if the request should be retried.
func (d DefaultRetryer) ShouldRetry(r *request.Request) bool {
// ShouldRetry returns false if number of max retries is 0.
if d.NumMaxRetries == 0 {
return false
}
// If one of the other handlers already set the retry state
// we don't want to override it based on the service's state
if r.Retryable != nil {
return *r.Retryable
}
if r.HTTPResponse.StatusCode >= 500 && r.HTTPResponse.StatusCode != 501 {
return true
}
return r.IsErrorRetryable() || d.shouldThrottle(r)
}
// ShouldThrottle returns true if the request should be throttled.
func (d DefaultRetryer) shouldThrottle(r *request.Request) bool {
switch r.HTTPResponse.StatusCode {
case 429:
case 502:
case 503:
case 504:
default:
return r.IsErrorThrottle()
}
return true
return r.IsErrorRetryable() || r.IsErrorThrottle()
}
// This will look in the Retry-After header, RFC 7231, for how long
// it will wait before attempting another request
func getRetryDelay(r *request.Request) (time.Duration, bool) {
func getRetryAfterDelay(r *request.Request) (time.Duration, bool) {
if !canUseRetryAfterHeader(r) {
return 0, false
}

View File

@ -67,10 +67,14 @@ func logRequest(r *request.Request) {
if !bodySeekable {
r.SetReaderBody(aws.ReadSeekCloser(r.HTTPRequest.Body))
}
// Reset the request body because dumpRequest will re-wrap the r.HTTPRequest's
// Body as a NoOpCloser and will not be reset after read by the HTTP
// client reader.
r.ResetBody()
// Reset the request body because dumpRequest will re-wrap the
// r.HTTPRequest's Body as a NoOpCloser and will not be reset after
// read by the HTTP client reader.
if err := r.Error; err != nil {
r.Config.Logger.Log(fmt.Sprintf(logReqErrMsg,
r.ClientInfo.ServiceName, r.Operation.Name, err))
return
}
}
r.Config.Logger.Log(fmt.Sprintf(logReqMsg,

View File

@ -5,6 +5,7 @@ type ClientInfo struct {
ServiceName string
ServiceID string
APIVersion string
PartitionID string
Endpoint string
SigningName string
SigningRegion string

View File

@ -0,0 +1,28 @@
package client
import (
"time"
"github.com/aws/aws-sdk-go/aws/request"
)
// NoOpRetryer provides a retryer that performs no retries.
// It should be used when we do not want retries to be performed.
type NoOpRetryer struct{}
// MaxRetries returns the number of maximum returns the service will use to make
// an individual API; For NoOpRetryer the MaxRetries will always be zero.
func (d NoOpRetryer) MaxRetries() int {
return 0
}
// ShouldRetry will always return false for NoOpRetryer, as it should never retry.
func (d NoOpRetryer) ShouldRetry(_ *request.Request) bool {
return false
}
// RetryRules returns the delay duration before retrying this request again;
// since NoOpRetryer does not retry, RetryRules always returns 0.
func (d NoOpRetryer) RetryRules(_ *request.Request) time.Duration {
return 0
}

View File

@ -20,7 +20,7 @@ type RequestRetryer interface{}
// A Config provides service configuration for service clients. By default,
// all clients will use the defaults.DefaultConfig structure.
//
// // Create Session with MaxRetry configuration to be shared by multiple
// // Create Session with MaxRetries configuration to be shared by multiple
// // service clients.
// sess := session.Must(session.NewSession(&aws.Config{
// MaxRetries: aws.Int(3),
@ -43,7 +43,7 @@ type Config struct {
// An optional endpoint URL (hostname only or fully qualified URI)
// that overrides the default generated endpoint for a client. Set this
// to `""` to use the default generated endpoint.
// to `nil` to use the default generated endpoint.
//
// Note: You must still provide a `Region` value when specifying an
// endpoint for a client.
@ -161,6 +161,17 @@ type Config struct {
// on GetObject API calls.
S3DisableContentMD5Validation *bool
// Set this to `true` to have the S3 service client to use the region specified
// in the ARN, when an ARN is provided as an argument to a bucket parameter.
S3UseARNRegion *bool
// Set this to `true` to enable the SDK to unmarshal API response header maps to
// normalized lower case map keys.
//
// For example S3's X-Amz-Meta prefixed header will be unmarshaled to lower case
// Metadata member's map keys. The value of the header in the map is unaffected.
LowerCaseHeaderMaps *bool
// Set this to `true` to disable the EC2Metadata client from overriding the
// default http.Client's Timeout. This is helpful if you do not want the
// EC2Metadata client to create a new http.Client. This options is only
@ -227,6 +238,7 @@ type Config struct {
// EnableEndpointDiscovery will allow for endpoint discovery on operations that
// have the definition in its model. By default, endpoint discovery is off.
// To use EndpointDiscovery, Endpoint should be unset or set to an empty string.
//
// Example:
// sess := session.Must(session.NewSession(&aws.Config{
@ -246,12 +258,18 @@ type Config struct {
// Disabling this feature is useful when you want to use local endpoints
// for testing that do not support the modeled host prefix pattern.
DisableEndpointHostPrefix *bool
// STSRegionalEndpoint will enable regional or legacy endpoint resolving
STSRegionalEndpoint endpoints.STSRegionalEndpoint
// S3UsEast1RegionalEndpoint will enable regional or legacy endpoint resolving
S3UsEast1RegionalEndpoint endpoints.S3UsEast1RegionalEndpoint
}
// NewConfig returns a new Config pointer that can be chained with builder
// methods to set multiple configuration values inline without using pointers.
//
// // Create Session with MaxRetry configuration to be shared by multiple
// // Create Session with MaxRetries configuration to be shared by multiple
// // service clients.
// sess := session.Must(session.NewSession(aws.NewConfig().
// WithMaxRetries(3),
@ -379,6 +397,13 @@ func (c *Config) WithS3DisableContentMD5Validation(enable bool) *Config {
}
// WithS3UseARNRegion sets a config S3UseARNRegion value and
// returning a Config pointer for chaining
func (c *Config) WithS3UseARNRegion(enable bool) *Config {
c.S3UseARNRegion = &enable
return c
}
// WithUseDualStack sets a config UseDualStack value returning a Config
// pointer for chaining.
func (c *Config) WithUseDualStack(enable bool) *Config {
@ -420,6 +445,20 @@ func (c *Config) MergeIn(cfgs ...*Config) {
}
}
// WithSTSRegionalEndpoint will set whether or not to use regional endpoint flag
// when resolving the endpoint for a service
func (c *Config) WithSTSRegionalEndpoint(sre endpoints.STSRegionalEndpoint) *Config {
c.STSRegionalEndpoint = sre
return c
}
// WithS3UsEast1RegionalEndpoint will set whether or not to use regional endpoint flag
// when resolving the endpoint for a service
func (c *Config) WithS3UsEast1RegionalEndpoint(sre endpoints.S3UsEast1RegionalEndpoint) *Config {
c.S3UsEast1RegionalEndpoint = sre
return c
}
func mergeInConfig(dst *Config, other *Config) {
if other == nil {
return
@ -493,6 +532,10 @@ func mergeInConfig(dst *Config, other *Config) {
dst.S3DisableContentMD5Validation = other.S3DisableContentMD5Validation
}
if other.S3UseARNRegion != nil {
dst.S3UseARNRegion = other.S3UseARNRegion
}
if other.UseDualStack != nil {
dst.UseDualStack = other.UseDualStack
}
@ -520,6 +563,14 @@ func mergeInConfig(dst *Config, other *Config) {
if other.DisableEndpointHostPrefix != nil {
dst.DisableEndpointHostPrefix = other.DisableEndpointHostPrefix
}
if other.STSRegionalEndpoint != endpoints.UnsetSTSEndpoint {
dst.STSRegionalEndpoint = other.STSRegionalEndpoint
}
if other.S3UsEast1RegionalEndpoint != endpoints.UnsetS3UsEast1Endpoint {
dst.S3UsEast1RegionalEndpoint = other.S3UsEast1RegionalEndpoint
}
}
// Copy will return a shallow copy of the Config object. If any additional

View File

@ -2,42 +2,8 @@
package aws
import "time"
// An emptyCtx is a copy of the Go 1.7 context.emptyCtx type. This is copied to
// provide a 1.6 and 1.5 safe version of context that is compatible with Go
// 1.7's Context.
//
// An emptyCtx is never canceled, has no values, and has no deadline. It is not
// struct{}, since vars of this type must have distinct addresses.
type emptyCtx int
func (*emptyCtx) Deadline() (deadline time.Time, ok bool) {
return
}
func (*emptyCtx) Done() <-chan struct{} {
return nil
}
func (*emptyCtx) Err() error {
return nil
}
func (*emptyCtx) Value(key interface{}) interface{} {
return nil
}
func (e *emptyCtx) String() string {
switch e {
case backgroundCtx:
return "aws.BackgroundContext"
}
return "unknown empty Context"
}
var (
backgroundCtx = new(emptyCtx)
import (
"github.com/aws/aws-sdk-go/internal/context"
)
// BackgroundContext returns a context that will never be canceled, has no
@ -52,5 +18,5 @@ var (
//
// See https://golang.org/pkg/context for more information on Contexts.
func BackgroundContext() Context {
return backgroundCtx
return context.BackgroundCtx
}

View File

@ -179,6 +179,242 @@ func IntValueMap(src map[string]*int) map[string]int {
return dst
}
// Uint returns a pointer to the uint value passed in.
func Uint(v uint) *uint {
return &v
}
// UintValue returns the value of the uint pointer passed in or
// 0 if the pointer is nil.
func UintValue(v *uint) uint {
if v != nil {
return *v
}
return 0
}
// UintSlice converts a slice of uint values uinto a slice of
// uint pointers
func UintSlice(src []uint) []*uint {
dst := make([]*uint, len(src))
for i := 0; i < len(src); i++ {
dst[i] = &(src[i])
}
return dst
}
// UintValueSlice converts a slice of uint pointers uinto a slice of
// uint values
func UintValueSlice(src []*uint) []uint {
dst := make([]uint, len(src))
for i := 0; i < len(src); i++ {
if src[i] != nil {
dst[i] = *(src[i])
}
}
return dst
}
// UintMap converts a string map of uint values uinto a string
// map of uint pointers
func UintMap(src map[string]uint) map[string]*uint {
dst := make(map[string]*uint)
for k, val := range src {
v := val
dst[k] = &v
}
return dst
}
// UintValueMap converts a string map of uint pointers uinto a string
// map of uint values
func UintValueMap(src map[string]*uint) map[string]uint {
dst := make(map[string]uint)
for k, val := range src {
if val != nil {
dst[k] = *val
}
}
return dst
}
// Int8 returns a pointer to the int8 value passed in.
func Int8(v int8) *int8 {
return &v
}
// Int8Value returns the value of the int8 pointer passed in or
// 0 if the pointer is nil.
func Int8Value(v *int8) int8 {
if v != nil {
return *v
}
return 0
}
// Int8Slice converts a slice of int8 values into a slice of
// int8 pointers
func Int8Slice(src []int8) []*int8 {
dst := make([]*int8, len(src))
for i := 0; i < len(src); i++ {
dst[i] = &(src[i])
}
return dst
}
// Int8ValueSlice converts a slice of int8 pointers into a slice of
// int8 values
func Int8ValueSlice(src []*int8) []int8 {
dst := make([]int8, len(src))
for i := 0; i < len(src); i++ {
if src[i] != nil {
dst[i] = *(src[i])
}
}
return dst
}
// Int8Map converts a string map of int8 values into a string
// map of int8 pointers
func Int8Map(src map[string]int8) map[string]*int8 {
dst := make(map[string]*int8)
for k, val := range src {
v := val
dst[k] = &v
}
return dst
}
// Int8ValueMap converts a string map of int8 pointers into a string
// map of int8 values
func Int8ValueMap(src map[string]*int8) map[string]int8 {
dst := make(map[string]int8)
for k, val := range src {
if val != nil {
dst[k] = *val
}
}
return dst
}
// Int16 returns a pointer to the int16 value passed in.
func Int16(v int16) *int16 {
return &v
}
// Int16Value returns the value of the int16 pointer passed in or
// 0 if the pointer is nil.
func Int16Value(v *int16) int16 {
if v != nil {
return *v
}
return 0
}
// Int16Slice converts a slice of int16 values into a slice of
// int16 pointers
func Int16Slice(src []int16) []*int16 {
dst := make([]*int16, len(src))
for i := 0; i < len(src); i++ {
dst[i] = &(src[i])
}
return dst
}
// Int16ValueSlice converts a slice of int16 pointers into a slice of
// int16 values
func Int16ValueSlice(src []*int16) []int16 {
dst := make([]int16, len(src))
for i := 0; i < len(src); i++ {
if src[i] != nil {
dst[i] = *(src[i])
}
}
return dst
}
// Int16Map converts a string map of int16 values into a string
// map of int16 pointers
func Int16Map(src map[string]int16) map[string]*int16 {
dst := make(map[string]*int16)
for k, val := range src {
v := val
dst[k] = &v
}
return dst
}
// Int16ValueMap converts a string map of int16 pointers into a string
// map of int16 values
func Int16ValueMap(src map[string]*int16) map[string]int16 {
dst := make(map[string]int16)
for k, val := range src {
if val != nil {
dst[k] = *val
}
}
return dst
}
// Int32 returns a pointer to the int32 value passed in.
func Int32(v int32) *int32 {
return &v
}
// Int32Value returns the value of the int32 pointer passed in or
// 0 if the pointer is nil.
func Int32Value(v *int32) int32 {
if v != nil {
return *v
}
return 0
}
// Int32Slice converts a slice of int32 values into a slice of
// int32 pointers
func Int32Slice(src []int32) []*int32 {
dst := make([]*int32, len(src))
for i := 0; i < len(src); i++ {
dst[i] = &(src[i])
}
return dst
}
// Int32ValueSlice converts a slice of int32 pointers into a slice of
// int32 values
func Int32ValueSlice(src []*int32) []int32 {
dst := make([]int32, len(src))
for i := 0; i < len(src); i++ {
if src[i] != nil {
dst[i] = *(src[i])
}
}
return dst
}
// Int32Map converts a string map of int32 values into a string
// map of int32 pointers
func Int32Map(src map[string]int32) map[string]*int32 {
dst := make(map[string]*int32)
for k, val := range src {
v := val
dst[k] = &v
}
return dst
}
// Int32ValueMap converts a string map of int32 pointers into a string
// map of int32 values
func Int32ValueMap(src map[string]*int32) map[string]int32 {
dst := make(map[string]int32)
for k, val := range src {
if val != nil {
dst[k] = *val
}
}
return dst
}
// Int64 returns a pointer to the int64 value passed in.
func Int64(v int64) *int64 {
return &v
@ -238,6 +474,301 @@ func Int64ValueMap(src map[string]*int64) map[string]int64 {
return dst
}
// Uint8 returns a pointer to the uint8 value passed in.
func Uint8(v uint8) *uint8 {
return &v
}
// Uint8Value returns the value of the uint8 pointer passed in or
// 0 if the pointer is nil.
func Uint8Value(v *uint8) uint8 {
if v != nil {
return *v
}
return 0
}
// Uint8Slice converts a slice of uint8 values into a slice of
// uint8 pointers
func Uint8Slice(src []uint8) []*uint8 {
dst := make([]*uint8, len(src))
for i := 0; i < len(src); i++ {
dst[i] = &(src[i])
}
return dst
}
// Uint8ValueSlice converts a slice of uint8 pointers into a slice of
// uint8 values
func Uint8ValueSlice(src []*uint8) []uint8 {
dst := make([]uint8, len(src))
for i := 0; i < len(src); i++ {
if src[i] != nil {
dst[i] = *(src[i])
}
}
return dst
}
// Uint8Map converts a string map of uint8 values into a string
// map of uint8 pointers
func Uint8Map(src map[string]uint8) map[string]*uint8 {
dst := make(map[string]*uint8)
for k, val := range src {
v := val
dst[k] = &v
}
return dst
}
// Uint8ValueMap converts a string map of uint8 pointers into a string
// map of uint8 values
func Uint8ValueMap(src map[string]*uint8) map[string]uint8 {
dst := make(map[string]uint8)
for k, val := range src {
if val != nil {
dst[k] = *val
}
}
return dst
}
// Uint16 returns a pointer to the uint16 value passed in.
func Uint16(v uint16) *uint16 {
return &v
}
// Uint16Value returns the value of the uint16 pointer passed in or
// 0 if the pointer is nil.
func Uint16Value(v *uint16) uint16 {
if v != nil {
return *v
}
return 0
}
// Uint16Slice converts a slice of uint16 values into a slice of
// uint16 pointers
func Uint16Slice(src []uint16) []*uint16 {
dst := make([]*uint16, len(src))
for i := 0; i < len(src); i++ {
dst[i] = &(src[i])
}
return dst
}
// Uint16ValueSlice converts a slice of uint16 pointers into a slice of
// uint16 values
func Uint16ValueSlice(src []*uint16) []uint16 {
dst := make([]uint16, len(src))
for i := 0; i < len(src); i++ {
if src[i] != nil {
dst[i] = *(src[i])
}
}
return dst
}
// Uint16Map converts a string map of uint16 values into a string
// map of uint16 pointers
func Uint16Map(src map[string]uint16) map[string]*uint16 {
dst := make(map[string]*uint16)
for k, val := range src {
v := val
dst[k] = &v
}
return dst
}
// Uint16ValueMap converts a string map of uint16 pointers into a string
// map of uint16 values
func Uint16ValueMap(src map[string]*uint16) map[string]uint16 {
dst := make(map[string]uint16)
for k, val := range src {
if val != nil {
dst[k] = *val
}
}
return dst
}
// Uint32 returns a pointer to the uint32 value passed in.
func Uint32(v uint32) *uint32 {
return &v
}
// Uint32Value returns the value of the uint32 pointer passed in or
// 0 if the pointer is nil.
func Uint32Value(v *uint32) uint32 {
if v != nil {
return *v
}
return 0
}
// Uint32Slice converts a slice of uint32 values into a slice of
// uint32 pointers
func Uint32Slice(src []uint32) []*uint32 {
dst := make([]*uint32, len(src))
for i := 0; i < len(src); i++ {
dst[i] = &(src[i])
}
return dst
}
// Uint32ValueSlice converts a slice of uint32 pointers into a slice of
// uint32 values
func Uint32ValueSlice(src []*uint32) []uint32 {
dst := make([]uint32, len(src))
for i := 0; i < len(src); i++ {
if src[i] != nil {
dst[i] = *(src[i])
}
}
return dst
}
// Uint32Map converts a string map of uint32 values into a string
// map of uint32 pointers
func Uint32Map(src map[string]uint32) map[string]*uint32 {
dst := make(map[string]*uint32)
for k, val := range src {
v := val
dst[k] = &v
}
return dst
}
// Uint32ValueMap converts a string map of uint32 pointers into a string
// map of uint32 values
func Uint32ValueMap(src map[string]*uint32) map[string]uint32 {
dst := make(map[string]uint32)
for k, val := range src {
if val != nil {
dst[k] = *val
}
}
return dst
}
// Uint64 returns a pointer to the uint64 value passed in.
func Uint64(v uint64) *uint64 {
return &v
}
// Uint64Value returns the value of the uint64 pointer passed in or
// 0 if the pointer is nil.
func Uint64Value(v *uint64) uint64 {
if v != nil {
return *v
}
return 0
}
// Uint64Slice converts a slice of uint64 values into a slice of
// uint64 pointers
func Uint64Slice(src []uint64) []*uint64 {
dst := make([]*uint64, len(src))
for i := 0; i < len(src); i++ {
dst[i] = &(src[i])
}
return dst
}
// Uint64ValueSlice converts a slice of uint64 pointers into a slice of
// uint64 values
func Uint64ValueSlice(src []*uint64) []uint64 {
dst := make([]uint64, len(src))
for i := 0; i < len(src); i++ {
if src[i] != nil {
dst[i] = *(src[i])
}
}
return dst
}
// Uint64Map converts a string map of uint64 values into a string
// map of uint64 pointers
func Uint64Map(src map[string]uint64) map[string]*uint64 {
dst := make(map[string]*uint64)
for k, val := range src {
v := val
dst[k] = &v
}
return dst
}
// Uint64ValueMap converts a string map of uint64 pointers into a string
// map of uint64 values
func Uint64ValueMap(src map[string]*uint64) map[string]uint64 {
dst := make(map[string]uint64)
for k, val := range src {
if val != nil {
dst[k] = *val
}
}
return dst
}
// Float32 returns a pointer to the float32 value passed in.
func Float32(v float32) *float32 {
return &v
}
// Float32Value returns the value of the float32 pointer passed in or
// 0 if the pointer is nil.
func Float32Value(v *float32) float32 {
if v != nil {
return *v
}
return 0
}
// Float32Slice converts a slice of float32 values into a slice of
// float32 pointers
func Float32Slice(src []float32) []*float32 {
dst := make([]*float32, len(src))
for i := 0; i < len(src); i++ {
dst[i] = &(src[i])
}
return dst
}
// Float32ValueSlice converts a slice of float32 pointers into a slice of
// float32 values
func Float32ValueSlice(src []*float32) []float32 {
dst := make([]float32, len(src))
for i := 0; i < len(src); i++ {
if src[i] != nil {
dst[i] = *(src[i])
}
}
return dst
}
// Float32Map converts a string map of float32 values into a string
// map of float32 pointers
func Float32Map(src map[string]float32) map[string]*float32 {
dst := make(map[string]*float32)
for k, val := range src {
v := val
dst[k] = &v
}
return dst
}
// Float32ValueMap converts a string map of float32 pointers into a string
// map of float32 values
func Float32ValueMap(src map[string]*float32) map[string]float32 {
dst := make(map[string]float32)
for k, val := range src {
if val != nil {
dst[k] = *val
}
}
return dst
}
// Float64 returns a pointer to the float64 value passed in.
func Float64(v float64) *float64 {
return &v

View File

@ -159,9 +159,9 @@ func handleSendError(r *request.Request, err error) {
Body: ioutil.NopCloser(bytes.NewReader([]byte{})),
}
}
// Catch all other request errors.
r.Error = awserr.New("RequestError", "send request failed", err)
r.Retryable = aws.Bool(true) // network errors are retryable
// Catch all request errors, and let the default retrier determine
// if the error is retryable.
r.Error = awserr.New(request.ErrCodeRequestError, "send request failed", err)
// Override the error with a context canceled error, if that was canceled.
ctx := r.Context()
@ -184,7 +184,9 @@ var ValidateResponseHandler = request.NamedHandler{Name: "core.ValidateResponseH
// AfterRetryHandler performs final checks to determine if the request should
// be retried and how long to delay.
var AfterRetryHandler = request.NamedHandler{Name: "core.AfterRetryHandler", Fn: func(r *request.Request) {
var AfterRetryHandler = request.NamedHandler{
Name: "core.AfterRetryHandler",
Fn: func(r *request.Request) {
// If one of the other handlers already set the retry state
// we don't want to override it based on the service's state
if r.Retryable == nil || aws.BoolValue(r.Config.EnforceShouldRetryCheck) {
@ -214,7 +216,7 @@ var AfterRetryHandler = request.NamedHandler{Name: "core.AfterRetryHandler", Fn:
r.RetryCount++
r.Error = nil
}
}}
}}
// ValidateEndpointHandler is a request handler to validate a request had the
// appropriate Region and Endpoint set. Will set r.Error if the endpoint or

View File

@ -0,0 +1,22 @@
// +build !go1.7
package credentials
import (
"github.com/aws/aws-sdk-go/internal/context"
)
// backgroundContext returns a context that will never be canceled, has no
// values, and no deadline. This context is used by the SDK to provide
// backwards compatibility with non-context API operations and functionality.
//
// Go 1.6 and before:
// This context function is equivalent to context.Background in the Go stdlib.
//
// Go 1.7 and later:
// The context returned will be the value returned by context.Background()
//
// See https://golang.org/pkg/context for more information on Contexts.
func backgroundContext() Context {
return context.BackgroundCtx
}

View File

@ -0,0 +1,20 @@
// +build go1.7
package credentials
import "context"
// backgroundContext returns a context that will never be canceled, has no
// values, and no deadline. This context is used by the SDK to provide
// backwards compatibility with non-context API operations and functionality.
//
// Go 1.6 and before:
// This context function is equivalent to context.Background in the Go stdlib.
//
// Go 1.7 and later:
// The context returned will be the value returned by context.Background()
//
// See https://golang.org/pkg/context for more information on Contexts.
func backgroundContext() Context {
return context.Background()
}

View File

@ -0,0 +1,39 @@
// +build !go1.9
package credentials
import "time"
// Context is an copy of the Go v1.7 stdlib's context.Context interface.
// It is represented as a SDK interface to enable you to use the "WithContext"
// API methods with Go v1.6 and a Context type such as golang.org/x/net/context.
//
// This type, aws.Context, and context.Context are equivalent.
//
// See https://golang.org/pkg/context on how to use contexts.
type Context interface {
// Deadline returns the time when work done on behalf of this context
// should be canceled. Deadline returns ok==false when no deadline is
// set. Successive calls to Deadline return the same results.
Deadline() (deadline time.Time, ok bool)
// Done returns a channel that's closed when work done on behalf of this
// context should be canceled. Done may return nil if this context can
// never be canceled. Successive calls to Done return the same value.
Done() <-chan struct{}
// Err returns a non-nil error value after Done is closed. Err returns
// Canceled if the context was canceled or DeadlineExceeded if the
// context's deadline passed. No other values for Err are defined.
// After Done is closed, successive calls to Err return the same value.
Err() error
// Value returns the value associated with this context for key, or nil
// if no value is associated with key. Successive calls to Value with
// the same key returns the same result.
//
// Use context values only for request-scoped data that transits
// processes and API boundaries, not for passing optional parameters to
// functions.
Value(key interface{}) interface{}
}

View File

@ -0,0 +1,13 @@
// +build go1.9
package credentials
import "context"
// Context is an alias of the Go stdlib's context.Context interface.
// It can be used within the SDK's API operation "WithContext" methods.
//
// This type, aws.Context, and context.Context are equivalent.
//
// See https://golang.org/pkg/context on how to use contexts.
type Context = context.Context

View File

@ -50,9 +50,11 @@ package credentials
import (
"fmt"
"github.com/aws/aws-sdk-go/aws/awserr"
"sync"
"sync/atomic"
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/internal/sync/singleflight"
)
// AnonymousCredentials is an empty Credential object that can be used as
@ -83,6 +85,12 @@ type Value struct {
ProviderName string
}
// HasKeys returns if the credentials Value has both AccessKeyID and
// SecretAccessKey value set.
func (v Value) HasKeys() bool {
return len(v.AccessKeyID) != 0 && len(v.SecretAccessKey) != 0
}
// A Provider is the interface for any component which will provide credentials
// Value. A provider is required to manage its own Expired state, and what to
// be expired means.
@ -99,6 +107,13 @@ type Provider interface {
IsExpired() bool
}
// ProviderWithContext is a Provider that can retrieve credentials with a Context
type ProviderWithContext interface {
Provider
RetrieveWithContext(Context) (Value, error)
}
// An Expirer is an interface that Providers can implement to expose the expiration
// time, if known. If the Provider cannot accurately provide this info,
// it should not implement this interface.
@ -190,20 +205,68 @@ func (e *Expiry) ExpiresAt() time.Time {
// first instance of the credentials Value. All calls to Get() after that
// will return the cached credentials Value until IsExpired() returns true.
type Credentials struct {
creds Value
forceRefresh bool
m sync.RWMutex
creds atomic.Value
sf singleflight.Group
provider Provider
}
// NewCredentials returns a pointer to a new Credentials with the provider set.
func NewCredentials(provider Provider) *Credentials {
return &Credentials{
c := &Credentials{
provider: provider,
forceRefresh: true,
}
c.creds.Store(Value{})
return c
}
// GetWithContext returns the credentials value, or error if the credentials
// Value failed to be retrieved. Will return early if the passed in context is
// canceled.
//
// Will return the cached credentials Value if it has not expired. If the
// credentials Value has expired the Provider's Retrieve() will be called
// to refresh the credentials.
//
// If Credentials.Expire() was called the credentials Value will be force
// expired, and the next call to Get() will cause them to be refreshed.
//
// Passed in Context is equivalent to aws.Context, and context.Context.
func (c *Credentials) GetWithContext(ctx Context) (Value, error) {
if curCreds := c.creds.Load(); !c.isExpired(curCreds) {
return curCreds.(Value), nil
}
// Cannot pass context down to the actual retrieve, because the first
// context would cancel the whole group when there is not direct
// association of items in the group.
resCh := c.sf.DoChan("", func() (interface{}, error) {
return c.singleRetrieve(&suppressedContext{ctx})
})
select {
case res := <-resCh:
return res.Val.(Value), res.Err
case <-ctx.Done():
return Value{}, awserr.New("RequestCanceled",
"request context canceled", ctx.Err())
}
}
func (c *Credentials) singleRetrieve(ctx Context) (creds interface{}, err error) {
if curCreds := c.creds.Load(); !c.isExpired(curCreds) {
return curCreds.(Value), nil
}
if p, ok := c.provider.(ProviderWithContext); ok {
creds, err = p.RetrieveWithContext(ctx)
} else {
creds, err = c.provider.Retrieve()
}
if err == nil {
c.creds.Store(creds)
}
return creds, err
}
// Get returns the credentials value, or error if the credentials Value failed
@ -216,30 +279,7 @@ func NewCredentials(provider Provider) *Credentials {
// If Credentials.Expire() was called the credentials Value will be force
// expired, and the next call to Get() will cause them to be refreshed.
func (c *Credentials) Get() (Value, error) {
// Check the cached credentials first with just the read lock.
c.m.RLock()
if !c.isExpired() {
creds := c.creds
c.m.RUnlock()
return creds, nil
}
c.m.RUnlock()
// Credentials are expired need to retrieve the credentials taking the full
// lock.
c.m.Lock()
defer c.m.Unlock()
if c.isExpired() {
creds, err := c.provider.Retrieve()
if err != nil {
return Value{}, err
}
c.creds = creds
c.forceRefresh = false
}
return c.creds, nil
return c.GetWithContext(backgroundContext())
}
// Expire expires the credentials and forces them to be retrieved on the
@ -248,10 +288,7 @@ func (c *Credentials) Get() (Value, error) {
// This will override the Provider's expired state, and force Credentials
// to call the Provider's Retrieve().
func (c *Credentials) Expire() {
c.m.Lock()
defer c.m.Unlock()
c.forceRefresh = true
c.creds.Store(Value{})
}
// IsExpired returns if the credentials are no longer valid, and need
@ -260,33 +297,43 @@ func (c *Credentials) Expire() {
// If the Credentials were forced to be expired with Expire() this will
// reflect that override.
func (c *Credentials) IsExpired() bool {
c.m.RLock()
defer c.m.RUnlock()
return c.isExpired()
return c.isExpired(c.creds.Load())
}
// isExpired helper method wrapping the definition of expired credentials.
func (c *Credentials) isExpired() bool {
return c.forceRefresh || c.provider.IsExpired()
func (c *Credentials) isExpired(creds interface{}) bool {
return creds == nil || creds.(Value) == Value{} || c.provider.IsExpired()
}
// ExpiresAt provides access to the functionality of the Expirer interface of
// the underlying Provider, if it supports that interface. Otherwise, it returns
// an error.
func (c *Credentials) ExpiresAt() (time.Time, error) {
c.m.RLock()
defer c.m.RUnlock()
expirer, ok := c.provider.(Expirer)
if !ok {
return time.Time{}, awserr.New("ProviderNotExpirer",
fmt.Sprintf("provider %s does not support ExpiresAt()", c.creds.ProviderName),
fmt.Sprintf("provider %s does not support ExpiresAt()", c.creds.Load().(Value).ProviderName),
nil)
}
if c.forceRefresh {
if c.creds.Load().(Value) == (Value{}) {
// set expiration time to the distant past
return time.Time{}, nil
}
return expirer.ExpiresAt(), nil
}
type suppressedContext struct {
Context
}
func (s *suppressedContext) Deadline() (deadline time.Time, ok bool) {
return time.Time{}, false
}
func (s *suppressedContext) Done() <-chan struct{} {
return nil
}
func (s *suppressedContext) Err() error {
return nil
}

View File

@ -7,6 +7,7 @@ import (
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/credentials"
@ -87,7 +88,14 @@ func NewCredentialsWithClient(client *ec2metadata.EC2Metadata, options ...func(*
// Error will be returned if the request fails, or unable to extract
// the desired credentials.
func (m *EC2RoleProvider) Retrieve() (credentials.Value, error) {
credsList, err := requestCredList(m.Client)
return m.RetrieveWithContext(aws.BackgroundContext())
}
// RetrieveWithContext retrieves credentials from the EC2 service.
// Error will be returned if the request fails, or unable to extract
// the desired credentials.
func (m *EC2RoleProvider) RetrieveWithContext(ctx credentials.Context) (credentials.Value, error) {
credsList, err := requestCredList(ctx, m.Client)
if err != nil {
return credentials.Value{ProviderName: ProviderName}, err
}
@ -97,7 +105,7 @@ func (m *EC2RoleProvider) Retrieve() (credentials.Value, error) {
}
credsName := credsList[0]
roleCreds, err := requestCred(m.Client, credsName)
roleCreds, err := requestCred(ctx, m.Client, credsName)
if err != nil {
return credentials.Value{ProviderName: ProviderName}, err
}
@ -130,8 +138,8 @@ const iamSecurityCredsPath = "iam/security-credentials/"
// requestCredList requests a list of credentials from the EC2 service.
// If there are no credentials, or there is an error making or receiving the request
func requestCredList(client *ec2metadata.EC2Metadata) ([]string, error) {
resp, err := client.GetMetadata(iamSecurityCredsPath)
func requestCredList(ctx aws.Context, client *ec2metadata.EC2Metadata) ([]string, error) {
resp, err := client.GetMetadataWithContext(ctx, iamSecurityCredsPath)
if err != nil {
return nil, awserr.New("EC2RoleRequestError", "no EC2 instance role found", err)
}
@ -154,8 +162,8 @@ func requestCredList(client *ec2metadata.EC2Metadata) ([]string, error) {
//
// If the credentials cannot be found, or there is an error reading the response
// and error will be returned.
func requestCred(client *ec2metadata.EC2Metadata, credsName string) (ec2RoleCredRespBody, error) {
resp, err := client.GetMetadata(sdkuri.PathJoin(iamSecurityCredsPath, credsName))
func requestCred(ctx aws.Context, client *ec2metadata.EC2Metadata, credsName string) (ec2RoleCredRespBody, error) {
resp, err := client.GetMetadataWithContext(ctx, sdkuri.PathJoin(iamSecurityCredsPath, credsName))
if err != nil {
return ec2RoleCredRespBody{},
awserr.New("EC2RoleRequestError",

View File

@ -98,8 +98,8 @@ func NewProviderClient(cfg aws.Config, handlers request.Handlers, endpoint strin
return p
}
// NewCredentialsClient returns a Credentials wrapper for retrieving credentials
// from an arbitrary endpoint concurrently. The client will request the
// NewCredentialsClient returns a pointer to a new Credentials object
// wrapping the endpoint credentials Provider.
func NewCredentialsClient(cfg aws.Config, handlers request.Handlers, endpoint string, options ...func(*Provider)) *credentials.Credentials {
return credentials.NewCredentials(NewProviderClient(cfg, handlers, endpoint, options...))
}
@ -116,7 +116,13 @@ func (p *Provider) IsExpired() bool {
// Retrieve will attempt to request the credentials from the endpoint the Provider
// was configured for. And error will be returned if the retrieval fails.
func (p *Provider) Retrieve() (credentials.Value, error) {
resp, err := p.getCredentials()
return p.RetrieveWithContext(aws.BackgroundContext())
}
// RetrieveWithContext will attempt to request the credentials from the endpoint the Provider
// was configured for. And error will be returned if the retrieval fails.
func (p *Provider) RetrieveWithContext(ctx credentials.Context) (credentials.Value, error) {
resp, err := p.getCredentials(ctx)
if err != nil {
return credentials.Value{ProviderName: ProviderName},
awserr.New("CredentialsEndpointError", "failed to load credentials", err)
@ -148,7 +154,7 @@ type errorOutput struct {
Message string `json:"message"`
}
func (p *Provider) getCredentials() (*getCredentialsOutput, error) {
func (p *Provider) getCredentials(ctx aws.Context) (*getCredentialsOutput, error) {
op := &request.Operation{
Name: "GetCredentials",
HTTPMethod: "GET",
@ -156,6 +162,7 @@ func (p *Provider) getCredentials() (*getCredentialsOutput, error) {
out := &getCredentialsOutput{}
req := p.Client.NewRequest(op, nil, out)
req.SetContext(ctx)
req.HTTPRequest.Header.Set("Accept", "application/json")
if authToken := p.AuthorizationToken; len(authToken) != 0 {
req.HTTPRequest.Header.Set("Authorization", authToken)

View File

@ -90,6 +90,7 @@ import (
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/internal/sdkio"
)
const (
@ -142,7 +143,7 @@ const (
// DefaultBufSize limits buffer size from growing to an enormous
// amount due to a faulty process.
DefaultBufSize = 1024
DefaultBufSize = int(8 * sdkio.KibiByte)
// DefaultTimeout default limit on time a process can run.
DefaultTimeout = time.Duration(1) * time.Minute

View File

@ -17,8 +17,9 @@ var (
ErrSharedCredentialsHomeNotFound = awserr.New("UserHomeNotFound", "user home directory not found.", nil)
)
// A SharedCredentialsProvider retrieves credentials from the current user's home
// directory, and keeps track if those credentials are expired.
// A SharedCredentialsProvider retrieves access key pair (access key ID,
// secret access key, and session token if present) credentials from the current
// user's home directory, and keeps track if those credentials are expired.
//
// Profile ini file example: $HOME/.aws/credentials
type SharedCredentialsProvider struct {

View File

@ -19,7 +19,9 @@ type StaticProvider struct {
}
// NewStaticCredentials returns a pointer to a new Credentials object
// wrapping a static credentials value provider.
// wrapping a static credentials value provider. Token is only required
// for temporary security credentials retrieved via STS, otherwise an empty
// string can be passed for this parameter.
func NewStaticCredentials(id, secret, token string) *Credentials {
return NewCredentials(&StaticProvider{Value: Value{
AccessKeyID: id,

View File

@ -87,6 +87,7 @@ import (
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/internal/sdkrand"
"github.com/aws/aws-sdk-go/service/sts"
)
@ -118,6 +119,10 @@ type AssumeRoler interface {
AssumeRole(input *sts.AssumeRoleInput) (*sts.AssumeRoleOutput, error)
}
type assumeRolerWithContext interface {
AssumeRoleWithContext(aws.Context, *sts.AssumeRoleInput, ...request.Option) (*sts.AssumeRoleOutput, error)
}
// DefaultDuration is the default amount of time in minutes that the credentials
// will be valid for.
var DefaultDuration = time.Duration(15) * time.Minute
@ -144,6 +149,13 @@ type AssumeRoleProvider struct {
// Session name, if you wish to reuse the credentials elsewhere.
RoleSessionName string
// Optional, you can pass tag key-value pairs to your session. These tags are called session tags.
Tags []*sts.Tag
// A list of keys for session tags that you want to set as transitive.
// If you set a tag key as transitive, the corresponding key and value passes to subsequent sessions in a role chain.
TransitiveTagKeys []*string
// Expiry duration of the STS credentials. Defaults to 15 minutes if not set.
Duration time.Duration
@ -157,6 +169,29 @@ type AssumeRoleProvider struct {
// size.
Policy *string
// The ARNs of IAM managed policies you want to use as managed session policies.
// The policies must exist in the same account as the role.
//
// This parameter is optional. You can provide up to 10 managed policy ARNs.
// However, the plain text that you use for both inline and managed session
// policies can't exceed 2,048 characters.
//
// An AWS conversion compresses the passed session policies and session tags
// into a packed binary format that has a separate limit. Your request can fail
// for this limit even if your plain text meets the other requirements. The
// PackedPolicySize response element indicates by percentage how close the policies
// and tags for your request are to the upper size limit.
//
// Passing policies to this operation returns new temporary credentials. The
// resulting session's permissions are the intersection of the role's identity-based
// policy and the session policies. You can use the role's temporary credentials
// in subsequent AWS API calls to access resources in the account that owns
// the role. You cannot use session policies to grant more permissions than
// those allowed by the identity-based policy of the role that is being assumed.
// For more information, see Session Policies (https://docs.aws.amazon.com/IAM/latest/UserGuide/access_policies.html#policies_session)
// in the IAM User Guide.
PolicyArns []*sts.PolicyDescriptorType
// The identification number of the MFA device that is associated with the user
// who is making the AssumeRole call. Specify this value if the trust policy
// of the role being assumed includes a condition that requires MFA authentication.
@ -258,7 +293,11 @@ func NewCredentialsWithClient(svc AssumeRoler, roleARN string, options ...func(*
// Retrieve generates a new set of temporary credentials using STS.
func (p *AssumeRoleProvider) Retrieve() (credentials.Value, error) {
return p.RetrieveWithContext(aws.BackgroundContext())
}
// RetrieveWithContext generates a new set of temporary credentials using STS.
func (p *AssumeRoleProvider) RetrieveWithContext(ctx credentials.Context) (credentials.Value, error) {
// Apply defaults where parameters are not set.
if p.RoleSessionName == "" {
// Try to work out a role name that will hopefully end up unique.
@ -274,6 +313,9 @@ func (p *AssumeRoleProvider) Retrieve() (credentials.Value, error) {
RoleArn: aws.String(p.RoleARN),
RoleSessionName: aws.String(p.RoleSessionName),
ExternalId: p.ExternalID,
Tags: p.Tags,
PolicyArns: p.PolicyArns,
TransitiveTagKeys: p.TransitiveTagKeys,
}
if p.Policy != nil {
input.Policy = p.Policy
@ -296,7 +338,15 @@ func (p *AssumeRoleProvider) Retrieve() (credentials.Value, error) {
}
}
roleOutput, err := p.Client.AssumeRole(input)
var roleOutput *sts.AssumeRoleOutput
var err error
if c, ok := p.Client.(assumeRolerWithContext); ok {
roleOutput, err = c.AssumeRoleWithContext(ctx, input)
} else {
roleOutput, err = p.Client.AssumeRole(input)
}
if err != nil {
return credentials.Value{ProviderName: ProviderName}, err
}

View File

@ -0,0 +1,135 @@
package stscreds
import (
"fmt"
"io/ioutil"
"strconv"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go/service/sts/stsiface"
)
const (
// ErrCodeWebIdentity will be used as an error code when constructing
// a new error to be returned during session creation or retrieval.
ErrCodeWebIdentity = "WebIdentityErr"
// WebIdentityProviderName is the web identity provider name
WebIdentityProviderName = "WebIdentityCredentials"
)
// now is used to return a time.Time object representing
// the current time. This can be used to easily test and
// compare test values.
var now = time.Now
// TokenFetcher shuold return WebIdentity token bytes or an error
type TokenFetcher interface {
FetchToken(credentials.Context) ([]byte, error)
}
// FetchTokenPath is a path to a WebIdentity token file
type FetchTokenPath string
// FetchToken returns a token by reading from the filesystem
func (f FetchTokenPath) FetchToken(ctx credentials.Context) ([]byte, error) {
data, err := ioutil.ReadFile(string(f))
if err != nil {
errMsg := fmt.Sprintf("unable to read file at %s", f)
return nil, awserr.New(ErrCodeWebIdentity, errMsg, err)
}
return data, nil
}
// WebIdentityRoleProvider is used to retrieve credentials using
// an OIDC token.
type WebIdentityRoleProvider struct {
credentials.Expiry
PolicyArns []*sts.PolicyDescriptorType
client stsiface.STSAPI
ExpiryWindow time.Duration
tokenFetcher TokenFetcher
roleARN string
roleSessionName string
}
// NewWebIdentityCredentials will return a new set of credentials with a given
// configuration, role arn, and token file path.
func NewWebIdentityCredentials(c client.ConfigProvider, roleARN, roleSessionName, path string) *credentials.Credentials {
svc := sts.New(c)
p := NewWebIdentityRoleProvider(svc, roleARN, roleSessionName, path)
return credentials.NewCredentials(p)
}
// NewWebIdentityRoleProvider will return a new WebIdentityRoleProvider with the
// provided stsiface.STSAPI
func NewWebIdentityRoleProvider(svc stsiface.STSAPI, roleARN, roleSessionName, path string) *WebIdentityRoleProvider {
return NewWebIdentityRoleProviderWithToken(svc, roleARN, roleSessionName, FetchTokenPath(path))
}
// NewWebIdentityRoleProviderWithToken will return a new WebIdentityRoleProvider with the
// provided stsiface.STSAPI and a TokenFetcher
func NewWebIdentityRoleProviderWithToken(svc stsiface.STSAPI, roleARN, roleSessionName string, tokenFetcher TokenFetcher) *WebIdentityRoleProvider {
return &WebIdentityRoleProvider{
client: svc,
tokenFetcher: tokenFetcher,
roleARN: roleARN,
roleSessionName: roleSessionName,
}
}
// Retrieve will attempt to assume a role from a token which is located at
// 'WebIdentityTokenFilePath' specified destination and if that is empty an
// error will be returned.
func (p *WebIdentityRoleProvider) Retrieve() (credentials.Value, error) {
return p.RetrieveWithContext(aws.BackgroundContext())
}
// RetrieveWithContext will attempt to assume a role from a token which is located at
// 'WebIdentityTokenFilePath' specified destination and if that is empty an
// error will be returned.
func (p *WebIdentityRoleProvider) RetrieveWithContext(ctx credentials.Context) (credentials.Value, error) {
b, err := p.tokenFetcher.FetchToken(ctx)
if err != nil {
return credentials.Value{}, awserr.New(ErrCodeWebIdentity, "failed fetching WebIdentity token: ", err)
}
sessionName := p.roleSessionName
if len(sessionName) == 0 {
// session name is used to uniquely identify a session. This simply
// uses unix time in nanoseconds to uniquely identify sessions.
sessionName = strconv.FormatInt(now().UnixNano(), 10)
}
req, resp := p.client.AssumeRoleWithWebIdentityRequest(&sts.AssumeRoleWithWebIdentityInput{
PolicyArns: p.PolicyArns,
RoleArn: &p.roleARN,
RoleSessionName: &sessionName,
WebIdentityToken: aws.String(string(b)),
})
req.SetContext(ctx)
// InvalidIdentityToken error is a temporary error that can occur
// when assuming an Role with a JWT web identity token.
req.RetryErrorCodes = append(req.RetryErrorCodes, sts.ErrCodeInvalidIdentityTokenException)
if err := req.Send(); err != nil {
return credentials.Value{}, awserr.New(ErrCodeWebIdentity, "failed to retrieve credentials", err)
}
p.SetExpiration(aws.TimeValue(resp.Credentials.Expiration), p.ExpiryWindow)
value := credentials.Value{
AccessKeyID: aws.StringValue(resp.Credentials.AccessKeyId),
SecretAccessKey: aws.StringValue(resp.Credentials.SecretAccessKey),
SessionToken: aws.StringValue(resp.Credentials.SessionToken),
ProviderName: WebIdentityProviderName,
}
return value, nil
}

View File

@ -1,30 +1,61 @@
// Package csm provides Client Side Monitoring (CSM) which enables sending metrics
// via UDP connection. Using the Start function will enable the reporting of
// metrics on a given port. If Start is called, with different parameters, again,
// a panic will occur.
// Package csm provides the Client Side Monitoring (CSM) client which enables
// sending metrics via UDP connection to the CSM agent. This package provides
// control options, and configuration for the CSM client. The client can be
// controlled manually, or automatically via the SDK's Session configuration.
//
// Pause can be called to pause any metrics publishing on a given port. Sessions
// that have had their handlers modified via InjectHandlers may still be used.
// However, the handlers will act as a no-op meaning no metrics will be published.
// Enabling CSM client via SDK's Session configuration
//
// The CSM client can be enabled automatically via SDK's Session configuration.
// The SDK's session configuration enables the CSM client if the AWS_CSM_PORT
// environment variable is set to a non-empty value.
//
// The configuration options for the CSM client via the SDK's session
// configuration are:
//
// * AWS_CSM_PORT=<port number>
// The port number the CSM agent will receive metrics on.
//
// * AWS_CSM_HOST=<hostname or ip>
// The hostname, or IP address the CSM agent will receive metrics on.
// Without port number.
//
// Manually enabling the CSM client
//
// The CSM client can be started, paused, and resumed manually. The Start
// function will enable the CSM client to publish metrics to the CSM agent. It
// is safe to call Start concurrently, but if Start is called additional times
// with different ClientID or address it will panic.
//
// Example:
// r, err := csm.Start("clientID", ":31000")
// if err != nil {
// panic(fmt.Errorf("failed starting CSM: %v", err))
// }
//
// When controlling the CSM client manually, you must also inject its request
// handlers into the SDK's Session configuration for the SDK's API clients to
// publish metrics.
//
// sess, err := session.NewSession(&aws.Config{})
// if err != nil {
// panic(fmt.Errorf("failed loading session: %v", err))
// }
//
// // Add CSM client's metric publishing request handlers to the SDK's
// // Session Configuration.
// r.InjectHandlers(&sess.Handlers)
//
// client := s3.New(sess)
// resp, err := client.GetObject(&s3.GetObjectInput{
// Bucket: aws.String("bucket"),
// Key: aws.String("key"),
// })
// Controlling CSM client
//
// Once the CSM client has been enabled the Get function will return a Reporter
// value that you can use to pause and resume the metrics published to the CSM
// agent. If Get function is called before the reporter is enabled with the
// Start function or via SDK's Session configuration nil will be returned.
//
// The Pause method can be called to stop the CSM client publishing metrics to
// the CSM agent. The Continue method will resume metric publishing.
//
// // Get the CSM client Reporter.
// r := csm.Get()
//
// // Will pause monitoring
// r.Pause()
@ -35,12 +66,4 @@
//
// // Resume monitoring
// r.Continue()
//
// Start returns a Reporter that is used to enable or disable monitoring. If
// access to the Reporter is required later, calling Get will return the Reporter
// singleton.
//
// Example:
// r := csm.Get()
// r.Continue()
package csm

View File

@ -2,6 +2,7 @@ package csm
import (
"fmt"
"strings"
"sync"
)
@ -9,19 +10,40 @@ var (
lock sync.Mutex
)
// Client side metric handler names
const (
APICallMetricHandlerName = "awscsm.SendAPICallMetric"
APICallAttemptMetricHandlerName = "awscsm.SendAPICallAttemptMetric"
// DefaultPort is used when no port is specified.
DefaultPort = "31000"
// DefaultHost is the host that will be used when none is specified.
DefaultHost = "127.0.0.1"
)
// Start will start the a long running go routine to capture
// AddressWithDefaults returns a CSM address built from the host and port
// values. If the host or port is not set, default values will be used
// instead. If host is "localhost" it will be replaced with "127.0.0.1".
func AddressWithDefaults(host, port string) string {
if len(host) == 0 || strings.EqualFold(host, "localhost") {
host = DefaultHost
}
if len(port) == 0 {
port = DefaultPort
}
// Only IP6 host can contain a colon
if strings.Contains(host, ":") {
return "[" + host + "]:" + port
}
return host + ":" + port
}
// Start will start a long running go routine to capture
// client side metrics. Calling start multiple time will only
// start the metric listener once and will panic if a different
// client ID or port is passed in.
//
// Example:
// r, err := csm.Start("clientID", "127.0.0.1:8094")
// r, err := csm.Start("clientID", "127.0.0.1:31000")
// if err != nil {
// panic(fmt.Errorf("expected no error, but received %v", err))
// }

View File

@ -16,25 +16,26 @@ var (
type metricChan struct {
ch chan metric
paused int64
paused *int64
}
func newMetricChan(size int) metricChan {
return metricChan{
ch: make(chan metric, size),
paused: new(int64),
}
}
func (ch *metricChan) Pause() {
atomic.StoreInt64(&ch.paused, pausedEnum)
atomic.StoreInt64(ch.paused, pausedEnum)
}
func (ch *metricChan) Continue() {
atomic.StoreInt64(&ch.paused, runningEnum)
atomic.StoreInt64(ch.paused, runningEnum)
}
func (ch *metricChan) IsPaused() bool {
v := atomic.LoadInt64(&ch.paused)
v := atomic.LoadInt64(ch.paused)
return v == pausedEnum
}

View File

@ -10,11 +10,6 @@ import (
"github.com/aws/aws-sdk-go/aws/request"
)
const (
// DefaultPort is used when no port is specified
DefaultPort = "31000"
)
// Reporter will gather metrics of API requests made and
// send those metrics to the CSM endpoint.
type Reporter struct {
@ -71,7 +66,6 @@ func (rep *Reporter) sendAPICallAttemptMetric(r *request.Request) {
XAmzRequestID: aws.String(r.RequestID),
AttemptCount: aws.Int(r.RetryCount + 1),
AttemptLatency: aws.Int(int(now.Sub(r.AttemptTime).Nanoseconds() / int64(time.Millisecond))),
AccessKey: aws.String(creds.AccessKeyID),
}
@ -95,7 +89,7 @@ func getMetricException(err awserr.Error) metricException {
code := err.Code()
switch code {
case "RequestError",
case request.ErrCodeRequestError,
request.ErrCodeSerialization,
request.CanceledErrorCode:
return sdkException{
@ -123,7 +117,7 @@ func (rep *Reporter) sendAPICallMetric(r *request.Request) {
Type: aws.String("ApiCall"),
AttemptCount: aws.Int(r.RetryCount + 1),
Region: r.Config.Region,
Latency: aws.Int(int(time.Now().Sub(r.Time) / time.Millisecond)),
Latency: aws.Int(int(time.Since(r.Time) / time.Millisecond)),
XAmzRequestID: aws.String(r.RequestID),
MaxRetriesExceeded: aws.Int(boolIntValue(r.RetryCount >= r.MaxRetries())),
}
@ -190,8 +184,9 @@ func (rep *Reporter) start() {
}
}
// Pause will pause the metric channel preventing any new metrics from
// being added.
// Pause will pause the metric channel preventing any new metrics from being
// added. It is safe to call concurrently with other calls to Pause, but if
// called concurently with Continue can lead to unexpected state.
func (rep *Reporter) Pause() {
lock.Lock()
defer lock.Unlock()
@ -203,8 +198,9 @@ func (rep *Reporter) Pause() {
rep.close()
}
// Continue will reopen the metric channel and allow for monitoring
// to be resumed.
// Continue will reopen the metric channel and allow for monitoring to be
// resumed. It is safe to call concurrently with other calls to Continue, but
// if called concurently with Pause can lead to unexpected state.
func (rep *Reporter) Continue() {
lock.Lock()
defer lock.Unlock()
@ -219,10 +215,18 @@ func (rep *Reporter) Continue() {
rep.metricsCh.Continue()
}
// Client side metric handler names
const (
APICallMetricHandlerName = "awscsm.SendAPICallMetric"
APICallAttemptMetricHandlerName = "awscsm.SendAPICallAttemptMetric"
)
// InjectHandlers will will enable client side metrics and inject the proper
// handlers to handle how metrics are sent.
//
// Example:
// InjectHandlers is NOT safe to call concurrently. Calling InjectHandlers
// multiple times may lead to unexpected behavior, (e.g. duplicate metrics).
//
// // Start must be called in order to inject the correct handlers
// r, err := csm.Start("clientID", "127.0.0.1:8094")
// if err != nil {

View File

@ -4,28 +4,73 @@ import (
"encoding/json"
"fmt"
"net/http"
"strconv"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/internal/sdkuri"
)
// getToken uses the duration to return a token for EC2 metadata service,
// or an error if the request failed.
func (c *EC2Metadata) getToken(ctx aws.Context, duration time.Duration) (tokenOutput, error) {
op := &request.Operation{
Name: "GetToken",
HTTPMethod: "PUT",
HTTPPath: "/api/token",
}
var output tokenOutput
req := c.NewRequest(op, nil, &output)
req.SetContext(ctx)
// remove the fetch token handler from the request handlers to avoid infinite recursion
req.Handlers.Sign.RemoveByName(fetchTokenHandlerName)
// Swap the unmarshalMetadataHandler with unmarshalTokenHandler on this request.
req.Handlers.Unmarshal.Swap(unmarshalMetadataHandlerName, unmarshalTokenHandler)
ttl := strconv.FormatInt(int64(duration/time.Second), 10)
req.HTTPRequest.Header.Set(ttlHeader, ttl)
err := req.Send()
// Errors with bad request status should be returned.
if err != nil {
err = awserr.NewRequestFailure(
awserr.New(req.HTTPResponse.Status, http.StatusText(req.HTTPResponse.StatusCode), err),
req.HTTPResponse.StatusCode, req.RequestID)
}
return output, err
}
// GetMetadata uses the path provided to request information from the EC2
// instance metdata service. The content will be returned as a string, or
// instance metadata service. The content will be returned as a string, or
// error if the request failed.
func (c *EC2Metadata) GetMetadata(p string) (string, error) {
return c.GetMetadataWithContext(aws.BackgroundContext(), p)
}
// GetMetadataWithContext uses the path provided to request information from the EC2
// instance metadata service. The content will be returned as a string, or
// error if the request failed.
func (c *EC2Metadata) GetMetadataWithContext(ctx aws.Context, p string) (string, error) {
op := &request.Operation{
Name: "GetMetadata",
HTTPMethod: "GET",
HTTPPath: sdkuri.PathJoin("/meta-data", p),
}
output := &metadataOutput{}
req := c.NewRequest(op, nil, output)
err := req.Send()
req := c.NewRequest(op, nil, output)
req.SetContext(ctx)
err := req.Send()
return output.Content, err
}
@ -33,6 +78,13 @@ func (c *EC2Metadata) GetMetadata(p string) (string, error) {
// there is no user-data setup for the EC2 instance a "NotFoundError" error
// code will be returned.
func (c *EC2Metadata) GetUserData() (string, error) {
return c.GetUserDataWithContext(aws.BackgroundContext())
}
// GetUserDataWithContext returns the userdata that was configured for the service. If
// there is no user-data setup for the EC2 instance a "NotFoundError" error
// code will be returned.
func (c *EC2Metadata) GetUserDataWithContext(ctx aws.Context) (string, error) {
op := &request.Operation{
Name: "GetUserData",
HTTPMethod: "GET",
@ -41,13 +93,9 @@ func (c *EC2Metadata) GetUserData() (string, error) {
output := &metadataOutput{}
req := c.NewRequest(op, nil, output)
req.Handlers.UnmarshalError.PushBack(func(r *request.Request) {
if r.HTTPResponse.StatusCode == http.StatusNotFound {
r.Error = awserr.New("NotFoundError", "user-data not found", r.Error)
}
})
err := req.Send()
req.SetContext(ctx)
err := req.Send()
return output.Content, err
}
@ -55,6 +103,13 @@ func (c *EC2Metadata) GetUserData() (string, error) {
// instance metadata service for dynamic data. The content will be returned
// as a string, or error if the request failed.
func (c *EC2Metadata) GetDynamicData(p string) (string, error) {
return c.GetDynamicDataWithContext(aws.BackgroundContext(), p)
}
// GetDynamicDataWithContext uses the path provided to request information from the EC2
// instance metadata service for dynamic data. The content will be returned
// as a string, or error if the request failed.
func (c *EC2Metadata) GetDynamicDataWithContext(ctx aws.Context, p string) (string, error) {
op := &request.Operation{
Name: "GetDynamicData",
HTTPMethod: "GET",
@ -63,8 +118,9 @@ func (c *EC2Metadata) GetDynamicData(p string) (string, error) {
output := &metadataOutput{}
req := c.NewRequest(op, nil, output)
err := req.Send()
req.SetContext(ctx)
err := req.Send()
return output.Content, err
}
@ -72,7 +128,14 @@ func (c *EC2Metadata) GetDynamicData(p string) (string, error) {
// instance. Error is returned if the request fails or is unable to parse
// the response.
func (c *EC2Metadata) GetInstanceIdentityDocument() (EC2InstanceIdentityDocument, error) {
resp, err := c.GetDynamicData("instance-identity/document")
return c.GetInstanceIdentityDocumentWithContext(aws.BackgroundContext())
}
// GetInstanceIdentityDocumentWithContext retrieves an identity document describing an
// instance. Error is returned if the request fails or is unable to parse
// the response.
func (c *EC2Metadata) GetInstanceIdentityDocumentWithContext(ctx aws.Context) (EC2InstanceIdentityDocument, error) {
resp, err := c.GetDynamicDataWithContext(ctx, "instance-identity/document")
if err != nil {
return EC2InstanceIdentityDocument{},
awserr.New("EC2MetadataRequestError",
@ -91,7 +154,12 @@ func (c *EC2Metadata) GetInstanceIdentityDocument() (EC2InstanceIdentityDocument
// IAMInfo retrieves IAM info from the metadata API
func (c *EC2Metadata) IAMInfo() (EC2IAMInfo, error) {
resp, err := c.GetMetadata("iam/info")
return c.IAMInfoWithContext(aws.BackgroundContext())
}
// IAMInfoWithContext retrieves IAM info from the metadata API
func (c *EC2Metadata) IAMInfoWithContext(ctx aws.Context) (EC2IAMInfo, error) {
resp, err := c.GetMetadataWithContext(ctx, "iam/info")
if err != nil {
return EC2IAMInfo{},
awserr.New("EC2MetadataRequestError",
@ -116,24 +184,36 @@ func (c *EC2Metadata) IAMInfo() (EC2IAMInfo, error) {
// Region returns the region the instance is running in.
func (c *EC2Metadata) Region() (string, error) {
resp, err := c.GetMetadata("placement/availability-zone")
return c.RegionWithContext(aws.BackgroundContext())
}
// RegionWithContext returns the region the instance is running in.
func (c *EC2Metadata) RegionWithContext(ctx aws.Context) (string, error) {
ec2InstanceIdentityDocument, err := c.GetInstanceIdentityDocumentWithContext(ctx)
if err != nil {
return "", err
}
if len(resp) == 0 {
return "", awserr.New("EC2MetadataError", "invalid Region response", nil)
// extract region from the ec2InstanceIdentityDocument
region := ec2InstanceIdentityDocument.Region
if len(region) == 0 {
return "", awserr.New("EC2MetadataError", "invalid region received for ec2metadata instance", nil)
}
// returns region without the suffix. Eg: us-west-2a becomes us-west-2
return resp[:len(resp)-1], nil
// returns region
return region, nil
}
// Available returns if the application has access to the EC2 Metadata service.
// Can be used to determine if application is running within an EC2 Instance and
// the metadata service is available.
func (c *EC2Metadata) Available() bool {
if _, err := c.GetMetadata("instance-id"); err != nil {
return c.AvailableWithContext(aws.BackgroundContext())
}
// AvailableWithContext returns if the application has access to the EC2 Metadata service.
// Can be used to determine if application is running within an EC2 Instance and
// the metadata service is available.
func (c *EC2Metadata) AvailableWithContext(ctx aws.Context) bool {
if _, err := c.GetMetadataWithContext(ctx, "instance-id"); err != nil {
return false
}
@ -153,6 +233,7 @@ type EC2IAMInfo struct {
// an instance identity document
type EC2InstanceIdentityDocument struct {
DevpayProductCodes []string `json:"devpayProductCodes"`
MarketplaceProductCodes []string `json:"marketplaceProductCodes"`
AvailabilityZone string `json:"availabilityZone"`
PrivateIP string `json:"privateIp"`
Version string `json:"version"`

View File

@ -13,6 +13,7 @@ import (
"io"
"net/http"
"os"
"strconv"
"strings"
"time"
@ -24,9 +25,25 @@ import (
"github.com/aws/aws-sdk-go/aws/request"
)
// ServiceName is the name of the service.
const ServiceName = "ec2metadata"
const disableServiceEnvVar = "AWS_EC2_METADATA_DISABLED"
const (
// ServiceName is the name of the service.
ServiceName = "ec2metadata"
disableServiceEnvVar = "AWS_EC2_METADATA_DISABLED"
// Headers for Token and TTL
ttlHeader = "x-aws-ec2-metadata-token-ttl-seconds"
tokenHeader = "x-aws-ec2-metadata-token"
// Named Handler constants
fetchTokenHandlerName = "FetchTokenHandler"
unmarshalMetadataHandlerName = "unmarshalMetadataHandler"
unmarshalTokenHandlerName = "unmarshalTokenHandler"
enableTokenProviderHandlerName = "enableTokenProviderHandler"
// TTL constants
defaultTTL = 21600 * time.Second
ttlExpirationWindow = 30 * time.Second
)
// A EC2Metadata is an EC2 Metadata service Client.
type EC2Metadata struct {
@ -63,8 +80,10 @@ func NewClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegio
// use a shorter timeout than default because the metadata
// service is local if it is running, and to fail faster
// if not running on an ec2 instance.
Timeout: 5 * time.Second,
Timeout: 1 * time.Second,
}
// max number of retries on the client operation
cfg.MaxRetries = aws.Int(2)
}
svc := &EC2Metadata{
@ -80,13 +99,27 @@ func NewClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegio
),
}
svc.Handlers.Unmarshal.PushBack(unmarshalHandler)
// token provider instance
tp := newTokenProvider(svc, defaultTTL)
// NamedHandler for fetching token
svc.Handlers.Sign.PushBackNamed(request.NamedHandler{
Name: fetchTokenHandlerName,
Fn: tp.fetchTokenHandler,
})
// NamedHandler for enabling token provider
svc.Handlers.Complete.PushBackNamed(request.NamedHandler{
Name: enableTokenProviderHandlerName,
Fn: tp.enableTokenProviderHandler,
})
svc.Handlers.Unmarshal.PushBackNamed(unmarshalHandler)
svc.Handlers.UnmarshalError.PushBack(unmarshalError)
svc.Handlers.Validate.Clear()
svc.Handlers.Validate.PushBack(validateEndpointHandler)
// Disable the EC2 Metadata service if the environment variable is set.
// This shortcirctes the service's functionality to always fail to send
// This short-circuits the service's functionality to always fail to send
// requests.
if strings.ToLower(os.Getenv(disableServiceEnvVar)) == "true" {
svc.Handlers.Send.SwapNamed(request.NamedHandler{
@ -107,7 +140,6 @@ func NewClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegio
for _, option := range opts {
option(svc.Client)
}
return svc
}
@ -119,30 +151,74 @@ type metadataOutput struct {
Content string
}
func unmarshalHandler(r *request.Request) {
type tokenOutput struct {
Token string
TTL time.Duration
}
// unmarshal token handler is used to parse the response of a getToken operation
var unmarshalTokenHandler = request.NamedHandler{
Name: unmarshalTokenHandlerName,
Fn: func(r *request.Request) {
defer r.HTTPResponse.Body.Close()
b := &bytes.Buffer{}
if _, err := io.Copy(b, r.HTTPResponse.Body); err != nil {
r.Error = awserr.New(request.ErrCodeSerialization, "unable to unmarshal EC2 metadata respose", err)
var b bytes.Buffer
if _, err := io.Copy(&b, r.HTTPResponse.Body); err != nil {
r.Error = awserr.NewRequestFailure(awserr.New(request.ErrCodeSerialization,
"unable to unmarshal EC2 metadata response", err), r.HTTPResponse.StatusCode, r.RequestID)
return
}
v := r.HTTPResponse.Header.Get(ttlHeader)
data, ok := r.Data.(*tokenOutput)
if !ok {
return
}
data.Token = b.String()
// TTL is in seconds
i, err := strconv.ParseInt(v, 10, 64)
if err != nil {
r.Error = awserr.NewRequestFailure(awserr.New(request.ParamFormatErrCode,
"unable to parse EC2 token TTL response", err), r.HTTPResponse.StatusCode, r.RequestID)
return
}
t := time.Duration(i) * time.Second
data.TTL = t
},
}
var unmarshalHandler = request.NamedHandler{
Name: unmarshalMetadataHandlerName,
Fn: func(r *request.Request) {
defer r.HTTPResponse.Body.Close()
var b bytes.Buffer
if _, err := io.Copy(&b, r.HTTPResponse.Body); err != nil {
r.Error = awserr.NewRequestFailure(awserr.New(request.ErrCodeSerialization,
"unable to unmarshal EC2 metadata response", err), r.HTTPResponse.StatusCode, r.RequestID)
return
}
if data, ok := r.Data.(*metadataOutput); ok {
data.Content = b.String()
}
},
}
func unmarshalError(r *request.Request) {
defer r.HTTPResponse.Body.Close()
b := &bytes.Buffer{}
if _, err := io.Copy(b, r.HTTPResponse.Body); err != nil {
r.Error = awserr.New(request.ErrCodeSerialization, "unable to unmarshal EC2 metadata error respose", err)
var b bytes.Buffer
if _, err := io.Copy(&b, r.HTTPResponse.Body); err != nil {
r.Error = awserr.NewRequestFailure(
awserr.New(request.ErrCodeSerialization, "unable to unmarshal EC2 metadata error response", err),
r.HTTPResponse.StatusCode, r.RequestID)
return
}
// Response body format is not consistent between metadata endpoints.
// Grab the error message as a string and include that as the source error
r.Error = awserr.New("EC2MetadataError", "failed to make EC2Metadata request", errors.New(b.String()))
r.Error = awserr.NewRequestFailure(awserr.New("EC2MetadataError", "failed to make EC2Metadata request", errors.New(b.String())),
r.HTTPResponse.StatusCode, r.RequestID)
}
func validateEndpointHandler(r *request.Request) {

View File

@ -0,0 +1,92 @@
package ec2metadata
import (
"net/http"
"sync/atomic"
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
)
// A tokenProvider struct provides access to EC2Metadata client
// and atomic instance of a token, along with configuredTTL for it.
// tokenProvider also provides an atomic flag to disable the
// fetch token operation.
// The disabled member will use 0 as false, and 1 as true.
type tokenProvider struct {
client *EC2Metadata
token atomic.Value
configuredTTL time.Duration
disabled uint32
}
// A ec2Token struct helps use of token in EC2 Metadata service ops
type ec2Token struct {
token string
credentials.Expiry
}
// newTokenProvider provides a pointer to a tokenProvider instance
func newTokenProvider(c *EC2Metadata, duration time.Duration) *tokenProvider {
return &tokenProvider{client: c, configuredTTL: duration}
}
// fetchTokenHandler fetches token for EC2Metadata service client by default.
func (t *tokenProvider) fetchTokenHandler(r *request.Request) {
// short-circuits to insecure data flow if tokenProvider is disabled.
if v := atomic.LoadUint32(&t.disabled); v == 1 {
return
}
if ec2Token, ok := t.token.Load().(ec2Token); ok && !ec2Token.IsExpired() {
r.HTTPRequest.Header.Set(tokenHeader, ec2Token.token)
return
}
output, err := t.client.getToken(r.Context(), t.configuredTTL)
if err != nil {
// change the disabled flag on token provider to true,
// when error is request timeout error.
if requestFailureError, ok := err.(awserr.RequestFailure); ok {
switch requestFailureError.StatusCode() {
case http.StatusForbidden, http.StatusNotFound, http.StatusMethodNotAllowed:
atomic.StoreUint32(&t.disabled, 1)
case http.StatusBadRequest:
r.Error = requestFailureError
}
// Check if request timed out while waiting for response
if e, ok := requestFailureError.OrigErr().(awserr.Error); ok {
if e.Code() == request.ErrCodeRequestError {
atomic.StoreUint32(&t.disabled, 1)
}
}
}
return
}
newToken := ec2Token{
token: output.Token,
}
newToken.SetExpiration(time.Now().Add(output.TTL), ttlExpirationWindow)
t.token.Store(newToken)
// Inject token header to the request.
if ec2Token, ok := t.token.Load().(ec2Token); ok {
r.HTTPRequest.Header.Set(tokenHeader, ec2Token.token)
}
}
// enableTokenProviderHandler enables the token provider
func (t *tokenProvider) enableTokenProviderHandler(r *request.Request) {
// If the error code status is 401, we enable the token provider
if e, ok := r.Error.(awserr.RequestFailure); ok && e != nil &&
e.StatusCode() == http.StatusUnauthorized {
atomic.StoreUint32(&t.disabled, 0)
}
}

View File

@ -83,6 +83,7 @@ func decodeV3Endpoints(modelDef modelDefinition, opts DecodeModelOptions) (Resol
p := &ps[i]
custAddEC2Metadata(p)
custAddS3DualStack(p)
custRegionalS3(p)
custRmIotDataService(p)
custFixAppAutoscalingChina(p)
custFixAppAutoscalingUsGov(p)
@ -92,7 +93,7 @@ func decodeV3Endpoints(modelDef modelDefinition, opts DecodeModelOptions) (Resol
}
func custAddS3DualStack(p *partition) {
if p.ID != "aws" {
if !(p.ID == "aws" || p.ID == "aws-cn" || p.ID == "aws-us-gov") {
return
}
@ -100,6 +101,33 @@ func custAddS3DualStack(p *partition) {
custAddDualstack(p, "s3-control")
}
func custRegionalS3(p *partition) {
if p.ID != "aws" {
return
}
service, ok := p.Services["s3"]
if !ok {
return
}
// If global endpoint already exists no customization needed.
if _, ok := service.Endpoints["aws-global"]; ok {
return
}
service.PartitionEndpoint = "aws-global"
service.Endpoints["us-east-1"] = endpoint{}
service.Endpoints["aws-global"] = endpoint{
Hostname: "s3.amazonaws.com",
CredentialScope: credentialScope{
Region: "us-east-1",
},
}
p.Services["s3"] = service
}
func custAddDualstack(p *partition, svcName string) {
s, ok := p.Services[svcName]
if !ok {

File diff suppressed because it is too large Load Diff

View File

@ -3,6 +3,7 @@ package endpoints
import (
"fmt"
"regexp"
"strings"
"github.com/aws/aws-sdk-go/aws/awserr"
)
@ -46,6 +47,108 @@ type Options struct {
//
// This option is ignored if StrictMatching is enabled.
ResolveUnknownService bool
// STS Regional Endpoint flag helps with resolving the STS endpoint
STSRegionalEndpoint STSRegionalEndpoint
// S3 Regional Endpoint flag helps with resolving the S3 endpoint
S3UsEast1RegionalEndpoint S3UsEast1RegionalEndpoint
}
// STSRegionalEndpoint is an enum for the states of the STS Regional Endpoint
// options.
type STSRegionalEndpoint int
func (e STSRegionalEndpoint) String() string {
switch e {
case LegacySTSEndpoint:
return "legacy"
case RegionalSTSEndpoint:
return "regional"
case UnsetSTSEndpoint:
return ""
default:
return "unknown"
}
}
const (
// UnsetSTSEndpoint represents that STS Regional Endpoint flag is not specified.
UnsetSTSEndpoint STSRegionalEndpoint = iota
// LegacySTSEndpoint represents when STS Regional Endpoint flag is specified
// to use legacy endpoints.
LegacySTSEndpoint
// RegionalSTSEndpoint represents when STS Regional Endpoint flag is specified
// to use regional endpoints.
RegionalSTSEndpoint
)
// GetSTSRegionalEndpoint function returns the STSRegionalEndpointFlag based
// on the input string provided in env config or shared config by the user.
//
// `legacy`, `regional` are the only case-insensitive valid strings for
// resolving the STS regional Endpoint flag.
func GetSTSRegionalEndpoint(s string) (STSRegionalEndpoint, error) {
switch {
case strings.EqualFold(s, "legacy"):
return LegacySTSEndpoint, nil
case strings.EqualFold(s, "regional"):
return RegionalSTSEndpoint, nil
default:
return UnsetSTSEndpoint, fmt.Errorf("unable to resolve the value of STSRegionalEndpoint for %v", s)
}
}
// S3UsEast1RegionalEndpoint is an enum for the states of the S3 us-east-1
// Regional Endpoint options.
type S3UsEast1RegionalEndpoint int
func (e S3UsEast1RegionalEndpoint) String() string {
switch e {
case LegacyS3UsEast1Endpoint:
return "legacy"
case RegionalS3UsEast1Endpoint:
return "regional"
case UnsetS3UsEast1Endpoint:
return ""
default:
return "unknown"
}
}
const (
// UnsetS3UsEast1Endpoint represents that S3 Regional Endpoint flag is not
// specified.
UnsetS3UsEast1Endpoint S3UsEast1RegionalEndpoint = iota
// LegacyS3UsEast1Endpoint represents when S3 Regional Endpoint flag is
// specified to use legacy endpoints.
LegacyS3UsEast1Endpoint
// RegionalS3UsEast1Endpoint represents when S3 Regional Endpoint flag is
// specified to use regional endpoints.
RegionalS3UsEast1Endpoint
)
// GetS3UsEast1RegionalEndpoint function returns the S3UsEast1RegionalEndpointFlag based
// on the input string provided in env config or shared config by the user.
//
// `legacy`, `regional` are the only case-insensitive valid strings for
// resolving the S3 regional Endpoint flag.
func GetS3UsEast1RegionalEndpoint(s string) (S3UsEast1RegionalEndpoint, error) {
switch {
case strings.EqualFold(s, "legacy"):
return LegacyS3UsEast1Endpoint, nil
case strings.EqualFold(s, "regional"):
return RegionalS3UsEast1Endpoint, nil
default:
return UnsetS3UsEast1Endpoint,
fmt.Errorf("unable to resolve the value of S3UsEast1RegionalEndpoint for %v", s)
}
}
// Set combines all of the option functions together.
@ -79,6 +182,12 @@ func ResolveUnknownServiceOption(o *Options) {
o.ResolveUnknownService = true
}
// STSRegionalEndpointOption enables the STS endpoint resolver behavior to resolve
// STS endpoint to their regional endpoint, instead of the global endpoint.
func STSRegionalEndpointOption(o *Options) {
o.STSRegionalEndpoint = RegionalSTSEndpoint
}
// A Resolver provides the interface for functionality to resolve endpoints.
// The build in Partition and DefaultResolver return value satisfy this interface.
type Resolver interface {
@ -170,10 +279,13 @@ func PartitionForRegion(ps []Partition, regionID string) (Partition, bool) {
// A Partition provides the ability to enumerate the partition's regions
// and services.
type Partition struct {
id string
id, dnsSuffix string
p *partition
}
// DNSSuffix returns the base domain name of the partition.
func (p Partition) DNSSuffix() string { return p.dnsSuffix }
// ID returns the identifier of the partition.
func (p Partition) ID() string { return p.id }
@ -191,7 +303,7 @@ func (p Partition) ID() string { return p.id }
// require the provided service and region to be known by the partition.
// If the endpoint cannot be strictly resolved an error will be returned. This
// mode is useful to ensure the endpoint resolved is valid. Without
// StrictMatching enabled the endpoint returned my look valid but may not work.
// StrictMatching enabled the endpoint returned may look valid but may not work.
// StrictMatching requires the SDK to be updated if you want to take advantage
// of new regions and services expansions.
//
@ -205,7 +317,7 @@ func (p Partition) EndpointFor(service, region string, opts ...func(*Options)) (
// Regions returns a map of Regions indexed by their ID. This is useful for
// enumerating over the regions in a partition.
func (p Partition) Regions() map[string]Region {
rs := map[string]Region{}
rs := make(map[string]Region, len(p.p.Regions))
for id, r := range p.p.Regions {
rs[id] = Region{
id: id,
@ -220,7 +332,7 @@ func (p Partition) Regions() map[string]Region {
// Services returns a map of Service indexed by their ID. This is useful for
// enumerating over the services in a partition.
func (p Partition) Services() map[string]Service {
ss := map[string]Service{}
ss := make(map[string]Service, len(p.p.Services))
for id := range p.p.Services {
ss[id] = Service{
id: id,
@ -307,7 +419,7 @@ func (s Service) Regions() map[string]Region {
// A region is the AWS region the service exists in. Whereas a Endpoint is
// an URL that can be resolved to a instance of a service.
func (s Service) Endpoints() map[string]Endpoint {
es := map[string]Endpoint{}
es := make(map[string]Endpoint, len(s.p.Services[s.id].Endpoints))
for id := range s.p.Services[s.id].Endpoints {
es[id] = Endpoint{
id: id,
@ -347,6 +459,9 @@ type ResolvedEndpoint struct {
// The endpoint URL
URL string
// The endpoint partition
PartitionID string
// The region that should be used for signing requests.
SigningRegion string

View File

@ -0,0 +1,24 @@
package endpoints
var legacyGlobalRegions = map[string]map[string]struct{}{
"sts": {
"ap-northeast-1": {},
"ap-south-1": {},
"ap-southeast-1": {},
"ap-southeast-2": {},
"ca-central-1": {},
"eu-central-1": {},
"eu-north-1": {},
"eu-west-1": {},
"eu-west-2": {},
"eu-west-3": {},
"sa-east-1": {},
"us-east-1": {},
"us-east-2": {},
"us-west-1": {},
"us-west-2": {},
},
"s3": {
"us-east-1": {},
},
}

View File

@ -7,6 +7,8 @@ import (
"strings"
)
var regionValidationRegex = regexp.MustCompile(`^[[:alnum:]]([[:alnum:]\-]*[[:alnum:]])?$`)
type partitions []partition
func (ps partitions) EndpointFor(service, region string, opts ...func(*Options)) (ResolvedEndpoint, error) {
@ -54,6 +56,7 @@ type partition struct {
func (p partition) Partition() Partition {
return Partition{
dnsSuffix: p.DNSSuffix,
id: p.ID,
p: &p,
}
@ -74,24 +77,56 @@ func (p partition) canResolveEndpoint(service, region string, strictMatch bool)
return p.RegionRegex.MatchString(region)
}
func allowLegacyEmptyRegion(service string) bool {
legacy := map[string]struct{}{
"budgets": {},
"ce": {},
"chime": {},
"cloudfront": {},
"ec2metadata": {},
"iam": {},
"importexport": {},
"organizations": {},
"route53": {},
"sts": {},
"support": {},
"waf": {},
}
_, allowed := legacy[service]
return allowed
}
func (p partition) EndpointFor(service, region string, opts ...func(*Options)) (resolved ResolvedEndpoint, err error) {
var opt Options
opt.Set(opts...)
s, hasService := p.Services[service]
if !(hasService || opt.ResolveUnknownService) {
if len(service) == 0 || !(hasService || opt.ResolveUnknownService) {
// Only return error if the resolver will not fallback to creating
// endpoint based on service endpoint ID passed in.
return resolved, NewUnknownServiceError(p.ID, service, serviceList(p.Services))
}
if len(region) == 0 && allowLegacyEmptyRegion(service) && len(s.PartitionEndpoint) != 0 {
region = s.PartitionEndpoint
}
if (service == "sts" && opt.STSRegionalEndpoint != RegionalSTSEndpoint) ||
(service == "s3" && opt.S3UsEast1RegionalEndpoint != RegionalS3UsEast1Endpoint) {
if _, ok := legacyGlobalRegions[service][region]; ok {
region = "aws-global"
}
}
e, hasEndpoint := s.endpointForRegion(region)
if !hasEndpoint && opt.StrictMatching {
if len(region) == 0 || (!hasEndpoint && opt.StrictMatching) {
return resolved, NewUnknownEndpointError(p.ID, service, region, endpointList(s.Endpoints))
}
defs := []endpoint{p.Defaults, s.Defaults}
return e.resolve(service, region, p.DNSSuffix, defs, opt), nil
return e.resolve(service, p.ID, region, p.DNSSuffix, defs, opt)
}
func serviceList(ss services) []string {
@ -200,7 +235,7 @@ func getByPriority(s []string, p []string, def string) string {
return s[0]
}
func (e endpoint) resolve(service, region, dnsSuffix string, defs []endpoint, opts Options) ResolvedEndpoint {
func (e endpoint) resolve(service, partitionID, region, dnsSuffix string, defs []endpoint, opts Options) (ResolvedEndpoint, error) {
var merged endpoint
for _, def := range defs {
merged.mergeIn(def)
@ -208,20 +243,6 @@ func (e endpoint) resolve(service, region, dnsSuffix string, defs []endpoint, op
merged.mergeIn(e)
e = merged
hostname := e.Hostname
// Offset the hostname for dualstack if enabled
if opts.UseDualStack && e.HasDualStack == boxedTrue {
hostname = e.DualStackHostname
}
u := strings.Replace(hostname, "{service}", service, 1)
u = strings.Replace(u, "{region}", region, 1)
u = strings.Replace(u, "{dnsSuffix}", dnsSuffix, 1)
scheme := getEndpointScheme(e.Protocols, opts.DisableSSL)
u = fmt.Sprintf("%s://%s", scheme, u)
signingRegion := e.CredentialScope.Region
if len(signingRegion) == 0 {
signingRegion = region
@ -234,13 +255,32 @@ func (e endpoint) resolve(service, region, dnsSuffix string, defs []endpoint, op
signingNameDerived = true
}
hostname := e.Hostname
// Offset the hostname for dualstack if enabled
if opts.UseDualStack && e.HasDualStack == boxedTrue {
hostname = e.DualStackHostname
region = signingRegion
}
if !validateInputRegion(region) {
return ResolvedEndpoint{}, fmt.Errorf("invalid region identifier format provided")
}
u := strings.Replace(hostname, "{service}", service, 1)
u = strings.Replace(u, "{region}", region, 1)
u = strings.Replace(u, "{dnsSuffix}", dnsSuffix, 1)
scheme := getEndpointScheme(e.Protocols, opts.DisableSSL)
u = fmt.Sprintf("%s://%s", scheme, u)
return ResolvedEndpoint{
URL: u,
PartitionID: partitionID,
SigningRegion: signingRegion,
SigningName: signingName,
SigningNameDerived: signingNameDerived,
SigningMethod: getByPriority(e.SignatureVersions, signerPriority, defaultSigner),
}
}, nil
}
func getEndpointScheme(protocols []string, disableSSL bool) string {
@ -305,3 +345,7 @@ const (
boxedFalse
boxedTrue
)
func validateInputRegion(region string) bool {
return regionValidationRegex.MatchString(region)
}

View File

@ -5,5 +5,14 @@ import (
)
func isErrConnectionReset(err error) bool {
return strings.Contains(err.Error(), "connection reset")
if strings.Contains(err.Error(), "read: connection reset") {
return false
}
if strings.Contains(err.Error(), "connection reset") ||
strings.Contains(err.Error(), "broken pipe") {
return true
}
return false
}

View File

@ -10,6 +10,7 @@ import (
type Handlers struct {
Validate HandlerList
Build HandlerList
BuildStream HandlerList
Sign HandlerList
Send HandlerList
ValidateResponse HandlerList
@ -23,11 +24,12 @@ type Handlers struct {
Complete HandlerList
}
// Copy returns of this handler's lists.
// Copy returns a copy of this handler's lists.
func (h *Handlers) Copy() Handlers {
return Handlers{
Validate: h.Validate.copy(),
Build: h.Build.copy(),
BuildStream: h.BuildStream.copy(),
Sign: h.Sign.copy(),
Send: h.Send.copy(),
ValidateResponse: h.ValidateResponse.copy(),
@ -42,10 +44,11 @@ func (h *Handlers) Copy() Handlers {
}
}
// Clear removes callback functions for all handlers
// Clear removes callback functions for all handlers.
func (h *Handlers) Clear() {
h.Validate.Clear()
h.Build.Clear()
h.BuildStream.Clear()
h.Send.Clear()
h.Sign.Clear()
h.Unmarshal.Clear()
@ -59,6 +62,54 @@ func (h *Handlers) Clear() {
h.Complete.Clear()
}
// IsEmpty returns if there are no handlers in any of the handlerlists.
func (h *Handlers) IsEmpty() bool {
if h.Validate.Len() != 0 {
return false
}
if h.Build.Len() != 0 {
return false
}
if h.BuildStream.Len() != 0 {
return false
}
if h.Send.Len() != 0 {
return false
}
if h.Sign.Len() != 0 {
return false
}
if h.Unmarshal.Len() != 0 {
return false
}
if h.UnmarshalStream.Len() != 0 {
return false
}
if h.UnmarshalMeta.Len() != 0 {
return false
}
if h.UnmarshalError.Len() != 0 {
return false
}
if h.ValidateResponse.Len() != 0 {
return false
}
if h.Retry.Len() != 0 {
return false
}
if h.AfterRetry.Len() != 0 {
return false
}
if h.CompleteAttempt.Len() != 0 {
return false
}
if h.Complete.Len() != 0 {
return false
}
return true
}
// A HandlerListRunItem represents an entry in the HandlerList which
// is being run.
type HandlerListRunItem struct {
@ -275,3 +326,18 @@ func MakeAddToUserAgentFreeFormHandler(s string) func(*Request) {
AddToUserAgent(r, s)
}
}
// WithSetRequestHeaders updates the operation request's HTTP header to contain
// the header key value pairs provided. If the header key already exists in the
// request's HTTP header set, the existing value(s) will be replaced.
func WithSetRequestHeaders(h map[string]string) Option {
return withRequestHeader(h).SetRequestHeaders
}
type withRequestHeader map[string]string
func (h withRequestHeader) SetRequestHeaders(r *Request) {
for k, v := range h {
r.HTTPRequest.Header[k] = []string{v}
}
}

View File

@ -15,12 +15,15 @@ type offsetReader struct {
closed bool
}
func newOffsetReader(buf io.ReadSeeker, offset int64) *offsetReader {
func newOffsetReader(buf io.ReadSeeker, offset int64) (*offsetReader, error) {
reader := &offsetReader{}
buf.Seek(offset, sdkio.SeekStart)
_, err := buf.Seek(offset, sdkio.SeekStart)
if err != nil {
return nil, err
}
reader.buf = buf
return reader
return reader, nil
}
// Close will close the instance of the offset reader's access to
@ -54,7 +57,9 @@ func (o *offsetReader) Seek(offset int64, whence int) (int64, error) {
// CloseAndCopy will return a new offsetReader with a copy of the old buffer
// and close the old buffer.
func (o *offsetReader) CloseAndCopy(offset int64) *offsetReader {
o.Close()
func (o *offsetReader) CloseAndCopy(offset int64) (*offsetReader, error) {
if err := o.Close(); err != nil {
return nil, err
}
return newOffsetReader(o.buf, offset)
}

View File

@ -36,6 +36,10 @@ const (
// API request that was canceled. Requests given a aws.Context may
// return this error when canceled.
CanceledErrorCode = "RequestCanceled"
// ErrCodeRequestError is an error preventing the SDK from continuing to
// process the request.
ErrCodeRequestError = "RequestError"
)
// A Request is the service request to be made.
@ -51,6 +55,7 @@ type Request struct {
HTTPRequest *http.Request
HTTPResponse *http.Response
Body io.ReadSeeker
streamingBody io.ReadCloser
BodyStart int64 // offset from beginning of Body that the request body starts
Params interface{}
Error error
@ -64,6 +69,15 @@ type Request struct {
LastSignedAt time.Time
DisableFollowRedirects bool
// Additional API error codes that should be retried. IsErrorRetryable
// will consider these codes in addition to its built in cases.
RetryErrorCodes []string
// Additional API error codes that should be retried with throttle backoff
// delay. IsErrorThrottle will consider these codes in addition to its
// built in cases.
ThrottleErrorCodes []string
// A value greater than 0 instructs the request to be signed as Presigned URL
// You should not set this field directly. Instead use Request's
// Presign or PresignRequest methods.
@ -90,8 +104,12 @@ type Operation struct {
BeforePresignFn func(r *Request) error
}
// New returns a new Request pointer for the service API
// operation and parameters.
// New returns a new Request pointer for the service API operation and
// parameters.
//
// A Retryer should be provided to direct how the request is retried. If
// Retryer is nil, a default no retry value will be used. You can use
// NoOpRetryer in the Client package to disable retry behavior directly.
//
// Params is any value of input parameters to be the request payload.
// Data is pointer value to an object which the request's response
@ -99,6 +117,10 @@ type Operation struct {
func New(cfg aws.Config, clientInfo metadata.ClientInfo, handlers Handlers,
retryer Retryer, operation *Operation, params interface{}, data interface{}) *Request {
if retryer == nil {
retryer = noOpRetryer{}
}
method := operation.HTTPMethod
if method == "" {
method = "POST"
@ -113,8 +135,6 @@ func New(cfg aws.Config, clientInfo metadata.ClientInfo, handlers Handlers,
err = awserr.New("InvalidEndpointURL", "invalid endpoint uri", err)
}
SanitizeHostForHeader(httpReq)
r := &Request{
Config: cfg,
ClientInfo: clientInfo,
@ -231,6 +251,10 @@ func (r *Request) WillRetry() bool {
return r.Error != nil && aws.BoolValue(r.Retryable) && r.RetryCount < r.MaxRetries()
}
func fmtAttemptCount(retryCount, maxRetries int) string {
return fmt.Sprintf("attempt %v/%v", retryCount, maxRetries)
}
// ParamsFilled returns if the request's parameters have been populated
// and the parameters are valid. False is returned if no parameters are
// provided or invalid.
@ -259,10 +283,28 @@ func (r *Request) SetStringBody(s string) {
// SetReaderBody will set the request's body reader.
func (r *Request) SetReaderBody(reader io.ReadSeeker) {
r.Body = reader
r.BodyStart, _ = reader.Seek(0, sdkio.SeekCurrent) // Get the Bodies current offset.
if aws.IsReaderSeekable(reader) {
var err error
// Get the Bodies current offset so retries will start from the same
// initial position.
r.BodyStart, err = reader.Seek(0, sdkio.SeekCurrent)
if err != nil {
r.Error = awserr.New(ErrCodeSerialization,
"failed to determine start of request body", err)
return
}
}
r.ResetBody()
}
// SetStreamingBody set the reader to be used for the request that will stream
// bytes to the server. Request's Body must not be set to any reader.
func (r *Request) SetStreamingBody(reader io.ReadCloser) {
r.streamingBody = reader
r.SetReaderBody(aws.ReadSeekCloser(reader))
}
// Presign returns the request's signed URL. Error will be returned
// if the signing fails. The expire parameter is only used for presigned Amazon
// S3 API requests. All other AWS services will use a fixed expiration
@ -330,16 +372,15 @@ func getPresignedURL(r *Request, expire time.Duration) (string, http.Header, err
return r.HTTPRequest.URL.String(), r.SignedHeaderVals, nil
}
func debugLogReqError(r *Request, stage string, retrying bool, err error) {
const (
notRetrying = "not retrying"
)
func debugLogReqError(r *Request, stage, retryStr string, err error) {
if !r.Config.LogLevel.Matches(aws.LogDebugWithRequestErrors) {
return
}
retryStr := "not retrying"
if retrying {
retryStr = "will retry"
}
r.Config.Logger.Log(fmt.Sprintf("DEBUG: %s %s/%s failed, %s, error %v",
stage, r.ClientInfo.ServiceName, r.Operation.Name, retryStr, err))
}
@ -358,12 +399,12 @@ func (r *Request) Build() error {
if !r.built {
r.Handlers.Validate.Run(r)
if r.Error != nil {
debugLogReqError(r, "Validate Request", false, r.Error)
debugLogReqError(r, "Validate Request", notRetrying, r.Error)
return r.Error
}
r.Handlers.Build.Run(r)
if r.Error != nil {
debugLogReqError(r, "Build Request", false, r.Error)
debugLogReqError(r, "Build Request", notRetrying, r.Error)
return r.Error
}
r.built = true
@ -379,20 +420,30 @@ func (r *Request) Build() error {
func (r *Request) Sign() error {
r.Build()
if r.Error != nil {
debugLogReqError(r, "Build Request", false, r.Error)
debugLogReqError(r, "Build Request", notRetrying, r.Error)
return r.Error
}
SanitizeHostForHeader(r.HTTPRequest)
r.Handlers.Sign.Run(r)
return r.Error
}
func (r *Request) getNextRequestBody() (io.ReadCloser, error) {
func (r *Request) getNextRequestBody() (body io.ReadCloser, err error) {
if r.streamingBody != nil {
return r.streamingBody, nil
}
if r.safeBody != nil {
r.safeBody.Close()
}
r.safeBody = newOffsetReader(r.Body, r.BodyStart)
r.safeBody, err = newOffsetReader(r.Body, r.BodyStart)
if err != nil {
return nil, awserr.New(ErrCodeSerialization,
"failed to get next request body reader", err)
}
// Go 1.8 tightened and clarified the rules code needs to use when building
// requests with the http package. Go 1.8 removed the automatic detection
@ -409,10 +460,10 @@ func (r *Request) getNextRequestBody() (io.ReadCloser, error) {
// Related golang/go#18257
l, err := aws.SeekerLen(r.Body)
if err != nil {
return nil, awserr.New(ErrCodeSerialization, "failed to compute request body size", err)
return nil, awserr.New(ErrCodeSerialization,
"failed to compute request body size", err)
}
var body io.ReadCloser
if l == 0 {
body = NoBody
} else if l > 0 {
@ -473,15 +524,13 @@ func (r *Request) Send() error {
r.AttemptTime = time.Now()
if err := r.Sign(); err != nil {
debugLogReqError(r, "Sign Request", false, err)
debugLogReqError(r, "Sign Request", notRetrying, err)
return err
}
if err := r.sendRequest(); err == nil {
return nil
} else if !shouldRetryCancel(r.Error) {
return err
} else {
}
r.Handlers.Retry.Run(r)
r.Handlers.AfterRetry.Run(r)
@ -489,13 +538,14 @@ func (r *Request) Send() error {
return r.Error
}
r.prepareRetry()
continue
if err := r.prepareRetry(); err != nil {
r.Error = err
return err
}
}
}
func (r *Request) prepareRetry() {
func (r *Request) prepareRetry() error {
if r.Config.LogLevel.Matches(aws.LogDebugWithRequestRetries) {
r.Config.Logger.Log(fmt.Sprintf("DEBUG: Retrying Request %s/%s, attempt %d",
r.ClientInfo.ServiceName, r.Operation.Name, r.RetryCount))
@ -506,12 +556,19 @@ func (r *Request) prepareRetry() {
// the request's body even though the Client's Do returned.
r.HTTPRequest = copyHTTPRequest(r.HTTPRequest, nil)
r.ResetBody()
if err := r.Error; err != nil {
return awserr.New(ErrCodeSerialization,
"failed to prepare body for retry", err)
}
// Closing response body to ensure that no response body is leaked
// between retry attempts.
if r.HTTPResponse != nil && r.HTTPResponse.Body != nil {
r.HTTPResponse.Body.Close()
}
return nil
}
func (r *Request) sendRequest() (sendErr error) {
@ -520,7 +577,9 @@ func (r *Request) sendRequest() (sendErr error) {
r.Retryable = nil
r.Handlers.Send.Run(r)
if r.Error != nil {
debugLogReqError(r, "Send Request", r.WillRetry(), r.Error)
debugLogReqError(r, "Send Request",
fmtAttemptCount(r.RetryCount, r.MaxRetries()),
r.Error)
return r.Error
}
@ -528,13 +587,17 @@ func (r *Request) sendRequest() (sendErr error) {
r.Handlers.ValidateResponse.Run(r)
if r.Error != nil {
r.Handlers.UnmarshalError.Run(r)
debugLogReqError(r, "Validate Response", r.WillRetry(), r.Error)
debugLogReqError(r, "Validate Response",
fmtAttemptCount(r.RetryCount, r.MaxRetries()),
r.Error)
return r.Error
}
r.Handlers.Unmarshal.Run(r)
if r.Error != nil {
debugLogReqError(r, "Unmarshal Response", r.WillRetry(), r.Error)
debugLogReqError(r, "Unmarshal Response",
fmtAttemptCount(r.RetryCount, r.MaxRetries()),
r.Error)
return r.Error
}
@ -561,48 +624,6 @@ func AddToUserAgent(r *Request, s string) {
r.HTTPRequest.Header.Set("User-Agent", s)
}
type temporary interface {
Temporary() bool
}
func shouldRetryCancel(err error) bool {
switch err := err.(type) {
case awserr.Error:
if err.Code() == CanceledErrorCode {
return false
}
return shouldRetryCancel(err.OrigErr())
case *url.Error:
if strings.Contains(err.Error(), "connection refused") {
// Refused connections should be retried as the service may not yet
// be running on the port. Go TCP dial considers refused
// connections as not temporary.
return true
}
// *url.Error only implements Temporary after golang 1.6 but since
// url.Error only wraps the error:
return shouldRetryCancel(err.Err)
case temporary:
// If the error is temporary, we want to allow continuation of the
// retry process
return err.Temporary()
case nil:
// `awserr.Error.OrigErr()` can be nil, meaning there was an error but
// because we don't know the cause, it is marked as retryable. See
// TestRequest4xxUnretryable for an example.
return true
default:
switch err.Error() {
case "net/http: request canceled",
"net/http: request canceled while waiting for connection":
// known 1.5 error case when an http request is cancelled
return false
}
// here we don't know the error; so we allow a retry.
return true
}
}
// SanitizeHostForHeader removes default port from host and updates request.Host
func SanitizeHostForHeader(r *http.Request) {
host := getHost(r)
@ -618,6 +639,10 @@ func getHost(r *http.Request) string {
return r.Host
}
if r.URL == nil {
return ""
}
return r.URL.Host
}

View File

@ -4,6 +4,8 @@ package request
import (
"net/http"
"github.com/aws/aws-sdk-go/aws/awserr"
)
// NoBody is a http.NoBody reader instructing Go HTTP client to not include
@ -24,7 +26,8 @@ var NoBody = http.NoBody
func (r *Request) ResetBody() {
body, err := r.getNextRequestBody()
if err != nil {
r.Error = err
r.Error = awserr.New(ErrCodeSerialization,
"failed to reset request body", err)
return
}

View File

@ -17,11 +17,13 @@ import (
// does the pagination between API operations, and Paginator defines the
// configuration that will be used per page request.
//
// cont := true
// for p.Next() && cont {
// for p.Next() {
// data := p.Page().(*s3.ListObjectsOutput)
// // process the page's data
// // ...
// // break out of loop to stop fetching additional pages
// }
//
// return p.Err()
//
// See service client API operation Pages methods for examples how the SDK will
@ -146,7 +148,7 @@ func (r *Request) nextPageTokens() []interface{} {
return nil
}
case bool:
if v == false {
if !v {
return nil
}
}

View File

@ -1,32 +1,81 @@
package request
import (
"net"
"net/url"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
)
// Retryer is an interface to control retry logic for a given service.
// The default implementation used by most services is the client.DefaultRetryer
// structure, which contains basic retry logic using exponential backoff.
// Retryer provides the interface drive the SDK's request retry behavior. The
// Retryer implementation is responsible for implementing exponential backoff,
// and determine if a request API error should be retried.
//
// client.DefaultRetryer is the SDK's default implementation of the Retryer. It
// uses the which uses the Request.IsErrorRetryable and Request.IsErrorThrottle
// methods to determine if the request is retried.
type Retryer interface {
// RetryRules return the retry delay that should be used by the SDK before
// making another request attempt for the failed request.
RetryRules(*Request) time.Duration
// ShouldRetry returns if the failed request is retryable.
//
// Implementations may consider request attempt count when determining if a
// request is retryable, but the SDK will use MaxRetries to limit the
// number of attempts a request are made.
ShouldRetry(*Request) bool
// MaxRetries is the number of times a request may be retried before
// failing.
MaxRetries() int
}
// WithRetryer sets a config Retryer value to the given Config returning it
// for chaining.
// WithRetryer sets a Retryer value to the given Config returning the Config
// value for chaining. The value must not be nil.
func WithRetryer(cfg *aws.Config, retryer Retryer) *aws.Config {
if retryer == nil {
if cfg.Logger != nil {
cfg.Logger.Log("ERROR: Request.WithRetryer called with nil retryer. Replacing with retry disabled Retryer.")
}
retryer = noOpRetryer{}
}
cfg.Retryer = retryer
return cfg
}
// noOpRetryer is a internal no op retryer used when a request is created
// without a retryer.
//
// Provides a retryer that performs no retries.
// It should be used when we do not want retries to be performed.
type noOpRetryer struct{}
// MaxRetries returns the number of maximum returns the service will use to make
// an individual API; For NoOpRetryer the MaxRetries will always be zero.
func (d noOpRetryer) MaxRetries() int {
return 0
}
// ShouldRetry will always return false for NoOpRetryer, as it should never retry.
func (d noOpRetryer) ShouldRetry(_ *Request) bool {
return false
}
// RetryRules returns the delay duration before retrying this request again;
// since NoOpRetryer does not retry, RetryRules always returns 0.
func (d noOpRetryer) RetryRules(_ *Request) time.Duration {
return 0
}
// retryableCodes is a collection of service response codes which are retry-able
// without any further action.
var retryableCodes = map[string]struct{}{
"RequestError": {},
ErrCodeRequestError: {},
"RequestTimeout": {},
ErrCodeResponseTimeout: {},
"RequestTimeoutException": {}, // Glacier's flavor of RequestTimeout
@ -34,6 +83,7 @@ var retryableCodes = map[string]struct{}{
var throttleCodes = map[string]struct{}{
"ProvisionedThroughputExceededException": {},
"ThrottledException": {}, // SNS, XRay, ResourceGroupsTagging API
"Throttling": {},
"ThrottlingException": {},
"RequestLimitExceeded": {},
@ -42,6 +92,7 @@ var throttleCodes = map[string]struct{}{
"TooManyRequestsException": {}, // Lambda functions
"PriorRequestNotComplete": {}, // Route53
"TransactionInProgressException": {},
"EC2ThrottledException": {}, // EC2
}
// credsExpiredCodes is a collection of error codes which signify the credentials
@ -76,10 +127,6 @@ var validParentCodes = map[string]struct{}{
ErrCodeRead: {},
}
type temporaryError interface {
Temporary() bool
}
func isNestedErrorRetryable(parentErr awserr.Error) bool {
if parentErr == nil {
return false
@ -98,7 +145,7 @@ func isNestedErrorRetryable(parentErr awserr.Error) bool {
return isCodeRetryable(aerr.Code())
}
if t, ok := err.(temporaryError); ok {
if t, ok := err.(temporary); ok {
return t.Temporary() || isErrConnectionReset(err)
}
@ -108,33 +155,91 @@ func isNestedErrorRetryable(parentErr awserr.Error) bool {
// IsErrorRetryable returns whether the error is retryable, based on its Code.
// Returns false if error is nil.
func IsErrorRetryable(err error) bool {
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
return isCodeRetryable(aerr.Code()) || isNestedErrorRetryable(aerr)
}
}
if err == nil {
return false
}
return shouldRetryError(err)
}
type temporary interface {
Temporary() bool
}
func shouldRetryError(origErr error) bool {
switch err := origErr.(type) {
case awserr.Error:
if err.Code() == CanceledErrorCode {
return false
}
if isNestedErrorRetryable(err) {
return true
}
origErr := err.OrigErr()
var shouldRetry bool
if origErr != nil {
shouldRetry = shouldRetryError(origErr)
if err.Code() == ErrCodeRequestError && !shouldRetry {
return false
}
}
if isCodeRetryable(err.Code()) {
return true
}
return shouldRetry
case *url.Error:
if strings.Contains(err.Error(), "connection refused") {
// Refused connections should be retried as the service may not yet
// be running on the port. Go TCP dial considers refused
// connections as not temporary.
return true
}
// *url.Error only implements Temporary after golang 1.6 but since
// url.Error only wraps the error:
return shouldRetryError(err.Err)
case temporary:
if netErr, ok := err.(*net.OpError); ok && netErr.Op == "dial" {
return true
}
// If the error is temporary, we want to allow continuation of the
// retry process
return err.Temporary() || isErrConnectionReset(origErr)
case nil:
// `awserr.Error.OrigErr()` can be nil, meaning there was an error but
// because we don't know the cause, it is marked as retryable. See
// TestRequest4xxUnretryable for an example.
return true
default:
switch err.Error() {
case "net/http: request canceled",
"net/http: request canceled while waiting for connection":
// known 1.5 error case when an http request is cancelled
return false
}
// here we don't know the error; so we allow a retry.
return true
}
}
// IsErrorThrottle returns whether the error is to be throttled based on its code.
// Returns false if error is nil.
func IsErrorThrottle(err error) bool {
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
if aerr, ok := err.(awserr.Error); ok && aerr != nil {
return isCodeThrottle(aerr.Code())
}
}
return false
}
// IsErrorExpiredCreds returns whether the error code is a credential expiry error.
// Returns false if error is nil.
// IsErrorExpiredCreds returns whether the error code is a credential expiry
// error. Returns false if error is nil.
func IsErrorExpiredCreds(err error) bool {
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
if aerr, ok := err.(awserr.Error); ok && aerr != nil {
return isCodeExpiredCreds(aerr.Code())
}
}
return false
}
@ -143,17 +248,58 @@ func IsErrorExpiredCreds(err error) bool {
//
// Alias for the utility function IsErrorRetryable
func (r *Request) IsErrorRetryable() bool {
if isErrCode(r.Error, r.RetryErrorCodes) {
return true
}
// HTTP response status code 501 should not be retried.
// 501 represents Not Implemented which means the request method is not
// supported by the server and cannot be handled.
if r.HTTPResponse != nil {
// HTTP response status code 500 represents internal server error and
// should be retried without any throttle.
if r.HTTPResponse.StatusCode == 500 {
return true
}
}
return IsErrorRetryable(r.Error)
}
// IsErrorThrottle returns whether the error is to be throttled based on its code.
// Returns false if the request has no Error set
// IsErrorThrottle returns whether the error is to be throttled based on its
// code. Returns false if the request has no Error set.
//
// Alias for the utility function IsErrorThrottle
func (r *Request) IsErrorThrottle() bool {
if isErrCode(r.Error, r.ThrottleErrorCodes) {
return true
}
if r.HTTPResponse != nil {
switch r.HTTPResponse.StatusCode {
case
429, // error caused due to too many requests
502, // Bad Gateway error should be throttled
503, // caused when service is unavailable
504: // error occurred due to gateway timeout
return true
}
}
return IsErrorThrottle(r.Error)
}
func isErrCode(err error, codes []string) bool {
if aerr, ok := err.(awserr.Error); ok && aerr != nil {
for _, code := range codes {
if code == aerr.Code() {
return true
}
}
}
return false
}
// IsErrorExpired returns whether the error code is a credential expiry error.
// Returns false if the request has no Error set.
//

View File

@ -0,0 +1,267 @@
package session
import (
"fmt"
"os"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/credentials/processcreds"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/defaults"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/internal/shareddefaults"
)
func resolveCredentials(cfg *aws.Config,
envCfg envConfig, sharedCfg sharedConfig,
handlers request.Handlers,
sessOpts Options,
) (*credentials.Credentials, error) {
switch {
case len(sessOpts.Profile) != 0:
// User explicitly provided an Profile in the session's configuration
// so load that profile from shared config first.
// Github(aws/aws-sdk-go#2727)
return resolveCredsFromProfile(cfg, envCfg, sharedCfg, handlers, sessOpts)
case envCfg.Creds.HasKeys():
// Environment credentials
return credentials.NewStaticCredentialsFromCreds(envCfg.Creds), nil
case len(envCfg.WebIdentityTokenFilePath) != 0:
// Web identity token from environment, RoleARN required to also be
// set.
return assumeWebIdentity(cfg, handlers,
envCfg.WebIdentityTokenFilePath,
envCfg.RoleARN,
envCfg.RoleSessionName,
)
default:
// Fallback to the "default" credential resolution chain.
return resolveCredsFromProfile(cfg, envCfg, sharedCfg, handlers, sessOpts)
}
}
// WebIdentityEmptyRoleARNErr will occur if 'AWS_WEB_IDENTITY_TOKEN_FILE' was set but
// 'AWS_ROLE_ARN' was not set.
var WebIdentityEmptyRoleARNErr = awserr.New(stscreds.ErrCodeWebIdentity, "role ARN is not set", nil)
// WebIdentityEmptyTokenFilePathErr will occur if 'AWS_ROLE_ARN' was set but
// 'AWS_WEB_IDENTITY_TOKEN_FILE' was not set.
var WebIdentityEmptyTokenFilePathErr = awserr.New(stscreds.ErrCodeWebIdentity, "token file path is not set", nil)
func assumeWebIdentity(cfg *aws.Config, handlers request.Handlers,
filepath string,
roleARN, sessionName string,
) (*credentials.Credentials, error) {
if len(filepath) == 0 {
return nil, WebIdentityEmptyTokenFilePathErr
}
if len(roleARN) == 0 {
return nil, WebIdentityEmptyRoleARNErr
}
creds := stscreds.NewWebIdentityCredentials(
&Session{
Config: cfg,
Handlers: handlers.Copy(),
},
roleARN,
sessionName,
filepath,
)
return creds, nil
}
func resolveCredsFromProfile(cfg *aws.Config,
envCfg envConfig, sharedCfg sharedConfig,
handlers request.Handlers,
sessOpts Options,
) (creds *credentials.Credentials, err error) {
switch {
case sharedCfg.SourceProfile != nil:
// Assume IAM role with credentials source from a different profile.
creds, err = resolveCredsFromProfile(cfg, envCfg,
*sharedCfg.SourceProfile, handlers, sessOpts,
)
case sharedCfg.Creds.HasKeys():
// Static Credentials from Shared Config/Credentials file.
creds = credentials.NewStaticCredentialsFromCreds(
sharedCfg.Creds,
)
case len(sharedCfg.CredentialProcess) != 0:
// Get credentials from CredentialProcess
creds = processcreds.NewCredentials(sharedCfg.CredentialProcess)
case len(sharedCfg.CredentialSource) != 0:
creds, err = resolveCredsFromSource(cfg, envCfg,
sharedCfg, handlers, sessOpts,
)
case len(sharedCfg.WebIdentityTokenFile) != 0:
// Credentials from Assume Web Identity token require an IAM Role, and
// that roll will be assumed. May be wrapped with another assume role
// via SourceProfile.
return assumeWebIdentity(cfg, handlers,
sharedCfg.WebIdentityTokenFile,
sharedCfg.RoleARN,
sharedCfg.RoleSessionName,
)
default:
// Fallback to default credentials provider, include mock errors for
// the credential chain so user can identify why credentials failed to
// be retrieved.
creds = credentials.NewCredentials(&credentials.ChainProvider{
VerboseErrors: aws.BoolValue(cfg.CredentialsChainVerboseErrors),
Providers: []credentials.Provider{
&credProviderError{
Err: awserr.New("EnvAccessKeyNotFound",
"failed to find credentials in the environment.", nil),
},
&credProviderError{
Err: awserr.New("SharedCredsLoad",
fmt.Sprintf("failed to load profile, %s.", envCfg.Profile), nil),
},
defaults.RemoteCredProvider(*cfg, handlers),
},
})
}
if err != nil {
return nil, err
}
if len(sharedCfg.RoleARN) > 0 {
cfgCp := *cfg
cfgCp.Credentials = creds
return credsFromAssumeRole(cfgCp, handlers, sharedCfg, sessOpts)
}
return creds, nil
}
// valid credential source values
const (
credSourceEc2Metadata = "Ec2InstanceMetadata"
credSourceEnvironment = "Environment"
credSourceECSContainer = "EcsContainer"
)
func resolveCredsFromSource(cfg *aws.Config,
envCfg envConfig, sharedCfg sharedConfig,
handlers request.Handlers,
sessOpts Options,
) (creds *credentials.Credentials, err error) {
switch sharedCfg.CredentialSource {
case credSourceEc2Metadata:
p := defaults.RemoteCredProvider(*cfg, handlers)
creds = credentials.NewCredentials(p)
case credSourceEnvironment:
creds = credentials.NewStaticCredentialsFromCreds(envCfg.Creds)
case credSourceECSContainer:
if len(os.Getenv(shareddefaults.ECSCredsProviderEnvVar)) == 0 {
return nil, ErrSharedConfigECSContainerEnvVarEmpty
}
p := defaults.RemoteCredProvider(*cfg, handlers)
creds = credentials.NewCredentials(p)
default:
return nil, ErrSharedConfigInvalidCredSource
}
return creds, nil
}
func credsFromAssumeRole(cfg aws.Config,
handlers request.Handlers,
sharedCfg sharedConfig,
sessOpts Options,
) (*credentials.Credentials, error) {
if len(sharedCfg.MFASerial) != 0 && sessOpts.AssumeRoleTokenProvider == nil {
// AssumeRole Token provider is required if doing Assume Role
// with MFA.
return nil, AssumeRoleTokenProviderNotSetError{}
}
return stscreds.NewCredentials(
&Session{
Config: &cfg,
Handlers: handlers.Copy(),
},
sharedCfg.RoleARN,
func(opt *stscreds.AssumeRoleProvider) {
opt.RoleSessionName = sharedCfg.RoleSessionName
if sessOpts.AssumeRoleDuration == 0 &&
sharedCfg.AssumeRoleDuration != nil &&
*sharedCfg.AssumeRoleDuration/time.Minute > 15 {
opt.Duration = *sharedCfg.AssumeRoleDuration
} else if sessOpts.AssumeRoleDuration != 0 {
opt.Duration = sessOpts.AssumeRoleDuration
}
// Assume role with external ID
if len(sharedCfg.ExternalID) > 0 {
opt.ExternalID = aws.String(sharedCfg.ExternalID)
}
// Assume role with MFA
if len(sharedCfg.MFASerial) > 0 {
opt.SerialNumber = aws.String(sharedCfg.MFASerial)
opt.TokenProvider = sessOpts.AssumeRoleTokenProvider
}
},
), nil
}
// AssumeRoleTokenProviderNotSetError is an error returned when creating a
// session when the MFAToken option is not set when shared config is configured
// load assume a role with an MFA token.
type AssumeRoleTokenProviderNotSetError struct{}
// Code is the short id of the error.
func (e AssumeRoleTokenProviderNotSetError) Code() string {
return "AssumeRoleTokenProviderNotSetError"
}
// Message is the description of the error
func (e AssumeRoleTokenProviderNotSetError) Message() string {
return fmt.Sprintf("assume role with MFA enabled, but AssumeRoleTokenProvider session option not set.")
}
// OrigErr is the underlying error that caused the failure.
func (e AssumeRoleTokenProviderNotSetError) OrigErr() error {
return nil
}
// Error satisfies the error interface.
func (e AssumeRoleTokenProviderNotSetError) Error() string {
return awserr.SprintError(e.Code(), e.Message(), "", nil)
}
type credProviderError struct {
Err error
}
func (c credProviderError) Retrieve() (credentials.Value, error) {
return credentials.Value{}, c.Err
}
func (c credProviderError) IsExpired() bool {
return true
}

View File

@ -1,97 +1,93 @@
/*
Package session provides configuration for the SDK's service clients.
Sessions can be shared across all service clients that share the same base
configuration. The Session is built from the SDK's default configuration and
request handlers.
Sessions should be cached when possible, because creating a new Session will
load all configuration values from the environment, and config files each time
the Session is created. Sharing the Session value across all of your service
clients will ensure the configuration is loaded the fewest number of times possible.
Concurrency
Package session provides configuration for the SDK's service clients. Sessions
can be shared across service clients that share the same base configuration.
Sessions are safe to use concurrently as long as the Session is not being
modified. The SDK will not modify the Session once the Session has been created.
Creating service clients concurrently from a shared Session is safe.
modified. Sessions should be cached when possible, because creating a new
Session will load all configuration values from the environment, and config
files each time the Session is created. Sharing the Session value across all of
your service clients will ensure the configuration is loaded the fewest number
of times possible.
Sessions from Shared Config
Sessions can be created using the method above that will only load the
additional config if the AWS_SDK_LOAD_CONFIG environment variable is set.
Alternatively you can explicitly create a Session with shared config enabled.
To do this you can use NewSessionWithOptions to configure how the Session will
be created. Using the NewSessionWithOptions with SharedConfigState set to
SharedConfigEnable will create the session as if the AWS_SDK_LOAD_CONFIG
environment variable was set.
Creating Sessions
When creating Sessions optional aws.Config values can be passed in that will
override the default, or loaded config values the Session is being created
with. This allows you to provide additional, or case based, configuration
as needed.
Sessions options from Shared Config
By default NewSession will only load credentials from the shared credentials
file (~/.aws/credentials). If the AWS_SDK_LOAD_CONFIG environment variable is
set to a truthy value the Session will be created from the configuration
values from the shared config (~/.aws/config) and shared credentials
(~/.aws/credentials) files. See the section Sessions from Shared Config for
more information.
(~/.aws/credentials) files. Using the NewSessionWithOptions with
SharedConfigState set to SharedConfigEnable will create the session as if the
AWS_SDK_LOAD_CONFIG environment variable was set.
Create a Session with the default config and request handlers. With credentials
region, and profile loaded from the environment and shared config automatically.
Requires the AWS_PROFILE to be set, or "default" is used.
Credential and config loading order
The Session will attempt to load configuration and credentials from the
environment, configuration files, and other credential sources. The order
configuration is loaded in is:
* Environment Variables
* Shared Credentials file
* Shared Configuration file (if SharedConfig is enabled)
* EC2 Instance Metadata (credentials only)
The Environment variables for credentials will have precedence over shared
config even if SharedConfig is enabled. To override this behavior, and use
shared config credentials instead specify the session.Options.Profile, (e.g.
when using credential_source=Environment to assume a role).
sess, err := session.NewSessionWithOptions(session.Options{
Profile: "myProfile",
})
Creating Sessions
Creating a Session without additional options will load credentials region, and
profile loaded from the environment and shared config automatically. See,
"Environment Variables" section for information on environment variables used
by Session.
// Create Session
sess := session.Must(session.NewSession())
sess, err := session.NewSession()
When creating Sessions optional aws.Config values can be passed in that will
override the default, or loaded, config values the Session is being created
with. This allows you to provide additional, or case based, configuration
as needed.
// Create a Session with a custom region
sess := session.Must(session.NewSession(&aws.Config{
Region: aws.String("us-east-1"),
}))
sess, err := session.NewSession(&aws.Config{
Region: aws.String("us-west-2"),
})
// Create a S3 client instance from a session
sess := session.Must(session.NewSession())
svc := s3.New(sess)
Create Session With Option Overrides
In addition to NewSession, Sessions can be created using NewSessionWithOptions.
This func allows you to control and override how the Session will be created
through code instead of being driven by environment variables only.
Use NewSessionWithOptions when you want to provide the config profile, or
override the shared config state (AWS_SDK_LOAD_CONFIG).
Use NewSessionWithOptions to provide additional configuration driving how the
Session's configuration will be loaded. Such as, specifying shared config
profile, or override the shared config state, (AWS_SDK_LOAD_CONFIG).
// Equivalent to session.NewSession()
sess := session.Must(session.NewSessionWithOptions(session.Options{
sess, err := session.NewSessionWithOptions(session.Options{
// Options
}))
})
sess, err := session.NewSessionWithOptions(session.Options{
// Specify profile to load for the session's config
sess := session.Must(session.NewSessionWithOptions(session.Options{
Profile: "profile_name",
}))
// Specify profile for config and region for requests
sess := session.Must(session.NewSessionWithOptions(session.Options{
Config: aws.Config{Region: aws.String("us-east-1")},
Profile: "profile_name",
}))
// Provide SDK Config options, such as Region.
Config: aws.Config{
Region: aws.String("us-west-2"),
},
// Force enable Shared Config support
sess := session.Must(session.NewSessionWithOptions(session.Options{
SharedConfigState: session.SharedConfigEnable,
}))
})
Adding Handlers
You can add handlers to a session for processing HTTP requests. All service
clients that use the session inherit the handlers. For example, the following
handler logs every request and its payload made by a service client:
You can add handlers to a session to decorate API operation, (e.g. adding HTTP
headers). All clients that use the Session receive a copy of the Session's
handlers. For example, the following request handler added to the Session logs
every requests made.
// Create a session, and add additional handlers for all service
// clients created with the Session to inherit. Adds logging handler.
@ -99,22 +95,15 @@ handler logs every request and its payload made by a service client:
sess.Handlers.Send.PushFront(func(r *request.Request) {
// Log every request made and its payload
logger.Printf("Request: %s/%s, Payload: %s",
logger.Printf("Request: %s/%s, Params: %s",
r.ClientInfo.ServiceName, r.Operation, r.Params)
})
Deprecated "New" function
The New session function has been deprecated because it does not provide good
way to return errors that occur when loading the configuration files and values.
Because of this, NewSession was created so errors can be retrieved when
creating a session fails.
Shared Config Fields
By default the SDK will only load the shared credentials file's (~/.aws/credentials)
credentials values, and all other config is provided by the environment variables,
SDK defaults, and user provided aws.Config values.
By default the SDK will only load the shared credentials file's
(~/.aws/credentials) credentials values, and all other config is provided by
the environment variables, SDK defaults, and user provided aws.Config values.
If the AWS_SDK_LOAD_CONFIG environment variable is set, or SharedConfigEnable
option is used to create the Session the full shared config values will be
@ -125,24 +114,31 @@ files have the same format.
If both config files are present the configuration from both files will be
read. The Session will be created from configuration values from the shared
credentials file (~/.aws/credentials) over those in the shared config file (~/.aws/config).
credentials file (~/.aws/credentials) over those in the shared config file
(~/.aws/config).
Credentials are the values the SDK should use for authenticating requests with
AWS Services. They are from a configuration file will need to include both
aws_access_key_id and aws_secret_access_key must be provided together in the
same file to be considered valid. The values will be ignored if not a complete
group. aws_session_token is an optional field that can be provided if both of
the other two fields are also provided.
Credentials are the values the SDK uses to authenticating requests with AWS
Services. When specified in a file, both aws_access_key_id and
aws_secret_access_key must be provided together in the same file to be
considered valid. They will be ignored if both are not present.
aws_session_token is an optional field that can be provided in addition to the
other two fields.
aws_access_key_id = AKID
aws_secret_access_key = SECRET
aws_session_token = TOKEN
Assume Role values allow you to configure the SDK to assume an IAM role using
a set of credentials provided in a config file via the source_profile field.
Both "role_arn" and "source_profile" are required. The SDK supports assuming
a role with MFA token if the session option AssumeRoleTokenProvider
is set.
; region only supported if SharedConfigEnabled.
region = us-east-1
Assume Role configuration
The role_arn field allows you to configure the SDK to assume an IAM role using
a set of credentials from another source. Such as when paired with static
credentials, "profile_source", "credential_process", or "credential_source"
fields. If "role_arn" is provided, a source of credentials must also be
specified, such as "source_profile", "credential_source", or
"credential_process".
role_arn = arn:aws:iam::<account_number>:role/<role_name>
source_profile = profile_with_creds
@ -150,40 +146,16 @@ is set.
mfa_serial = <serial or mfa arn>
role_session_name = session_name
Region is the region the SDK should use for looking up AWS service endpoints
and signing requests.
region = us-east-1
Assume Role with MFA token
To create a session with support for assuming an IAM role with MFA set the
session option AssumeRoleTokenProvider to a function that will prompt for the
MFA token code when the SDK assumes the role and refreshes the role's credentials.
This allows you to configure the SDK via the shared config to assumea role
with MFA tokens.
In order for the SDK to assume a role with MFA the SharedConfigState
session option must be set to SharedConfigEnable, or AWS_SDK_LOAD_CONFIG
environment variable set.
The shared configuration instructs the SDK to assume an IAM role with MFA
when the mfa_serial configuration field is set in the shared config
(~/.aws/config) or shared credentials (~/.aws/credentials) file.
If mfa_serial is set in the configuration, the SDK will assume the role, and
the AssumeRoleTokenProvider session option is not set an an error will
be returned when creating the session.
The SDK supports assuming a role with MFA token. If "mfa_serial" is set, you
must also set the Session Option.AssumeRoleTokenProvider. The Session will fail
to load if the AssumeRoleTokenProvider is not specified.
sess := session.Must(session.NewSessionWithOptions(session.Options{
AssumeRoleTokenProvider: stscreds.StdinTokenProvider,
}))
// Create service client value configured for credentials
// from assumed role.
svc := s3.New(sess)
To setup assume role outside of a session see the stscreds.AssumeRoleProvider
To setup Assume Role outside of a session see the stscreds.AssumeRoleProvider
documentation.
Environment Variables

View File

@ -1,12 +1,15 @@
package session
import (
"fmt"
"os"
"strconv"
"strings"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/defaults"
"github.com/aws/aws-sdk-go/aws/endpoints"
)
// EnvProviderName provides a name of the provider when config is loaded from environment.
@ -99,21 +102,61 @@ type envConfig struct {
CustomCABundle string
csmEnabled string
CSMEnabled bool
CSMEnabled *bool
CSMPort string
CSMHost string
CSMClientID string
enableEndpointDiscovery string
// Enables endpoint discovery via environment variables.
//
// AWS_ENABLE_ENDPOINT_DISCOVERY=true
EnableEndpointDiscovery *bool
enableEndpointDiscovery string
// Specifies the WebIdentity token the SDK should use to assume a role
// with.
//
// AWS_WEB_IDENTITY_TOKEN_FILE=file_path
WebIdentityTokenFilePath string
// Specifies the IAM role arn to use when assuming an role.
//
// AWS_ROLE_ARN=role_arn
RoleARN string
// Specifies the IAM role session name to use when assuming a role.
//
// AWS_ROLE_SESSION_NAME=session_name
RoleSessionName string
// Specifies the STS Regional Endpoint flag for the SDK to resolve the endpoint
// for a service.
//
// AWS_STS_REGIONAL_ENDPOINTS=regional
// This can take value as `regional` or `legacy`
STSRegionalEndpoint endpoints.STSRegionalEndpoint
// Specifies the S3 Regional Endpoint flag for the SDK to resolve the
// endpoint for a service.
//
// AWS_S3_US_EAST_1_REGIONAL_ENDPOINT=regional
// This can take value as `regional` or `legacy`
S3UsEast1RegionalEndpoint endpoints.S3UsEast1RegionalEndpoint
// Specifies if the S3 service should allow ARNs to direct the region
// the client's requests are sent to.
//
// AWS_S3_USE_ARN_REGION=true
S3UseARNRegion bool
}
var (
csmEnabledEnvKey = []string{
"AWS_CSM_ENABLED",
}
csmHostEnvKey = []string{
"AWS_CSM_HOST",
}
csmPortEnvKey = []string{
"AWS_CSM_PORT",
}
@ -150,6 +193,24 @@ var (
sharedConfigFileEnvKey = []string{
"AWS_CONFIG_FILE",
}
webIdentityTokenFilePathEnvKey = []string{
"AWS_WEB_IDENTITY_TOKEN_FILE",
}
roleARNEnvKey = []string{
"AWS_ROLE_ARN",
}
roleSessionNameEnvKey = []string{
"AWS_ROLE_SESSION_NAME",
}
stsRegionalEndpointKey = []string{
"AWS_STS_REGIONAL_ENDPOINTS",
}
s3UsEast1RegionalEndpoint = []string{
"AWS_S3_US_EAST_1_REGIONAL_ENDPOINT",
}
s3UseARNRegionEnvKey = []string{
"AWS_S3_USE_ARN_REGION",
}
)
// loadEnvConfig retrieves the SDK's environment configuration.
@ -158,7 +219,7 @@ var (
// If the environment variable `AWS_SDK_LOAD_CONFIG` is set to a truthy value
// the shared SDK config will be loaded in addition to the SDK's specific
// configuration values.
func loadEnvConfig() envConfig {
func loadEnvConfig() (envConfig, error) {
enableSharedConfig, _ := strconv.ParseBool(os.Getenv("AWS_SDK_LOAD_CONFIG"))
return envConfigLoad(enableSharedConfig)
}
@ -169,30 +230,42 @@ func loadEnvConfig() envConfig {
// Loads the shared configuration in addition to the SDK's specific configuration.
// This will load the same values as `loadEnvConfig` if the `AWS_SDK_LOAD_CONFIG`
// environment variable is set.
func loadSharedEnvConfig() envConfig {
func loadSharedEnvConfig() (envConfig, error) {
return envConfigLoad(true)
}
func envConfigLoad(enableSharedConfig bool) envConfig {
func envConfigLoad(enableSharedConfig bool) (envConfig, error) {
cfg := envConfig{}
cfg.EnableSharedConfig = enableSharedConfig
setFromEnvVal(&cfg.Creds.AccessKeyID, credAccessEnvKey)
setFromEnvVal(&cfg.Creds.SecretAccessKey, credSecretEnvKey)
setFromEnvVal(&cfg.Creds.SessionToken, credSessionEnvKey)
// Static environment credentials
var creds credentials.Value
setFromEnvVal(&creds.AccessKeyID, credAccessEnvKey)
setFromEnvVal(&creds.SecretAccessKey, credSecretEnvKey)
setFromEnvVal(&creds.SessionToken, credSessionEnvKey)
if creds.HasKeys() {
// Require logical grouping of credentials
creds.ProviderName = EnvProviderName
cfg.Creds = creds
}
// Role Metadata
setFromEnvVal(&cfg.RoleARN, roleARNEnvKey)
setFromEnvVal(&cfg.RoleSessionName, roleSessionNameEnvKey)
// Web identity environment variables
setFromEnvVal(&cfg.WebIdentityTokenFilePath, webIdentityTokenFilePathEnvKey)
// CSM environment variables
setFromEnvVal(&cfg.csmEnabled, csmEnabledEnvKey)
setFromEnvVal(&cfg.CSMHost, csmHostEnvKey)
setFromEnvVal(&cfg.CSMPort, csmPortEnvKey)
setFromEnvVal(&cfg.CSMClientID, csmClientIDEnvKey)
cfg.CSMEnabled = len(cfg.csmEnabled) > 0
// Require logical grouping of credentials
if len(cfg.Creds.AccessKeyID) == 0 || len(cfg.Creds.SecretAccessKey) == 0 {
cfg.Creds = credentials.Value{}
} else {
cfg.Creds.ProviderName = EnvProviderName
if len(cfg.csmEnabled) != 0 {
v, _ := strconv.ParseBool(cfg.csmEnabled)
cfg.CSMEnabled = &v
}
regionKeys := regionEnvKeys
@ -223,12 +296,48 @@ func envConfigLoad(enableSharedConfig bool) envConfig {
cfg.CustomCABundle = os.Getenv("AWS_CA_BUNDLE")
return cfg
var err error
// STS Regional Endpoint variable
for _, k := range stsRegionalEndpointKey {
if v := os.Getenv(k); len(v) != 0 {
cfg.STSRegionalEndpoint, err = endpoints.GetSTSRegionalEndpoint(v)
if err != nil {
return cfg, fmt.Errorf("failed to load, %v from env config, %v", k, err)
}
}
}
// S3 Regional Endpoint variable
for _, k := range s3UsEast1RegionalEndpoint {
if v := os.Getenv(k); len(v) != 0 {
cfg.S3UsEast1RegionalEndpoint, err = endpoints.GetS3UsEast1RegionalEndpoint(v)
if err != nil {
return cfg, fmt.Errorf("failed to load, %v from env config, %v", k, err)
}
}
}
var s3UseARNRegion string
setFromEnvVal(&s3UseARNRegion, s3UseARNRegionEnvKey)
if len(s3UseARNRegion) != 0 {
switch {
case strings.EqualFold(s3UseARNRegion, "false"):
cfg.S3UseARNRegion = false
case strings.EqualFold(s3UseARNRegion, "true"):
cfg.S3UseARNRegion = true
default:
return envConfig{}, fmt.Errorf(
"invalid value for environment variable, %s=%s, need true or false",
s3UseARNRegionEnvKey[0], s3UseARNRegion)
}
}
return cfg, nil
}
func setFromEnvVal(dst *string, keys []string) {
for _, k := range keys {
if v := os.Getenv(k); len(v) > 0 {
if v := os.Getenv(k); len(v) != 0 {
*dst = v
break
}

View File

@ -8,19 +8,17 @@ import (
"io/ioutil"
"net/http"
"os"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/corehandlers"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/credentials/processcreds"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/csm"
"github.com/aws/aws-sdk-go/aws/defaults"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/internal/shareddefaults"
)
const (
@ -75,7 +73,7 @@ type Session struct {
// func is called instead of waiting to receive an error until a request is made.
func New(cfgs ...*aws.Config) *Session {
// load initial config from environment
envCfg := loadEnvConfig()
envCfg, envErr := loadEnvConfig()
if envCfg.EnableSharedConfig {
var cfg aws.Config
@ -95,19 +93,28 @@ func New(cfgs ...*aws.Config) *Session {
// Session creation failed, need to report the error and prevent
// any requests from succeeding.
s = &Session{Config: defaults.Config()}
s.Config.MergeIn(cfgs...)
s.Config.Logger.Log("ERROR:", msg, "Error:", err)
s.Handlers.Validate.PushBack(func(r *request.Request) {
r.Error = err
})
s.logDeprecatedNewSessionError(msg, err, cfgs)
}
return s
}
s := deprecatedNewSession(cfgs...)
if envCfg.CSMEnabled {
enableCSM(&s.Handlers, envCfg.CSMClientID, envCfg.CSMPort, s.Config.Logger)
if envErr != nil {
msg := "failed to load env config"
s.logDeprecatedNewSessionError(msg, envErr, cfgs)
}
if csmCfg, err := loadCSMConfig(envCfg, []string{}); err != nil {
if l := s.Config.Logger; l != nil {
l.Log(fmt.Sprintf("ERROR: failed to load CSM configuration, %v", err))
}
} else if csmCfg.Enabled {
err := enableCSM(&s.Handlers, csmCfg, s.Config.Logger)
if err != nil {
msg := "failed to enable CSM"
s.logDeprecatedNewSessionError(msg, err, cfgs)
}
}
return s
@ -126,7 +133,7 @@ func New(cfgs ...*aws.Config) *Session {
// to be built with retrieving credentials with AssumeRole set in the config.
//
// See the NewSessionWithOptions func for information on how to override or
// control through code how the Session will be created. Such as specifying the
// control through code how the Session will be created, such as specifying the
// config profile, and controlling if shared config is enabled or not.
func NewSession(cfgs ...*aws.Config) (*Session, error) {
opts := Options{}
@ -210,6 +217,12 @@ type Options struct {
// the config enables assume role wit MFA via the mfa_serial field.
AssumeRoleTokenProvider func() (string, error)
// When the SDK's shared config is configured to assume a role this option
// may be provided to set the expiry duration of the STS credentials.
// Defaults to 15 minutes if not set as documented in the
// stscreds.AssumeRoleProvider.
AssumeRoleDuration time.Duration
// Reader for a custom Credentials Authority (CA) bundle in PEM format that
// the SDK will use instead of the default system's root CA bundle. Use this
// only if you want to replace the CA bundle the SDK uses for TLS requests.
@ -224,6 +237,12 @@ type Options struct {
// to also enable this feature. CustomCABundle session option field has priority
// over the AWS_CA_BUNDLE environment variable, and will be used if both are set.
CustomCABundle io.Reader
// The handlers that the session and all API clients will be created with.
// This must be a complete set of handlers. Use the defaults.Handlers()
// function to initialize this value before changing the handlers to be
// used by the SDK.
Handlers request.Handlers
}
// NewSessionWithOptions returns a new Session created from SDK defaults, config files,
@ -257,13 +276,20 @@ type Options struct {
// }))
func NewSessionWithOptions(opts Options) (*Session, error) {
var envCfg envConfig
var err error
if opts.SharedConfigState == SharedConfigEnable {
envCfg = loadSharedEnvConfig()
envCfg, err = loadSharedEnvConfig()
if err != nil {
return nil, fmt.Errorf("failed to load shared config, %v", err)
}
} else {
envCfg = loadEnvConfig()
envCfg, err = loadEnvConfig()
if err != nil {
return nil, fmt.Errorf("failed to load environment config, %v", err)
}
}
if len(opts.Profile) > 0 {
if len(opts.Profile) != 0 {
envCfg.Profile = opts.Profile
}
@ -329,27 +355,33 @@ func deprecatedNewSession(cfgs ...*aws.Config) *Session {
return s
}
func enableCSM(handlers *request.Handlers, clientID string, port string, logger aws.Logger) {
func enableCSM(handlers *request.Handlers, cfg csmConfig, logger aws.Logger) error {
if logger != nil {
logger.Log("Enabling CSM")
if len(port) == 0 {
port = csm.DefaultPort
}
r, err := csm.Start(clientID, "127.0.0.1:"+port)
r, err := csm.Start(cfg.ClientID, csm.AddressWithDefaults(cfg.Host, cfg.Port))
if err != nil {
return
return err
}
r.InjectHandlers(handlers)
return nil
}
func newSession(opts Options, envCfg envConfig, cfgs ...*aws.Config) (*Session, error) {
cfg := defaults.Config()
handlers := defaults.Handlers()
handlers := opts.Handlers
if handlers.IsEmpty() {
handlers = defaults.Handlers()
}
// Get a merged version of the user provided config to determine if
// credentials were.
userCfg := &aws.Config{}
userCfg.MergeIn(cfgs...)
cfg.MergeIn(userCfg)
// Ordered config files will be loaded in with later files overwriting
// previous config file values.
@ -366,10 +398,18 @@ func newSession(opts Options, envCfg envConfig, cfgs ...*aws.Config) (*Session,
}
// Load additional config from file(s)
sharedCfg, err := loadSharedConfig(envCfg.Profile, cfgFiles)
sharedCfg, err := loadSharedConfig(envCfg.Profile, cfgFiles, envCfg.EnableSharedConfig)
if err != nil {
if len(envCfg.Profile) == 0 && !envCfg.EnableSharedConfig && (envCfg.Creds.HasKeys() || userCfg.Credentials != nil) {
// Special case where the user has not explicitly specified an AWS_PROFILE,
// or session.Options.profile, shared config is not enabled, and the
// environment has credentials, allow the shared config file to fail to
// load since the user has already provided credentials, and nothing else
// is required to be read file. Github(aws/aws-sdk-go#2455)
} else if _, ok := err.(SharedConfigProfileNotExistsError); !ok {
return nil, err
}
}
if err := mergeConfigSrcs(cfg, userCfg, envCfg, sharedCfg, handlers, opts); err != nil {
return nil, err
@ -381,8 +421,16 @@ func newSession(opts Options, envCfg envConfig, cfgs ...*aws.Config) (*Session,
}
initHandlers(s)
if envCfg.CSMEnabled {
enableCSM(&s.Handlers, envCfg.CSMClientID, envCfg.CSMPort, s.Config.Logger)
if csmCfg, err := loadCSMConfig(envCfg, cfgFiles); err != nil {
if l := s.Config.Logger; l != nil {
l.Log(fmt.Sprintf("ERROR: failed to load CSM configuration, %v", err))
}
} else if csmCfg.Enabled {
err = enableCSM(&s.Handlers, csmCfg, s.Config.Logger)
if err != nil {
return nil, err
}
}
// Setup HTTP client with custom cert bundle if enabled
@ -395,6 +443,46 @@ func newSession(opts Options, envCfg envConfig, cfgs ...*aws.Config) (*Session,
return s, nil
}
type csmConfig struct {
Enabled bool
Host string
Port string
ClientID string
}
var csmProfileName = "aws_csm"
func loadCSMConfig(envCfg envConfig, cfgFiles []string) (csmConfig, error) {
if envCfg.CSMEnabled != nil {
if *envCfg.CSMEnabled {
return csmConfig{
Enabled: true,
ClientID: envCfg.CSMClientID,
Host: envCfg.CSMHost,
Port: envCfg.CSMPort,
}, nil
}
return csmConfig{}, nil
}
sharedCfg, err := loadSharedConfig(csmProfileName, cfgFiles, false)
if err != nil {
if _, ok := err.(SharedConfigProfileNotExistsError); !ok {
return csmConfig{}, err
}
}
if sharedCfg.CSMEnabled != nil && *sharedCfg.CSMEnabled == true {
return csmConfig{
Enabled: true,
ClientID: sharedCfg.CSMClientID,
Host: sharedCfg.CSMHost,
Port: sharedCfg.CSMPort,
}, nil
}
return csmConfig{}, nil
}
func loadCustomCABundle(s *Session, bundle io.Reader) error {
var t *http.Transport
switch v := s.Config.HTTPClient.Transport.(type) {
@ -443,9 +531,11 @@ func loadCertPool(r io.Reader) (*x509.CertPool, error) {
return p, nil
}
func mergeConfigSrcs(cfg, userCfg *aws.Config, envCfg envConfig, sharedCfg sharedConfig, handlers request.Handlers, sessOpts Options) error {
// Merge in user provided configuration
cfg.MergeIn(userCfg)
func mergeConfigSrcs(cfg, userCfg *aws.Config,
envCfg envConfig, sharedCfg sharedConfig,
handlers request.Handlers,
sessOpts Options,
) error {
// Region if not already set by user
if len(aws.StringValue(cfg.Region)) == 0 {
@ -464,162 +554,59 @@ func mergeConfigSrcs(cfg, userCfg *aws.Config, envCfg envConfig, sharedCfg share
}
}
// Configure credentials if not already set
if cfg.Credentials == credentials.AnonymousCredentials && userCfg.Credentials == nil {
// inspect the profile to see if a credential source has been specified.
if envCfg.EnableSharedConfig && len(sharedCfg.AssumeRole.CredentialSource) > 0 {
// if both credential_source and source_profile have been set, return an error
// as this is undefined behavior.
if len(sharedCfg.AssumeRole.SourceProfile) > 0 {
return ErrSharedConfigSourceCollision
}
// valid credential source values
const (
credSourceEc2Metadata = "Ec2InstanceMetadata"
credSourceEnvironment = "Environment"
credSourceECSContainer = "EcsContainer"
)
switch sharedCfg.AssumeRole.CredentialSource {
case credSourceEc2Metadata:
cfgCp := *cfg
p := defaults.RemoteCredProvider(cfgCp, handlers)
cfgCp.Credentials = credentials.NewCredentials(p)
if len(sharedCfg.AssumeRole.MFASerial) > 0 && sessOpts.AssumeRoleTokenProvider == nil {
// AssumeRole Token provider is required if doing Assume Role
// with MFA.
return AssumeRoleTokenProviderNotSetError{}
}
cfg.Credentials = assumeRoleCredentials(cfgCp, handlers, sharedCfg, sessOpts)
case credSourceEnvironment:
cfg.Credentials = credentials.NewStaticCredentialsFromCreds(
envCfg.Creds,
)
case credSourceECSContainer:
if len(os.Getenv(shareddefaults.ECSCredsProviderEnvVar)) == 0 {
return ErrSharedConfigECSContainerEnvVarEmpty
}
cfgCp := *cfg
p := defaults.RemoteCredProvider(cfgCp, handlers)
creds := credentials.NewCredentials(p)
cfg.Credentials = creds
default:
return ErrSharedConfigInvalidCredSource
}
return nil
}
if len(envCfg.Creds.AccessKeyID) > 0 {
cfg.Credentials = credentials.NewStaticCredentialsFromCreds(
envCfg.Creds,
)
} else if envCfg.EnableSharedConfig && len(sharedCfg.AssumeRole.RoleARN) > 0 && sharedCfg.AssumeRoleSource != nil {
cfgCp := *cfg
cfgCp.Credentials = credentials.NewStaticCredentialsFromCreds(
sharedCfg.AssumeRoleSource.Creds,
)
if len(sharedCfg.AssumeRole.MFASerial) > 0 && sessOpts.AssumeRoleTokenProvider == nil {
// AssumeRole Token provider is required if doing Assume Role
// with MFA.
return AssumeRoleTokenProviderNotSetError{}
}
cfg.Credentials = assumeRoleCredentials(cfgCp, handlers, sharedCfg, sessOpts)
} else if len(sharedCfg.Creds.AccessKeyID) > 0 {
cfg.Credentials = credentials.NewStaticCredentialsFromCreds(
sharedCfg.Creds,
)
} else if len(sharedCfg.CredentialProcess) > 0 {
cfg.Credentials = processcreds.NewCredentials(
sharedCfg.CredentialProcess,
)
} else {
// Fallback to default credentials provider, include mock errors
// for the credential chain so user can identify why credentials
// failed to be retrieved.
cfg.Credentials = credentials.NewCredentials(&credentials.ChainProvider{
VerboseErrors: aws.BoolValue(cfg.CredentialsChainVerboseErrors),
Providers: []credentials.Provider{
&credProviderError{Err: awserr.New("EnvAccessKeyNotFound", "failed to find credentials in the environment.", nil)},
&credProviderError{Err: awserr.New("SharedCredsLoad", fmt.Sprintf("failed to load profile, %s.", envCfg.Profile), nil)},
defaults.RemoteCredProvider(*cfg, handlers),
},
// Regional Endpoint flag for STS endpoint resolving
mergeSTSRegionalEndpointConfig(cfg, []endpoints.STSRegionalEndpoint{
userCfg.STSRegionalEndpoint,
envCfg.STSRegionalEndpoint,
sharedCfg.STSRegionalEndpoint,
endpoints.LegacySTSEndpoint,
})
// Regional Endpoint flag for S3 endpoint resolving
mergeS3UsEast1RegionalEndpointConfig(cfg, []endpoints.S3UsEast1RegionalEndpoint{
userCfg.S3UsEast1RegionalEndpoint,
envCfg.S3UsEast1RegionalEndpoint,
sharedCfg.S3UsEast1RegionalEndpoint,
endpoints.LegacyS3UsEast1Endpoint,
})
// Configure credentials if not already set by the user when creating the
// Session.
if cfg.Credentials == credentials.AnonymousCredentials && userCfg.Credentials == nil {
creds, err := resolveCredentials(cfg, envCfg, sharedCfg, handlers, sessOpts)
if err != nil {
return err
}
cfg.Credentials = creds
}
cfg.S3UseARNRegion = userCfg.S3UseARNRegion
if cfg.S3UseARNRegion == nil {
cfg.S3UseARNRegion = &envCfg.S3UseARNRegion
}
if cfg.S3UseARNRegion == nil {
cfg.S3UseARNRegion = &sharedCfg.S3UseARNRegion
}
return nil
}
func assumeRoleCredentials(cfg aws.Config, handlers request.Handlers, sharedCfg sharedConfig, sessOpts Options) *credentials.Credentials {
return stscreds.NewCredentials(
&Session{
Config: &cfg,
Handlers: handlers.Copy(),
},
sharedCfg.AssumeRole.RoleARN,
func(opt *stscreds.AssumeRoleProvider) {
opt.RoleSessionName = sharedCfg.AssumeRole.RoleSessionName
// Assume role with external ID
if len(sharedCfg.AssumeRole.ExternalID) > 0 {
opt.ExternalID = aws.String(sharedCfg.AssumeRole.ExternalID)
func mergeSTSRegionalEndpointConfig(cfg *aws.Config, values []endpoints.STSRegionalEndpoint) {
for _, v := range values {
if v != endpoints.UnsetSTSEndpoint {
cfg.STSRegionalEndpoint = v
break
}
// Assume role with MFA
if len(sharedCfg.AssumeRole.MFASerial) > 0 {
opt.SerialNumber = aws.String(sharedCfg.AssumeRole.MFASerial)
opt.TokenProvider = sessOpts.AssumeRoleTokenProvider
}
},
)
}
// AssumeRoleTokenProviderNotSetError is an error returned when creating a session when the
// MFAToken option is not set when shared config is configured load assume a
// role with an MFA token.
type AssumeRoleTokenProviderNotSetError struct{}
// Code is the short id of the error.
func (e AssumeRoleTokenProviderNotSetError) Code() string {
return "AssumeRoleTokenProviderNotSetError"
}
// Message is the description of the error
func (e AssumeRoleTokenProviderNotSetError) Message() string {
return fmt.Sprintf("assume role with MFA enabled, but AssumeRoleTokenProvider session option not set.")
}
// OrigErr is the underlying error that caused the failure.
func (e AssumeRoleTokenProviderNotSetError) OrigErr() error {
return nil
}
// Error satisfies the error interface.
func (e AssumeRoleTokenProviderNotSetError) Error() string {
return awserr.SprintError(e.Code(), e.Message(), "", nil)
}
type credProviderError struct {
Err error
}
var emptyCreds = credentials.Value{}
func (c credProviderError) Retrieve() (credentials.Value, error) {
return credentials.Value{}, c.Err
}
func (c credProviderError) IsExpired() bool {
return true
func mergeS3UsEast1RegionalEndpointConfig(cfg *aws.Config, values []endpoints.S3UsEast1RegionalEndpoint) {
for _, v := range values {
if v != endpoints.UnsetS3UsEast1Endpoint {
cfg.S3UsEast1RegionalEndpoint = v
break
}
}
}
func initHandlers(s *Session) {
@ -630,7 +617,7 @@ func initHandlers(s *Session) {
}
}
// Copy creates and returns a copy of the current Session, coping the config
// Copy creates and returns a copy of the current Session, copying the config
// and handlers. If any additional configs are provided they will be merged
// on top of the Session's copied config.
//
@ -650,47 +637,67 @@ func (s *Session) Copy(cfgs ...*aws.Config) *Session {
// ClientConfig satisfies the client.ConfigProvider interface and is used to
// configure the service client instances. Passing the Session to the service
// client's constructor (New) will use this method to configure the client.
func (s *Session) ClientConfig(serviceName string, cfgs ...*aws.Config) client.Config {
// Backwards compatibility, the error will be eaten if user calls ClientConfig
// directly. All SDK services will use ClientconfigWithError.
cfg, _ := s.clientConfigWithErr(serviceName, cfgs...)
return cfg
}
func (s *Session) clientConfigWithErr(serviceName string, cfgs ...*aws.Config) (client.Config, error) {
func (s *Session) ClientConfig(service string, cfgs ...*aws.Config) client.Config {
s = s.Copy(cfgs...)
var resolved endpoints.ResolvedEndpoint
var err error
region := aws.StringValue(s.Config.Region)
resolved, err := s.resolveEndpoint(service, region, s.Config)
if err != nil {
s.Handlers.Validate.PushBack(func(r *request.Request) {
if len(r.ClientInfo.Endpoint) != 0 {
// Error occurred while resolving endpoint, but the request
// being invoked has had an endpoint specified after the client
// was created.
return
}
r.Error = err
})
}
if endpoint := aws.StringValue(s.Config.Endpoint); len(endpoint) != 0 {
resolved.URL = endpoints.AddScheme(endpoint, aws.BoolValue(s.Config.DisableSSL))
resolved.SigningRegion = region
} else {
resolved, err = s.Config.EndpointResolver.EndpointFor(
serviceName, region,
return client.Config{
Config: s.Config,
Handlers: s.Handlers,
PartitionID: resolved.PartitionID,
Endpoint: resolved.URL,
SigningRegion: resolved.SigningRegion,
SigningNameDerived: resolved.SigningNameDerived,
SigningName: resolved.SigningName,
}
}
func (s *Session) resolveEndpoint(service, region string, cfg *aws.Config) (endpoints.ResolvedEndpoint, error) {
if ep := aws.StringValue(cfg.Endpoint); len(ep) != 0 {
return endpoints.ResolvedEndpoint{
URL: endpoints.AddScheme(ep, aws.BoolValue(cfg.DisableSSL)),
SigningRegion: region,
}, nil
}
resolved, err := cfg.EndpointResolver.EndpointFor(service, region,
func(opt *endpoints.Options) {
opt.DisableSSL = aws.BoolValue(s.Config.DisableSSL)
opt.UseDualStack = aws.BoolValue(s.Config.UseDualStack)
opt.DisableSSL = aws.BoolValue(cfg.DisableSSL)
opt.UseDualStack = aws.BoolValue(cfg.UseDualStack)
// Support for STSRegionalEndpoint where the STSRegionalEndpoint is
// provided in envConfig or sharedConfig with envConfig getting
// precedence.
opt.STSRegionalEndpoint = cfg.STSRegionalEndpoint
// Support for S3UsEast1RegionalEndpoint where the S3UsEast1RegionalEndpoint is
// provided in envConfig or sharedConfig with envConfig getting
// precedence.
opt.S3UsEast1RegionalEndpoint = cfg.S3UsEast1RegionalEndpoint
// Support the condition where the service is modeled but its
// endpoint metadata is not available.
opt.ResolveUnknownService = true
},
)
if err != nil {
return endpoints.ResolvedEndpoint{}, err
}
return client.Config{
Config: s.Config,
Handlers: s.Handlers,
Endpoint: resolved.URL,
SigningRegion: resolved.SigningRegion,
SigningNameDerived: resolved.SigningNameDerived,
SigningName: resolved.SigningName,
}, err
return resolved, nil
}
// ClientConfigNoResolveEndpoint is the same as ClientConfig with the exception
@ -700,12 +707,9 @@ func (s *Session) ClientConfigNoResolveEndpoint(cfgs ...*aws.Config) client.Conf
s = s.Copy(cfgs...)
var resolved endpoints.ResolvedEndpoint
region := aws.StringValue(s.Config.Region)
if ep := aws.StringValue(s.Config.Endpoint); len(ep) > 0 {
resolved.URL = endpoints.AddScheme(ep, aws.BoolValue(s.Config.DisableSSL))
resolved.SigningRegion = region
resolved.SigningRegion = aws.StringValue(s.Config.Region)
}
return client.Config{
@ -717,3 +721,14 @@ func (s *Session) ClientConfigNoResolveEndpoint(cfgs ...*aws.Config) client.Conf
SigningName: resolved.SigningName,
}
}
// logDeprecatedNewSessionError function enables error handling for session
func (s *Session) logDeprecatedNewSessionError(msg string, err error, cfgs []*aws.Config) {
// Session creation failed, need to report the error and prevent
// any requests from succeeding.
s.Config.MergeIn(cfgs...)
s.Config.Logger.Log("ERROR:", msg, "Error:", err)
s.Handlers.Validate.PushBack(func(r *request.Request) {
r.Error = err
})
}

View File

@ -2,10 +2,11 @@ package session
import (
"fmt"
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/internal/ini"
)
@ -22,51 +23,69 @@ const (
externalIDKey = `external_id` // optional
mfaSerialKey = `mfa_serial` // optional
roleSessionNameKey = `role_session_name` // optional
roleDurationSecondsKey = "duration_seconds" // optional
// CSM options
csmEnabledKey = `csm_enabled`
csmHostKey = `csm_host`
csmPortKey = `csm_port`
csmClientIDKey = `csm_client_id`
// Additional Config fields
regionKey = `region`
// endpoint discovery group
enableEndpointDiscoveryKey = `endpoint_discovery_enabled` // optional
// External Credential Process
credentialProcessKey = `credential_process`
credentialProcessKey = `credential_process` // optional
// Web Identity Token File
webIdentityTokenFileKey = `web_identity_token_file` // optional
// Additional config fields for regional or legacy endpoints
stsRegionalEndpointSharedKey = `sts_regional_endpoints`
// Additional config fields for regional or legacy endpoints
s3UsEast1RegionalSharedKey = `s3_us_east_1_regional_endpoint`
// DefaultSharedConfigProfile is the default profile to be used when
// loading configuration from the config files if another profile name
// is not provided.
DefaultSharedConfigProfile = `default`
)
type assumeRoleConfig struct {
RoleARN string
SourceProfile string
CredentialSource string
ExternalID string
MFASerial string
RoleSessionName string
}
// S3 ARN Region Usage
s3UseARNRegionKey = "s3_use_arn_region"
)
// sharedConfig represents the configuration fields of the SDK config files.
type sharedConfig struct {
// Credentials values from the config file. Both aws_access_key_id
// and aws_secret_access_key must be provided together in the same file
// to be considered valid. The values will be ignored if not a complete group.
// aws_session_token is an optional field that can be provided if both of the
// other two fields are also provided.
// Credentials values from the config file. Both aws_access_key_id and
// aws_secret_access_key must be provided together in the same file to be
// considered valid. The values will be ignored if not a complete group.
// aws_session_token is an optional field that can be provided if both of
// the other two fields are also provided.
//
// aws_access_key_id
// aws_secret_access_key
// aws_session_token
Creds credentials.Value
AssumeRole assumeRoleConfig
AssumeRoleSource *sharedConfig
// An external process to request credentials
CredentialSource string
CredentialProcess string
WebIdentityTokenFile string
// Region is the region the SDK should use for looking up AWS service endpoints
// and signing requests.
RoleARN string
RoleSessionName string
ExternalID string
MFASerial string
AssumeRoleDuration *time.Duration
SourceProfileName string
SourceProfile *sharedConfig
// Region is the region the SDK should use for looking up AWS service
// endpoints and signing requests.
//
// region
Region string
@ -76,6 +95,30 @@ type sharedConfig struct {
//
// endpoint_discovery_enabled = true
EnableEndpointDiscovery *bool
// CSM Options
CSMEnabled *bool
CSMHost string
CSMPort string
CSMClientID string
// Specifies the Regional Endpoint flag for the SDK to resolve the endpoint for a service
//
// sts_regional_endpoints = regional
// This can take value as `LegacySTSEndpoint` or `RegionalSTSEndpoint`
STSRegionalEndpoint endpoints.STSRegionalEndpoint
// Specifies the Regional Endpoint flag for the SDK to resolve the endpoint for a service
//
// s3_us_east_1_regional_endpoint = regional
// This can take value as `LegacyS3UsEast1Endpoint` or `RegionalS3UsEast1Endpoint`
S3UsEast1RegionalEndpoint endpoints.S3UsEast1RegionalEndpoint
// Specifies if the S3 service should allow ARNs to direct the region
// the client's requests are sent to.
//
// s3_use_arn_region=true
S3UseARNRegion bool
}
type sharedConfigFile struct {
@ -83,17 +126,18 @@ type sharedConfigFile struct {
IniData ini.Sections
}
// loadSharedConfig retrieves the configuration from the list of files
// using the profile provided. The order the files are listed will determine
// loadSharedConfig retrieves the configuration from the list of files using
// the profile provided. The order the files are listed will determine
// precedence. Values in subsequent files will overwrite values defined in
// earlier files.
//
// For example, given two files A and B. Both define credentials. If the order
// of the files are A then B, B's credential values will be used instead of A's.
// of the files are A then B, B's credential values will be used instead of
// A's.
//
// See sharedConfig.setFromFile for information how the config files
// will be loaded.
func loadSharedConfig(profile string, filenames []string) (sharedConfig, error) {
func loadSharedConfig(profile string, filenames []string, exOpts bool) (sharedConfig, error) {
if len(profile) == 0 {
profile = DefaultSharedConfigProfile
}
@ -104,16 +148,11 @@ func loadSharedConfig(profile string, filenames []string) (sharedConfig, error)
}
cfg := sharedConfig{}
if err = cfg.setFromIniFiles(profile, files); err != nil {
profiles := map[string]struct{}{}
if err = cfg.setFromIniFiles(profiles, profile, files, exOpts); err != nil {
return sharedConfig{}, err
}
if len(cfg.AssumeRole.SourceProfile) > 0 {
if err := cfg.setAssumeRoleSource(profile, files); err != nil {
return sharedConfig{}, err
}
}
return cfg, nil
}
@ -137,60 +176,88 @@ func loadSharedConfigIniFiles(filenames []string) ([]sharedConfigFile, error) {
return files, nil
}
func (cfg *sharedConfig) setAssumeRoleSource(origProfile string, files []sharedConfigFile) error {
var assumeRoleSrc sharedConfig
if len(cfg.AssumeRole.CredentialSource) > 0 {
// setAssumeRoleSource is only called when source_profile is found.
// If both source_profile and credential_source are set, then
// ErrSharedConfigSourceCollision will be returned
return ErrSharedConfigSourceCollision
}
// Multiple level assume role chains are not support
if cfg.AssumeRole.SourceProfile == origProfile {
assumeRoleSrc = *cfg
assumeRoleSrc.AssumeRole = assumeRoleConfig{}
} else {
err := assumeRoleSrc.setFromIniFiles(cfg.AssumeRole.SourceProfile, files)
if err != nil {
return err
}
}
if len(assumeRoleSrc.Creds.AccessKeyID) == 0 {
return SharedConfigAssumeRoleError{RoleARN: cfg.AssumeRole.RoleARN}
}
cfg.AssumeRoleSource = &assumeRoleSrc
return nil
}
func (cfg *sharedConfig) setFromIniFiles(profile string, files []sharedConfigFile) error {
func (cfg *sharedConfig) setFromIniFiles(profiles map[string]struct{}, profile string, files []sharedConfigFile, exOpts bool) error {
// Trim files from the list that don't exist.
var skippedFiles int
var profileNotFoundErr error
for _, f := range files {
if err := cfg.setFromIniFile(profile, f); err != nil {
if err := cfg.setFromIniFile(profile, f, exOpts); err != nil {
if _, ok := err.(SharedConfigProfileNotExistsError); ok {
// Ignore proviles missings
// Ignore profiles not defined in individual files.
profileNotFoundErr = err
skippedFiles++
continue
}
return err
}
}
if skippedFiles == len(files) {
// If all files were skipped because the profile is not found, return
// the original profile not found error.
return profileNotFoundErr
}
if _, ok := profiles[profile]; ok {
// if this is the second instance of the profile the Assume Role
// options must be cleared because they are only valid for the
// first reference of a profile. The self linked instance of the
// profile only have credential provider options.
cfg.clearAssumeRoleOptions()
} else {
// First time a profile has been seen, It must either be a assume role
// or credentials. Assert if the credential type requires a role ARN,
// the ARN is also set.
if err := cfg.validateCredentialsRequireARN(profile); err != nil {
return err
}
}
profiles[profile] = struct{}{}
if err := cfg.validateCredentialType(); err != nil {
return err
}
// Link source profiles for assume roles
if len(cfg.SourceProfileName) != 0 {
// Linked profile via source_profile ignore credential provider
// options, the source profile must provide the credentials.
cfg.clearCredentialOptions()
srcCfg := &sharedConfig{}
err := srcCfg.setFromIniFiles(profiles, cfg.SourceProfileName, files, exOpts)
if err != nil {
// SourceProfile that doesn't exist is an error in configuration.
if _, ok := err.(SharedConfigProfileNotExistsError); ok {
err = SharedConfigAssumeRoleError{
RoleARN: cfg.RoleARN,
SourceProfile: cfg.SourceProfileName,
}
}
return err
}
if !srcCfg.hasCredentials() {
return SharedConfigAssumeRoleError{
RoleARN: cfg.RoleARN,
SourceProfile: cfg.SourceProfileName,
}
}
cfg.SourceProfile = srcCfg
}
return nil
}
// setFromFile loads the configuration from the file using
// the profile provided. A sharedConfig pointer type value is used so that
// multiple config file loadings can be chained.
// setFromFile loads the configuration from the file using the profile
// provided. A sharedConfig pointer type value is used so that multiple config
// file loadings can be chained.
//
// Only loads complete logically grouped values, and will not set fields in cfg
// for incomplete grouped values in the config. Such as credentials. For example
// if a config file only includes aws_access_key_id but no aws_secret_access_key
// the aws_access_key_id will be ignored.
func (cfg *sharedConfig) setFromIniFile(profile string, file sharedConfigFile) error {
// for incomplete grouped values in the config. Such as credentials. For
// example if a config file only includes aws_access_key_id but no
// aws_secret_access_key the aws_access_key_id will be ignored.
func (cfg *sharedConfig) setFromIniFile(profile string, file sharedConfigFile, exOpts bool) error {
section, ok := file.IniData.GetSection(profile)
if !ok {
// Fallback to to alternate profile name: profile <name>
@ -200,53 +267,176 @@ func (cfg *sharedConfig) setFromIniFile(profile string, file sharedConfigFile) e
}
}
if exOpts {
// Assume Role Parameters
updateString(&cfg.RoleARN, section, roleArnKey)
updateString(&cfg.ExternalID, section, externalIDKey)
updateString(&cfg.MFASerial, section, mfaSerialKey)
updateString(&cfg.RoleSessionName, section, roleSessionNameKey)
updateString(&cfg.SourceProfileName, section, sourceProfileKey)
updateString(&cfg.CredentialSource, section, credentialSourceKey)
updateString(&cfg.Region, section, regionKey)
if section.Has(roleDurationSecondsKey) {
d := time.Duration(section.Int(roleDurationSecondsKey)) * time.Second
cfg.AssumeRoleDuration = &d
}
if v := section.String(stsRegionalEndpointSharedKey); len(v) != 0 {
sre, err := endpoints.GetSTSRegionalEndpoint(v)
if err != nil {
return fmt.Errorf("failed to load %s from shared config, %s, %v",
stsRegionalEndpointSharedKey, file.Filename, err)
}
cfg.STSRegionalEndpoint = sre
}
if v := section.String(s3UsEast1RegionalSharedKey); len(v) != 0 {
sre, err := endpoints.GetS3UsEast1RegionalEndpoint(v)
if err != nil {
return fmt.Errorf("failed to load %s from shared config, %s, %v",
s3UsEast1RegionalSharedKey, file.Filename, err)
}
cfg.S3UsEast1RegionalEndpoint = sre
}
}
updateString(&cfg.CredentialProcess, section, credentialProcessKey)
updateString(&cfg.WebIdentityTokenFile, section, webIdentityTokenFileKey)
// Shared Credentials
akid := section.String(accessKeyIDKey)
secret := section.String(secretAccessKey)
if len(akid) > 0 && len(secret) > 0 {
cfg.Creds = credentials.Value{
AccessKeyID: akid,
SecretAccessKey: secret,
creds := credentials.Value{
AccessKeyID: section.String(accessKeyIDKey),
SecretAccessKey: section.String(secretAccessKey),
SessionToken: section.String(sessionTokenKey),
ProviderName: fmt.Sprintf("SharedConfigCredentials: %s", file.Filename),
}
}
// Assume Role
roleArn := section.String(roleArnKey)
srcProfile := section.String(sourceProfileKey)
credentialSource := section.String(credentialSourceKey)
hasSource := len(srcProfile) > 0 || len(credentialSource) > 0
if len(roleArn) > 0 && hasSource {
cfg.AssumeRole = assumeRoleConfig{
RoleARN: roleArn,
SourceProfile: srcProfile,
CredentialSource: credentialSource,
ExternalID: section.String(externalIDKey),
MFASerial: section.String(mfaSerialKey),
RoleSessionName: section.String(roleSessionNameKey),
}
}
// `credential_process`
if credProc := section.String(credentialProcessKey); len(credProc) > 0 {
cfg.CredentialProcess = credProc
}
// Region
if v := section.String(regionKey); len(v) > 0 {
cfg.Region = v
if creds.HasKeys() {
cfg.Creds = creds
}
// Endpoint discovery
if section.Has(enableEndpointDiscoveryKey) {
v := section.Bool(enableEndpointDiscoveryKey)
cfg.EnableEndpointDiscovery = &v
updateBoolPtr(&cfg.EnableEndpointDiscovery, section, enableEndpointDiscoveryKey)
// CSM options
updateBoolPtr(&cfg.CSMEnabled, section, csmEnabledKey)
updateString(&cfg.CSMHost, section, csmHostKey)
updateString(&cfg.CSMPort, section, csmPortKey)
updateString(&cfg.CSMClientID, section, csmClientIDKey)
updateBool(&cfg.S3UseARNRegion, section, s3UseARNRegionKey)
return nil
}
func (cfg *sharedConfig) validateCredentialsRequireARN(profile string) error {
var credSource string
switch {
case len(cfg.SourceProfileName) != 0:
credSource = sourceProfileKey
case len(cfg.CredentialSource) != 0:
credSource = credentialSourceKey
case len(cfg.WebIdentityTokenFile) != 0:
credSource = webIdentityTokenFileKey
}
if len(credSource) != 0 && len(cfg.RoleARN) == 0 {
return CredentialRequiresARNError{
Type: credSource,
Profile: profile,
}
}
return nil
}
func (cfg *sharedConfig) validateCredentialType() error {
// Only one or no credential type can be defined.
if !oneOrNone(
len(cfg.SourceProfileName) != 0,
len(cfg.CredentialSource) != 0,
len(cfg.CredentialProcess) != 0,
len(cfg.WebIdentityTokenFile) != 0,
) {
return ErrSharedConfigSourceCollision
}
return nil
}
func (cfg *sharedConfig) hasCredentials() bool {
switch {
case len(cfg.SourceProfileName) != 0:
case len(cfg.CredentialSource) != 0:
case len(cfg.CredentialProcess) != 0:
case len(cfg.WebIdentityTokenFile) != 0:
case cfg.Creds.HasKeys():
default:
return false
}
return true
}
func (cfg *sharedConfig) clearCredentialOptions() {
cfg.CredentialSource = ""
cfg.CredentialProcess = ""
cfg.WebIdentityTokenFile = ""
cfg.Creds = credentials.Value{}
}
func (cfg *sharedConfig) clearAssumeRoleOptions() {
cfg.RoleARN = ""
cfg.ExternalID = ""
cfg.MFASerial = ""
cfg.RoleSessionName = ""
cfg.SourceProfileName = ""
}
func oneOrNone(bs ...bool) bool {
var count int
for _, b := range bs {
if b {
count++
if count > 1 {
return false
}
}
}
return true
}
// updateString will only update the dst with the value in the section key, key
// is present in the section.
func updateString(dst *string, section ini.Section, key string) {
if !section.Has(key) {
return
}
*dst = section.String(key)
}
// updateBool will only update the dst with the value in the section key, key
// is present in the section.
func updateBool(dst *bool, section ini.Section, key string) {
if !section.Has(key) {
return
}
*dst = section.Bool(key)
}
// updateBoolPtr will only update the dst with the value in the section key,
// key is present in the section.
func updateBoolPtr(dst **bool, section ini.Section, key string) {
if !section.Has(key) {
return
}
*dst = new(bool)
**dst = section.Bool(key)
}
// SharedConfigLoadError is an error for the shared config file failed to load.
type SharedConfigLoadError struct {
Filename string
@ -305,6 +495,7 @@ func (e SharedConfigProfileNotExistsError) Error() string {
// or not complete.
type SharedConfigAssumeRoleError struct {
RoleARN string
SourceProfile string
}
// Code is the short id of the error.
@ -314,8 +505,10 @@ func (e SharedConfigAssumeRoleError) Code() string {
// Message is the description of the error
func (e SharedConfigAssumeRoleError) Message() string {
return fmt.Sprintf("failed to load assume role for %s, source profile has no shared credentials",
e.RoleARN)
return fmt.Sprintf(
"failed to load assume role for %s, source profile %s has no shared credentials",
e.RoleARN, e.SourceProfile,
)
}
// OrigErr is the underlying error that caused the failure.
@ -327,3 +520,36 @@ func (e SharedConfigAssumeRoleError) OrigErr() error {
func (e SharedConfigAssumeRoleError) Error() string {
return awserr.SprintError(e.Code(), e.Message(), "", nil)
}
// CredentialRequiresARNError provides the error for shared config credentials
// that are incorrectly configured in the shared config or credentials file.
type CredentialRequiresARNError struct {
// type of credentials that were configured.
Type string
// Profile name the credentials were in.
Profile string
}
// Code is the short id of the error.
func (e CredentialRequiresARNError) Code() string {
return "CredentialRequiresARNError"
}
// Message is the description of the error
func (e CredentialRequiresARNError) Message() string {
return fmt.Sprintf(
"credential type %s requires role_arn, profile %s",
e.Type, e.Profile,
)
}
// OrigErr is the underlying error that caused the failure.
func (e CredentialRequiresARNError) OrigErr() error {
return nil
}
// Error satisfies the error interface.
func (e CredentialRequiresARNError) Error() string {
return awserr.SprintError(e.Code(), e.Message(), "", nil)
}

View File

@ -1,8 +1,7 @@
package v4
import (
"net/http"
"strings"
"github.com/aws/aws-sdk-go/internal/strings"
)
// validator houses a set of rule needed for validation of a
@ -61,7 +60,7 @@ type patterns []string
// been found
func (p patterns) IsValid(value string) bool {
for _, pattern := range p {
if strings.HasPrefix(http.CanonicalHeaderKey(value), pattern) {
if strings.HasPrefixFold(value, pattern) {
return true
}
}

View File

@ -0,0 +1,13 @@
// +build !go1.7
package v4
import (
"net/http"
"github.com/aws/aws-sdk-go/aws"
)
func requestContext(r *http.Request) aws.Context {
return aws.BackgroundContext()
}

View File

@ -0,0 +1,13 @@
// +build go1.7
package v4
import (
"net/http"
"github.com/aws/aws-sdk-go/aws"
)
func requestContext(r *http.Request) aws.Context {
return r.Context()
}

View File

@ -0,0 +1,63 @@
package v4
import (
"encoding/hex"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws/credentials"
)
type credentialValueProvider interface {
Get() (credentials.Value, error)
}
// StreamSigner implements signing of event stream encoded payloads
type StreamSigner struct {
region string
service string
credentials credentialValueProvider
prevSig []byte
}
// NewStreamSigner creates a SigV4 signer used to sign Event Stream encoded messages
func NewStreamSigner(region, service string, seedSignature []byte, credentials *credentials.Credentials) *StreamSigner {
return &StreamSigner{
region: region,
service: service,
credentials: credentials,
prevSig: seedSignature,
}
}
// GetSignature takes an event stream encoded headers and payload and returns a signature
func (s *StreamSigner) GetSignature(headers, payload []byte, date time.Time) ([]byte, error) {
credValue, err := s.credentials.Get()
if err != nil {
return nil, err
}
sigKey := deriveSigningKey(s.region, s.service, credValue.SecretAccessKey, date)
keyPath := buildSigningScope(s.region, s.service, date)
stringToSign := buildEventStreamStringToSign(headers, payload, s.prevSig, keyPath, date)
signature := hmacSHA256(sigKey, []byte(stringToSign))
s.prevSig = signature
return signature, nil
}
func buildEventStreamStringToSign(headers, payload, prevSig []byte, scope string, date time.Time) string {
return strings.Join([]string{
"AWS4-HMAC-SHA256-PAYLOAD",
formatTime(date),
scope,
hex.EncodeToString(prevSig),
hex.EncodeToString(hashSHA256(headers)),
hex.EncodeToString(hashSHA256(payload)),
}, "\n")
}

View File

@ -76,9 +76,14 @@ import (
)
const (
authorizationHeader = "Authorization"
authHeaderSignatureElem = "Signature="
signatureQueryKey = "X-Amz-Signature"
authHeaderPrefix = "AWS4-HMAC-SHA256"
timeFormat = "20060102T150405Z"
shortTimeFormat = "20060102"
awsV4Request = "aws4_request"
// emptyStringSHA256 is a SHA256 of an empty string
emptyStringSHA256 = `e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855`
@ -87,7 +92,7 @@ const (
var ignoredHeaders = rules{
blacklist{
mapRule{
"Authorization": struct{}{},
authorizationHeader: struct{}{},
"User-Agent": struct{}{},
"X-Amzn-Trace-Id": struct{}{},
},
@ -231,8 +236,6 @@ type signingCtx struct {
credValues credentials.Value
isPresign bool
formattedTime string
formattedShortTime string
unsignedPayload bool
bodyDigest string
@ -337,7 +340,7 @@ func (v4 Signer) signWithBody(r *http.Request, body io.ReadSeeker, service, regi
}
var err error
ctx.credValues, err = v4.Credentials.Get()
ctx.credValues, err = v4.Credentials.GetWithContext(requestContext(r))
if err != nil {
return http.Header{}, err
}
@ -532,39 +535,56 @@ func (ctx *signingCtx) build(disableHeaderHoisting bool) error {
ctx.buildSignature() // depends on string to sign
if ctx.isPresign {
ctx.Request.URL.RawQuery += "&X-Amz-Signature=" + ctx.signature
ctx.Request.URL.RawQuery += "&" + signatureQueryKey + "=" + ctx.signature
} else {
parts := []string{
authHeaderPrefix + " Credential=" + ctx.credValues.AccessKeyID + "/" + ctx.credentialString,
"SignedHeaders=" + ctx.signedHeaders,
"Signature=" + ctx.signature,
authHeaderSignatureElem + ctx.signature,
}
ctx.Request.Header.Set("Authorization", strings.Join(parts, ", "))
ctx.Request.Header.Set(authorizationHeader, strings.Join(parts, ", "))
}
return nil
}
func (ctx *signingCtx) buildTime() {
ctx.formattedTime = ctx.Time.UTC().Format(timeFormat)
ctx.formattedShortTime = ctx.Time.UTC().Format(shortTimeFormat)
// GetSignedRequestSignature attempts to extract the signature of the request.
// Returning an error if the request is unsigned, or unable to extract the
// signature.
func GetSignedRequestSignature(r *http.Request) ([]byte, error) {
if auth := r.Header.Get(authorizationHeader); len(auth) != 0 {
ps := strings.Split(auth, ", ")
for _, p := range ps {
if idx := strings.Index(p, authHeaderSignatureElem); idx >= 0 {
sig := p[len(authHeaderSignatureElem):]
if len(sig) == 0 {
return nil, fmt.Errorf("invalid request signature authorization header")
}
return hex.DecodeString(sig)
}
}
}
if sig := r.URL.Query().Get("X-Amz-Signature"); len(sig) != 0 {
return hex.DecodeString(sig)
}
return nil, fmt.Errorf("request not signed")
}
func (ctx *signingCtx) buildTime() {
if ctx.isPresign {
duration := int64(ctx.ExpireTime / time.Second)
ctx.Query.Set("X-Amz-Date", ctx.formattedTime)
ctx.Query.Set("X-Amz-Date", formatTime(ctx.Time))
ctx.Query.Set("X-Amz-Expires", strconv.FormatInt(duration, 10))
} else {
ctx.Request.Header.Set("X-Amz-Date", ctx.formattedTime)
ctx.Request.Header.Set("X-Amz-Date", formatTime(ctx.Time))
}
}
func (ctx *signingCtx) buildCredentialString() {
ctx.credentialString = strings.Join([]string{
ctx.formattedShortTime,
ctx.Region,
ctx.ServiceName,
"aws4_request",
}, "/")
ctx.credentialString = buildSigningScope(ctx.Region, ctx.ServiceName, ctx.Time)
if ctx.isPresign {
ctx.Query.Set("X-Amz-Credential", ctx.credValues.AccessKeyID+"/"+ctx.credentialString)
@ -588,8 +608,7 @@ func (ctx *signingCtx) buildCanonicalHeaders(r rule, header http.Header) {
var headers []string
headers = append(headers, "host")
for k, v := range header {
canonicalKey := http.CanonicalHeaderKey(k)
if !r.IsValid(canonicalKey) {
if !r.IsValid(k) {
continue // ignored header
}
if ctx.SignedHeaderVals == nil {
@ -653,19 +672,15 @@ func (ctx *signingCtx) buildCanonicalString() {
func (ctx *signingCtx) buildStringToSign() {
ctx.stringToSign = strings.Join([]string{
authHeaderPrefix,
ctx.formattedTime,
formatTime(ctx.Time),
ctx.credentialString,
hex.EncodeToString(makeSha256([]byte(ctx.canonicalString))),
hex.EncodeToString(hashSHA256([]byte(ctx.canonicalString))),
}, "\n")
}
func (ctx *signingCtx) buildSignature() {
secret := ctx.credValues.SecretAccessKey
date := makeHmac([]byte("AWS4"+secret), []byte(ctx.formattedShortTime))
region := makeHmac(date, []byte(ctx.Region))
service := makeHmac(region, []byte(ctx.ServiceName))
credentials := makeHmac(service, []byte("aws4_request"))
signature := makeHmac(credentials, []byte(ctx.stringToSign))
creds := deriveSigningKey(ctx.Region, ctx.ServiceName, ctx.credValues.SecretAccessKey, ctx.Time)
signature := hmacSHA256(creds, []byte(ctx.stringToSign))
ctx.signature = hex.EncodeToString(signature)
}
@ -687,7 +702,11 @@ func (ctx *signingCtx) buildBodyDigest() error {
if !aws.IsReaderSeekable(ctx.Body) {
return fmt.Errorf("cannot use unseekable request body %T, for signed request with body", ctx.Body)
}
hash = hex.EncodeToString(makeSha256Reader(ctx.Body))
hashBytes, err := makeSha256Reader(ctx.Body)
if err != nil {
return err
}
hash = hex.EncodeToString(hashBytes)
}
if includeSHA256Header {
@ -722,22 +741,28 @@ func (ctx *signingCtx) removePresign() {
ctx.Query.Del("X-Amz-SignedHeaders")
}
func makeHmac(key []byte, data []byte) []byte {
func hmacSHA256(key []byte, data []byte) []byte {
hash := hmac.New(sha256.New, key)
hash.Write(data)
return hash.Sum(nil)
}
func makeSha256(data []byte) []byte {
func hashSHA256(data []byte) []byte {
hash := sha256.New()
hash.Write(data)
return hash.Sum(nil)
}
func makeSha256Reader(reader io.ReadSeeker) []byte {
func makeSha256Reader(reader io.ReadSeeker) (hashBytes []byte, err error) {
hash := sha256.New()
start, _ := reader.Seek(0, sdkio.SeekCurrent)
defer reader.Seek(start, sdkio.SeekStart)
start, err := reader.Seek(0, sdkio.SeekCurrent)
if err != nil {
return nil, err
}
defer func() {
// ensure error is return if unable to seek back to start of payload.
_, err = reader.Seek(start, sdkio.SeekStart)
}()
// Use CopyN to avoid allocating the 32KB buffer in io.Copy for bodies
// smaller than 32KB. Fall back to io.Copy if we fail to determine the size.
@ -748,7 +773,7 @@ func makeSha256Reader(reader io.ReadSeeker) []byte {
io.CopyN(hash, reader, size)
}
return hash.Sum(nil)
return hash.Sum(nil), nil
}
const doubleSpace = " "
@ -794,3 +819,28 @@ func stripExcessSpaces(vals []string) {
vals[i] = string(buf[:m])
}
}
func buildSigningScope(region, service string, dt time.Time) string {
return strings.Join([]string{
formatShortTime(dt),
region,
service,
awsV4Request,
}, "/")
}
func deriveSigningKey(region, service, secretKey string, dt time.Time) []byte {
kDate := hmacSHA256([]byte("AWS4"+secretKey), []byte(formatShortTime(dt)))
kRegion := hmacSHA256(kDate, []byte(region))
kService := hmacSHA256(kRegion, []byte(service))
signingKey := hmacSHA256(kService, []byte(awsV4Request))
return signingKey
}
func formatShortTime(dt time.Time) string {
return dt.UTC().Format(shortTimeFormat)
}
func formatTime(dt time.Time) string {
return dt.UTC().Format(timeFormat)
}

View File

@ -2,6 +2,7 @@ package aws
import (
"io"
"strings"
"sync"
"github.com/aws/aws-sdk-go/internal/sdkio"
@ -205,3 +206,59 @@ func (b *WriteAtBuffer) Bytes() []byte {
defer b.m.Unlock()
return b.buf
}
// MultiCloser is a utility to close multiple io.Closers within a single
// statement.
type MultiCloser []io.Closer
// Close closes all of the io.Closers making up the MultiClosers. Any
// errors that occur while closing will be returned in the order they
// occur.
func (m MultiCloser) Close() error {
var errs errors
for _, c := range m {
err := c.Close()
if err != nil {
errs = append(errs, err)
}
}
if len(errs) != 0 {
return errs
}
return nil
}
type errors []error
func (es errors) Error() string {
var parts []string
for _, e := range es {
parts = append(parts, e.Error())
}
return strings.Join(parts, "\n")
}
// CopySeekableBody copies the seekable body to an io.Writer
func CopySeekableBody(dst io.Writer, src io.ReadSeeker) (int64, error) {
curPos, err := src.Seek(0, sdkio.SeekCurrent)
if err != nil {
return 0, err
}
// copy errors may be assumed to be from the body.
n, err := io.Copy(dst, src)
if err != nil {
return n, err
}
// seek back to the first position after reading to reset
// the body for transmission.
_, err = src.Seek(curPos, sdkio.SeekStart)
if err != nil {
return n, err
}
return n, nil
}

View File

@ -5,4 +5,4 @@ package aws
const SDKName = "aws-sdk-go"
// SDKVersion is the version of this SDK
const SDKVersion = "1.19.47"
const SDKVersion = "1.32.5"

View File

@ -0,0 +1,40 @@
// +build !go1.7
package context
import "time"
// An emptyCtx is a copy of the Go 1.7 context.emptyCtx type. This is copied to
// provide a 1.6 and 1.5 safe version of context that is compatible with Go
// 1.7's Context.
//
// An emptyCtx is never canceled, has no values, and has no deadline. It is not
// struct{}, since vars of this type must have distinct addresses.
type emptyCtx int
func (*emptyCtx) Deadline() (deadline time.Time, ok bool) {
return
}
func (*emptyCtx) Done() <-chan struct{} {
return nil
}
func (*emptyCtx) Err() error {
return nil
}
func (*emptyCtx) Value(key interface{}) interface{} {
return nil
}
func (e *emptyCtx) String() string {
switch e {
case BackgroundCtx:
return "aws.BackgroundContext"
}
return "unknown empty Context"
}
// BackgroundCtx is the common base context.
var BackgroundCtx = new(emptyCtx)

View File

@ -162,7 +162,7 @@ loop:
if len(tokens) == 0 {
break loop
}
// if should skip is true, we skip the tokens until should skip is set to false.
step = SkipTokenState
}
@ -218,7 +218,7 @@ loop:
// S -> equal_expr' expr_stmt'
switch k.Kind {
case ASTKindEqualExpr:
// assiging a value to some key
// assigning a value to some key
k.AppendChild(newExpression(tok))
stack.Push(newExprStatement(k))
case ASTKindExpr:
@ -250,6 +250,13 @@ loop:
if !runeCompare(tok.Raw(), openBrace) {
return nil, NewParseError("expected '['")
}
// If OpenScopeState is not at the start, we must mark the previous ast as complete
//
// for example: if previous ast was a skip statement;
// we should mark it as complete before we create a new statement
if k.Kind != ASTKindStart {
stack.MarkComplete(k)
}
stmt := newStatement()
stack.Push(stmt)

View File

@ -22,24 +22,24 @@ func newSkipper() skipper {
}
func (s *skipper) ShouldSkip(tok Token) bool {
// should skip state will be modified only if previous token was new line (NL);
// and the current token is not WhiteSpace (WS).
if s.shouldSkip &&
s.prevTok.Type() == TokenNL &&
tok.Type() != TokenWS {
s.Continue()
return false
}
s.prevTok = tok
return s.shouldSkip
}
func (s *skipper) Skip() {
s.shouldSkip = true
s.prevTok = emptyToken
}
func (s *skipper) Continue() {
s.shouldSkip = false
// empty token is assigned as we return to default state, when should skip is false
s.prevTok = emptyToken
}

View File

@ -0,0 +1,12 @@
package sdkio
const (
// Byte is 8 bits
Byte int64 = 1
// KibiByte (KiB) is 1024 Bytes
KibiByte = Byte * 1024
// MebiByte (MiB) is 1024 KiB
MebiByte = KibiByte * 1024
// GibiByte (GiB) is 1024 MiB
GibiByte = MebiByte * 1024
)

View File

@ -0,0 +1,15 @@
// +build go1.10
package sdkmath
import "math"
// Round returns the nearest integer, rounding half away from zero.
//
// Special cases are:
// Round(±0) = ±0
// Round(±Inf) = ±Inf
// Round(NaN) = NaN
func Round(x float64) float64 {
return math.Round(x)
}

View File

@ -0,0 +1,56 @@
// +build !go1.10
package sdkmath
import "math"
// Copied from the Go standard library's (Go 1.12) math/floor.go for use in
// Go version prior to Go 1.10.
const (
uvone = 0x3FF0000000000000
mask = 0x7FF
shift = 64 - 11 - 1
bias = 1023
signMask = 1 << 63
fracMask = 1<<shift - 1
)
// Round returns the nearest integer, rounding half away from zero.
//
// Special cases are:
// Round(±0) = ±0
// Round(±Inf) = ±Inf
// Round(NaN) = NaN
//
// Copied from the Go standard library's (Go 1.12) math/floor.go for use in
// Go version prior to Go 1.10.
func Round(x float64) float64 {
// Round is a faster implementation of:
//
// func Round(x float64) float64 {
// t := Trunc(x)
// if Abs(x-t) >= 0.5 {
// return t + Copysign(1, x)
// }
// return t
// }
bits := math.Float64bits(x)
e := uint(bits>>shift) & mask
if e < bias {
// Round abs(x) < 1 including denormals.
bits &= signMask // +-0
if e == bias-1 {
bits |= uvone // +-1
}
} else if e < bias+shift {
// Round any abs(x) >= 1 containing a fractional component [0,1).
//
// Numbers with larger exponents are returned unchanged since they
// must be either an integer, infinity, or NaN.
const half = 1 << (shift - 1)
e -= bias
bits += half >> e
bits &^= fracMask >> e
}
return math.Float64frombits(bits)
}

View File

@ -0,0 +1,11 @@
// +build go1.6
package sdkrand
import "math/rand"
// Read provides the stub for math.Rand.Read method support for go version's
// 1.6 and greater.
func Read(r *rand.Rand, p []byte) (int, error) {
return r.Read(p)
}

View File

@ -0,0 +1,24 @@
// +build !go1.6
package sdkrand
import "math/rand"
// Read backfills Go 1.6's math.Rand.Reader for Go 1.5
func Read(r *rand.Rand, p []byte) (n int, err error) {
// Copy of Go standard libraries math package's read function not added to
// standard library until Go 1.6.
var pos int8
var val int64
for n = 0; n < len(p); n++ {
if pos == 0 {
val = r.Int63()
pos = 7
}
p[n] = byte(val)
val >>= 8
pos--
}
return n, err
}

View File

@ -0,0 +1,11 @@
package strings
import (
"strings"
)
// HasPrefixFold tests whether the string s begins with prefix, interpreted as UTF-8 strings,
// under Unicode case-folding.
func HasPrefixFold(s, prefix string) bool {
return len(s) >= len(prefix) && strings.EqualFold(s[0:len(prefix)], prefix)
}

View File

@ -0,0 +1,27 @@
Copyright (c) 2009 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -0,0 +1,120 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package singleflight provides a duplicate function call suppression
// mechanism.
package singleflight
import "sync"
// call is an in-flight or completed singleflight.Do call
type call struct {
wg sync.WaitGroup
// These fields are written once before the WaitGroup is done
// and are only read after the WaitGroup is done.
val interface{}
err error
// forgotten indicates whether Forget was called with this call's key
// while the call was still in flight.
forgotten bool
// These fields are read and written with the singleflight
// mutex held before the WaitGroup is done, and are read but
// not written after the WaitGroup is done.
dups int
chans []chan<- Result
}
// Group represents a class of work and forms a namespace in
// which units of work can be executed with duplicate suppression.
type Group struct {
mu sync.Mutex // protects m
m map[string]*call // lazily initialized
}
// Result holds the results of Do, so they can be passed
// on a channel.
type Result struct {
Val interface{}
Err error
Shared bool
}
// Do executes and returns the results of the given function, making
// sure that only one execution is in-flight for a given key at a
// time. If a duplicate comes in, the duplicate caller waits for the
// original to complete and receives the same results.
// The return value shared indicates whether v was given to multiple callers.
func (g *Group) Do(key string, fn func() (interface{}, error)) (v interface{}, err error, shared bool) {
g.mu.Lock()
if g.m == nil {
g.m = make(map[string]*call)
}
if c, ok := g.m[key]; ok {
c.dups++
g.mu.Unlock()
c.wg.Wait()
return c.val, c.err, true
}
c := new(call)
c.wg.Add(1)
g.m[key] = c
g.mu.Unlock()
g.doCall(c, key, fn)
return c.val, c.err, c.dups > 0
}
// DoChan is like Do but returns a channel that will receive the
// results when they are ready.
func (g *Group) DoChan(key string, fn func() (interface{}, error)) <-chan Result {
ch := make(chan Result, 1)
g.mu.Lock()
if g.m == nil {
g.m = make(map[string]*call)
}
if c, ok := g.m[key]; ok {
c.dups++
c.chans = append(c.chans, ch)
g.mu.Unlock()
return ch
}
c := &call{chans: []chan<- Result{ch}}
c.wg.Add(1)
g.m[key] = c
g.mu.Unlock()
go g.doCall(c, key, fn)
return ch
}
// doCall handles the single call for a key.
func (g *Group) doCall(c *call, key string, fn func() (interface{}, error)) {
c.val, c.err = fn()
c.wg.Done()
g.mu.Lock()
if !c.forgotten {
delete(g.m, key)
}
for _, ch := range c.chans {
ch <- Result{c.val, c.err, c.dups > 0}
}
g.mu.Unlock()
}
// Forget tells the singleflight to forget about a key. Future calls
// to Do for this key will call the function rather than waiting for
// an earlier call to complete.
func (g *Group) Forget(key string) {
g.mu.Lock()
if c, ok := g.m[key]; ok {
c.forgotten = true
}
delete(g.m, key)
g.mu.Unlock()
}

View File

@ -0,0 +1,53 @@
package checksum
import (
"crypto/md5"
"encoding/base64"
"fmt"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
)
const contentMD5Header = "Content-Md5"
// AddBodyContentMD5Handler computes and sets the HTTP Content-MD5 header for requests that
// require it.
func AddBodyContentMD5Handler(r *request.Request) {
// if Content-MD5 header is already present, return
if v := r.HTTPRequest.Header.Get(contentMD5Header); len(v) != 0 {
return
}
// if S3DisableContentMD5Validation flag is set, return
if aws.BoolValue(r.Config.S3DisableContentMD5Validation) {
return
}
// if request is presigned, return
if r.IsPresigned() {
return
}
// if body is not seekable, return
if !aws.IsReaderSeekable(r.Body) {
if r.Config.Logger != nil {
r.Config.Logger.Log(fmt.Sprintf(
"Unable to compute Content-MD5 for unseekable body, S3.%s",
r.Operation.Name))
}
return
}
h := md5.New()
if _, err := aws.CopySeekableBody(h, r.Body); err != nil {
r.Error = awserr.New("ContentMD5", "failed to compute body MD5", err)
return
}
// encode the md5 checksum in base64 and set the request header.
v := base64.StdEncoding.EncodeToString(h.Sum(nil))
r.HTTPRequest.Header.Set(contentMD5Header, v)
}

View File

@ -101,7 +101,7 @@ func (hs *decodedHeaders) UnmarshalJSON(b []byte) error {
}
headers.Set(h.Name, value)
}
(*hs) = decodedHeaders(headers)
*hs = decodedHeaders(headers)
return nil
}

View File

@ -21,10 +21,24 @@ type Decoder struct {
// NewDecoder initializes and returns a Decoder for decoding event
// stream messages from the reader provided.
func NewDecoder(r io.Reader) *Decoder {
return &Decoder{
func NewDecoder(r io.Reader, opts ...func(*Decoder)) *Decoder {
d := &Decoder{
r: r,
}
for _, opt := range opts {
opt(d)
}
return d
}
// DecodeWithLogger adds a logger to be used by the decoder when decoding
// stream events.
func DecodeWithLogger(logger aws.Logger) func(*Decoder) {
return func(d *Decoder) {
d.logger = logger
}
}
// Decode attempts to decode a single message from the event stream reader.
@ -40,6 +54,15 @@ func (d *Decoder) Decode(payloadBuf []byte) (m Message, err error) {
}()
}
m, err = Decode(reader, payloadBuf)
return m, err
}
// Decode attempts to decode a single message from the event stream reader.
// Will return the event stream message, or error if Decode fails to read
// the message from the reader.
func Decode(reader io.Reader, payloadBuf []byte) (m Message, err error) {
crc := crc32.New(crc32IEEETable)
hashReader := io.TeeReader(reader, crc)
@ -72,12 +95,6 @@ func (d *Decoder) Decode(payloadBuf []byte) (m Message, err error) {
return m, nil
}
// UseLogger specifies the Logger that that the decoder should use to log the
// message decode to.
func (d *Decoder) UseLogger(logger aws.Logger) {
d.logger = logger
}
func logMessageDecode(logger aws.Logger, msgBuf *bytes.Buffer, msg Message, decodeErr error) {
w := bytes.NewBuffer(nil)
defer func() { logger.Log(w.String()) }()

View File

@ -3,61 +3,107 @@ package eventstream
import (
"bytes"
"encoding/binary"
"encoding/hex"
"encoding/json"
"fmt"
"hash"
"hash/crc32"
"io"
"github.com/aws/aws-sdk-go/aws"
)
// Encoder provides EventStream message encoding.
type Encoder struct {
w io.Writer
logger aws.Logger
headersBuf *bytes.Buffer
}
// NewEncoder initializes and returns an Encoder to encode Event Stream
// messages to an io.Writer.
func NewEncoder(w io.Writer) *Encoder {
return &Encoder{
func NewEncoder(w io.Writer, opts ...func(*Encoder)) *Encoder {
e := &Encoder{
w: w,
headersBuf: bytes.NewBuffer(nil),
}
for _, opt := range opts {
opt(e)
}
return e
}
// EncodeWithLogger adds a logger to be used by the encode when decoding
// stream events.
func EncodeWithLogger(logger aws.Logger) func(*Encoder) {
return func(d *Encoder) {
d.logger = logger
}
}
// Encode encodes a single EventStream message to the io.Writer the Encoder
// was created with. An error is returned if writing the message fails.
func (e *Encoder) Encode(msg Message) error {
func (e *Encoder) Encode(msg Message) (err error) {
e.headersBuf.Reset()
err := encodeHeaders(e.headersBuf, msg.Headers)
if err != nil {
writer := e.w
if e.logger != nil {
encodeMsgBuf := bytes.NewBuffer(nil)
writer = io.MultiWriter(writer, encodeMsgBuf)
defer func() {
logMessageEncode(e.logger, encodeMsgBuf, msg, err)
}()
}
if err = EncodeHeaders(e.headersBuf, msg.Headers); err != nil {
return err
}
crc := crc32.New(crc32IEEETable)
hashWriter := io.MultiWriter(e.w, crc)
hashWriter := io.MultiWriter(writer, crc)
headersLen := uint32(e.headersBuf.Len())
payloadLen := uint32(len(msg.Payload))
if err := encodePrelude(hashWriter, crc, headersLen, payloadLen); err != nil {
if err = encodePrelude(hashWriter, crc, headersLen, payloadLen); err != nil {
return err
}
if headersLen > 0 {
if _, err := io.Copy(hashWriter, e.headersBuf); err != nil {
if _, err = io.Copy(hashWriter, e.headersBuf); err != nil {
return err
}
}
if payloadLen > 0 {
if _, err := hashWriter.Write(msg.Payload); err != nil {
if _, err = hashWriter.Write(msg.Payload); err != nil {
return err
}
}
msgCRC := crc.Sum32()
return binary.Write(e.w, binary.BigEndian, msgCRC)
return binary.Write(writer, binary.BigEndian, msgCRC)
}
func logMessageEncode(logger aws.Logger, msgBuf *bytes.Buffer, msg Message, encodeErr error) {
w := bytes.NewBuffer(nil)
defer func() { logger.Log(w.String()) }()
fmt.Fprintf(w, "Message to encode:\n")
encoder := json.NewEncoder(w)
if err := encoder.Encode(msg); err != nil {
fmt.Fprintf(w, "Failed to get encoded message, %v\n", err)
}
if encodeErr != nil {
fmt.Fprintf(w, "Encode error: %v\n", encodeErr)
return
}
fmt.Fprintf(w, "Raw message:\n%s\n", hex.Dump(msgBuf.Bytes()))
}
func encodePrelude(w io.Writer, crc hash.Hash32, headersLen, payloadLen uint32) error {
@ -86,7 +132,9 @@ func encodePrelude(w io.Writer, crc hash.Hash32, headersLen, payloadLen uint32)
return nil
}
func encodeHeaders(w io.Writer, headers Headers) error {
// EncodeHeaders writes the header values to the writer encoded in the event
// stream format. Returns an error if a header fails to encode.
func EncodeHeaders(w io.Writer, headers Headers) error {
for _, h := range headers {
hn := headerName{
Len: uint8(len(h.Name)),

View File

@ -1,6 +1,9 @@
package eventstreamapi
import "fmt"
import (
"fmt"
"sync"
)
type messageError struct {
code string
@ -22,3 +25,53 @@ func (e messageError) Error() string {
func (e messageError) OrigErr() error {
return nil
}
// OnceError wraps the behavior of recording an error
// once and signal on a channel when this has occurred.
// Signaling is done by closing of the channel.
//
// Type is safe for concurrent usage.
type OnceError struct {
mu sync.RWMutex
err error
ch chan struct{}
}
// NewOnceError return a new OnceError
func NewOnceError() *OnceError {
return &OnceError{
ch: make(chan struct{}, 1),
}
}
// Err acquires a read-lock and returns an
// error if one has been set.
func (e *OnceError) Err() error {
e.mu.RLock()
err := e.err
e.mu.RUnlock()
return err
}
// SetError acquires a write-lock and will set
// the underlying error value if one has not been set.
func (e *OnceError) SetError(err error) {
if err == nil {
return
}
e.mu.Lock()
if e.err == nil {
e.err = err
close(e.ch)
}
e.mu.Unlock()
}
// ErrorSet returns a channel that will be used to signal
// that an error has been set. This channel will be closed
// when the error value has been set for OnceError.
func (e *OnceError) ErrorSet() <-chan struct{} {
return e.ch
}

View File

@ -2,9 +2,7 @@ package eventstreamapi
import (
"fmt"
"io"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/private/protocol"
"github.com/aws/aws-sdk-go/private/protocol/eventstream"
)
@ -15,27 +13,8 @@ type Unmarshaler interface {
UnmarshalEvent(protocol.PayloadUnmarshaler, eventstream.Message) error
}
// EventStream headers with specific meaning to async API functionality.
const (
MessageTypeHeader = `:message-type` // Identifies type of message.
EventMessageType = `event`
ErrorMessageType = `error`
ExceptionMessageType = `exception`
// Message Events
EventTypeHeader = `:event-type` // Identifies message event type e.g. "Stats".
// Message Error
ErrorCodeHeader = `:error-code`
ErrorMessageHeader = `:error-message`
// Message Exception
ExceptionTypeHeader = `:exception-type`
)
// EventReader provides reading from the EventStream of an reader.
type EventReader struct {
reader io.ReadCloser
decoder *eventstream.Decoder
unmarshalerForEventType func(string) (Unmarshaler, error)
@ -47,27 +26,18 @@ type EventReader struct {
// NewEventReader returns a EventReader built from the reader and unmarshaler
// provided. Use ReadStream method to start reading from the EventStream.
func NewEventReader(
reader io.ReadCloser,
decoder *eventstream.Decoder,
payloadUnmarshaler protocol.PayloadUnmarshaler,
unmarshalerForEventType func(string) (Unmarshaler, error),
) *EventReader {
return &EventReader{
reader: reader,
decoder: eventstream.NewDecoder(reader),
decoder: decoder,
payloadUnmarshaler: payloadUnmarshaler,
unmarshalerForEventType: unmarshalerForEventType,
payloadBuf: make([]byte, 10*1024),
}
}
// UseLogger instructs the EventReader to use the logger and log level
// specified.
func (r *EventReader) UseLogger(logger aws.Logger, logLevel aws.LogLevelType) {
if logger != nil && logLevel.Matches(aws.LogDebugWithEventStreamBody) {
r.decoder.UseLogger(logger)
}
}
// ReadEvent attempts to read a message from the EventStream and return the
// unmarshaled event value that the message is for.
//
@ -95,8 +65,7 @@ func (r *EventReader) ReadEvent() (event interface{}, err error) {
case EventMessageType:
return r.unmarshalEventMessage(msg)
case ExceptionMessageType:
err = r.unmarshalEventException(msg)
return nil, err
return nil, r.unmarshalEventException(msg)
case ErrorMessageType:
return nil, r.unmarshalErrorMessage(msg)
default:
@ -174,11 +143,6 @@ func (r *EventReader) unmarshalErrorMessage(msg eventstream.Message) (err error)
return msgErr
}
// Close closes the EventReader's EventStream reader.
func (r *EventReader) Close() error {
return r.reader.Close()
}
// GetHeaderString returns the value of the header as a string. If the header
// is not set or the value is not a string an error will be returned.
func GetHeaderString(msg eventstream.Message, headerName string) (string, error) {

View File

@ -0,0 +1,23 @@
package eventstreamapi
// EventStream headers with specific meaning to async API functionality.
const (
ChunkSignatureHeader = `:chunk-signature` // chunk signature for message
DateHeader = `:date` // Date header for signature
// Message header and values
MessageTypeHeader = `:message-type` // Identifies type of message.
EventMessageType = `event`
ErrorMessageType = `error`
ExceptionMessageType = `exception`
// Message Events
EventTypeHeader = `:event-type` // Identifies message event type e.g. "Stats".
// Message Error
ErrorCodeHeader = `:error-code`
ErrorMessageHeader = `:error-message`
// Message Exception
ExceptionTypeHeader = `:exception-type`
)

View File

@ -0,0 +1,123 @@
package eventstreamapi
import (
"bytes"
"strings"
"time"
"github.com/aws/aws-sdk-go/private/protocol/eventstream"
)
var timeNow = time.Now
// StreamSigner defines an interface for the implementation of signing of event stream payloads
type StreamSigner interface {
GetSignature(headers, payload []byte, date time.Time) ([]byte, error)
}
// SignEncoder envelopes event stream messages
// into an event stream message payload with included
// signature headers using the provided signer and encoder.
type SignEncoder struct {
signer StreamSigner
encoder Encoder
bufEncoder *BufferEncoder
closeErr error
closed bool
}
// NewSignEncoder returns a new SignEncoder using the provided stream signer and
// event stream encoder.
func NewSignEncoder(signer StreamSigner, encoder Encoder) *SignEncoder {
// TODO: Need to pass down logging
return &SignEncoder{
signer: signer,
encoder: encoder,
bufEncoder: NewBufferEncoder(),
}
}
// Close encodes a final event stream signing envelope with an empty event stream
// payload. This final end-frame is used to mark the conclusion of the stream.
func (s *SignEncoder) Close() error {
if s.closed {
return s.closeErr
}
if err := s.encode([]byte{}); err != nil {
if strings.Contains(err.Error(), "on closed pipe") {
return nil
}
s.closeErr = err
s.closed = true
return s.closeErr
}
return nil
}
// Encode takes the provided message and add envelopes the message
// with the required signature.
func (s *SignEncoder) Encode(msg eventstream.Message) error {
payload, err := s.bufEncoder.Encode(msg)
if err != nil {
return err
}
return s.encode(payload)
}
func (s SignEncoder) encode(payload []byte) error {
date := timeNow()
var msg eventstream.Message
msg.Headers.Set(DateHeader, eventstream.TimestampValue(date))
msg.Payload = payload
var headers bytes.Buffer
if err := eventstream.EncodeHeaders(&headers, msg.Headers); err != nil {
return err
}
sig, err := s.signer.GetSignature(headers.Bytes(), msg.Payload, date)
if err != nil {
return err
}
msg.Headers.Set(ChunkSignatureHeader, eventstream.BytesValue(sig))
return s.encoder.Encode(msg)
}
// BufferEncoder is a utility that provides a buffered
// event stream encoder
type BufferEncoder struct {
encoder Encoder
buffer *bytes.Buffer
}
// NewBufferEncoder returns a new BufferEncoder initialized
// with a 1024 byte buffer.
func NewBufferEncoder() *BufferEncoder {
buf := bytes.NewBuffer(make([]byte, 1024))
return &BufferEncoder{
encoder: eventstream.NewEncoder(buf),
buffer: buf,
}
}
// Encode returns the encoded message as a byte slice.
// The returned byte slice will be modified on the next encode call
// and should not be held onto.
func (e *BufferEncoder) Encode(msg eventstream.Message) ([]byte, error) {
e.buffer.Reset()
if err := e.encoder.Encode(msg); err != nil {
return nil, err
}
return e.buffer.Bytes(), nil
}

View File

@ -0,0 +1,129 @@
package eventstreamapi
import (
"fmt"
"io"
"sync"
"github.com/aws/aws-sdk-go/aws"
)
// StreamWriter provides concurrent safe writing to an event stream.
type StreamWriter struct {
eventWriter *EventWriter
stream chan eventWriteAsyncReport
done chan struct{}
closeOnce sync.Once
err *OnceError
streamCloser io.Closer
}
// NewStreamWriter returns a StreamWriter for the event writer, and stream
// closer provided.
func NewStreamWriter(eventWriter *EventWriter, streamCloser io.Closer) *StreamWriter {
w := &StreamWriter{
eventWriter: eventWriter,
streamCloser: streamCloser,
stream: make(chan eventWriteAsyncReport),
done: make(chan struct{}),
err: NewOnceError(),
}
go w.writeStream()
return w
}
// Close terminates the writers ability to write new events to the stream. Any
// future call to Send will fail with an error.
func (w *StreamWriter) Close() error {
w.closeOnce.Do(w.safeClose)
return w.Err()
}
func (w *StreamWriter) safeClose() {
close(w.done)
}
// ErrorSet returns a channel which will be closed
// if an error occurs.
func (w *StreamWriter) ErrorSet() <-chan struct{} {
return w.err.ErrorSet()
}
// Err returns any error that occurred while attempting to write an event to the
// stream.
func (w *StreamWriter) Err() error {
return w.err.Err()
}
// Send writes a single event to the stream returning an error if the write
// failed.
//
// Send may be called concurrently. Events will be written to the stream
// safely.
func (w *StreamWriter) Send(ctx aws.Context, event Marshaler) error {
if err := w.Err(); err != nil {
return err
}
resultCh := make(chan error)
wrapped := eventWriteAsyncReport{
Event: event,
Result: resultCh,
}
select {
case w.stream <- wrapped:
case <-ctx.Done():
return ctx.Err()
case <-w.done:
return fmt.Errorf("stream closed, unable to send event")
}
select {
case err := <-resultCh:
return err
case <-ctx.Done():
return ctx.Err()
case <-w.done:
return fmt.Errorf("stream closed, unable to send event")
}
}
func (w *StreamWriter) writeStream() {
defer w.Close()
for {
select {
case wrapper := <-w.stream:
err := w.eventWriter.WriteEvent(wrapper.Event)
wrapper.ReportResult(w.done, err)
if err != nil {
w.err.SetError(err)
return
}
case <-w.done:
if err := w.streamCloser.Close(); err != nil {
w.err.SetError(err)
}
return
}
}
}
type eventWriteAsyncReport struct {
Event Marshaler
Result chan<- error
}
func (e eventWriteAsyncReport) ReportResult(cancel <-chan struct{}, err error) bool {
select {
case e.Result <- err:
return true
case <-cancel:
return false
}
}

View File

@ -0,0 +1,109 @@
package eventstreamapi
import (
"github.com/aws/aws-sdk-go/private/protocol"
"github.com/aws/aws-sdk-go/private/protocol/eventstream"
)
// Marshaler provides a marshaling interface for event types to event stream
// messages.
type Marshaler interface {
MarshalEvent(protocol.PayloadMarshaler) (eventstream.Message, error)
}
// Encoder is an stream encoder that will encode an event stream message for
// the transport.
type Encoder interface {
Encode(eventstream.Message) error
}
// EventWriter provides a wrapper around the underlying event stream encoder
// for an io.WriteCloser.
type EventWriter struct {
encoder Encoder
payloadMarshaler protocol.PayloadMarshaler
eventTypeFor func(Marshaler) (string, error)
}
// NewEventWriter returns a new event stream writer, that will write to the
// writer provided. Use the WriteEvent method to write an event to the stream.
func NewEventWriter(encoder Encoder, pm protocol.PayloadMarshaler, eventTypeFor func(Marshaler) (string, error),
) *EventWriter {
return &EventWriter{
encoder: encoder,
payloadMarshaler: pm,
eventTypeFor: eventTypeFor,
}
}
// WriteEvent writes an event to the stream. Returns an error if the event
// fails to marshal into a message, or writing to the underlying writer fails.
func (w *EventWriter) WriteEvent(event Marshaler) error {
msg, err := w.marshal(event)
if err != nil {
return err
}
return w.encoder.Encode(msg)
}
func (w *EventWriter) marshal(event Marshaler) (eventstream.Message, error) {
eventType, err := w.eventTypeFor(event)
if err != nil {
return eventstream.Message{}, err
}
msg, err := event.MarshalEvent(w.payloadMarshaler)
if err != nil {
return eventstream.Message{}, err
}
msg.Headers.Set(EventTypeHeader, eventstream.StringValue(eventType))
return msg, nil
}
//type EventEncoder struct {
// encoder Encoder
// ppayloadMarshaler protocol.PayloadMarshaler
// eventTypeFor func(Marshaler) (string, error)
//}
//
//func (e EventEncoder) Encode(event Marshaler) error {
// msg, err := e.marshal(event)
// if err != nil {
// return err
// }
//
// return w.encoder.Encode(msg)
//}
//
//func (e EventEncoder) marshal(event Marshaler) (eventstream.Message, error) {
// eventType, err := w.eventTypeFor(event)
// if err != nil {
// return eventstream.Message{}, err
// }
//
// msg, err := event.MarshalEvent(w.payloadMarshaler)
// if err != nil {
// return eventstream.Message{}, err
// }
//
// msg.Headers.Set(EventTypeHeader, eventstream.StringValue(eventType))
// return msg, nil
//}
//
//func (w *EventWriter) marshal(event Marshaler) (eventstream.Message, error) {
// eventType, err := w.eventTypeFor(event)
// if err != nil {
// return eventstream.Message{}, err
// }
//
// msg, err := event.MarshalEvent(w.payloadMarshaler)
// if err != nil {
// return eventstream.Message{}, err
// }
//
// msg.Headers.Set(EventTypeHeader, eventstream.StringValue(eventType))
// return msg, nil
//}
//

View File

@ -461,6 +461,11 @@ func (v *TimestampValue) decode(r io.Reader) error {
return nil
}
// MarshalJSON implements the json.Marshaler interface
func (v TimestampValue) MarshalJSON() ([]byte, error) {
return []byte(v.String()), nil
}
func timeFromEpochMilli(t int64) time.Time {
secs := t / 1e3
msec := t % 1e3

View File

@ -27,7 +27,7 @@ func (m *Message) rawMessage() (rawMessage, error) {
if len(m.Headers) > 0 {
var headers bytes.Buffer
if err := encodeHeaders(&headers, m.Headers); err != nil {
if err := EncodeHeaders(&headers, m.Headers); err != nil {
return rawMessage{}, err
}
raw.Headers = headers.Bytes()

View File

@ -7,6 +7,7 @@ import (
"fmt"
"io"
"reflect"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
@ -45,10 +46,31 @@ func UnmarshalJSON(v interface{}, stream io.Reader) error {
return err
}
return unmarshalAny(reflect.ValueOf(v), out, "")
return unmarshaler{}.unmarshalAny(reflect.ValueOf(v), out, "")
}
func unmarshalAny(value reflect.Value, data interface{}, tag reflect.StructTag) error {
// UnmarshalJSONCaseInsensitive reads a stream and unmarshals the result into the
// object v. Ignores casing for structure members.
func UnmarshalJSONCaseInsensitive(v interface{}, stream io.Reader) error {
var out interface{}
err := json.NewDecoder(stream).Decode(&out)
if err == io.EOF {
return nil
} else if err != nil {
return err
}
return unmarshaler{
caseInsensitive: true,
}.unmarshalAny(reflect.ValueOf(v), out, "")
}
type unmarshaler struct {
caseInsensitive bool
}
func (u unmarshaler) unmarshalAny(value reflect.Value, data interface{}, tag reflect.StructTag) error {
vtype := value.Type()
if vtype.Kind() == reflect.Ptr {
vtype = vtype.Elem() // check kind of actual element type
@ -80,17 +102,17 @@ func unmarshalAny(value reflect.Value, data interface{}, tag reflect.StructTag)
if field, ok := vtype.FieldByName("_"); ok {
tag = field.Tag
}
return unmarshalStruct(value, data, tag)
return u.unmarshalStruct(value, data, tag)
case "list":
return unmarshalList(value, data, tag)
return u.unmarshalList(value, data, tag)
case "map":
return unmarshalMap(value, data, tag)
return u.unmarshalMap(value, data, tag)
default:
return unmarshalScalar(value, data, tag)
return u.unmarshalScalar(value, data, tag)
}
}
func unmarshalStruct(value reflect.Value, data interface{}, tag reflect.StructTag) error {
func (u unmarshaler) unmarshalStruct(value reflect.Value, data interface{}, tag reflect.StructTag) error {
if data == nil {
return nil
}
@ -114,7 +136,7 @@ func unmarshalStruct(value reflect.Value, data interface{}, tag reflect.StructTa
// unwrap any payloads
if payload := tag.Get("payload"); payload != "" {
field, _ := t.FieldByName(payload)
return unmarshalAny(value.FieldByName(payload), data, field.Tag)
return u.unmarshalAny(value.FieldByName(payload), data, field.Tag)
}
for i := 0; i < t.NumField(); i++ {
@ -128,9 +150,19 @@ func unmarshalStruct(value reflect.Value, data interface{}, tag reflect.StructTa
if locName := field.Tag.Get("locationName"); locName != "" {
name = locName
}
if u.caseInsensitive {
if _, ok := mapData[name]; !ok {
// Fallback to uncased name search if the exact name didn't match.
for kn, v := range mapData {
if strings.EqualFold(kn, name) {
mapData[name] = v
}
}
}
}
member := value.FieldByIndex(field.Index)
err := unmarshalAny(member, mapData[name], field.Tag)
err := u.unmarshalAny(member, mapData[name], field.Tag)
if err != nil {
return err
}
@ -138,7 +170,7 @@ func unmarshalStruct(value reflect.Value, data interface{}, tag reflect.StructTa
return nil
}
func unmarshalList(value reflect.Value, data interface{}, tag reflect.StructTag) error {
func (u unmarshaler) unmarshalList(value reflect.Value, data interface{}, tag reflect.StructTag) error {
if data == nil {
return nil
}
@ -153,7 +185,7 @@ func unmarshalList(value reflect.Value, data interface{}, tag reflect.StructTag)
}
for i, c := range listData {
err := unmarshalAny(value.Index(i), c, "")
err := u.unmarshalAny(value.Index(i), c, "")
if err != nil {
return err
}
@ -162,7 +194,7 @@ func unmarshalList(value reflect.Value, data interface{}, tag reflect.StructTag)
return nil
}
func unmarshalMap(value reflect.Value, data interface{}, tag reflect.StructTag) error {
func (u unmarshaler) unmarshalMap(value reflect.Value, data interface{}, tag reflect.StructTag) error {
if data == nil {
return nil
}
@ -179,14 +211,14 @@ func unmarshalMap(value reflect.Value, data interface{}, tag reflect.StructTag)
kvalue := reflect.ValueOf(k)
vvalue := reflect.New(value.Type().Elem()).Elem()
unmarshalAny(vvalue, v, "")
u.unmarshalAny(vvalue, v, "")
value.SetMapIndex(kvalue, vvalue)
}
return nil
}
func unmarshalScalar(value reflect.Value, data interface{}, tag reflect.StructTag) error {
func (u unmarshaler) unmarshalScalar(value reflect.Value, data interface{}, tag reflect.StructTag) error {
switch d := data.(type) {
case nil:

View File

@ -2,12 +2,10 @@
// requests and responses.
package jsonrpc
//go:generate go run -tags codegen ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/input/json.json build_test.go
//go:generate go run -tags codegen ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/output/json.json unmarshal_test.go
//go:generate go run -tags codegen ../../../private/model/cli/gen-protocol-tests ../../../models/protocol_tests/input/json.json build_test.go
//go:generate go run -tags codegen ../../../private/model/cli/gen-protocol-tests ../../../models/protocol_tests/output/json.json unmarshal_test.go
import (
"strings"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/private/protocol/json/jsonutil"
@ -16,17 +14,26 @@ import (
var emptyJSON = []byte("{}")
// BuildHandler is a named request handler for building jsonrpc protocol requests
var BuildHandler = request.NamedHandler{Name: "awssdk.jsonrpc.Build", Fn: Build}
// BuildHandler is a named request handler for building jsonrpc protocol
// requests
var BuildHandler = request.NamedHandler{
Name: "awssdk.jsonrpc.Build",
Fn: Build,
}
// UnmarshalHandler is a named request handler for unmarshaling jsonrpc protocol requests
var UnmarshalHandler = request.NamedHandler{Name: "awssdk.jsonrpc.Unmarshal", Fn: Unmarshal}
// UnmarshalHandler is a named request handler for unmarshaling jsonrpc
// protocol requests
var UnmarshalHandler = request.NamedHandler{
Name: "awssdk.jsonrpc.Unmarshal",
Fn: Unmarshal,
}
// UnmarshalMetaHandler is a named request handler for unmarshaling jsonrpc protocol request metadata
var UnmarshalMetaHandler = request.NamedHandler{Name: "awssdk.jsonrpc.UnmarshalMeta", Fn: UnmarshalMeta}
// UnmarshalErrorHandler is a named request handler for unmarshaling jsonrpc protocol request errors
var UnmarshalErrorHandler = request.NamedHandler{Name: "awssdk.jsonrpc.UnmarshalError", Fn: UnmarshalError}
// UnmarshalMetaHandler is a named request handler for unmarshaling jsonrpc
// protocol request metadata
var UnmarshalMetaHandler = request.NamedHandler{
Name: "awssdk.jsonrpc.UnmarshalMeta",
Fn: UnmarshalMeta,
}
// Build builds a JSON payload for a JSON RPC request.
func Build(req *request.Request) {
@ -79,32 +86,3 @@ func Unmarshal(req *request.Request) {
func UnmarshalMeta(req *request.Request) {
rest.UnmarshalMeta(req)
}
// UnmarshalError unmarshals an error response for a JSON RPC service.
func UnmarshalError(req *request.Request) {
defer req.HTTPResponse.Body.Close()
var jsonErr jsonErrorResponse
err := jsonutil.UnmarshalJSONError(&jsonErr, req.HTTPResponse.Body)
if err != nil {
req.Error = awserr.NewRequestFailure(
awserr.New(request.ErrCodeSerialization,
"failed to unmarshal error message", err),
req.HTTPResponse.StatusCode,
req.RequestID,
)
return
}
codes := strings.SplitN(jsonErr.Code, "#", 2)
req.Error = awserr.NewRequestFailure(
awserr.New(codes[len(codes)-1], jsonErr.Message, nil),
req.HTTPResponse.StatusCode,
req.RequestID,
)
}
type jsonErrorResponse struct {
Code string `json:"__type"`
Message string `json:"message"`
}

View File

@ -0,0 +1,107 @@
package jsonrpc
import (
"bytes"
"io"
"io/ioutil"
"net/http"
"strings"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/private/protocol"
"github.com/aws/aws-sdk-go/private/protocol/json/jsonutil"
)
// UnmarshalTypedError provides unmarshaling errors API response errors
// for both typed and untyped errors.
type UnmarshalTypedError struct {
exceptions map[string]func(protocol.ResponseMetadata) error
}
// NewUnmarshalTypedError returns an UnmarshalTypedError initialized for the
// set of exception names to the error unmarshalers
func NewUnmarshalTypedError(exceptions map[string]func(protocol.ResponseMetadata) error) *UnmarshalTypedError {
return &UnmarshalTypedError{
exceptions: exceptions,
}
}
// UnmarshalError attempts to unmarshal the HTTP response error as a known
// error type. If unable to unmarshal the error type, the generic SDK error
// type will be used.
func (u *UnmarshalTypedError) UnmarshalError(
resp *http.Response,
respMeta protocol.ResponseMetadata,
) (error, error) {
var buf bytes.Buffer
var jsonErr jsonErrorResponse
teeReader := io.TeeReader(resp.Body, &buf)
err := jsonutil.UnmarshalJSONError(&jsonErr, teeReader)
if err != nil {
return nil, err
}
body := ioutil.NopCloser(&buf)
// Code may be separated by hash(#), with the last element being the code
// used by the SDK.
codeParts := strings.SplitN(jsonErr.Code, "#", 2)
code := codeParts[len(codeParts)-1]
msg := jsonErr.Message
if fn, ok := u.exceptions[code]; ok {
// If exception code is know, use associated constructor to get a value
// for the exception that the JSON body can be unmarshaled into.
v := fn(respMeta)
err := jsonutil.UnmarshalJSONCaseInsensitive(v, body)
if err != nil {
return nil, err
}
return v, nil
}
// fallback to unmodeled generic exceptions
return awserr.NewRequestFailure(
awserr.New(code, msg, nil),
respMeta.StatusCode,
respMeta.RequestID,
), nil
}
// UnmarshalErrorHandler is a named request handler for unmarshaling jsonrpc
// protocol request errors
var UnmarshalErrorHandler = request.NamedHandler{
Name: "awssdk.jsonrpc.UnmarshalError",
Fn: UnmarshalError,
}
// UnmarshalError unmarshals an error response for a JSON RPC service.
func UnmarshalError(req *request.Request) {
defer req.HTTPResponse.Body.Close()
var jsonErr jsonErrorResponse
err := jsonutil.UnmarshalJSONError(&jsonErr, req.HTTPResponse.Body)
if err != nil {
req.Error = awserr.NewRequestFailure(
awserr.New(request.ErrCodeSerialization,
"failed to unmarshal error message", err),
req.HTTPResponse.StatusCode,
req.RequestID,
)
return
}
codes := strings.SplitN(jsonErr.Code, "#", 2)
req.Error = awserr.NewRequestFailure(
awserr.New(codes[len(codes)-1], jsonErr.Message, nil),
req.HTTPResponse.StatusCode,
req.RequestID,
)
}
type jsonErrorResponse struct {
Code string `json:"__type"`
Message string `json:"message"`
}

View File

@ -64,7 +64,7 @@ func (h HandlerPayloadMarshal) MarshalPayload(w io.Writer, v interface{}) error
metadata.ClientInfo{},
request.Handlers{},
nil,
&request.Operation{HTTPMethod: "GET"},
&request.Operation{HTTPMethod: "PUT"},
v,
nil,
)

View File

@ -0,0 +1,49 @@
package protocol
import (
"fmt"
"strings"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
)
// RequireHTTPMinProtocol request handler is used to enforce that
// the target endpoint supports the given major and minor HTTP protocol version.
type RequireHTTPMinProtocol struct {
Major, Minor int
}
// Handler will mark the request.Request with an error if the
// target endpoint did not connect with the required HTTP protocol
// major and minor version.
func (p RequireHTTPMinProtocol) Handler(r *request.Request) {
if r.Error != nil || r.HTTPResponse == nil {
return
}
if !strings.HasPrefix(r.HTTPResponse.Proto, "HTTP") {
r.Error = newMinHTTPProtoError(p.Major, p.Minor, r)
}
if r.HTTPResponse.ProtoMajor < p.Major || r.HTTPResponse.ProtoMinor < p.Minor {
r.Error = newMinHTTPProtoError(p.Major, p.Minor, r)
}
}
// ErrCodeMinimumHTTPProtocolError error code is returned when the target endpoint
// did not match the required HTTP major and minor protocol version.
const ErrCodeMinimumHTTPProtocolError = "MinimumHTTPProtocolError"
func newMinHTTPProtoError(major, minor int, r *request.Request) error {
return awserr.NewRequestFailure(
awserr.New("MinimumHTTPProtocolError",
fmt.Sprintf(
"operation requires minimum HTTP protocol of HTTP/%d.%d, but was %s",
major, minor, r.HTTPResponse.Proto,
),
nil,
),
r.HTTPResponse.StatusCode, r.RequestID,
)
}

View File

@ -1,7 +1,7 @@
// Package query provides serialization of AWS query requests, and responses.
package query
//go:generate go run -tags codegen ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/input/query.json build_test.go
//go:generate go run -tags codegen ../../../private/model/cli/gen-protocol-tests ../../../models/protocol_tests/input/query.json build_test.go
import (
"net/url"

View File

@ -1,6 +1,6 @@
package query
//go:generate go run -tags codegen ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/output/query.json unmarshal_test.go
//go:generate go run -tags codegen ../../../private/model/cli/gen-protocol-tests ../../../models/protocol_tests/output/query.json unmarshal_test.go
import (
"encoding/xml"

View File

@ -15,6 +15,7 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
awsStrings "github.com/aws/aws-sdk-go/internal/strings"
"github.com/aws/aws-sdk-go/private/protocol"
)
@ -28,7 +29,9 @@ var UnmarshalMetaHandler = request.NamedHandler{Name: "awssdk.rest.UnmarshalMeta
func Unmarshal(r *request.Request) {
if r.DataFilled() {
v := reflect.Indirect(reflect.ValueOf(r.Data))
unmarshalBody(r, v)
if err := unmarshalBody(r, v); err != nil {
r.Error = err
}
}
}
@ -40,12 +43,21 @@ func UnmarshalMeta(r *request.Request) {
r.RequestID = r.HTTPResponse.Header.Get("X-Amz-Request-Id")
}
if r.DataFilled() {
v := reflect.Indirect(reflect.ValueOf(r.Data))
unmarshalLocationElements(r, v)
if err := UnmarshalResponse(r.HTTPResponse, r.Data, aws.BoolValue(r.Config.LowerCaseHeaderMaps)); err != nil {
r.Error = err
}
}
}
func unmarshalBody(r *request.Request, v reflect.Value) {
// UnmarshalResponse attempts to unmarshal the REST response headers to
// the data type passed in. The type must be a pointer. An error is returned
// with any error unmarshaling the response into the target datatype.
func UnmarshalResponse(resp *http.Response, data interface{}, lowerCaseHeaderMaps bool) error {
v := reflect.Indirect(reflect.ValueOf(data))
return unmarshalLocationElements(resp, v, lowerCaseHeaderMaps)
}
func unmarshalBody(r *request.Request, v reflect.Value) error {
if field, ok := v.Type().FieldByName("_"); ok {
if payloadName := field.Tag.Get("payload"); payloadName != "" {
pfield, _ := v.Type().FieldByName(payloadName)
@ -57,35 +69,38 @@ func unmarshalBody(r *request.Request, v reflect.Value) {
defer r.HTTPResponse.Body.Close()
b, err := ioutil.ReadAll(r.HTTPResponse.Body)
if err != nil {
r.Error = awserr.New(request.ErrCodeSerialization, "failed to decode REST response", err)
} else {
payload.Set(reflect.ValueOf(b))
return awserr.New(request.ErrCodeSerialization, "failed to decode REST response", err)
}
payload.Set(reflect.ValueOf(b))
case *string:
defer r.HTTPResponse.Body.Close()
b, err := ioutil.ReadAll(r.HTTPResponse.Body)
if err != nil {
r.Error = awserr.New(request.ErrCodeSerialization, "failed to decode REST response", err)
} else {
return awserr.New(request.ErrCodeSerialization, "failed to decode REST response", err)
}
str := string(b)
payload.Set(reflect.ValueOf(&str))
}
default:
switch payload.Type().String() {
case "io.ReadCloser":
payload.Set(reflect.ValueOf(r.HTTPResponse.Body))
case "io.ReadSeeker":
b, err := ioutil.ReadAll(r.HTTPResponse.Body)
if err != nil {
r.Error = awserr.New(request.ErrCodeSerialization,
return awserr.New(request.ErrCodeSerialization,
"failed to read response body", err)
return
}
payload.Set(reflect.ValueOf(ioutil.NopCloser(bytes.NewReader(b))))
default:
io.Copy(ioutil.Discard, r.HTTPResponse.Body)
defer r.HTTPResponse.Body.Close()
r.Error = awserr.New(request.ErrCodeSerialization,
r.HTTPResponse.Body.Close()
return awserr.New(request.ErrCodeSerialization,
"failed to decode REST response",
fmt.Errorf("unknown payload type %s", payload.Type()))
}
@ -94,9 +109,11 @@ func unmarshalBody(r *request.Request, v reflect.Value) {
}
}
}
return nil
}
func unmarshalLocationElements(r *request.Request, v reflect.Value) {
func unmarshalLocationElements(resp *http.Response, v reflect.Value, lowerCaseHeaderMaps bool) error {
for i := 0; i < v.NumField(); i++ {
m, field := v.Field(i), v.Type().Field(i)
if n := field.Name; n[0:1] == strings.ToLower(n[0:1]) {
@ -111,26 +128,25 @@ func unmarshalLocationElements(r *request.Request, v reflect.Value) {
switch field.Tag.Get("location") {
case "statusCode":
unmarshalStatusCode(m, r.HTTPResponse.StatusCode)
unmarshalStatusCode(m, resp.StatusCode)
case "header":
err := unmarshalHeader(m, r.HTTPResponse.Header.Get(name), field.Tag)
err := unmarshalHeader(m, resp.Header.Get(name), field.Tag)
if err != nil {
r.Error = awserr.New(request.ErrCodeSerialization, "failed to decode REST response", err)
break
return awserr.New(request.ErrCodeSerialization, "failed to decode REST response", err)
}
case "headers":
prefix := field.Tag.Get("locationName")
err := unmarshalHeaderMap(m, r.HTTPResponse.Header, prefix)
err := unmarshalHeaderMap(m, resp.Header, prefix, lowerCaseHeaderMaps)
if err != nil {
r.Error = awserr.New(request.ErrCodeSerialization, "failed to decode REST response", err)
break
awserr.New(request.ErrCodeSerialization, "failed to decode REST response", err)
}
}
}
if r.Error != nil {
return
}
}
return nil
}
func unmarshalStatusCode(v reflect.Value, statusCode int) {
@ -145,30 +161,46 @@ func unmarshalStatusCode(v reflect.Value, statusCode int) {
}
}
func unmarshalHeaderMap(r reflect.Value, headers http.Header, prefix string) error {
func unmarshalHeaderMap(r reflect.Value, headers http.Header, prefix string, normalize bool) error {
if len(headers) == 0 {
return nil
}
switch r.Interface().(type) {
case map[string]*string: // we only support string map value types
out := map[string]*string{}
for k, v := range headers {
if awsStrings.HasPrefixFold(k, prefix) {
if normalize == true {
k = strings.ToLower(k)
} else {
k = http.CanonicalHeaderKey(k)
if strings.HasPrefix(strings.ToLower(k), strings.ToLower(prefix)) {
}
out[k[len(prefix):]] = &v[0]
}
}
if len(out) != 0 {
r.Set(reflect.ValueOf(out))
}
}
return nil
}
func unmarshalHeader(v reflect.Value, header string, tag reflect.StructTag) error {
isJSONValue := tag.Get("type") == "jsonvalue"
if isJSONValue {
switch tag.Get("type") {
case "jsonvalue":
if len(header) == 0 {
return nil
}
} else if !v.IsValid() || (header == "" && v.Elem().Kind() != reflect.String) {
case "blob":
if len(header) == 0 {
return nil
}
default:
if !v.IsValid() || (header == "" && v.Elem().Kind() != reflect.String) {
return nil
}
}
switch v.Interface().(type) {
case *string:
@ -178,7 +210,7 @@ func unmarshalHeader(v reflect.Value, header string, tag reflect.StructTag) erro
if err != nil {
return err
}
v.Set(reflect.ValueOf(&b))
v.Set(reflect.ValueOf(b))
case *bool:
b, err := strconv.ParseBool(header)
if err != nil {

View File

@ -2,8 +2,8 @@
// requests and responses.
package restxml
//go:generate go run -tags codegen ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/input/rest-xml.json build_test.go
//go:generate go run -tags codegen ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/output/rest-xml.json unmarshal_test.go
//go:generate go run -tags codegen ../../../private/model/cli/gen-protocol-tests ../../../models/protocol_tests/input/rest-xml.json build_test.go
//go:generate go run -tags codegen ../../../private/model/cli/gen-protocol-tests ../../../models/protocol_tests/output/rest-xml.json unmarshal_test.go
import (
"bytes"
@ -39,7 +39,7 @@ func Build(r *request.Request) {
r.Error = awserr.NewRequestFailure(
awserr.New(request.ErrCodeSerialization,
"failed to encode rest XML request", err),
r.HTTPResponse.StatusCode,
0,
r.RequestID,
)
return

View File

@ -1,8 +1,11 @@
package protocol
import (
"math"
"strconv"
"time"
"github.com/aws/aws-sdk-go/internal/sdkmath"
)
// Names of time formats supported by the SDK
@ -13,12 +16,19 @@ const (
)
// Time formats supported by the SDK
// Output time is intended to not contain decimals
const (
// RFC 7231#section-7.1.1.1 timetamp format. e.g Tue, 29 Apr 2014 18:30:38 GMT
RFC822TimeFormat = "Mon, 2 Jan 2006 15:04:05 GMT"
// This format is used for output time without seconds precision
RFC822OutputTimeFormat = "Mon, 02 Jan 2006 15:04:05 GMT"
// RFC3339 a subset of the ISO8601 timestamp format. e.g 2014-04-29T18:30:38Z
ISO8601TimeFormat = "2006-01-02T15:04:05Z"
ISO8601TimeFormat = "2006-01-02T15:04:05.999999999Z"
// This format is used for output time without seconds precision
ISO8601OutputTimeFormat = "2006-01-02T15:04:05Z"
)
// IsKnownTimestampFormat returns if the timestamp format name
@ -42,11 +52,12 @@ func FormatTime(name string, t time.Time) string {
switch name {
case RFC822TimeFormatName:
return t.Format(RFC822TimeFormat)
return t.Format(RFC822OutputTimeFormat)
case ISO8601TimeFormatName:
return t.Format(ISO8601TimeFormat)
return t.Format(ISO8601OutputTimeFormat)
case UnixTimeFormatName:
return strconv.FormatInt(t.Unix(), 10)
ms := t.UnixNano() / int64(time.Millisecond)
return strconv.FormatFloat(float64(ms)/1e3, 'f', -1, 64)
default:
panic("unknown timestamp format name, " + name)
}
@ -62,10 +73,12 @@ func ParseTime(formatName, value string) (time.Time, error) {
return time.Parse(ISO8601TimeFormat, value)
case UnixTimeFormatName:
v, err := strconv.ParseFloat(value, 64)
_, dec := math.Modf(v)
dec = sdkmath.Round(dec*1e3) / 1e3 //Rounds 0.1229999 to 0.123
if err != nil {
return time.Time{}, err
}
return time.Unix(int64(v), 0), nil
return time.Unix(int64(v), int64(dec*(1e9))), nil
default:
panic("unknown timestamp format name, " + formatName)
}

View File

@ -19,3 +19,9 @@ func UnmarshalDiscardBody(r *request.Request) {
io.Copy(ioutil.Discard, r.HTTPResponse.Body)
r.HTTPResponse.Body.Close()
}
// ResponseMetadata provides the SDK response metadata attributes.
type ResponseMetadata struct {
StatusCode int
RequestID string
}

View File

@ -0,0 +1,65 @@
package protocol
import (
"net/http"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
)
// UnmarshalErrorHandler provides unmarshaling errors API response errors for
// both typed and untyped errors.
type UnmarshalErrorHandler struct {
unmarshaler ErrorUnmarshaler
}
// ErrorUnmarshaler is an abstract interface for concrete implementations to
// unmarshal protocol specific response errors.
type ErrorUnmarshaler interface {
UnmarshalError(*http.Response, ResponseMetadata) (error, error)
}
// NewUnmarshalErrorHandler returns an UnmarshalErrorHandler
// initialized for the set of exception names to the error unmarshalers
func NewUnmarshalErrorHandler(unmarshaler ErrorUnmarshaler) *UnmarshalErrorHandler {
return &UnmarshalErrorHandler{
unmarshaler: unmarshaler,
}
}
// UnmarshalErrorHandlerName is the name of the named handler.
const UnmarshalErrorHandlerName = "awssdk.protocol.UnmarshalError"
// NamedHandler returns a NamedHandler for the unmarshaler using the set of
// errors the unmarshaler was initialized for.
func (u *UnmarshalErrorHandler) NamedHandler() request.NamedHandler {
return request.NamedHandler{
Name: UnmarshalErrorHandlerName,
Fn: u.UnmarshalError,
}
}
// UnmarshalError will attempt to unmarshal the API response's error message
// into either a generic SDK error type, or a typed error corresponding to the
// errors exception name.
func (u *UnmarshalErrorHandler) UnmarshalError(r *request.Request) {
defer r.HTTPResponse.Body.Close()
respMeta := ResponseMetadata{
StatusCode: r.HTTPResponse.StatusCode,
RequestID: r.RequestID,
}
v, err := u.unmarshaler.UnmarshalError(r.HTTPResponse, respMeta)
if err != nil {
r.Error = awserr.NewRequestFailure(
awserr.New(request.ErrCodeSerialization,
"failed to unmarshal response error", err),
respMeta.StatusCode,
respMeta.RequestID,
)
return
}
r.Error = v
}

View File

@ -8,6 +8,7 @@ import (
"reflect"
"sort"
"strconv"
"strings"
"time"
"github.com/aws/aws-sdk-go/private/protocol"
@ -60,6 +61,14 @@ func (b *xmlBuilder) buildValue(value reflect.Value, current *XMLNode, tag refle
return nil
}
xml := tag.Get("xml")
if len(xml) != 0 {
name := strings.SplitAfterN(xml, ",", 2)[0]
if name == "-" {
return nil
}
}
t := tag.Get("type")
if t == "" {
switch value.Kind() {

View File

@ -0,0 +1,32 @@
package xmlutil
import (
"encoding/xml"
"strings"
)
type xmlAttrSlice []xml.Attr
func (x xmlAttrSlice) Len() int {
return len(x)
}
func (x xmlAttrSlice) Less(i, j int) bool {
spaceI, spaceJ := x[i].Name.Space, x[j].Name.Space
localI, localJ := x[i].Name.Local, x[j].Name.Local
valueI, valueJ := x[i].Value, x[j].Value
spaceCmp := strings.Compare(spaceI, spaceJ)
localCmp := strings.Compare(localI, localJ)
valueCmp := strings.Compare(valueI, valueJ)
if spaceCmp == -1 || (spaceCmp == 0 && (localCmp == -1 || (localCmp == 0 && valueCmp == -1))) {
return true
}
return false
}
func (x xmlAttrSlice) Swap(i, j int) {
x[i], x[j] = x[j], x[i]
}

View File

@ -64,6 +64,14 @@ func UnmarshalXML(v interface{}, d *xml.Decoder, wrapper string) error {
// parse deserializes any value from the XMLNode. The type tag is used to infer the type, or reflect
// will be used to determine the type from r.
func parse(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
xml := tag.Get("xml")
if len(xml) != 0 {
name := strings.SplitAfterN(xml, ",", 2)[0]
if name == "-" {
return nil
}
}
rtype := r.Type()
if rtype.Kind() == reflect.Ptr {
rtype = rtype.Elem() // check kind of actual element type

View File

@ -119,7 +119,18 @@ func (n *XMLNode) findElem(name string) (string, bool) {
// StructToXML writes an XMLNode to a xml.Encoder as tokens.
func StructToXML(e *xml.Encoder, node *XMLNode, sorted bool) error {
e.EncodeToken(xml.StartElement{Name: node.Name, Attr: node.Attr})
// Sort Attributes
attrs := node.Attr
if sorted {
sortedAttrs := make([]xml.Attr, len(attrs))
for _, k := range node.Attr {
sortedAttrs = append(sortedAttrs, k)
}
sort.Sort(xmlAttrSlice(sortedAttrs))
attrs = sortedAttrs
}
e.EncodeToken(xml.StartElement{Name: node.Name, Attr: attrs})
if node.Text != "" {
e.EncodeToken(xml.CharData([]byte(node.Text)))

File diff suppressed because it is too large Load Diff

View File

@ -2,6 +2,10 @@
package ecr
import (
"github.com/aws/aws-sdk-go/private/protocol"
)
const (
// ErrCodeEmptyUploadException for service response error code
@ -23,6 +27,13 @@ const (
// The image requested does not exist in the specified repository.
ErrCodeImageNotFoundException = "ImageNotFoundException"
// ErrCodeImageTagAlreadyExistsException for service response error code
// "ImageTagAlreadyExistsException".
//
// The specified image is tagged with a tag that already exists. The repository
// is configured for tag immutability.
ErrCodeImageTagAlreadyExistsException = "ImageTagAlreadyExistsException"
// ErrCodeInvalidLayerException for service response error code
// "InvalidLayerException".
//
@ -101,11 +112,16 @@ const (
// "LimitExceededException".
//
// The operation did not succeed because it would have exceeded a service limit
// for your account. For more information, see Amazon ECR Default Service Limits
// (http://docs.aws.amazon.com/AmazonECR/latest/userguide/service_limits.html)
// for your account. For more information, see Amazon ECR Service Quotas (https://docs.aws.amazon.com/AmazonECR/latest/userguide/service-quotas.html)
// in the Amazon Elastic Container Registry User Guide.
ErrCodeLimitExceededException = "LimitExceededException"
// ErrCodeReferencedImagesNotFoundException for service response error code
// "ReferencedImagesNotFoundException".
//
// The manifest list is referencing an image that does not exist.
ErrCodeReferencedImagesNotFoundException = "ReferencedImagesNotFoundException"
// ErrCodeRepositoryAlreadyExistsException for service response error code
// "RepositoryAlreadyExistsException".
//
@ -133,6 +149,13 @@ const (
// repository policy.
ErrCodeRepositoryPolicyNotFoundException = "RepositoryPolicyNotFoundException"
// ErrCodeScanNotFoundException for service response error code
// "ScanNotFoundException".
//
// The specified image scan could not be found. Ensure that image scanning is
// enabled on the repository and try again.
ErrCodeScanNotFoundException = "ScanNotFoundException"
// ErrCodeServerException for service response error code
// "ServerException".
//
@ -146,6 +169,12 @@ const (
// of tags that can be applied to a repository is 50.
ErrCodeTooManyTagsException = "TooManyTagsException"
// ErrCodeUnsupportedImageTypeException for service response error code
// "UnsupportedImageTypeException".
//
// The image is of a type that cannot be scanned.
ErrCodeUnsupportedImageTypeException = "UnsupportedImageTypeException"
// ErrCodeUploadNotFoundException for service response error code
// "UploadNotFoundException".
//
@ -153,3 +182,32 @@ const (
// this repository.
ErrCodeUploadNotFoundException = "UploadNotFoundException"
)
var exceptionFromCode = map[string]func(protocol.ResponseMetadata) error{
"EmptyUploadException": newErrorEmptyUploadException,
"ImageAlreadyExistsException": newErrorImageAlreadyExistsException,
"ImageNotFoundException": newErrorImageNotFoundException,
"ImageTagAlreadyExistsException": newErrorImageTagAlreadyExistsException,
"InvalidLayerException": newErrorInvalidLayerException,
"InvalidLayerPartException": newErrorInvalidLayerPartException,
"InvalidParameterException": newErrorInvalidParameterException,
"InvalidTagParameterException": newErrorInvalidTagParameterException,
"LayerAlreadyExistsException": newErrorLayerAlreadyExistsException,
"LayerInaccessibleException": newErrorLayerInaccessibleException,
"LayerPartTooSmallException": newErrorLayerPartTooSmallException,
"LayersNotFoundException": newErrorLayersNotFoundException,
"LifecyclePolicyNotFoundException": newErrorLifecyclePolicyNotFoundException,
"LifecyclePolicyPreviewInProgressException": newErrorLifecyclePolicyPreviewInProgressException,
"LifecyclePolicyPreviewNotFoundException": newErrorLifecyclePolicyPreviewNotFoundException,
"LimitExceededException": newErrorLimitExceededException,
"ReferencedImagesNotFoundException": newErrorReferencedImagesNotFoundException,
"RepositoryAlreadyExistsException": newErrorRepositoryAlreadyExistsException,
"RepositoryNotEmptyException": newErrorRepositoryNotEmptyException,
"RepositoryNotFoundException": newErrorRepositoryNotFoundException,
"RepositoryPolicyNotFoundException": newErrorRepositoryPolicyNotFoundException,
"ScanNotFoundException": newErrorScanNotFoundException,
"ServerException": newErrorServerException,
"TooManyTagsException": newErrorTooManyTagsException,
"UnsupportedImageTypeException": newErrorUnsupportedImageTypeException,
"UploadNotFoundException": newErrorUploadNotFoundException,
}

View File

@ -8,6 +8,7 @@ import (
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/signer/v4"
"github.com/aws/aws-sdk-go/private/protocol"
"github.com/aws/aws-sdk-go/private/protocol/jsonrpc"
)
@ -31,7 +32,7 @@ var initRequest func(*request.Request)
const (
ServiceName = "ecr" // Name of service.
EndpointsID = "api.ecr" // ID to lookup a service endpoint with.
ServiceID = "ECR" // ServiceID is a unique identifer of a specific service.
ServiceID = "ECR" // ServiceID is a unique identifier of a specific service.
)
// New creates a new instance of the ECR client with a session.
@ -39,6 +40,8 @@ const (
// aws.Config parameter to add your extra config.
//
// Example:
// mySession := session.Must(session.NewSession())
//
// // Create a ECR client from just a session.
// svc := ecr.New(mySession)
//
@ -49,11 +52,11 @@ func New(p client.ConfigProvider, cfgs ...*aws.Config) *ECR {
if c.SigningNameDerived || len(c.SigningName) == 0 {
c.SigningName = "ecr"
}
return newClient(*c.Config, c.Handlers, c.Endpoint, c.SigningRegion, c.SigningName)
return newClient(*c.Config, c.Handlers, c.PartitionID, c.Endpoint, c.SigningRegion, c.SigningName)
}
// newClient creates, initializes and returns a new service client instance.
func newClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegion, signingName string) *ECR {
func newClient(cfg aws.Config, handlers request.Handlers, partitionID, endpoint, signingRegion, signingName string) *ECR {
svc := &ECR{
Client: client.New(
cfg,
@ -62,6 +65,7 @@ func newClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegio
ServiceID: ServiceID,
SigningName: signingName,
SigningRegion: signingRegion,
PartitionID: partitionID,
Endpoint: endpoint,
APIVersion: "2015-09-21",
JSONVersion: "1.1",
@ -76,7 +80,9 @@ func newClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegio
svc.Handlers.Build.PushBackNamed(jsonrpc.BuildHandler)
svc.Handlers.Unmarshal.PushBackNamed(jsonrpc.UnmarshalHandler)
svc.Handlers.UnmarshalMeta.PushBackNamed(jsonrpc.UnmarshalMetaHandler)
svc.Handlers.UnmarshalError.PushBackNamed(jsonrpc.UnmarshalErrorHandler)
svc.Handlers.UnmarshalError.PushBackNamed(
protocol.NewUnmarshalErrorHandler(jsonrpc.NewUnmarshalTypedError(exceptionFromCode)).NamedHandler(),
)
// Run custom client initialization if present
if initClient != nil {

View File

@ -0,0 +1,112 @@
// Code generated by private/model/cli/gen-api/main.go. DO NOT EDIT.
package ecr
import (
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
)
// WaitUntilImageScanComplete uses the Amazon ECR API operation
// DescribeImageScanFindings to wait for a condition to be met before returning.
// If the condition is not met within the max attempt window, an error will
// be returned.
func (c *ECR) WaitUntilImageScanComplete(input *DescribeImageScanFindingsInput) error {
return c.WaitUntilImageScanCompleteWithContext(aws.BackgroundContext(), input)
}
// WaitUntilImageScanCompleteWithContext is an extended version of WaitUntilImageScanComplete.
// With the support for passing in a context and options to configure the
// Waiter and the underlying request options.
//
// The context must be non-nil and will be used for request cancellation. If
// the context is nil a panic will occur. In the future the SDK may create
// sub-contexts for http.Requests. See https://golang.org/pkg/context/
// for more information on using Contexts.
func (c *ECR) WaitUntilImageScanCompleteWithContext(ctx aws.Context, input *DescribeImageScanFindingsInput, opts ...request.WaiterOption) error {
w := request.Waiter{
Name: "WaitUntilImageScanComplete",
MaxAttempts: 60,
Delay: request.ConstantWaiterDelay(5 * time.Second),
Acceptors: []request.WaiterAcceptor{
{
State: request.SuccessWaiterState,
Matcher: request.PathWaiterMatch, Argument: "imageScanStatus.status",
Expected: "COMPLETE",
},
{
State: request.FailureWaiterState,
Matcher: request.PathWaiterMatch, Argument: "imageScanStatus.status",
Expected: "FAILED",
},
},
Logger: c.Config.Logger,
NewRequest: func(opts []request.Option) (*request.Request, error) {
var inCpy *DescribeImageScanFindingsInput
if input != nil {
tmp := *input
inCpy = &tmp
}
req, _ := c.DescribeImageScanFindingsRequest(inCpy)
req.SetContext(ctx)
req.ApplyOptions(opts...)
return req, nil
},
}
w.ApplyOptions(opts...)
return w.WaitWithContext(ctx)
}
// WaitUntilLifecyclePolicyPreviewComplete uses the Amazon ECR API operation
// GetLifecyclePolicyPreview to wait for a condition to be met before returning.
// If the condition is not met within the max attempt window, an error will
// be returned.
func (c *ECR) WaitUntilLifecyclePolicyPreviewComplete(input *GetLifecyclePolicyPreviewInput) error {
return c.WaitUntilLifecyclePolicyPreviewCompleteWithContext(aws.BackgroundContext(), input)
}
// WaitUntilLifecyclePolicyPreviewCompleteWithContext is an extended version of WaitUntilLifecyclePolicyPreviewComplete.
// With the support for passing in a context and options to configure the
// Waiter and the underlying request options.
//
// The context must be non-nil and will be used for request cancellation. If
// the context is nil a panic will occur. In the future the SDK may create
// sub-contexts for http.Requests. See https://golang.org/pkg/context/
// for more information on using Contexts.
func (c *ECR) WaitUntilLifecyclePolicyPreviewCompleteWithContext(ctx aws.Context, input *GetLifecyclePolicyPreviewInput, opts ...request.WaiterOption) error {
w := request.Waiter{
Name: "WaitUntilLifecyclePolicyPreviewComplete",
MaxAttempts: 20,
Delay: request.ConstantWaiterDelay(5 * time.Second),
Acceptors: []request.WaiterAcceptor{
{
State: request.SuccessWaiterState,
Matcher: request.PathWaiterMatch, Argument: "status",
Expected: "COMPLETE",
},
{
State: request.FailureWaiterState,
Matcher: request.PathWaiterMatch, Argument: "status",
Expected: "FAILED",
},
},
Logger: c.Config.Logger,
NewRequest: func(opts []request.Option) (*request.Request, error) {
var inCpy *GetLifecyclePolicyPreviewInput
if input != nil {
tmp := *input
inCpy = &tmp
}
req, _ := c.GetLifecyclePolicyPreviewRequest(inCpy)
req.SetContext(ctx)
req.ApplyOptions(opts...)
return req, nil
},
}
w.ApplyOptions(opts...)
return w.WaitWithContext(ctx)
}

Some files were not shown because too many files have changed in this diff Show More