diff --git a/src/common/utils/ldap/ldap.go b/src/common/utils/ldap/ldap.go index 1cc7cfc6e..6cc05c3c8 100644 --- a/src/common/utils/ldap/ldap.go +++ b/src/common/utils/ldap/ldap.go @@ -294,9 +294,9 @@ func (session *Session) SearchLdapAttribute(baseDN, filter string, attributes [] if err := session.Bind(session.ldapConfig.LdapSearchDn, session.ldapConfig.LdapSearchPassword); err != nil { return nil, fmt.Errorf("Can not bind search dn, error: %v", err) } - filter = strings.TrimSpace(filter) - if !(strings.HasPrefix(filter, "(") || strings.HasSuffix(filter, ")")) { - filter = "(" + filter + ")" + filter = normalizeFilter(filter) + if len(filter) == 0 { + return nil, ErrInvalidFilter } if _, err := goldap.CompileFilter(filter); err != nil { log.Errorf("Wrong filter format, filter:%v", filter) @@ -371,7 +371,11 @@ func (session *Session) SearchGroupByDN(groupDN string) ([]models.LdapGroup, err if _, err := goldap.ParseDN(groupDN); err != nil { return nil, ErrDNSyntax } - groupList, err := session.searchGroup(groupDN, session.ldapGroupConfig.LdapGroupFilter, "", session.ldapGroupConfig.LdapGroupNameAttribute) + ldapFilter, err := createGroupSearchFilter(session.ldapGroupConfig.LdapGroupFilter, "", session.ldapGroupConfig.LdapGroupNameAttribute) + if err != nil { + return nil, err + } + groupList, err := session.searchGroup(groupDN, ldapFilter, "", session.ldapGroupConfig.LdapGroupNameAttribute) if serverError, ok := err.(*goldap.Error); ok { log.Debugf("resultCode:%v", serverError.ResultCode) } @@ -388,56 +392,85 @@ func (session *Session) groupBaseDN() string { return session.ldapGroupConfig.LdapGroupBaseDN } -func (session *Session) searchGroup(groupDN, filter, groupName, groupNameAttribute string) ([]models.LdapGroup, error) { +// searchGroup -- Given a group DN and filter, search group +func (session *Session) searchGroup(groupDN, filter, gName, groupNameAttribute string) ([]models.LdapGroup, error) { ldapGroups := make([]models.LdapGroup, 0) - log.Debugf("Groupname: %v, groupDN: %v", groupName, groupDN) - ldapFilter, err := createGroupSearchFilter(filter, groupName, groupNameAttribute) + log.Debugf("Groupname: %v, groupDN: %v", gName, groupDN) + + // Check current group DN is under the LDAP group base DN + isChild, err := UnderBaseDN(session.groupBaseDN(), groupDN) if err != nil { - log.Errorf("wrong filter format: filter:%v, groupName:%v, groupNameAttribute:%v", filter, groupName, groupNameAttribute) - return nil, err + return ldapGroups, err } - attributes := []string{groupNameAttribute} - result, err := session.SearchLdapAttribute(session.groupBaseDN(), ldapFilter, attributes) + if !isChild { + return ldapGroups, nil + } + + // Search the groupDN with LDAP group filter condition + ldapFilter, err := createGroupSearchFilter(filter, gName, groupNameAttribute) if err != nil { - return nil, err + log.Errorf("wrong filter format: filter:%v, gName:%v, groupNameAttribute:%v", filter, gName, groupNameAttribute) + return ldapGroups, err } - for _, ldapEntry := range result.Entries { - var group models.LdapGroup - if groupDN != ldapEntry.DN { - continue - } - group.GroupDN = ldapEntry.DN - for _, attr := range ldapEntry.Attributes { - // OpenLdap sometimes contain leading space in username - val := strings.TrimSpace(attr.Values[0]) - log.Debugf("Current ldap entry attr name: %s\n", attr.Name) - switch strings.ToLower(attr.Name) { - case strings.ToLower(groupNameAttribute): - group.GroupName = val - } - } - ldapGroups = append(ldapGroups, group) + + // There maybe many groups under the LDAP group base DN + // If return all groups in LDAP group base DN, it might get "Size Limit Exceeded" error + // Take the groupDN as the baseDN in the search request to avoid return too many records + result, err := session.SearchLdapAttribute(groupDN, ldapFilter, []string{groupNameAttribute}) + if err != nil { + return ldapGroups, err } + if len(result.Entries) == 0 { + return ldapGroups, nil + } + groupName := "" + if len(result.Entries[0].Attributes) > 0 { + groupName = result.Entries[0].Attributes[0].Values[0] + } + group := models.LdapGroup{ + GroupDN: groupDN, + GroupName: groupName, + } + ldapGroups = append(ldapGroups, group) + return ldapGroups, nil } -func createGroupSearchFilter(oldFilterStr, groupName, groupNameAttribute string) (string, error) { - origFilter, err := NewFilterBuilder(oldFilterStr) +// UnderBaseDN - check if the childDN is under the baseDN, if the baseDN equals current DN, return true +func UnderBaseDN(baseDN, childDN string) (bool, error) { + base, err := goldap.ParseDN(baseDN) if err != nil { - log.Errorf("failed to create group search filter:%v", oldFilterStr) + return false, err + } + child, err := goldap.ParseDN(childDN) + if err != nil { + return false, err + } + return base.AncestorOf(child) || base.Equal(child), nil +} + +// createGroupSearchFilter - Create group search filter with base filter and group name filter condition +func createGroupSearchFilter(baseFilter, groupName, groupNameAttr string) (string, error) { + base, err := NewFilterBuilder(baseFilter) + if err != nil { + log.Errorf("failed to create group search filter:%v", baseFilter) return "", err } groupName = goldap.EscapeFilter(groupName) gFilterStr := "" - if len(groupName) > 0 { - gFilterStr = fmt.Sprintf("(%v=%v)", goldap.EscapeFilter(groupNameAttribute), groupName) + // when groupName is empty, search all groups in current base DN + if len(groupName) == 0 { + groupName = "*" } - gFilter, err := NewFilterBuilder(gFilterStr) + if len(groupNameAttr) == 0 { + groupNameAttr = "cn" + } + gFilter, err := NewFilterBuilder("(" + goldap.EscapeFilter(groupNameAttr) + "=" + groupName + ")") if err != nil { log.Errorf("invalid ldap filter:%v", gFilterStr) return "", err } - fb := origFilter.And(gFilter) + fb := base.And(gFilter) return fb.String() } diff --git a/src/common/utils/ldap/ldap_test.go b/src/common/utils/ldap/ldap_test.go index f01c7affe..2fa66760d 100644 --- a/src/common/utils/ldap/ldap_test.go +++ b/src/common/utils/ldap/ldap_test.go @@ -368,7 +368,7 @@ func TestSession_SearchGroupByDN(t *testing.T) { {"search non-exist group", fields{ldapConfig: ldapConfig, ldapGroupConfig: ldapGroupConfig}, args{groupDN: "cn=harbor_non_users,ou=groups,dc=example,dc=com"}, - []models.LdapGroup{}, false}, + nil, true}, {"search invalid group dn", fields{ldapConfig: ldapConfig, ldapGroupConfig: ldapGroupConfig}, args{groupDN: "random string"}, @@ -475,3 +475,59 @@ func TestNormalizeFilter(t *testing.T) { }) } } + +func TestUnderBaseDN(t *testing.T) { + type args struct { + baseDN string + childDN string + } + cases := []struct { + name string + in args + wantError bool + want bool + }{ + { + name: `normal`, + in: args{"dc=example,dc=com", "cn=admin,dc=example,dc=com"}, + wantError: false, + want: true, + }, + { + name: `false`, + in: args{"dc=vmware,dc=com", "cn=admin,dc=example,dc=com"}, + wantError: false, + want: false, + }, + { + name: `same dn`, + in: args{"cn=admin,dc=example,dc=com", "cn=admin,dc=example,dc=com"}, + wantError: false, + want: true, + }, + { + name: `error format in base`, + in: args{"abc", "cn=admin,dc=example,dc=com"}, + wantError: true, + want: false, + }, + { + name: `error format in child`, + in: args{"dc=vmware,dc=com", "wrong format"}, + wantError: true, + want: false, + }, + } + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + got, err := UnderBaseDN(tt.in.baseDN, tt.in.childDN) + if (err != nil) != tt.wantError { + t.Errorf("UnderBaseDN error = %v, wantErr %v", err, tt.wantError) + return + } + if got != tt.want { + t.Errorf(`(%v) = %v; want "%v"`, tt.in, got, tt.want) + } + }) + } +}