package dnsrecord import ( "fmt" "net/netip" "regexp" "strconv" "strings" "github.com/go-playground/validator/v10" "pdns_admin/internal/pdns" ) var ( txtValuePattern = regexp.MustCompile(`^"([^"\\]|\\.)*"( "([^"\\]|\\.)*")*$`) caaTagPattern = regexp.MustCompile(`^[A-Za-z][A-Za-z0-9-]*$`) ) type Validator struct { validate *validator.Validate } func NewValidator() (*Validator, error) { v := validator.New() if err := v.RegisterValidation("dns_fqdn", func(fl validator.FieldLevel) bool { return IsFQDN(fl.Field().String()) }); err != nil { return nil, err } return &Validator{validate: v}, nil } func SupportedTypes() []string { return []string{"A", "AAAA", "CAA", "CNAME", "MX", "NS", "SOA", "SRV", "TXT"} } func (v *Validator) ValidateRRSet(name, recordType string, ttl uint64, contents []string) (pdns.RRSet, error) { name = EnsureTrailingDot(name) recordType = strings.ToUpper(strings.TrimSpace(recordType)) if ttl == 0 || ttl > 1<<32-1 { return pdns.RRSet{}, fmt.Errorf("ttl must be between 1 and 4294967295") } if err := v.validate.Var(name, "required,dns_fqdn"); err != nil { return pdns.RRSet{}, fmt.Errorf("record name must be a fully qualified domain name") } if !supported(recordType) { return pdns.RRSet{}, fmt.Errorf("unsupported record type %q", recordType) } records := make([]pdns.Record, 0, len(contents)) for _, content := range contents { content = strings.TrimSpace(content) if content == "" { continue } if err := validateContent(recordType, content); err != nil { return pdns.RRSet{}, err } records = append(records, pdns.Record{Content: content}) } if len(records) == 0 { return pdns.RRSet{}, fmt.Errorf("at least one record value is required") } return pdns.RRSet{ Name: name, Type: recordType, TTL: uint32(ttl), Records: records, }, nil } func EnsureTrailingDot(value string) string { value = strings.TrimSpace(value) if value == "" || strings.HasSuffix(value, ".") { return value } return value + "." } func IsFQDN(value string) bool { value = strings.TrimSpace(value) if value == "." || value == "" || !strings.HasSuffix(value, ".") || len(value) > 253 { return false } labels := strings.Split(strings.TrimSuffix(value, "."), ".") for i, label := range labels { if label == "" || len(label) > 63 { return false } if label == "*" && i == 0 { continue } if strings.HasPrefix(label, "-") || strings.HasSuffix(label, "-") { return false } for _, r := range label { if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '-' { continue } return false } } return true } func validateContent(recordType, content string) error { switch recordType { case "A": addr, err := netip.ParseAddr(content) if err != nil || !addr.Is4() { return fmt.Errorf("A record content must be an IPv4 address") } case "AAAA": addr, err := netip.ParseAddr(content) if err != nil || !addr.Is6() { return fmt.Errorf("AAAA record content must be an IPv6 address") } case "CAA": return validateCAA(content) case "CNAME", "NS": if !IsFQDN(content) { return fmt.Errorf("%s record content must be a fully qualified domain name", recordType) } case "MX": return validateMX(content) case "SOA": return validateSOA(content) case "SRV": return validateSRV(content) case "TXT": if !txtValuePattern.MatchString(content) { return fmt.Errorf("TXT record content must be one or more quoted strings") } } return nil } func validateMX(content string) error { fields := strings.Fields(content) if len(fields) != 2 { return fmt.Errorf("MX record content must be: priority target") } if _, err := parseUint(fields[0], 16); err != nil { return fmt.Errorf("MX priority must be between 0 and 65535") } if !IsFQDN(fields[1]) { return fmt.Errorf("MX target must be a fully qualified domain name") } return nil } func validateSRV(content string) error { fields := strings.Fields(content) if len(fields) != 4 { return fmt.Errorf("SRV record content must be: priority weight port target") } for i, label := range []string{"priority", "weight", "port"} { if _, err := parseUint(fields[i], 16); err != nil { return fmt.Errorf("SRV %s must be between 0 and 65535", label) } } target := strings.TrimSpace(fields[3]) if target != "." && !IsFQDN(target) { return fmt.Errorf("SRV target must be a fully qualified domain name or .") } return nil } func validateSOA(content string) error { fields := strings.Fields(content) if len(fields) != 7 { return fmt.Errorf("SOA record content must be: primary hostmaster serial refresh retry expire minimum") } if !IsFQDN(fields[0]) { return fmt.Errorf("SOA primary nameserver must be a fully qualified domain name") } if !IsFQDN(fields[1]) { return fmt.Errorf("SOA hostmaster must be a fully qualified domain name") } for i, label := range []string{"serial", "refresh", "retry", "expire", "minimum"} { if _, err := parseUint(fields[i+2], 32); err != nil { return fmt.Errorf("SOA %s must be between 0 and 4294967295", label) } } return nil } func validateCAA(content string) error { fields := strings.Fields(content) if len(fields) < 3 { return fmt.Errorf("CAA record content must be: flags tag \"value\"") } flags, err := parseUint(fields[0], 8) if err != nil || flags > 255 { return fmt.Errorf("CAA flags must be between 0 and 255") } if !caaTagPattern.MatchString(fields[1]) { return fmt.Errorf("CAA tag must start with a letter and contain only letters, numbers, and hyphens") } value := strings.Join(fields[2:], " ") if !txtValuePattern.MatchString(value) { return fmt.Errorf("CAA value must be quoted") } return nil } func parseUint(value string, bits int) (uint64, error) { return strconv.ParseUint(value, 10, bits) } func supported(recordType string) bool { for _, candidate := range SupportedTypes() { if candidate == recordType { return true } } return false }