package auth import ( "context" "crypto/tls" "fmt" "net" "regexp" "strings" "time" "github.com/go-ldap/ldap/v3" ) type Authenticator interface { Authenticate(context.Context, string, string) (bool, error) } type LDAPConfig struct { URL string StartTLS bool InsecureSkipVerify bool BindDN string BindPassword string UserBaseDN string UsernameAttribute string UserFilter string GroupBaseDN string GroupFilter string } type LDAPAuthenticator struct { cfg LDAPConfig } var attributePattern = regexp.MustCompile(`^[A-Za-z][A-Za-z0-9.-]*$`) func NewLDAPAuthenticator(cfg LDAPConfig) (*LDAPAuthenticator, error) { cfg.URL = strings.TrimSpace(cfg.URL) cfg.BindDN = strings.TrimSpace(cfg.BindDN) cfg.BindPassword = strings.TrimSpace(cfg.BindPassword) cfg.UserBaseDN = strings.TrimSpace(cfg.UserBaseDN) cfg.UsernameAttribute = strings.TrimSpace(cfg.UsernameAttribute) cfg.UserFilter = strings.TrimSpace(cfg.UserFilter) cfg.GroupBaseDN = strings.TrimSpace(cfg.GroupBaseDN) cfg.GroupFilter = strings.TrimSpace(cfg.GroupFilter) if cfg.UsernameAttribute == "" { cfg.UsernameAttribute = "uid" } if cfg.UserFilter == "" { cfg.UserFilter = "({username_attribute}={username})" } if !attributePattern.MatchString(cfg.UsernameAttribute) { return nil, fmt.Errorf("ldap username attribute %q is not safe for filters", cfg.UsernameAttribute) } return &LDAPAuthenticator{cfg: cfg}, nil } func (a *LDAPAuthenticator) Authenticate(ctx context.Context, username, password string) (bool, error) { username = strings.TrimSpace(username) if username == "" || password == "" { return false, nil } conn, err := a.dial(ctx) if err != nil { return false, err } defer conn.Close() if err := conn.Bind(a.cfg.BindDN, a.cfg.BindPassword); err != nil { return false, fmt.Errorf("ldap service bind failed: %w", err) } userDN, err := a.findUser(conn, username) if err != nil || userDN == "" { return false, err } if a.cfg.GroupFilter != "" { ok, err := a.userInAllowedGroup(conn, username, userDN) if err != nil || !ok { return ok, err } } if err := conn.Bind(userDN, password); err != nil { if ldap.IsErrorWithCode(err, ldap.LDAPResultInvalidCredentials) { return false, nil } return false, fmt.Errorf("ldap user bind failed: %w", err) } return true, nil } func (a *LDAPAuthenticator) dial(ctx context.Context) (*ldap.Conn, error) { dialer := &net.Dialer{Timeout: 5 * time.Second} if deadline, ok := ctx.Deadline(); ok { dialer.Deadline = deadline } conn, err := ldap.DialURL(a.cfg.URL, ldap.DialWithDialer(dialer), ldap.DialWithTLSConfig(a.tlsConfig())) if err != nil { return nil, fmt.Errorf("ldap dial failed: %w", err) } conn.SetTimeout(10 * time.Second) if a.cfg.StartTLS { if err := conn.StartTLS(a.tlsConfig()); err != nil { conn.Close() return nil, fmt.Errorf("ldap starttls failed: %w", err) } } return conn, nil } func (a *LDAPAuthenticator) tlsConfig() *tls.Config { return &tls.Config{InsecureSkipVerify: a.cfg.InsecureSkipVerify} //nolint:gosec // Explicit user-controlled LDAP option. } func (a *LDAPAuthenticator) findUser(conn *ldap.Conn, username string) (string, error) { filter := renderFilter(a.cfg.UserFilter, map[string]string{ "username_attribute": a.cfg.UsernameAttribute, "username": ldap.EscapeFilter(username), }) result, err := conn.Search(ldap.NewSearchRequest( a.cfg.UserBaseDN, ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 2, 10, false, filter, []string{"dn"}, nil, )) if err != nil { return "", fmt.Errorf("ldap user search failed: %w", err) } if len(result.Entries) == 0 { return "", nil } if len(result.Entries) > 1 { return "", fmt.Errorf("ldap user search returned multiple entries") } return result.Entries[0].DN, nil } func (a *LDAPAuthenticator) userInAllowedGroup(conn *ldap.Conn, username, userDN string) (bool, error) { filter := renderFilter(a.cfg.GroupFilter, map[string]string{ "username_attribute": a.cfg.UsernameAttribute, "username": ldap.EscapeFilter(username), "user_dn": ldap.EscapeFilter(userDN), }) result, err := conn.Search(ldap.NewSearchRequest( a.cfg.GroupBaseDN, ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 1, 10, false, filter, []string{"dn"}, nil, )) if err != nil { return false, fmt.Errorf("ldap group search failed: %w", err) } return len(result.Entries) > 0, nil } func renderFilter(template string, values map[string]string) string { result := template for key, value := range values { result = strings.ReplaceAll(result, "{"+key+"}", value) } return result }