Merge pull request #8220 from reasonerjt/oidc-rotation-fix

Reload OIDC provider older than 3 seconds
This commit is contained in:
Wenkai Yin(尹文开) 2019-07-05 10:12:33 +08:00 committed by GitHub
commit c01bedb740
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 21 deletions

View File

@ -35,20 +35,14 @@ const googleEndpoint = "https://accounts.google.com"
type providerHelper struct { type providerHelper struct {
sync.Mutex sync.Mutex
ep endpoint instance atomic.Value
instance atomic.Value setting atomic.Value
setting atomic.Value creationTime time.Time
}
type endpoint struct {
url string
VerifyCert bool
} }
func (p *providerHelper) get() (*gooidc.Provider, error) { func (p *providerHelper) get() (*gooidc.Provider, error) {
if p.instance.Load() != nil { if p.instance.Load() != nil {
s := p.setting.Load().(models.OIDCSetting) if time.Now().Sub(p.creationTime) > 3*time.Second {
if s.Endpoint != p.ep.url || s.VerifyCert != p.ep.VerifyCert { // relevant settings have changed, need to re-create provider.
if err := p.create(); err != nil { if err := p.create(); err != nil {
return nil, err return nil, err
} }
@ -57,7 +51,7 @@ func (p *providerHelper) get() (*gooidc.Provider, error) {
p.Lock() p.Lock()
defer p.Unlock() defer p.Unlock()
if p.instance.Load() == nil { if p.instance.Load() == nil {
if err := p.reload(); err != nil { if err := p.reloadSetting(); err != nil {
return nil, err return nil, err
} }
if err := p.create(); err != nil { if err := p.create(); err != nil {
@ -65,7 +59,7 @@ func (p *providerHelper) get() (*gooidc.Provider, error) {
} }
go func() { go func() {
for { for {
if err := p.reload(); err != nil { if err := p.reloadSetting(); err != nil {
log.Warningf("Failed to refresh configuration, error: %v", err) log.Warningf("Failed to refresh configuration, error: %v", err)
} }
time.Sleep(3 * time.Second) time.Sleep(3 * time.Second)
@ -73,10 +67,11 @@ func (p *providerHelper) get() (*gooidc.Provider, error) {
}() }()
} }
} }
return p.instance.Load().(*gooidc.Provider), nil return p.instance.Load().(*gooidc.Provider), nil
} }
func (p *providerHelper) reload() error { func (p *providerHelper) reloadSetting() error {
conf, err := config.OIDCSetting() conf, err := config.OIDCSetting()
if err != nil { if err != nil {
return fmt.Errorf("failed to load OIDC setting: %v", err) return fmt.Errorf("failed to load OIDC setting: %v", err)
@ -96,10 +91,7 @@ func (p *providerHelper) create() error {
return fmt.Errorf("failed to create OIDC provider, error: %v", err) return fmt.Errorf("failed to create OIDC provider, error: %v", err)
} }
p.instance.Store(provider) p.instance.Store(provider)
p.ep = endpoint{ p.creationTime = time.Now()
url: s.Endpoint,
VerifyCert: s.VerifyCert,
}
return nil return nil
} }

View File

@ -49,21 +49,20 @@ func TestMain(m *testing.M) {
func TestHelperLoadConf(t *testing.T) { func TestHelperLoadConf(t *testing.T) {
testP := &providerHelper{} testP := &providerHelper{}
assert.Nil(t, testP.setting.Load()) assert.Nil(t, testP.setting.Load())
err := testP.reload() err := testP.reloadSetting()
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "test", testP.setting.Load().(models.OIDCSetting).Name) assert.Equal(t, "test", testP.setting.Load().(models.OIDCSetting).Name)
assert.Equal(t, endpoint{}, testP.ep)
} }
func TestHelperCreate(t *testing.T) { func TestHelperCreate(t *testing.T) {
testP := &providerHelper{} testP := &providerHelper{}
err := testP.reload() err := testP.reloadSetting()
assert.Nil(t, err) assert.Nil(t, err)
assert.Nil(t, testP.instance.Load()) assert.Nil(t, testP.instance.Load())
err = testP.create() err = testP.create()
assert.Nil(t, err) assert.Nil(t, err)
assert.EqualValues(t, "https://accounts.google.com", testP.ep.url)
assert.NotNil(t, testP.instance.Load()) assert.NotNil(t, testP.instance.Load())
assert.True(t, time.Now().Sub(testP.creationTime) < 2*time.Second)
} }
func TestHelperGet(t *testing.T) { func TestHelperGet(t *testing.T) {