diff --git a/src/common/dao/base.go b/src/common/dao/base.go index 3e04867da..dc5608334 100644 --- a/src/common/dao/base.go +++ b/src/common/dao/base.go @@ -36,6 +36,12 @@ const ( // ErrDupRows is returned by DAO when inserting failed with error "duplicate key value violates unique constraint" var ErrDupRows = errors.New("sql: duplicate row in DB") +// ErrRollback is returned by DAO when transaction roll back failed +var ErrRollback = errors.New(" transaction roll back error") + +// ErrCommit is returned by DAO when transaction commit failed +var ErrCommit = errors.New(" transaction roll back error") + // Database is an interface of different databases type Database interface { // Name returns the name of database diff --git a/src/common/dao/oidc_user.go b/src/common/dao/oidc_user.go index 6f00726a6..aa26296f6 100644 --- a/src/common/dao/oidc_user.go +++ b/src/common/dao/oidc_user.go @@ -18,27 +18,13 @@ import ( "fmt" "strings" "time" + "errors" "github.com/astaxie/beego/orm" "github.com/goharbor/harbor/src/common/models" "github.com/goharbor/harbor/src/common/utils/log" ) -// AddOIDCUser adds a oidc user -func AddOIDCUser(meta *models.OIDCUser) (int64, error) { - now := time.Now() - meta.CreationTime = now - meta.UpdateTime = now - id, err := GetOrmer().Insert(meta) - if err != nil { - if strings.Contains(err.Error(), "duplicate key value violates unique constraint") { - return 0, ErrDupRows - } - return 0, err - } - return id, nil -} - // GetOIDCUserByID ... func GetOIDCUserByID(id int64) (*models.OIDCUser, error) { oidcUser := &models.OIDCUser{ @@ -54,10 +40,10 @@ func GetOIDCUserByID(id int64) (*models.OIDCUser, error) { return oidcUser, nil } -// GetUserBySub ... -func GetUserBySub(sub string) (*models.User, error) { +// GetUserBySubIss ... +func GetUserBySubIss(sub, issuer string) (*models.User, error) { var oidcUsers []models.OIDCUser - n, err := GetOrmer().Raw(`select * from oidc_user where subiss = ? `, sub).QueryRows(&oidcUsers) + n, err := GetOrmer().Raw(`select * from oidc_user where subiss = ? `, sub+issuer).QueryRows(&oidcUsers) if err != nil { return nil, err } @@ -85,7 +71,6 @@ func GetOIDCUserByUserID(userID int) (*models.OIDCUser, error) { if err != nil { return nil, err } - if n == 0 { return nil, nil } @@ -107,17 +92,22 @@ func DeleteOIDCUser(id int64) error { } // OnBoardOIDCUser onboard OIDC user -func OnBoardOIDCUser(u models.User) error { +func OnBoardOIDCUser(u *models.User) error { + if u.OIDCUserMeta == nil { + return errors.New("unable to onboard as empty oidc user") + } + o := orm.NewOrm() err := o.Begin() if err != nil { return err } - // the password is the required attribute of user table, - // but not required in the oidc user login scenario. - u.Email = "odicpassword" var errInsert error - userID, err := o.Insert(&u) + + // insert user + now := time.Now() + u.CreationTime = now + userID, err := o.Insert(u) if err != nil { errInsert = err log.Errorf("fail to insert user, %v", err) @@ -126,12 +116,18 @@ func OnBoardOIDCUser(u models.User) error { } err := o.Rollback() if err != nil { - return err + log.Errorf("fail to rollback, %v", err) + return ErrRollback } return errInsert } + u.UserID = int(userID) u.OIDCUserMeta.UserID = int(userID) + + // insert oidc user + now = time.Now() + u.OIDCUserMeta.CreationTime = now _, err = o.Insert(u.OIDCUserMeta) if err != nil { errInsert = err @@ -141,14 +137,16 @@ func OnBoardOIDCUser(u models.User) error { } err := o.Rollback() if err != nil { - return err + log.Errorf("fail to rollback, %v", err) + return ErrRollback } return errInsert } err = o.Commit() if err != nil { - return err + log.Errorf("fail to commit, %v", err) + return ErrCommit } return nil -} +} \ No newline at end of file diff --git a/src/common/dao/oidc_user_test.go b/src/common/dao/oidc_user_test.go index 2e27a0716..0da3e1a8c 100644 --- a/src/common/dao/oidc_user_test.go +++ b/src/common/dao/oidc_user_test.go @@ -22,45 +22,57 @@ import ( "github.com/stretchr/testify/require" ) -var ( - user111 = models.User{ +func TestOIDCUserMetaDaoMethods(t *testing.T) { + + user111 := models.User{ Username: "user111", Email: "user111@email.com", } - user222 = models.User{ + user222 := models.User{ Username: "user222", Email: "user222@email.com", } - ou111 = &models.OIDCUser{ + userEmptyOuMeta := models.User{ + Username: "userEmptyOuMeta", + Email: "userEmptyOuMeta@email.com", + } + ou111 := models.OIDCUser{ SubIss: "QWE123123RT1", Secret: "QWEQWE1", } - ou222 = &models.OIDCUser{ + ou222 := models.OIDCUser{ SubIss: "QWE123123RT2", Secret: "QWEQWE2", } -) -func TestOIDCUserMetaDaoMethods(t *testing.T) { - - err := OnBoardUser(&user111) + // onboard OIDC ... + user111.OIDCUserMeta = &ou111 + err := OnBoardOIDCUser(&user111) require.Nil(t, err) - ou111.UserID = user111.UserID - err = OnBoardUser(&user222) - require.Nil(t, err) - ou222.UserID = user222.UserID - - // test add - _, err = AddOIDCUser(ou111) - require.Nil(t, err) - _, err = AddOIDCUser(ou222) + user222.OIDCUserMeta = &ou222 + err = OnBoardOIDCUser(&user222) require.Nil(t, err) - // test get + // empty OIDC user meta ... + err = OnBoardOIDCUser(&userEmptyOuMeta) + require.NotNil(t, err) + assert.Equal(t, "unable to onboard as empty oidc user", err.Error()) + + // test get by ID oidcUser1, err := GetOIDCUserByID(ou111.ID) require.Nil(t, err) assert.Equal(t, ou111.UserID, oidcUser1.UserID) + // test get by userID + oidcUser2, err := GetOIDCUserByUserID(user111.UserID) + require.Nil(t, err) + assert.Equal(t, "QWE123123RT1", oidcUser2.SubIss) + + // test get by sub and iss + userGetBySubIss, err := GetUserBySubIss("QWE123", "123RT1") + require.Nil(t, err) + assert.Equal(t, "user111@email.com", userGetBySubIss.Email) + // test update meta3 := &models.OIDCUser{ ID: ou111.ID, @@ -72,9 +84,15 @@ func TestOIDCUserMetaDaoMethods(t *testing.T) { require.Nil(t, err) assert.Equal(t, "newSub", oidcUser1Update.SubIss) - user, err := GetUserBySub("newSub") + user, err := GetUserBySubIss("new", "Sub") require.Nil(t, err) assert.Equal(t, "user111", user.Username) + + // clear data + defer func() { + _, err := GetOrmer().Raw(`delete from oidc_user`).Exec() + require.Nil(t, err) + }() } func TestOIDCOnboard(t *testing.T) { @@ -86,38 +104,67 @@ func TestOIDCOnboard(t *testing.T) { Username: "user555", Email: "user555@email.com", } + user666 := models.User{ + Username: "user666", + Email: "user666@email.com", + } + userDup := models.User{ + Username: "user333", + Email: "user333@email.com", + } + ou333 := &models.OIDCUser{ - UserID: 333, - SubIss: "QWE123123RT1", + SubIss: "QWE123123RT3", + Secret: "QWEQWE333", + } + ou555 := &models.OIDCUser{ + SubIss: "QWE123123RT5", + Secret: "QWEQWE555", + } + ouDup := &models.OIDCUser{ + SubIss: "QWE123123RT3", Secret: "QWEQWE333", } ouDupSub := &models.OIDCUser{ - UserID: 444, - SubIss: "QWE123123RT1", - Secret: "QWEQWE444", + SubIss: "QWE123123RT3", + Secret: "ouDupSub", } + // data prepare ... + user333.OIDCUserMeta = ou333 + err := OnBoardOIDCUser(&user333) + require.Nil(t, err) + // duplicate user -- ErrDupRows - user111.OIDCUserMeta = ou333 - err := OnBoardOIDCUser(user111) + // userDup is duplicate with user333 + userDup.OIDCUserMeta = ou555 + err = OnBoardOIDCUser(&userDup) require.NotNil(t, err) require.Equal(t, err, ErrDupRows) // duplicate OIDC user -- ErrDupRows - user333.OIDCUserMeta = ou111 - err = OnBoardOIDCUser(user333) + // ouDup is duplicate with ou333 + user555.OIDCUserMeta = ouDup + err = OnBoardOIDCUser(&user555) require.NotNil(t, err) require.Equal(t, err, ErrDupRows) // success - user333.OIDCUserMeta = ou333 - err = OnBoardOIDCUser(user333) + user555.OIDCUserMeta = ou555 + err = OnBoardOIDCUser(&user555) require.Nil(t, err) // duplicate OIDC user's sub -- ErrDupRows - user555.OIDCUserMeta = ouDupSub - err = OnBoardOIDCUser(user555) + // ouDup is duplicate with ou333 + user666.OIDCUserMeta = ouDupSub + err = OnBoardOIDCUser(&user666) require.NotNil(t, err) require.Equal(t, err, ErrDupRows) + // clear data + defer func() { + _, err := GetOrmer().Raw(`delete from oidc_user`).Exec() + require.Nil(t, err) + }() + } diff --git a/src/common/models/user.go b/src/common/models/user.go index d10ae91a3..9b224bd80 100644 --- a/src/common/models/user.go +++ b/src/common/models/user.go @@ -41,7 +41,7 @@ type User struct { CreationTime time.Time `orm:"column(creation_time);auto_now_add" json:"creation_time"` UpdateTime time.Time `orm:"column(update_time);auto_now" json:"update_time"` GroupList []*UserGroup `orm:"-" json:"-"` - OIDCUserMeta *OIDCUser `orm:"-" json:"oidc_user_meta"` + OIDCUserMeta *OIDCUser `orm:"-" json:"oidc_user_meta,omitempty"` } // UserQuery ... diff --git a/src/core/controllers/controllers_test.go b/src/core/controllers/controllers_test.go index 9e3950716..8d442f20b 100644 --- a/src/core/controllers/controllers_test.go +++ b/src/core/controllers/controllers_test.go @@ -50,7 +50,6 @@ func init() { beego.Router("/c/reset", &CommonController{}, "post:ResetPassword") beego.Router("/c/userExists", &CommonController{}, "post:UserExists") beego.Router("/c/sendEmail", &CommonController{}, "get:SendResetEmail") - beego.Router("/c/oidc/onboard", &OIDCController{}, "post:Onboard") beego.Router("/v2/*", &RegistryProxy{}, "*:Handle") } @@ -140,8 +139,4 @@ func TestAll(t *testing.T) { beego.BeeApp.Handlers.ServeHTTP(w, r) assert.Equal(int(404), w.Code, "GET v2/noproject/manifests/1.0 should get a 404 response") - r, _ = http.NewRequest("POST", "/c/oidc/onboard", nil) - w = httptest.NewRecorder() - beego.BeeApp.Handlers.ServeHTTP(w, r) - assert.Equal(int(500), w.Code, "/c/oidc/onboard httpStatusCode should be 500") }