diff --git a/src/common/const.go b/src/common/const.go index 5e30b78fc..d94edf1af 100644 --- a/src/common/const.go +++ b/src/common/const.go @@ -226,4 +226,6 @@ const ( UIMaxLengthLimitedOfNumber = 10 // ExecutionStatusRefreshIntervalSeconds is the interval seconds for refreshing execution status ExecutionStatusRefreshIntervalSeconds = "execution_status_refresh_interval_seconds" + // QuotaUpdateProvider is the provider for updating quota, currently support Redis and DB + QuotaUpdateProvider = "quota_update_provider" ) diff --git a/src/controller/blob/controller.go b/src/controller/blob/controller.go index 112f0d8e1..f59a457ff 100644 --- a/src/controller/blob/controller.go +++ b/src/controller/blob/controller.go @@ -323,7 +323,12 @@ func (c *controller) Sync(ctx context.Context, references []distribution.Descrip func (c *controller) SetAcceptedBlobSize(ctx context.Context, sessionID string, size int64) error { key := blobSizeKey(sessionID) - err := libredis.Instance().Set(ctx, key, size, c.blobSizeExpiration).Err() + rc, err := libredis.GetRegistryClient() + if err != nil { + return err + } + + err = rc.Set(ctx, key, size, c.blobSizeExpiration).Err() if err != nil { log.Errorf("failed to set accepted blob size for session %s in redis, error: %v", sessionID, err) return err @@ -334,7 +339,12 @@ func (c *controller) SetAcceptedBlobSize(ctx context.Context, sessionID string, func (c *controller) GetAcceptedBlobSize(ctx context.Context, sessionID string) (int64, error) { key := blobSizeKey(sessionID) - size, err := libredis.Instance().Get(ctx, key).Int64() + rc, err := libredis.GetRegistryClient() + if err != nil { + return 0, err + } + + size, err := rc.Get(ctx, key).Int64() if err != nil { if err == redis.Nil { return 0, nil diff --git a/src/controller/quota/controller.go b/src/controller/quota/controller.go index 7e6d8f095..3c4b857d5 100644 --- a/src/controller/quota/controller.go +++ b/src/controller/quota/controller.go @@ -19,20 +19,48 @@ import ( "fmt" "time" + "github.com/go-redis/redis/v8" + "golang.org/x/sync/singleflight" + // quota driver _ "github.com/goharbor/harbor/src/controller/quota/driver" + "github.com/goharbor/harbor/src/lib/cache" + "github.com/goharbor/harbor/src/lib/config" "github.com/goharbor/harbor/src/lib/errors" + "github.com/goharbor/harbor/src/lib/gtask" "github.com/goharbor/harbor/src/lib/log" "github.com/goharbor/harbor/src/lib/orm" "github.com/goharbor/harbor/src/lib/q" + libredis "github.com/goharbor/harbor/src/lib/redis" "github.com/goharbor/harbor/src/lib/retry" "github.com/goharbor/harbor/src/pkg/quota" "github.com/goharbor/harbor/src/pkg/quota/driver" "github.com/goharbor/harbor/src/pkg/quota/types" + + // init the db config + _ "github.com/goharbor/harbor/src/pkg/config/db" ) +func init() { + // register the async task for flushing quota to db when enable update quota by redis + if provider := config.GetQuotaUpdateProvider(); provider == updateQuotaProviderRedis.String() { + gtask.DefaultPool().AddTask(flushQuota, 30*time.Second) + } +} + +type updateQuotaProviderType string + +func (t updateQuotaProviderType) String() string { + return string(t) +} + var ( defaultRetryTimeout = time.Minute * 5 + // quotaExpireTimeout is the expire time for quota when update quota by redis + quotaExpireTimeout = time.Minute * 5 + + updateQuotaProviderRedis updateQuotaProviderType = "Redis" + updateQuotaProviderDB updateQuotaProviderType = "DB" ) var ( @@ -87,6 +115,31 @@ type controller struct { reservedExpiration time.Duration quotaMgr quota.Manager + g singleflight.Group +} + +// flushQuota flushes the quota info from redis to db asynchronously. +func flushQuota(ctx context.Context) { + iter, err := cache.Default().Scan(ctx, "quota:*") + if err != nil { + log.Errorf("failed to scan out the quota records from redis") + } + + for iter.Next(ctx) { + key := iter.Val() + q := "a.Quota{} + err = cache.Default().Fetch(ctx, key, q) + if err != nil { + log.Errorf("failed to fetch quota: %s, error: %v", key, err) + continue + } + + if err = Ctl.Update(ctx, q); err != nil { + log.Errorf("failed to refresh quota: %s, error: %v", key, err) + } else { + log.Debugf("successfully refreshed quota: %s", key) + } + } } func (c *controller) Count(ctx context.Context, query *q.Query) (int64, error) { @@ -163,13 +216,83 @@ func (c *controller) List(ctx context.Context, query *q.Query, options ...Option return quotas, nil } -func (c *controller) updateUsageWithRetry(ctx context.Context, reference, referenceID string, op func(hardLimits, used types.ResourceList) (types.ResourceList, error), retryOpts ...retry.Option) error { - f := func() error { - q, err := c.quotaMgr.GetByRef(ctx, reference, referenceID) - if err != nil { +// updateUsageByDB updates the quota usage by the database which updates the quota usage immediately. +func (c *controller) updateUsageByDB(ctx context.Context, reference, referenceID string, op func(hardLimits, used types.ResourceList) (types.ResourceList, error)) error { + q, err := c.quotaMgr.GetByRef(ctx, reference, referenceID) + if err != nil { + return retry.Abort(err) + } + + hardLimits, err := q.GetHard() + if err != nil { + return retry.Abort(err) + } + + used, err := q.GetUsed() + if err != nil { + return retry.Abort(err) + } + + newUsed, err := op(hardLimits, used) + if err != nil { + return retry.Abort(err) + } + + // The PR https://github.com/goharbor/harbor/pull/17392 optimized the logic for post upload blob which use size 0 + // for checking quota, this will increase the pressure of optimistic lock, so here return earlier + // if the quota usage has not changed to reduce the probability of optimistic lock. + if types.Equals(used, newUsed) { + return nil + } + + q.SetUsed(newUsed) + + err = c.quotaMgr.Update(ctx, q) + if err != nil && !errors.Is(err, orm.ErrOptimisticLock) { + return retry.Abort(err) + } + + return err +} + +// updateUsageByRedis updates the quota usage by the redis and flush the quota usage to db asynchronously. +func (c *controller) updateUsageByRedis(ctx context.Context, reference, referenceID string, op func(hardLimits, used types.ResourceList) (types.ResourceList, error)) error { + // earlier abort if context is error such as context canceled + if ctx.Err() != nil { + return retry.Abort(ctx.Err()) + } + + client, err := libredis.GetCoreClient() + if err != nil { + return retry.Abort(err) + } + // normally use cache.Save will append prefix "cache:", in order to keep consistent + // here adopts raw redis client should also pad the prefix manually. + key := fmt.Sprintf("%s:quota:%s:%s", "cache", reference, referenceID) + return client.Watch(ctx, func(tx *redis.Tx) error { + data, err := tx.Get(ctx, key).Result() + if err != nil && err != redis.Nil { return retry.Abort(err) } + q := "a.Quota{} + // calc the quota usage in real time if no key found + if err == redis.Nil { + // use singleflight to prevent cache penetration and cause pressure on the database. + realQuota, err, _ := c.g.Do(key, func() (interface{}, error) { + return c.calcQuota(ctx, reference, referenceID) + }) + if err != nil { + return retry.Abort(err) + } + + q = realQuota.(*quota.Quota) + } else { + if err = cache.DefaultCodec().Decode([]byte(data), q); err != nil { + return retry.Abort(err) + } + } + hardLimits, err := q.GetHard() if err != nil { return retry.Abort(err) @@ -185,21 +308,42 @@ func (c *controller) updateUsageWithRetry(ctx context.Context, reference, refere return retry.Abort(err) } - // The PR https://github.com/goharbor/harbor/pull/17392 optimized the logic for post upload blob which use size 0 - // for checking quota, this will increase the pressure of optimistic lock, so here return earlier - // if the quota usage has not changed to reduce the probability of optimistic lock. - if types.Equals(used, newUsed) { - return nil - } - q.SetUsed(newUsed) - err = c.quotaMgr.Update(ctx, q) - if err != nil && !errors.Is(err, orm.ErrOptimisticLock) { + val, err := cache.DefaultCodec().Encode(q) + if err != nil { + return retry.Abort(err) + } + + _, err = tx.TxPipelined(ctx, func(p redis.Pipeliner) error { + _, err = p.Set(ctx, key, val, quotaExpireTimeout).Result() + return err + }) + + if err != nil && err != redis.TxFailedErr { return retry.Abort(err) } return err + }, key) +} + +func (c *controller) updateUsageWithRetry(ctx context.Context, reference, referenceID string, op func(hardLimits, used types.ResourceList) (types.ResourceList, error), provider updateQuotaProviderType, retryOpts ...retry.Option) error { + var f func() error + switch provider { + case updateQuotaProviderDB: + f = func() error { + return c.updateUsageByDB(ctx, reference, referenceID, op) + } + case updateQuotaProviderRedis: + f = func() error { + return c.updateUsageByRedis(ctx, reference, referenceID, op) + } + default: + // by default is update quota by db + f = func() error { + return c.updateUsageByDB(ctx, reference, referenceID, op) + } } options := []retry.Option{ @@ -235,7 +379,8 @@ func (c *controller) Refresh(ctx context.Context, reference, referenceID string, return newUsed, err } - return c.updateUsageWithRetry(ctx, reference, referenceID, refreshResources(calculateUsage, opts.IgnoreLimitation), opts.RetryOptions...) + // update quota usage by db for refresh operation + return c.updateUsageWithRetry(ctx, reference, referenceID, refreshResources(calculateUsage, opts.IgnoreLimitation), updateQuotaProviderType(config.GetQuotaUpdateProvider()), opts.RetryOptions...) } func (c *controller) Request(ctx context.Context, reference, referenceID string, resources types.ResourceList, f func() error) error { @@ -243,7 +388,8 @@ func (c *controller) Request(ctx context.Context, reference, referenceID string, return f() } - if err := c.updateUsageWithRetry(ctx, reference, referenceID, reserveResources(resources)); err != nil { + provider := updateQuotaProviderType(config.GetQuotaUpdateProvider()) + if err := c.updateUsageWithRetry(ctx, reference, referenceID, reserveResources(resources), provider); err != nil { log.G(ctx).Errorf("reserve resources %s for %s %s failed, error: %v", resources.String(), reference, referenceID, err) return err } @@ -251,7 +397,7 @@ func (c *controller) Request(ctx context.Context, reference, referenceID string, err := f() if err != nil { - if er := c.updateUsageWithRetry(ctx, reference, referenceID, rollbackResources(resources)); er != nil { + if er := c.updateUsageWithRetry(ctx, reference, referenceID, rollbackResources(resources), provider); er != nil { // ignore this error, the quota usage will be correct when users do operations which will call refresh quota log.G(ctx).Warningf("rollback resources %s for %s %s failed, error: %v", resources.String(), reference, referenceID, er) } @@ -260,6 +406,29 @@ func (c *controller) Request(ctx context.Context, reference, referenceID string, return err } +// calcQuota calculates the quota and usage in real time. +func (c *controller) calcQuota(ctx context.Context, reference, referenceID string) (*quota.Quota, error) { + // get quota and usage from db + q, err := c.quotaMgr.GetByRef(ctx, reference, referenceID) + if err != nil { + return nil, err + } + // the usage in the db maybe outdated, calc it in real time + driver, err := Driver(ctx, reference) + if err != nil { + return nil, err + } + + newUsed, err := driver.CalculateUsage(ctx, referenceID) + if err != nil { + log.G(ctx).Errorf("failed to calculate quota usage for %s %s, error: %v", reference, referenceID, err) + return nil, err + } + + q.SetUsed(newUsed) + return q, nil +} + func (c *controller) Update(ctx context.Context, u *quota.Quota) error { f := func() error { q, err := c.quotaMgr.GetByRef(ctx, u.Reference, u.ReferenceID) @@ -267,15 +436,19 @@ func (c *controller) Update(ctx context.Context, u *quota.Quota) error { return err } - if q.Hard != u.Hard { - if hard, err := u.GetHard(); err == nil { - q.SetHard(hard) + if oldHard, err := q.GetHard(); err == nil { + if newHard, err := u.GetHard(); err == nil { + if !types.Equals(oldHard, newHard) { + q.SetHard(newHard) + } } } - if q.Used != u.Used { - if used, err := u.GetUsed(); err == nil { - q.SetUsed(used) + if oldUsed, err := q.GetUsed(); err == nil { + if newUsed, err := u.GetUsed(); err == nil { + if !types.Equals(oldUsed, newUsed) { + q.SetUsed(newUsed) + } } } diff --git a/src/go.mod b/src/go.mod index a51c7130a..336667943 100644 --- a/src/go.mod +++ b/src/go.mod @@ -64,6 +64,7 @@ require ( golang.org/x/crypto v0.5.0 golang.org/x/net v0.9.0 golang.org/x/oauth2 v0.5.0 + golang.org/x/sync v0.3.0 golang.org/x/text v0.9.0 golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 gopkg.in/h2non/gock.v1 v1.0.16 @@ -162,7 +163,6 @@ require ( go.uber.org/atomic v1.7.0 // indirect go.uber.org/multierr v1.6.0 // indirect go.uber.org/zap v1.19.0 // indirect - golang.org/x/sync v0.3.0 golang.org/x/sys v0.7.0 // indirect golang.org/x/term v0.7.0 // indirect google.golang.org/api v0.110.0 // indirect diff --git a/src/lib/cache/redis/redis_test.go b/src/lib/cache/redis/redis_test.go index 1170543aa..17b6616ef 100644 --- a/src/lib/cache/redis/redis_test.go +++ b/src/lib/cache/redis/redis_test.go @@ -126,12 +126,12 @@ func (suite *CacheTestSuite) TestScan() { } } { - // no match should return all keys + // return all keys with test-scan-* expect := []string{"test-scan-0", "test-scan-1", "test-scan-2"} // seed data seed(3) // test scan - iter, err := suite.cache.Scan(suite.ctx, "") + iter, err := suite.cache.Scan(suite.ctx, "test-scan-*") suite.NoError(err) got := []string{} for iter.Next(suite.ctx) { @@ -143,12 +143,12 @@ func (suite *CacheTestSuite) TestScan() { } { - // with match should return matched keys + // return matched keys with test-scan-1* expect := []string{"test-scan-1", "test-scan-10"} // seed data seed(11) // test scan - iter, err := suite.cache.Scan(suite.ctx, "*test-scan-1*") + iter, err := suite.cache.Scan(suite.ctx, "test-scan-1*") suite.NoError(err) got := []string{} for iter.Next(suite.ctx) { diff --git a/src/lib/config/metadata/metadatalist.go b/src/lib/config/metadata/metadatalist.go index bb3285e5f..535226bc8 100644 --- a/src/lib/config/metadata/metadatalist.go +++ b/src/lib/config/metadata/metadatalist.go @@ -191,5 +191,6 @@ var ( {Name: common.ExecutionStatusRefreshIntervalSeconds, Scope: SystemScope, Group: BasicGroup, EnvKey: "EXECUTION_STATUS_REFRESH_INTERVAL_SECONDS", DefaultValue: "30", ItemType: &Int64Type{}, Editable: false, Description: `The interval seconds to refresh the execution status`}, {Name: common.BannerMessage, Scope: UserScope, Group: BasicGroup, EnvKey: "BANNER_MESSAGE", DefaultValue: "", ItemType: &StringType{}, Editable: true, Description: `The customized banner message for the UI`}, + {Name: common.QuotaUpdateProvider, Scope: SystemScope, Group: BasicGroup, EnvKey: "QUOTA_UPDATE_PROVIDER", DefaultValue: "db", ItemType: &StringType{}, Editable: false, Description: `The provider for updating quota, 'db' or 'redis' is supported`}, } ) diff --git a/src/lib/config/systemconfig.go b/src/lib/config/systemconfig.go index ac835fff9..5babc6550 100644 --- a/src/lib/config/systemconfig.go +++ b/src/lib/config/systemconfig.go @@ -132,6 +132,11 @@ func GetExecutionStatusRefreshIntervalSeconds() int64 { return DefaultMgr().Get(backgroundCtx, common.ExecutionStatusRefreshIntervalSeconds).GetInt64() } +// GetQuotaUpdateProvider returns the provider for updating quota. +func GetQuotaUpdateProvider() string { + return DefaultMgr().Get(backgroundCtx, common.QuotaUpdateProvider).GetString() +} + // WithTrivy returns a bool value to indicate if Harbor's deployed with Trivy. func WithTrivy() bool { return DefaultMgr().Get(backgroundCtx, common.WithTrivy).GetBool() diff --git a/src/lib/redis/client.go b/src/lib/redis/client.go new file mode 100644 index 000000000..d05471857 --- /dev/null +++ b/src/lib/redis/client.go @@ -0,0 +1,85 @@ +// Copyright Project Harbor Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package redis + +import ( + "errors" + "os" + "sync" + + "github.com/go-redis/redis/v8" + + "github.com/goharbor/harbor/src/lib/cache" + libredis "github.com/goharbor/harbor/src/lib/cache/redis" + "github.com/goharbor/harbor/src/lib/log" +) + +var ( + // registry is a global redis client for registry db + registry *redis.Client + registryOnce = &sync.Once{} + + // core is a global redis client for core db + core *redis.Client + coreOnce = &sync.Once{} +) + +// GetRegistryClient returns the registry redis client. +func GetRegistryClient() (*redis.Client, error) { + registryOnce.Do(func() { + url := os.Getenv("_REDIS_URL_REG") + c, err := libredis.New(cache.Options{Address: url}) + if err != nil { + log.Errorf("failed to initialize redis client for registry, error: %v", err) + // reset the once to support retry if error occurred + registryOnce = &sync.Once{} + return + } + + if c != nil { + registry = c.(*libredis.Cache).Client + } + }) + + if registry == nil { + return nil, errors.New("no registry redis client initialized") + } + + return registry, nil +} + +// GetCoreClient returns the core redis client. +func GetCoreClient() (*redis.Client, error) { + coreOnce.Do(func() { + url := os.Getenv("_REDIS_URL_CORE") + c, err := libredis.New(cache.Options{Address: url}) + if err != nil { + log.Errorf("failed to initialize redis client for core, error: %v", err) + // reset the once to support retry if error occurred + coreOnce = &sync.Once{} + return + } + + if c != nil { + core = c.(*libredis.Cache).Client + } + }) + + if core == nil { + return nil, errors.New("no core redis client initialized") + } + + return core, nil +} diff --git a/src/lib/redis/client_test.go b/src/lib/redis/client_test.go new file mode 100644 index 000000000..1e288ba44 --- /dev/null +++ b/src/lib/redis/client_test.go @@ -0,0 +1,63 @@ +// Copyright Project Harbor Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package redis + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGetRegistryClient(t *testing.T) { + // failure case with invalid address + t.Setenv("_REDIS_URL_REG", "invalid-address") + client, err := GetRegistryClient() + assert.Error(t, err) + assert.Nil(t, client) + + // normal case with valid address + t.Setenv("_REDIS_URL_REG", "redis://localhost:6379/1") + client, err = GetRegistryClient() + assert.NoError(t, err) + assert.NotNil(t, client) + + // multiple calls should return the same client + for i := 0; i < 10; i++ { + newClient, err := GetRegistryClient() + assert.NoError(t, err) + assert.Equal(t, client, newClient) + } +} + +func TestGetCoreClient(t *testing.T) { + // failure case with invalid address + t.Setenv("_REDIS_URL_CORE", "invalid-address") + client, err := GetCoreClient() + assert.Error(t, err) + assert.Nil(t, client) + + // normal case with valid address + t.Setenv("_REDIS_URL_CORE", "redis://localhost:6379/0") + client, err = GetCoreClient() + assert.NoError(t, err) + assert.NotNil(t, client) + + // multiple calls should return the same client + for i := 0; i < 10; i++ { + newClient, err := GetCoreClient() + assert.NoError(t, err) + assert.Equal(t, client, newClient) + } +} diff --git a/src/lib/redis/instance.go b/src/lib/redis/instance.go deleted file mode 100644 index 9e59c1a0a..000000000 --- a/src/lib/redis/instance.go +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright Project Harbor Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package redis - -import ( - "os" - "sync" - - "github.com/go-redis/redis/v8" - - "github.com/goharbor/harbor/src/lib/cache" - libredis "github.com/goharbor/harbor/src/lib/cache/redis" -) - -var ( - // instance is a global redis client. - _instance *redis.Client - _once sync.Once -) - -// Instance returns the redis instance. -func Instance() *redis.Client { - _once.Do(func() { - url := os.Getenv("_REDIS_URL_REG") - if url == "" { - url = "redis://localhost:6379/1" - } - - c, err := libredis.New(cache.Options{Address: url}) - if err != nil { - panic(err) - } - - _instance = c.(*libredis.Cache).Client - }) - - return _instance -} diff --git a/src/lib/redis/instance_test.go b/src/lib/redis/instance_test.go deleted file mode 100644 index 64d064612..000000000 --- a/src/lib/redis/instance_test.go +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright Project Harbor Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package redis - -import ( - "context" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestInstance(t *testing.T) { - ins := Instance() - assert.NotNil(t, ins, "should get instance") - - ctx := context.TODO() - // Test set - err := ins.Set(ctx, "foo", "bar", 0).Err() - assert.NoError(t, err, "redis set should be success") - // Test get - val := ins.Get(ctx, "foo").Val() - assert.Equal(t, "bar", val, "redis get should be success") - // Test delete - err = ins.Del(ctx, "foo").Err() - assert.NoError(t, err, "redis delete should be success") - exist := ins.Exists(ctx, "foo").Val() - assert.Equal(t, int64(0), exist, "key should not exist") -} diff --git a/src/lib/redis/redisclient.go b/src/lib/redis/pool.go similarity index 100% rename from src/lib/redis/redisclient.go rename to src/lib/redis/pool.go diff --git a/src/vendor/golang.org/x/sync/singleflight/singleflight.go b/src/vendor/golang.org/x/sync/singleflight/singleflight.go new file mode 100644 index 000000000..8473fb792 --- /dev/null +++ b/src/vendor/golang.org/x/sync/singleflight/singleflight.go @@ -0,0 +1,205 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package singleflight provides a duplicate function call suppression +// mechanism. +package singleflight // import "golang.org/x/sync/singleflight" + +import ( + "bytes" + "errors" + "fmt" + "runtime" + "runtime/debug" + "sync" +) + +// errGoexit indicates the runtime.Goexit was called in +// the user given function. +var errGoexit = errors.New("runtime.Goexit was called") + +// A panicError is an arbitrary value recovered from a panic +// with the stack trace during the execution of given function. +type panicError struct { + value interface{} + stack []byte +} + +// Error implements error interface. +func (p *panicError) Error() string { + return fmt.Sprintf("%v\n\n%s", p.value, p.stack) +} + +func newPanicError(v interface{}) error { + stack := debug.Stack() + + // The first line of the stack trace is of the form "goroutine N [status]:" + // but by the time the panic reaches Do the goroutine may no longer exist + // and its status will have changed. Trim out the misleading line. + if line := bytes.IndexByte(stack[:], '\n'); line >= 0 { + stack = stack[line+1:] + } + return &panicError{value: v, stack: stack} +} + +// call is an in-flight or completed singleflight.Do call +type call struct { + wg sync.WaitGroup + + // These fields are written once before the WaitGroup is done + // and are only read after the WaitGroup is done. + val interface{} + err error + + // These fields are read and written with the singleflight + // mutex held before the WaitGroup is done, and are read but + // not written after the WaitGroup is done. + dups int + chans []chan<- Result +} + +// Group represents a class of work and forms a namespace in +// which units of work can be executed with duplicate suppression. +type Group struct { + mu sync.Mutex // protects m + m map[string]*call // lazily initialized +} + +// Result holds the results of Do, so they can be passed +// on a channel. +type Result struct { + Val interface{} + Err error + Shared bool +} + +// Do executes and returns the results of the given function, making +// sure that only one execution is in-flight for a given key at a +// time. If a duplicate comes in, the duplicate caller waits for the +// original to complete and receives the same results. +// The return value shared indicates whether v was given to multiple callers. +func (g *Group) Do(key string, fn func() (interface{}, error)) (v interface{}, err error, shared bool) { + g.mu.Lock() + if g.m == nil { + g.m = make(map[string]*call) + } + if c, ok := g.m[key]; ok { + c.dups++ + g.mu.Unlock() + c.wg.Wait() + + if e, ok := c.err.(*panicError); ok { + panic(e) + } else if c.err == errGoexit { + runtime.Goexit() + } + return c.val, c.err, true + } + c := new(call) + c.wg.Add(1) + g.m[key] = c + g.mu.Unlock() + + g.doCall(c, key, fn) + return c.val, c.err, c.dups > 0 +} + +// DoChan is like Do but returns a channel that will receive the +// results when they are ready. +// +// The returned channel will not be closed. +func (g *Group) DoChan(key string, fn func() (interface{}, error)) <-chan Result { + ch := make(chan Result, 1) + g.mu.Lock() + if g.m == nil { + g.m = make(map[string]*call) + } + if c, ok := g.m[key]; ok { + c.dups++ + c.chans = append(c.chans, ch) + g.mu.Unlock() + return ch + } + c := &call{chans: []chan<- Result{ch}} + c.wg.Add(1) + g.m[key] = c + g.mu.Unlock() + + go g.doCall(c, key, fn) + + return ch +} + +// doCall handles the single call for a key. +func (g *Group) doCall(c *call, key string, fn func() (interface{}, error)) { + normalReturn := false + recovered := false + + // use double-defer to distinguish panic from runtime.Goexit, + // more details see https://golang.org/cl/134395 + defer func() { + // the given function invoked runtime.Goexit + if !normalReturn && !recovered { + c.err = errGoexit + } + + g.mu.Lock() + defer g.mu.Unlock() + c.wg.Done() + if g.m[key] == c { + delete(g.m, key) + } + + if e, ok := c.err.(*panicError); ok { + // In order to prevent the waiting channels from being blocked forever, + // needs to ensure that this panic cannot be recovered. + if len(c.chans) > 0 { + go panic(e) + select {} // Keep this goroutine around so that it will appear in the crash dump. + } else { + panic(e) + } + } else if c.err == errGoexit { + // Already in the process of goexit, no need to call again + } else { + // Normal return + for _, ch := range c.chans { + ch <- Result{c.val, c.err, c.dups > 0} + } + } + }() + + func() { + defer func() { + if !normalReturn { + // Ideally, we would wait to take a stack trace until we've determined + // whether this is a panic or a runtime.Goexit. + // + // Unfortunately, the only way we can distinguish the two is to see + // whether the recover stopped the goroutine from terminating, and by + // the time we know that, the part of the stack trace relevant to the + // panic has been discarded. + if r := recover(); r != nil { + c.err = newPanicError(r) + } + } + }() + + c.val, c.err = fn() + normalReturn = true + }() + + if !normalReturn { + recovered = true + } +} + +// Forget tells the singleflight to forget about a key. Future calls +// to Do for this key will call the function rather than waiting for +// an earlier call to complete. +func (g *Group) Forget(key string) { + g.mu.Lock() + delete(g.m, key) + g.mu.Unlock() +} diff --git a/src/vendor/modules.txt b/src/vendor/modules.txt index 408846e62..61be24ea2 100644 --- a/src/vendor/modules.txt +++ b/src/vendor/modules.txt @@ -699,6 +699,7 @@ golang.org/x/oauth2/jwt # golang.org/x/sync v0.3.0 ## explicit; go 1.17 golang.org/x/sync/errgroup +golang.org/x/sync/singleflight # golang.org/x/sys v0.7.0 ## explicit; go 1.17 golang.org/x/sys/internal/unsafeheader