Files
pdns-admin/internal/dnsrecord/validator.go
2026-06-18 22:32:42 -03:00

224 lines
5.9 KiB
Go

package dnsrecord
import (
"fmt"
"net/netip"
"regexp"
"strconv"
"strings"
"github.com/go-playground/validator/v10"
"pdns_admin/internal/pdns"
)
var (
txtValuePattern = regexp.MustCompile(`^"([^"\\]|\\.)*"( "([^"\\]|\\.)*")*$`)
caaTagPattern = regexp.MustCompile(`^[A-Za-z][A-Za-z0-9-]*$`)
)
type Validator struct {
validate *validator.Validate
}
func NewValidator() (*Validator, error) {
v := validator.New()
if err := v.RegisterValidation("dns_fqdn", func(fl validator.FieldLevel) bool {
return IsFQDN(fl.Field().String())
}); err != nil {
return nil, err
}
return &Validator{validate: v}, nil
}
func SupportedTypes() []string {
return []string{"A", "AAAA", "CAA", "CNAME", "MX", "NS", "SOA", "SRV", "TXT"}
}
func (v *Validator) ValidateRRSet(name, recordType string, ttl uint64, contents []string) (pdns.RRSet, error) {
name = EnsureTrailingDot(name)
recordType = strings.ToUpper(strings.TrimSpace(recordType))
if ttl == 0 || ttl > 1<<32-1 {
return pdns.RRSet{}, fmt.Errorf("ttl must be between 1 and 4294967295")
}
if err := v.validate.Var(name, "required,dns_fqdn"); err != nil {
return pdns.RRSet{}, fmt.Errorf("record name must be a fully qualified domain name")
}
if !supported(recordType) {
return pdns.RRSet{}, fmt.Errorf("unsupported record type %q", recordType)
}
records := make([]pdns.Record, 0, len(contents))
for _, content := range contents {
content = strings.TrimSpace(content)
if content == "" {
continue
}
if err := validateContent(recordType, content); err != nil {
return pdns.RRSet{}, err
}
records = append(records, pdns.Record{Content: content})
}
if len(records) == 0 {
return pdns.RRSet{}, fmt.Errorf("at least one record value is required")
}
return pdns.RRSet{
Name: name,
Type: recordType,
TTL: uint32(ttl),
Records: records,
}, nil
}
func EnsureTrailingDot(value string) string {
value = strings.TrimSpace(value)
if value == "" || strings.HasSuffix(value, ".") {
return value
}
return value + "."
}
func IsFQDN(value string) bool {
value = strings.TrimSpace(value)
if value == "." || value == "" || !strings.HasSuffix(value, ".") || len(value) > 253 {
return false
}
labels := strings.Split(strings.TrimSuffix(value, "."), ".")
for i, label := range labels {
if label == "" || len(label) > 63 {
return false
}
if label == "*" && i == 0 {
continue
}
if strings.HasPrefix(label, "-") || strings.HasSuffix(label, "-") {
return false
}
for _, r := range label {
if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '-' {
continue
}
return false
}
}
return true
}
func validateContent(recordType, content string) error {
switch recordType {
case "A":
addr, err := netip.ParseAddr(content)
if err != nil || !addr.Is4() {
return fmt.Errorf("A record content must be an IPv4 address")
}
case "AAAA":
addr, err := netip.ParseAddr(content)
if err != nil || !addr.Is6() {
return fmt.Errorf("AAAA record content must be an IPv6 address")
}
case "CAA":
return validateCAA(content)
case "CNAME", "NS":
if !IsFQDN(content) {
return fmt.Errorf("%s record content must be a fully qualified domain name", recordType)
}
case "MX":
return validateMX(content)
case "SOA":
return validateSOA(content)
case "SRV":
return validateSRV(content)
case "TXT":
if !txtValuePattern.MatchString(content) {
return fmt.Errorf("TXT record content must be one or more quoted strings")
}
}
return nil
}
func validateMX(content string) error {
fields := strings.Fields(content)
if len(fields) != 2 {
return fmt.Errorf("MX record content must be: priority target")
}
if _, err := parseUint(fields[0], 16); err != nil {
return fmt.Errorf("MX priority must be between 0 and 65535")
}
if !IsFQDN(fields[1]) {
return fmt.Errorf("MX target must be a fully qualified domain name")
}
return nil
}
func validateSRV(content string) error {
fields := strings.Fields(content)
if len(fields) != 4 {
return fmt.Errorf("SRV record content must be: priority weight port target")
}
for i, label := range []string{"priority", "weight", "port"} {
if _, err := parseUint(fields[i], 16); err != nil {
return fmt.Errorf("SRV %s must be between 0 and 65535", label)
}
}
target := strings.TrimSpace(fields[3])
if target != "." && !IsFQDN(target) {
return fmt.Errorf("SRV target must be a fully qualified domain name or .")
}
return nil
}
func validateSOA(content string) error {
fields := strings.Fields(content)
if len(fields) != 7 {
return fmt.Errorf("SOA record content must be: primary hostmaster serial refresh retry expire minimum")
}
if !IsFQDN(fields[0]) {
return fmt.Errorf("SOA primary nameserver must be a fully qualified domain name")
}
if !IsFQDN(fields[1]) {
return fmt.Errorf("SOA hostmaster must be a fully qualified domain name")
}
for i, label := range []string{"serial", "refresh", "retry", "expire", "minimum"} {
if _, err := parseUint(fields[i+2], 32); err != nil {
return fmt.Errorf("SOA %s must be between 0 and 4294967295", label)
}
}
return nil
}
func validateCAA(content string) error {
fields := strings.Fields(content)
if len(fields) < 3 {
return fmt.Errorf("CAA record content must be: flags tag \"value\"")
}
flags, err := parseUint(fields[0], 8)
if err != nil || flags > 255 {
return fmt.Errorf("CAA flags must be between 0 and 255")
}
if !caaTagPattern.MatchString(fields[1]) {
return fmt.Errorf("CAA tag must start with a letter and contain only letters, numbers, and hyphens")
}
value := strings.Join(fields[2:], " ")
if !txtValuePattern.MatchString(value) {
return fmt.Errorf("CAA value must be quoted")
}
return nil
}
func parseUint(value string, bits int) (uint64, error) {
return strconv.ParseUint(value, 10, bits)
}
func supported(recordType string) bool {
for _, candidate := range SupportedTypes() {
if candidate == recordType {
return true
}
}
return false
}