From ab636fe3da0b9e81232ab69ae2da63bd0d296e8d Mon Sep 17 00:00:00 2001 From: "stonezdj(Daojun Zhang)" Date: Sat, 11 Mar 2023 15:45:50 +0800 Subject: [PATCH] Remove go routine to reloadSetting (#18318) Config is cached in redis Fixes #18189 Signed-off-by: stonezdj --- src/core/controllers/oidc.go | 2 +- src/pkg/oidc/helper.go | 98 +++++++++++++++++++----------------- src/pkg/oidc/helper_test.go | 21 +++----- 3 files changed, 59 insertions(+), 62 deletions(-) diff --git a/src/core/controllers/oidc.go b/src/core/controllers/oidc.go index 1f4f040b3..6758da332 100644 --- a/src/core/controllers/oidc.go +++ b/src/core/controllers/oidc.go @@ -58,7 +58,7 @@ func (oc *OIDCController) Prepare() { // RedirectLogin redirect user's browser to OIDC provider's login page func (oc *OIDCController) RedirectLogin() { state := utils.GenerateRandomString() - url, err := oidc.AuthCodeURL(state) + url, err := oidc.AuthCodeURL(oc.Context(), state) if err != nil { oc.SendInternalServerError(err) return diff --git a/src/pkg/oidc/helper.go b/src/pkg/oidc/helper.go index ec828f97f..53bc5690f 100644 --- a/src/pkg/oidc/helper.go +++ b/src/pkg/oidc/helper.go @@ -17,7 +17,6 @@ package oidc import ( "context" "crypto/tls" - "errors" "fmt" "net/http" "regexp" @@ -50,14 +49,13 @@ type claimsProvider interface { type providerHelper struct { sync.Mutex instance atomic.Value - setting atomic.Value creationTime time.Time } -func (p *providerHelper) get() (*gooidc.Provider, error) { +func (p *providerHelper) get(ctx context.Context) (*gooidc.Provider, error) { if p.instance.Load() != nil { if time.Since(p.creationTime) > 3*time.Second { - if err := p.create(); err != nil { + if err := p.create(ctx); err != nil { return nil, err } } @@ -65,42 +63,23 @@ func (p *providerHelper) get() (*gooidc.Provider, error) { p.Lock() defer p.Unlock() if p.instance.Load() == nil { - if err := p.reloadSetting(); err != nil { + if err := p.create(ctx); err != nil { return nil, err } - if err := p.create(); err != nil { - return nil, err - } - go func() { - for { - if err := p.reloadSetting(); err != nil { - log.Warningf("Failed to refresh configuration, error: %v", err) - } - time.Sleep(3 * time.Second) - } - }() } } return p.instance.Load().(*gooidc.Provider), nil } -func (p *providerHelper) reloadSetting() error { - conf, err := config.OIDCSetting(orm.Context()) +func (p *providerHelper) create(ctx context.Context) error { + s, err := config.OIDCSetting(ctx) if err != nil { - return fmt.Errorf("failed to load OIDC setting: %v", err) + log.Errorf("Failed to get OIDC configuration, error: %v", err) + return err } - p.setting.Store(*conf) - return nil -} - -func (p *providerHelper) create() error { - if p.setting.Load() == nil { - return errors.New("the configuration is not loaded") - } - s := p.setting.Load().(cfgModels.OIDCSetting) - ctx := clientCtx(context.Background(), s.VerifyCert) - provider, err := gooidc.NewProvider(ctx, s.Endpoint) + c := clientCtx(ctx, s.VerifyCert) + provider, err := gooidc.NewProvider(c, s.Endpoint) if err != nil { return fmt.Errorf("failed to create OIDC provider, error: %v", err) } @@ -137,12 +116,16 @@ type UserInfo struct { hasGroupClaim bool } -func getOauthConf() (*oauth2.Config, error) { - p, err := provider.get() +func getOauthConf(ctx context.Context) (*oauth2.Config, error) { + p, err := provider.get(ctx) if err != nil { return nil, err } - setting := provider.setting.Load().(cfgModels.OIDCSetting) + setting, err := config.OIDCSetting(ctx) + if err != nil { + log.Errorf("Failed to get OIDC configuration, error: %v", err) + return nil, err + } scopes := make([]string, 0) for _, sc := range setting.Scope { if strings.HasPrefix(p.Endpoint().AuthURL, googleEndpoint) && sc == gooidc.ScopeOfflineAccess { @@ -162,14 +145,18 @@ func getOauthConf() (*oauth2.Config, error) { // AuthCodeURL returns the URL for OIDC provider's consent page. The state should be verified when user is redirected // back to Harbor. -func AuthCodeURL(state string) (string, error) { - conf, err := getOauthConf() +func AuthCodeURL(ctx context.Context, state string) (string, error) { + conf, err := getOauthConf(ctx) if err != nil { log.Errorf("Failed to get OAuth configuration, error: %v", err) return "", err } var options []oauth2.AuthCodeOption - setting := provider.setting.Load().(cfgModels.OIDCSetting) + setting, err := config.OIDCSetting(ctx) + if err != nil { + log.Errorf("Failed to get OIDC configuration, error: %v", err) + return "", err + } for k, v := range setting.ExtraRedirectParms { options = append(options, oauth2.SetAuthURLParam(k, v)) } @@ -182,12 +169,16 @@ func AuthCodeURL(state string) (string, error) { // ExchangeToken get the token from token provider via the code func ExchangeToken(ctx context.Context, code string) (*Token, error) { - oauth, err := getOauthConf() + oauth, err := getOauthConf(ctx) if err != nil { log.Errorf("Failed to get OAuth configuration, error: %v", err) return nil, err } - setting := provider.setting.Load().(cfgModels.OIDCSetting) + setting, err := config.OIDCSetting(ctx) + if err != nil { + log.Errorf("Failed to get OIDC configuration, error: %v", err) + return nil, err + } ctx = clientCtx(ctx, setting.VerifyCert) oauthToken, err := oauth.Exchange(ctx, code) if err != nil { @@ -208,11 +199,15 @@ func VerifyToken(ctx context.Context, rawIDToken string) (*gooidc.IDToken, error func verifyTokenWithConfig(ctx context.Context, rawIDToken string, conf *gooidc.Config) (*gooidc.IDToken, error) { log.Debugf("Raw ID token for verification: %s", rawIDToken) - p, err := provider.get() + p, err := provider.get(ctx) if err != nil { return nil, err } - settings := provider.setting.Load().(cfgModels.OIDCSetting) + settings, err := config.OIDCSetting(ctx) + if err != nil { + log.Errorf("Failed to get OIDC configuration, error: %v", err) + return nil, err + } if conf == nil { conf = &gooidc.Config{ClientID: settings.ClientID} } @@ -236,11 +231,15 @@ func clientCtx(ctx context.Context, verifyCert bool) context.Context { // refreshToken tries to refresh the token if it's expired, if it doesn't the // original one will be returned. func refreshToken(ctx context.Context, token *Token) (*Token, error) { - oauthCfg, err := getOauthConf() + oauthCfg, err := getOauthConf(ctx) if err != nil { return nil, err } - setting := provider.setting.Load().(cfgModels.OIDCSetting) + setting, err := config.OIDCSetting(ctx) + if err != nil { + log.Errorf("Failed to get OIDC configuration, error: %v", err) + return nil, err + } cctx := clientCtx(ctx, setting.VerifyCert) ts := oauthCfg.TokenSource(cctx, &token.Token) nt, err := ts.Token() @@ -258,16 +257,20 @@ func refreshToken(ctx context.Context, token *Token) (*Token, error) { // to generate a UserInfo object, if the ID token is not in the input token struct, some attributes will be empty func UserInfoFromToken(ctx context.Context, token *Token) (*UserInfo, error) { // #10913: preload the configuration, in case it was not previously loaded by the UI - _, err := provider.get() + _, err := provider.get(ctx) if err != nil { return nil, err } - setting := provider.setting.Load().(cfgModels.OIDCSetting) - local, err := UserInfoFromIDToken(ctx, token, setting) + setting, err := config.OIDCSetting(ctx) + if err != nil { + log.Errorf("Failed to get OIDC configuration, error: %v", err) + return nil, err + } + local, err := UserInfoFromIDToken(ctx, token, *setting) if err != nil { return nil, err } - remote, err := userInfoFromRemote(ctx, token, setting) + remote, err := userInfoFromRemote(ctx, token, *setting) if err != nil { log.Warningf("Failed to get userInfo by calling remote userinfo endpoint, error: %v ", err) } @@ -319,7 +322,7 @@ func mergeUserInfo(remote, local *UserInfo) *UserInfo { } func userInfoFromRemote(ctx context.Context, token *Token, setting cfgModels.OIDCSetting) (*UserInfo, error) { - p, err := provider.get() + p, err := provider.get(ctx) if err != nil { return nil, err } @@ -409,6 +412,7 @@ func populateGroupsDB(groupNames []string) ([]int, error) { ctx := orm.Context() cfg, err := config.OIDCSetting(ctx) if err != nil { + log.Errorf("failed to get OIDC config, error: %v", err) return nil, err } log.Debugf("populateGroupsDB, group filter %v", cfg.GroupFilter) diff --git a/src/pkg/oidc/helper_test.go b/src/pkg/oidc/helper_test.go index a71173077..0f96365c9 100644 --- a/src/pkg/oidc/helper_test.go +++ b/src/pkg/oidc/helper_test.go @@ -55,20 +55,11 @@ func TestMain(m *testing.M) { os.Exit(result) } } -func TestHelperLoadConf(t *testing.T) { - testP := &providerHelper{} - assert.Nil(t, testP.setting.Load()) - err := testP.reloadSetting() - assert.Nil(t, err) - assert.Equal(t, "test", testP.setting.Load().(cfgModels.OIDCSetting).Name) -} func TestHelperCreate(t *testing.T) { testP := &providerHelper{} - err := testP.reloadSetting() - assert.Nil(t, err) assert.Nil(t, testP.instance.Load()) - err = testP.create() + err := testP.create(orm.Context()) assert.Nil(t, err) assert.NotNil(t, testP.instance.Load()) assert.True(t, time.Now().Sub(testP.creationTime) < 2*time.Second) @@ -76,7 +67,8 @@ func TestHelperCreate(t *testing.T) { func TestHelperGet(t *testing.T) { testP := &providerHelper{} - p, err := testP.get() + ctx := orm.Context() + p, err := testP.get(ctx) assert.Nil(t, err) assert.Equal(t, "https://oauth2.googleapis.com/token", p.Endpoint().TokenURL) @@ -89,12 +81,13 @@ func TestHelperGet(t *testing.T) { common.OIDCClientSecret: "new-secret", common.ExtEndpoint: "https://harbor.test", } - ctx := orm.Context() config.GetCfgManager(ctx).UpdateConfig(ctx, update) t.Log("Sleep for 5 seconds") time.Sleep(5 * time.Second) - assert.Equal(t, "new-secret", testP.setting.Load().(cfgModels.OIDCSetting).ClientSecret) + oidcSetting, err := config.OIDCSetting(ctx) + assert.Nil(t, err) + assert.Equal(t, "new-secret", oidcSetting.ClientSecret) } func TestAuthCodeURL(t *testing.T) { @@ -110,7 +103,7 @@ func TestAuthCodeURL(t *testing.T) { } ctx := orm.Context() config.GetCfgManager(ctx).UpdateConfig(ctx, conf) - res, err := AuthCodeURL("random") + res, err := AuthCodeURL(ctx, "random") assert.Nil(t, err) u, err := url.ParseRequestURI(res) assert.Nil(t, err)