diff --git a/src/common/utils/oidc/helper.go b/src/common/utils/oidc/helper.go index 32fee3c29..a970d5942 100644 --- a/src/common/utils/oidc/helper.go +++ b/src/common/utils/oidc/helper.go @@ -35,20 +35,14 @@ const googleEndpoint = "https://accounts.google.com" type providerHelper struct { sync.Mutex - ep endpoint - instance atomic.Value - setting atomic.Value -} - -type endpoint struct { - url string - VerifyCert bool + instance atomic.Value + setting atomic.Value + creationTime time.Time } func (p *providerHelper) get() (*gooidc.Provider, error) { if p.instance.Load() != nil { - s := p.setting.Load().(models.OIDCSetting) - if s.Endpoint != p.ep.url || s.VerifyCert != p.ep.VerifyCert { // relevant settings have changed, need to re-create provider. + if time.Now().Sub(p.creationTime) > 3*time.Second { if err := p.create(); err != nil { return nil, err } @@ -57,7 +51,7 @@ func (p *providerHelper) get() (*gooidc.Provider, error) { p.Lock() defer p.Unlock() if p.instance.Load() == nil { - if err := p.reload(); err != nil { + if err := p.reloadSetting(); err != nil { return nil, err } if err := p.create(); err != nil { @@ -65,7 +59,7 @@ func (p *providerHelper) get() (*gooidc.Provider, error) { } go func() { for { - if err := p.reload(); err != nil { + if err := p.reloadSetting(); err != nil { log.Warningf("Failed to refresh configuration, error: %v", err) } time.Sleep(3 * time.Second) @@ -73,10 +67,11 @@ func (p *providerHelper) get() (*gooidc.Provider, error) { }() } } + return p.instance.Load().(*gooidc.Provider), nil } -func (p *providerHelper) reload() error { +func (p *providerHelper) reloadSetting() error { conf, err := config.OIDCSetting() if err != nil { return fmt.Errorf("failed to load OIDC setting: %v", err) @@ -96,10 +91,7 @@ func (p *providerHelper) create() error { return fmt.Errorf("failed to create OIDC provider, error: %v", err) } p.instance.Store(provider) - p.ep = endpoint{ - url: s.Endpoint, - VerifyCert: s.VerifyCert, - } + p.creationTime = time.Now() return nil } diff --git a/src/common/utils/oidc/helper_test.go b/src/common/utils/oidc/helper_test.go index e1e71a8b9..8586e6301 100644 --- a/src/common/utils/oidc/helper_test.go +++ b/src/common/utils/oidc/helper_test.go @@ -49,21 +49,20 @@ func TestMain(m *testing.M) { func TestHelperLoadConf(t *testing.T) { testP := &providerHelper{} assert.Nil(t, testP.setting.Load()) - err := testP.reload() + err := testP.reloadSetting() assert.Nil(t, err) assert.Equal(t, "test", testP.setting.Load().(models.OIDCSetting).Name) - assert.Equal(t, endpoint{}, testP.ep) } func TestHelperCreate(t *testing.T) { testP := &providerHelper{} - err := testP.reload() + err := testP.reloadSetting() assert.Nil(t, err) assert.Nil(t, testP.instance.Load()) err = testP.create() assert.Nil(t, err) - assert.EqualValues(t, "https://accounts.google.com", testP.ep.url) assert.NotNil(t, testP.instance.Load()) + assert.True(t, time.Now().Sub(testP.creationTime) < 2*time.Second) } func TestHelperGet(t *testing.T) {