183 lines
4.6 KiB
Go
183 lines
4.6 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"fmt"
|
|
"net"
|
|
"regexp"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/go-ldap/ldap/v3"
|
|
)
|
|
|
|
type Authenticator interface {
|
|
Authenticate(context.Context, string, string) (bool, error)
|
|
}
|
|
|
|
type LDAPConfig struct {
|
|
URL string
|
|
StartTLS bool
|
|
InsecureSkipVerify bool
|
|
BindDN string
|
|
BindPassword string
|
|
UserBaseDN string
|
|
UsernameAttribute string
|
|
UserFilter string
|
|
GroupBaseDN string
|
|
GroupFilter string
|
|
}
|
|
|
|
type LDAPAuthenticator struct {
|
|
cfg LDAPConfig
|
|
}
|
|
|
|
var attributePattern = regexp.MustCompile(`^[A-Za-z][A-Za-z0-9.-]*$`)
|
|
|
|
func NewLDAPAuthenticator(cfg LDAPConfig) (*LDAPAuthenticator, error) {
|
|
cfg.URL = strings.TrimSpace(cfg.URL)
|
|
cfg.BindDN = strings.TrimSpace(cfg.BindDN)
|
|
cfg.BindPassword = strings.TrimSpace(cfg.BindPassword)
|
|
cfg.UserBaseDN = strings.TrimSpace(cfg.UserBaseDN)
|
|
cfg.UsernameAttribute = strings.TrimSpace(cfg.UsernameAttribute)
|
|
cfg.UserFilter = strings.TrimSpace(cfg.UserFilter)
|
|
cfg.GroupBaseDN = strings.TrimSpace(cfg.GroupBaseDN)
|
|
cfg.GroupFilter = strings.TrimSpace(cfg.GroupFilter)
|
|
|
|
if cfg.UsernameAttribute == "" {
|
|
cfg.UsernameAttribute = "uid"
|
|
}
|
|
if cfg.UserFilter == "" {
|
|
cfg.UserFilter = "({username_attribute}={username})"
|
|
}
|
|
if !attributePattern.MatchString(cfg.UsernameAttribute) {
|
|
return nil, fmt.Errorf("ldap username attribute %q is not safe for filters", cfg.UsernameAttribute)
|
|
}
|
|
|
|
return &LDAPAuthenticator{cfg: cfg}, nil
|
|
}
|
|
|
|
func (a *LDAPAuthenticator) Authenticate(ctx context.Context, username, password string) (bool, error) {
|
|
username = strings.TrimSpace(username)
|
|
if username == "" || password == "" {
|
|
return false, nil
|
|
}
|
|
|
|
conn, err := a.dial(ctx)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
defer conn.Close()
|
|
|
|
if err := conn.Bind(a.cfg.BindDN, a.cfg.BindPassword); err != nil {
|
|
return false, fmt.Errorf("ldap service bind failed: %w", err)
|
|
}
|
|
|
|
userDN, err := a.findUser(conn, username)
|
|
if err != nil || userDN == "" {
|
|
return false, err
|
|
}
|
|
|
|
if a.cfg.GroupFilter != "" {
|
|
ok, err := a.userInAllowedGroup(conn, username, userDN)
|
|
if err != nil || !ok {
|
|
return ok, err
|
|
}
|
|
}
|
|
|
|
if err := conn.Bind(userDN, password); err != nil {
|
|
if ldap.IsErrorWithCode(err, ldap.LDAPResultInvalidCredentials) {
|
|
return false, nil
|
|
}
|
|
return false, fmt.Errorf("ldap user bind failed: %w", err)
|
|
}
|
|
|
|
return true, nil
|
|
}
|
|
|
|
func (a *LDAPAuthenticator) dial(ctx context.Context) (*ldap.Conn, error) {
|
|
dialer := &net.Dialer{Timeout: 5 * time.Second}
|
|
if deadline, ok := ctx.Deadline(); ok {
|
|
dialer.Deadline = deadline
|
|
}
|
|
|
|
conn, err := ldap.DialURL(a.cfg.URL, ldap.DialWithDialer(dialer), ldap.DialWithTLSConfig(a.tlsConfig()))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("ldap dial failed: %w", err)
|
|
}
|
|
conn.SetTimeout(10 * time.Second)
|
|
|
|
if a.cfg.StartTLS {
|
|
if err := conn.StartTLS(a.tlsConfig()); err != nil {
|
|
conn.Close()
|
|
return nil, fmt.Errorf("ldap starttls failed: %w", err)
|
|
}
|
|
}
|
|
|
|
return conn, nil
|
|
}
|
|
|
|
func (a *LDAPAuthenticator) tlsConfig() *tls.Config {
|
|
return &tls.Config{InsecureSkipVerify: a.cfg.InsecureSkipVerify} //nolint:gosec // Explicit user-controlled LDAP option.
|
|
}
|
|
|
|
func (a *LDAPAuthenticator) findUser(conn *ldap.Conn, username string) (string, error) {
|
|
filter := renderFilter(a.cfg.UserFilter, map[string]string{
|
|
"username_attribute": a.cfg.UsernameAttribute,
|
|
"username": ldap.EscapeFilter(username),
|
|
})
|
|
result, err := conn.Search(ldap.NewSearchRequest(
|
|
a.cfg.UserBaseDN,
|
|
ldap.ScopeWholeSubtree,
|
|
ldap.NeverDerefAliases,
|
|
2,
|
|
10,
|
|
false,
|
|
filter,
|
|
[]string{"dn"},
|
|
nil,
|
|
))
|
|
if err != nil {
|
|
return "", fmt.Errorf("ldap user search failed: %w", err)
|
|
}
|
|
if len(result.Entries) == 0 {
|
|
return "", nil
|
|
}
|
|
if len(result.Entries) > 1 {
|
|
return "", fmt.Errorf("ldap user search returned multiple entries")
|
|
}
|
|
return result.Entries[0].DN, nil
|
|
}
|
|
|
|
func (a *LDAPAuthenticator) userInAllowedGroup(conn *ldap.Conn, username, userDN string) (bool, error) {
|
|
filter := renderFilter(a.cfg.GroupFilter, map[string]string{
|
|
"username_attribute": a.cfg.UsernameAttribute,
|
|
"username": ldap.EscapeFilter(username),
|
|
"user_dn": ldap.EscapeFilter(userDN),
|
|
})
|
|
result, err := conn.Search(ldap.NewSearchRequest(
|
|
a.cfg.GroupBaseDN,
|
|
ldap.ScopeWholeSubtree,
|
|
ldap.NeverDerefAliases,
|
|
1,
|
|
10,
|
|
false,
|
|
filter,
|
|
[]string{"dn"},
|
|
nil,
|
|
))
|
|
if err != nil {
|
|
return false, fmt.Errorf("ldap group search failed: %w", err)
|
|
}
|
|
return len(result.Entries) > 0, nil
|
|
}
|
|
|
|
func renderFilter(template string, values map[string]string) string {
|
|
result := template
|
|
for key, value := range values {
|
|
result = strings.ReplaceAll(result, "{"+key+"}", value)
|
|
}
|
|
return result
|
|
}
|