Merge pull request #8648 from stonezdj/regular_filter

Normalize LDAP filter for user filter and group filter
This commit is contained in:
stonezdj(Daojun Zhang) 2019-08-20 13:47:32 +08:00 committed by GitHub
commit 4384e11422
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 2 deletions

View File

@ -351,13 +351,13 @@ func (session *Session) createUserFilter(username string) string {
filterTag = goldap.EscapeFilter(username)
}
ldapFilter := session.ldapConfig.LdapFilter
ldapFilter := normalizeFilter(session.ldapConfig.LdapFilter)
ldapUID := session.ldapConfig.LdapUID
if ldapFilter == "" {
ldapFilter = "(" + ldapUID + "=" + filterTag + ")"
} else {
ldapFilter = "(&" + ldapFilter + "(" + ldapUID + "=" + filterTag + "))"
ldapFilter = "(&(" + ldapFilter + ")(" + ldapUID + "=" + filterTag + "))"
}
log.Debug("ldap filter :", ldapFilter)
@ -425,6 +425,7 @@ func createGroupSearchFilter(oldFilter, groupName, groupNameAttribute string) st
filter := ""
groupName = goldap.EscapeFilter(groupName)
groupNameAttribute = goldap.EscapeFilter(groupNameAttribute)
oldFilter = normalizeFilter(oldFilter)
if len(oldFilter) == 0 {
if len(groupName) == 0 {
filter = groupNameAttribute + "=*"
@ -455,3 +456,11 @@ func contains(s []string, e string) bool {
}
return false
}
// normalizeFilter - remove '(' and ')' in ldap filter
func normalizeFilter(filter string) string {
norFilter := strings.TrimSpace(filter)
norFilter = strings.TrimPrefix(norFilter, "(")
norFilter = strings.TrimSuffix(norFilter, ")")
return norFilter
}

View File

@ -369,3 +369,25 @@ func TestSession_SearchGroupByDN(t *testing.T) {
})
}
}
func TestNormalizeFilter(t *testing.T) {
type args struct {
filter string
}
tests := []struct {
name string
args args
want string
}{
{"normal test", args{"(objectclass=user)"}, "objectclass=user"},
{"with space", args{" (objectclass=user) "}, "objectclass=user"},
{"nothing", args{"objectclass=user"}, "objectclass=user"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := normalizeFilter(tt.args.filter); got != tt.want {
t.Errorf("normalizeFilter() = %v, want %v", got, tt.want)
}
})
}
}