Files
pdns-admin/internal/auth/ldap.go
2026-06-18 22:32:42 -03:00

183 lines
4.6 KiB
Go

package auth
import (
"context"
"crypto/tls"
"fmt"
"net"
"regexp"
"strings"
"time"
"github.com/go-ldap/ldap/v3"
)
type Authenticator interface {
Authenticate(context.Context, string, string) (bool, error)
}
type LDAPConfig struct {
URL string
StartTLS bool
InsecureSkipVerify bool
BindDN string
BindPassword string
UserBaseDN string
UsernameAttribute string
UserFilter string
GroupBaseDN string
GroupFilter string
}
type LDAPAuthenticator struct {
cfg LDAPConfig
}
var attributePattern = regexp.MustCompile(`^[A-Za-z][A-Za-z0-9.-]*$`)
func NewLDAPAuthenticator(cfg LDAPConfig) (*LDAPAuthenticator, error) {
cfg.URL = strings.TrimSpace(cfg.URL)
cfg.BindDN = strings.TrimSpace(cfg.BindDN)
cfg.BindPassword = strings.TrimSpace(cfg.BindPassword)
cfg.UserBaseDN = strings.TrimSpace(cfg.UserBaseDN)
cfg.UsernameAttribute = strings.TrimSpace(cfg.UsernameAttribute)
cfg.UserFilter = strings.TrimSpace(cfg.UserFilter)
cfg.GroupBaseDN = strings.TrimSpace(cfg.GroupBaseDN)
cfg.GroupFilter = strings.TrimSpace(cfg.GroupFilter)
if cfg.UsernameAttribute == "" {
cfg.UsernameAttribute = "uid"
}
if cfg.UserFilter == "" {
cfg.UserFilter = "({username_attribute}={username})"
}
if !attributePattern.MatchString(cfg.UsernameAttribute) {
return nil, fmt.Errorf("ldap username attribute %q is not safe for filters", cfg.UsernameAttribute)
}
return &LDAPAuthenticator{cfg: cfg}, nil
}
func (a *LDAPAuthenticator) Authenticate(ctx context.Context, username, password string) (bool, error) {
username = strings.TrimSpace(username)
if username == "" || password == "" {
return false, nil
}
conn, err := a.dial(ctx)
if err != nil {
return false, err
}
defer conn.Close()
if err := conn.Bind(a.cfg.BindDN, a.cfg.BindPassword); err != nil {
return false, fmt.Errorf("ldap service bind failed: %w", err)
}
userDN, err := a.findUser(conn, username)
if err != nil || userDN == "" {
return false, err
}
if a.cfg.GroupFilter != "" {
ok, err := a.userInAllowedGroup(conn, username, userDN)
if err != nil || !ok {
return ok, err
}
}
if err := conn.Bind(userDN, password); err != nil {
if ldap.IsErrorWithCode(err, ldap.LDAPResultInvalidCredentials) {
return false, nil
}
return false, fmt.Errorf("ldap user bind failed: %w", err)
}
return true, nil
}
func (a *LDAPAuthenticator) dial(ctx context.Context) (*ldap.Conn, error) {
dialer := &net.Dialer{Timeout: 5 * time.Second}
if deadline, ok := ctx.Deadline(); ok {
dialer.Deadline = deadline
}
conn, err := ldap.DialURL(a.cfg.URL, ldap.DialWithDialer(dialer), ldap.DialWithTLSConfig(a.tlsConfig()))
if err != nil {
return nil, fmt.Errorf("ldap dial failed: %w", err)
}
conn.SetTimeout(10 * time.Second)
if a.cfg.StartTLS {
if err := conn.StartTLS(a.tlsConfig()); err != nil {
conn.Close()
return nil, fmt.Errorf("ldap starttls failed: %w", err)
}
}
return conn, nil
}
func (a *LDAPAuthenticator) tlsConfig() *tls.Config {
return &tls.Config{InsecureSkipVerify: a.cfg.InsecureSkipVerify} //nolint:gosec // Explicit user-controlled LDAP option.
}
func (a *LDAPAuthenticator) findUser(conn *ldap.Conn, username string) (string, error) {
filter := renderFilter(a.cfg.UserFilter, map[string]string{
"username_attribute": a.cfg.UsernameAttribute,
"username": ldap.EscapeFilter(username),
})
result, err := conn.Search(ldap.NewSearchRequest(
a.cfg.UserBaseDN,
ldap.ScopeWholeSubtree,
ldap.NeverDerefAliases,
2,
10,
false,
filter,
[]string{"dn"},
nil,
))
if err != nil {
return "", fmt.Errorf("ldap user search failed: %w", err)
}
if len(result.Entries) == 0 {
return "", nil
}
if len(result.Entries) > 1 {
return "", fmt.Errorf("ldap user search returned multiple entries")
}
return result.Entries[0].DN, nil
}
func (a *LDAPAuthenticator) userInAllowedGroup(conn *ldap.Conn, username, userDN string) (bool, error) {
filter := renderFilter(a.cfg.GroupFilter, map[string]string{
"username_attribute": a.cfg.UsernameAttribute,
"username": ldap.EscapeFilter(username),
"user_dn": ldap.EscapeFilter(userDN),
})
result, err := conn.Search(ldap.NewSearchRequest(
a.cfg.GroupBaseDN,
ldap.ScopeWholeSubtree,
ldap.NeverDerefAliases,
1,
10,
false,
filter,
[]string{"dn"},
nil,
))
if err != nil {
return false, fmt.Errorf("ldap group search failed: %w", err)
}
return len(result.Entries) > 0, nil
}
func renderFilter(template string, values map[string]string) string {
result := template
for key, value := range values {
result = strings.ReplaceAll(result, "{"+key+"}", value)
}
return result
}