Merge pull request #8220 from reasonerjt/oidc-rotation-fix

Reload OIDC provider older than 3 seconds
This commit is contained in:
Wenkai Yin(尹文开) 2019-07-05 10:12:33 +08:00 committed by GitHub
commit c01bedb740
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 21 deletions

View File

@ -35,20 +35,14 @@ const googleEndpoint = "https://accounts.google.com"
type providerHelper struct {
sync.Mutex
ep endpoint
instance atomic.Value
setting atomic.Value
}
type endpoint struct {
url string
VerifyCert bool
instance atomic.Value
setting atomic.Value
creationTime time.Time
}
func (p *providerHelper) get() (*gooidc.Provider, error) {
if p.instance.Load() != nil {
s := p.setting.Load().(models.OIDCSetting)
if s.Endpoint != p.ep.url || s.VerifyCert != p.ep.VerifyCert { // relevant settings have changed, need to re-create provider.
if time.Now().Sub(p.creationTime) > 3*time.Second {
if err := p.create(); err != nil {
return nil, err
}
@ -57,7 +51,7 @@ func (p *providerHelper) get() (*gooidc.Provider, error) {
p.Lock()
defer p.Unlock()
if p.instance.Load() == nil {
if err := p.reload(); err != nil {
if err := p.reloadSetting(); err != nil {
return nil, err
}
if err := p.create(); err != nil {
@ -65,7 +59,7 @@ func (p *providerHelper) get() (*gooidc.Provider, error) {
}
go func() {
for {
if err := p.reload(); err != nil {
if err := p.reloadSetting(); err != nil {
log.Warningf("Failed to refresh configuration, error: %v", err)
}
time.Sleep(3 * time.Second)
@ -73,10 +67,11 @@ func (p *providerHelper) get() (*gooidc.Provider, error) {
}()
}
}
return p.instance.Load().(*gooidc.Provider), nil
}
func (p *providerHelper) reload() error {
func (p *providerHelper) reloadSetting() error {
conf, err := config.OIDCSetting()
if err != nil {
return fmt.Errorf("failed to load OIDC setting: %v", err)
@ -96,10 +91,7 @@ func (p *providerHelper) create() error {
return fmt.Errorf("failed to create OIDC provider, error: %v", err)
}
p.instance.Store(provider)
p.ep = endpoint{
url: s.Endpoint,
VerifyCert: s.VerifyCert,
}
p.creationTime = time.Now()
return nil
}

View File

@ -49,21 +49,20 @@ func TestMain(m *testing.M) {
func TestHelperLoadConf(t *testing.T) {
testP := &providerHelper{}
assert.Nil(t, testP.setting.Load())
err := testP.reload()
err := testP.reloadSetting()
assert.Nil(t, err)
assert.Equal(t, "test", testP.setting.Load().(models.OIDCSetting).Name)
assert.Equal(t, endpoint{}, testP.ep)
}
func TestHelperCreate(t *testing.T) {
testP := &providerHelper{}
err := testP.reload()
err := testP.reloadSetting()
assert.Nil(t, err)
assert.Nil(t, testP.instance.Load())
err = testP.create()
assert.Nil(t, err)
assert.EqualValues(t, "https://accounts.google.com", testP.ep.url)
assert.NotNil(t, testP.instance.Load())
assert.True(t, time.Now().Sub(testP.creationTime) < 2*time.Second)
}
func TestHelperGet(t *testing.T) {