diff --git a/src/common/dao/group/usergroup.go b/src/common/dao/group/usergroup.go index 6855c883e..a55b16c45 100644 --- a/src/common/dao/group/usergroup.go +++ b/src/common/dao/group/usergroup.go @@ -19,8 +19,6 @@ import ( "github.com/goharbor/harbor/src/common/utils" - "fmt" - "github.com/goharbor/harbor/src/common" "github.com/goharbor/harbor/src/common/dao" "github.com/goharbor/harbor/src/common/models" @@ -97,21 +95,16 @@ func GetUserGroup(id int) (*models.UserGroup, error) { return nil, nil } -// GetGroupIDByGroupName - Return the group ID by given group name. it is possible less group ID than the given group name if some group doesn't exist. -func GetGroupIDByGroupName(groupName []string, groupType int) ([]int, error) { - var retGroupID []int - if len(groupName) == 0 { - return retGroupID, nil +// PopulateGroup - Return the group ID by given group name. if not exist in Harbor DB, create one +func PopulateGroup(userGroups []models.UserGroup) ([]int, error) { + var ugList []int + for _, group := range userGroups { + OnBoardUserGroup(&group) + if group.ID > 0 { + ugList = append(ugList, group.ID) + } } - sql := fmt.Sprintf("select id from user_group where group_name in ( %s ) and group_type = ? ", dao.ParamPlaceholderForIn(len(groupName))) - log.Debugf("GetGroupIDByGroupName: statement sql is %v", sql) - o := dao.GetOrmer() - cnt, err := o.Raw(sql, groupName, groupType).QueryRows(&retGroupID) - if err != nil { - return retGroupID, err - } - log.Debugf("Found rows %v", cnt) - return retGroupID, nil + return ugList, nil } // DeleteUserGroup ... diff --git a/src/common/dao/group/usergroup_test.go b/src/common/dao/group/usergroup_test.go index 2b7952ef9..d83dc7148 100644 --- a/src/common/dao/group/usergroup_test.go +++ b/src/common/dao/group/usergroup_test.go @@ -17,7 +17,6 @@ package group import ( "fmt" "os" - "reflect" "testing" "github.com/goharbor/harbor/src/common" @@ -25,6 +24,7 @@ import ( "github.com/goharbor/harbor/src/common/dao/project" "github.com/goharbor/harbor/src/common/models" "github.com/goharbor/harbor/src/common/utils/log" + "github.com/stretchr/testify/assert" ) var createdUserGroupID int @@ -52,6 +52,7 @@ func TestMain(m *testing.M) { `insert into project (name, owner_id) values ('group_project2', 1)`, `insert into project (name, owner_id) values ('group_project_private', 1)`, "insert into user_group (group_name, group_type, ldap_group_dn) values ('test_group_01', 1, 'cn=harbor_users,ou=sample,ou=vmware,dc=harbor,dc=com')", + "insert into user_group (group_name, group_type, ldap_group_dn) values ('sync_user_group4', 1, 'cn=sync_user_group4,dc=example,dc=com')", "insert into user_group (group_name, group_type, ldap_group_dn) values ('test_http_group', 2, '')", "insert into user_group (group_name, group_type, ldap_group_dn) values ('test_myhttp_group', 2, '')", "update project set owner_id = (select user_id from harbor_user where username = 'member_test_01') where name = 'member_test_01'", @@ -119,7 +120,7 @@ func TestQueryUserGroup(t *testing.T) { wantErr bool }{ {"Query all user group", args{query: models.UserGroup{GroupName: "test_group_01"}}, 1, false}, - {"Query all ldap group", args{query: models.UserGroup{GroupType: common.LDAPGroupType}}, 2, false}, + {"Query all ldap group", args{query: models.UserGroup{GroupType: common.LDAPGroupType}}, 3, false}, {"Query ldap group with group property", args{query: models.UserGroup{GroupType: common.LDAPGroupType, LdapGroupDN: "CN=harbor_users,OU=sample,OU=vmware,DC=harbor,DC=com"}}, 1, false}, } for _, tt := range tests { @@ -399,7 +400,13 @@ func TestGetRolesByLDAPGroup(t *testing.T) { if err != nil || len(userGroupList) < 1 { t.Errorf("failed to query user group, err %v", err) } - gl2, err2 := GetGroupIDByGroupName([]string{"test_http_group", "test_myhttp_group"}, common.HTTPGroupType) + + userGroups := []models.UserGroup{ + {GroupName: "test_http_group", GroupType: common.HTTPGroupType}, + {GroupName: "test_myhttp_group", GroupType: common.HTTPGroupType}, + } + + gl2, err2 := PopulateGroup(userGroups) if err2 != nil || len(gl2) != 2 { t.Errorf("failed to query http user group, err %v", err) } @@ -439,44 +446,61 @@ func TestGetRolesByLDAPGroup(t *testing.T) { } } -func TestGetGroupIDByGroupName(t *testing.T) { - groupList, err := QueryUserGroup(models.UserGroup{GroupName: "test_http_group", GroupType: 2}) - if err != nil { - t.Error(err) +func TestSyncGroupByGroupKey(t *testing.T) { + type args []models.UserGroup + type result struct { + wantError bool } - if len(groupList) < 0 { - t.Error(err) - } - groupList2, err := QueryUserGroup(models.UserGroup{GroupName: "test_myhttp_group", GroupType: 2}) - if err != nil { - t.Error(err) - } - if len(groupList2) < 0 { - t.Error(err) - } - var expectGroupID []int - type args struct { - groupName []string - } - tests := []struct { - name string - args args - want []int - wantErr bool + cases := []struct { + name string + in args + want result }{ - {"empty query", args{groupName: []string{}}, expectGroupID, false}, - {"normal query", args{groupName: []string{"test_http_group", "test_myhttp_group"}}, []int{groupList[0].ID, groupList2[0].ID}, false}, + { + name: `normal test http group`, + in: args{ + models.UserGroup{GroupName: "orange", GroupType: common.HTTPGroupType}, + models.UserGroup{GroupName: "apple", GroupType: common.HTTPGroupType}, + models.UserGroup{GroupName: "pearl", GroupType: common.HTTPGroupType}}, + want: result{false}, + }, + { + name: `normal test oidc group`, + in: args{ + models.UserGroup{GroupName: "dog", GroupType: common.OIDCGroupType}, + models.UserGroup{GroupName: "cat", GroupType: common.OIDCGroupType}, + models.UserGroup{GroupName: "bee", GroupType: common.OIDCGroupType}, + }, + want: result{false}, + }, + { + name: `normal test oidc group`, + in: args{ + models.UserGroup{GroupName: "cn=sync_user_group1,dc=example,dc=com", LdapGroupDN: "cn=sync_user_group1,dc=example,dc=com", GroupType: common.LDAPGroupType}, + models.UserGroup{GroupName: "cn=sync_user_group2,dc=example,dc=com", LdapGroupDN: "cn=sync_user_group2,dc=example,dc=com", GroupType: common.LDAPGroupType}, + models.UserGroup{GroupName: "cn=sync_user_group3,dc=example,dc=com", LdapGroupDN: "cn=sync_user_group3,dc=example,dc=com", GroupType: common.LDAPGroupType}, + models.UserGroup{GroupName: "cn=sync_user_group4,dc=example,dc=com", LdapGroupDN: "cn=sync_user_group4,dc=example,dc=com", GroupType: common.LDAPGroupType}, + }, + want: result{false}, + }, } - for _, tt := range tests { + + for _, tt := range cases { t.Run(tt.name, func(t *testing.T) { - got, err := GetGroupIDByGroupName(tt.args.groupName, common.HTTPGroupType) - if (err != nil) != tt.wantErr { - t.Errorf("GetHTTPGroupIDByGroupName() error = %v, wantErr %v", err, tt.wantErr) - return + + got, err := PopulateGroup(tt.in) + + if err != nil && !tt.want.wantError { + t.Errorf("error %v", err) } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("GetHTTPGroupIDByGroupName() = %#v, want %#v", got, tt.want) + if !assert.Equal(t, len(tt.in), len(got)) { + t.Errorf(`(%v) != %v; want "%v"`, len(tt.in), len(got), len(tt.in)) } + + for _, id := range got { + DeleteUserGroup(id) + } + }) } } diff --git a/src/common/models/usergroup.go b/src/common/models/usergroup.go index ffad09f23..915076083 100644 --- a/src/common/models/usergroup.go +++ b/src/common/models/usergroup.go @@ -29,3 +29,12 @@ type UserGroup struct { func (u *UserGroup) TableName() string { return UserGroupTable } + +// UserGroupsFromName ... +func UserGroupsFromName(groupNames []string, groupType int) []UserGroup { + var groups []UserGroup + for _, name := range groupNames { + groups = append(groups, UserGroup{GroupName: name, GroupType: groupType}) + } + return groups +} diff --git a/src/core/auth/authproxy/auth.go b/src/core/auth/authproxy/auth.go index 1efe42f3e..26467bac6 100644 --- a/src/core/auth/authproxy/auth.go +++ b/src/core/auth/authproxy/auth.go @@ -108,7 +108,8 @@ func (a *Auth) Authenticate(m models.AuthModel) (*models.User, error) { ugList := reviewResponse.Status.User.Groups log.Debugf("user groups %+v", ugList) if len(ugList) > 0 { - groupIDList, err := group.GetGroupIDByGroupName(ugList, common.HTTPGroupType) + userGroups := models.UserGroupsFromName(ugList, common.HTTPGroupType) + groupIDList, err := group.PopulateGroup(userGroups) if err != nil { return nil, err } diff --git a/src/core/auth/authproxy/auth_test.go b/src/core/auth/authproxy/auth_test.go index b1fb4ab22..86cfebb15 100644 --- a/src/core/auth/authproxy/auth_test.go +++ b/src/core/auth/authproxy/auth_test.go @@ -75,7 +75,17 @@ func TestMain(m *testing.M) { } func TestAuth_Authenticate(t *testing.T) { - groupIDs, err := group.GetGroupIDByGroupName([]string{"vsphere.local\\users", "vsphere.local\\administrators"}, common.HTTPGroupType) + userGroups := []models.UserGroup{ + {GroupName: "vsphere.local\\users", GroupType: common.HTTPGroupType}, + {GroupName: "vsphere.local\\administrators", GroupType: common.HTTPGroupType}, + {GroupName: "vsphere.local\\caadmins", GroupType: common.HTTPGroupType}, + {GroupName: "vsphere.local\\systemconfiguration.bashshelladministrators", GroupType: common.HTTPGroupType}, + {GroupName: "vsphere.local\\systemconfiguration.administrators", GroupType: common.HTTPGroupType}, + {GroupName: "vsphere.local\\licenseservice.administrators", GroupType: common.HTTPGroupType}, + {GroupName: "vsphere.local\\everyone", GroupType: common.HTTPGroupType}, + } + + groupIDs, err := group.PopulateGroup(userGroups) if err != nil { t.Fatal("Failed to get groupIDs") } diff --git a/src/core/auth/ldap/ldap.go b/src/core/auth/ldap/ldap.go index 904aafc66..6a95f8e1a 100644 --- a/src/core/auth/ldap/ldap.go +++ b/src/core/auth/ldap/ldap.go @@ -79,8 +79,6 @@ func (l *Auth) Authenticate(m models.AuthModel) (*models.User, error) { u.Username = ldapUsers[0].Username u.Email = strings.TrimSpace(ldapUsers[0].Email) u.Realname = ldapUsers[0].Realname - ugIDs := []int{} - dn := ldapUsers[0].DN if err = ldapSession.Bind(dn, m.Password); err != nil { log.Warningf("Failed to bind user, username: %s, dn: %s, error: %v", u.Username, dn, err) @@ -90,6 +88,10 @@ func (l *Auth) Authenticate(m models.AuthModel) (*models.User, error) { // Retrieve ldap related info in login to avoid too many traffic with LDAP server. // Get group admin dn groupCfg, err := config.LDAPGroupConf() + if err != nil { + log.Warningf("Failed to fetch ldap group configuration:%v", err) + // most likely user doesn't configure user group info, it should not block user login + } groupAdminDN := utils.TrimLower(groupCfg.LdapGroupAdminDN) // Attach user group for _, groupDN := range ldapUsers[0].GroupDNList { @@ -100,20 +102,15 @@ func (l *Auth) Authenticate(m models.AuthModel) (*models.User, error) { u.HasAdminRole = true } - userGroupQuery := models.UserGroup{ - GroupType: 1, - LdapGroupDN: groupDN, - } - userGroups, err := group.QueryUserGroup(userGroupQuery) - if err != nil { - continue - } - if len(userGroups) == 0 { - continue - } - ugIDs = append(ugIDs, userGroups[0].ID) } - u.GroupIDs = ugIDs + var userGroups []models.UserGroup + for _, dn := range ldapUsers[0].GroupDNList { + userGroups = append(userGroups, models.UserGroup{GroupName: dn, LdapGroupDN: dn, GroupType: common.LDAPGroupType}) + } + u.GroupIDs, err = group.PopulateGroup(userGroups) + if err != nil { + log.Warningf("Failed to fetch ldap group configuration:%v", err) + } return &u, nil } diff --git a/src/core/auth/ldap/ldap_test.go b/src/core/auth/ldap/ldap_test.go index 9002bd8bf..fb086a88b 100644 --- a/src/core/auth/ldap/ldap_test.go +++ b/src/core/auth/ldap/ldap_test.go @@ -28,6 +28,7 @@ import ( "github.com/goharbor/harbor/src/common/utils/test" "github.com/goharbor/harbor/src/core/api" + "github.com/goharbor/harbor/src/common/dao/group" "github.com/goharbor/harbor/src/core/auth" coreConfig "github.com/goharbor/harbor/src/core/config" ) @@ -401,11 +402,12 @@ func TestAddProjectMemberWithLdapGroup(t *testing.T) { if err != nil { t.Errorf("Error occurred when GetProjectByName: %v", err) } + userGroups := []models.UserGroup{{GroupName: "cn=harbor_users,ou=groups,dc=example,dc=com", LdapGroupDN: "cn=harbor_users,ou=groups,dc=example,dc=com", GroupType: common.LDAPGroupType}} + groupIds, err := group.PopulateGroup(userGroups) member := models.MemberReq{ ProjectID: currentProject.ProjectID, MemberGroup: models.UserGroup{ - LdapGroupDN: "cn=harbor_users,ou=groups,dc=example,dc=com", - GroupType: 1, + ID: groupIds[0], }, Role: models.PROJECTADMIN, } diff --git a/src/core/controllers/oidc.go b/src/core/controllers/oidc.go index c43e7ffca..6594721b4 100644 --- a/src/core/controllers/oidc.go +++ b/src/core/controllers/oidc.go @@ -116,7 +116,9 @@ func (oc *OIDCController) Callback() { oc.SendInternalServerError(err) return } - d.GroupIDs, err = group.GetGroupIDByGroupName(oidc.GroupsFromToken(idToken), common.OIDCGroupType) + groupNames := oidc.GroupsFromToken(idToken) + oidcGroups := models.UserGroupsFromName(groupNames, common.OIDCGroupType) + d.GroupIDs, err = group.PopulateGroup(oidcGroups) if err != nil { log.Warningf("Failed to get group ID list, due to error: %v, setting empty list into user model.", err) } diff --git a/src/core/filter/security.go b/src/core/filter/security.go index 891fcb5e4..3e71947f9 100644 --- a/src/core/filter/security.go +++ b/src/core/filter/security.go @@ -289,7 +289,9 @@ func (it *idTokenReqCtxModifier) Modify(ctx *beegoctx.Context) bool { log.Warning("User matches token's claims is not onboarded.") return false } - u.GroupIDs, err = group.GetGroupIDByGroupName(oidc.GroupsFromToken(claims), common.OIDCGroupType) + groupNames := oidc.GroupsFromToken(claims) + groups := models.UserGroupsFromName(groupNames, common.OIDCGroupType) + u.GroupIDs, err = group.PopulateGroup(groups) if err != nil { log.Errorf("Failed to get group ID list for OIDC user: %s, error: %v", u.Username, err) }