primeiro commit
This commit is contained in:
223
internal/dnsrecord/validator.go
Normal file
223
internal/dnsrecord/validator.go
Normal file
@@ -0,0 +1,223 @@
|
||||
package dnsrecord
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
|
||||
"pdns_admin/internal/pdns"
|
||||
)
|
||||
|
||||
var (
|
||||
txtValuePattern = regexp.MustCompile(`^"([^"\\]|\\.)*"( "([^"\\]|\\.)*")*$`)
|
||||
caaTagPattern = regexp.MustCompile(`^[A-Za-z][A-Za-z0-9-]*$`)
|
||||
)
|
||||
|
||||
type Validator struct {
|
||||
validate *validator.Validate
|
||||
}
|
||||
|
||||
func NewValidator() (*Validator, error) {
|
||||
v := validator.New()
|
||||
if err := v.RegisterValidation("dns_fqdn", func(fl validator.FieldLevel) bool {
|
||||
return IsFQDN(fl.Field().String())
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Validator{validate: v}, nil
|
||||
}
|
||||
|
||||
func SupportedTypes() []string {
|
||||
return []string{"A", "AAAA", "CAA", "CNAME", "MX", "NS", "SOA", "SRV", "TXT"}
|
||||
}
|
||||
|
||||
func (v *Validator) ValidateRRSet(name, recordType string, ttl uint64, contents []string) (pdns.RRSet, error) {
|
||||
name = EnsureTrailingDot(name)
|
||||
recordType = strings.ToUpper(strings.TrimSpace(recordType))
|
||||
|
||||
if ttl == 0 || ttl > 1<<32-1 {
|
||||
return pdns.RRSet{}, fmt.Errorf("ttl must be between 1 and 4294967295")
|
||||
}
|
||||
if err := v.validate.Var(name, "required,dns_fqdn"); err != nil {
|
||||
return pdns.RRSet{}, fmt.Errorf("record name must be a fully qualified domain name")
|
||||
}
|
||||
if !supported(recordType) {
|
||||
return pdns.RRSet{}, fmt.Errorf("unsupported record type %q", recordType)
|
||||
}
|
||||
|
||||
records := make([]pdns.Record, 0, len(contents))
|
||||
for _, content := range contents {
|
||||
content = strings.TrimSpace(content)
|
||||
if content == "" {
|
||||
continue
|
||||
}
|
||||
if err := validateContent(recordType, content); err != nil {
|
||||
return pdns.RRSet{}, err
|
||||
}
|
||||
records = append(records, pdns.Record{Content: content})
|
||||
}
|
||||
if len(records) == 0 {
|
||||
return pdns.RRSet{}, fmt.Errorf("at least one record value is required")
|
||||
}
|
||||
|
||||
return pdns.RRSet{
|
||||
Name: name,
|
||||
Type: recordType,
|
||||
TTL: uint32(ttl),
|
||||
Records: records,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func EnsureTrailingDot(value string) string {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" || strings.HasSuffix(value, ".") {
|
||||
return value
|
||||
}
|
||||
return value + "."
|
||||
}
|
||||
|
||||
func IsFQDN(value string) bool {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "." || value == "" || !strings.HasSuffix(value, ".") || len(value) > 253 {
|
||||
return false
|
||||
}
|
||||
|
||||
labels := strings.Split(strings.TrimSuffix(value, "."), ".")
|
||||
for i, label := range labels {
|
||||
if label == "" || len(label) > 63 {
|
||||
return false
|
||||
}
|
||||
if label == "*" && i == 0 {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(label, "-") || strings.HasSuffix(label, "-") {
|
||||
return false
|
||||
}
|
||||
for _, r := range label {
|
||||
if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '-' {
|
||||
continue
|
||||
}
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func validateContent(recordType, content string) error {
|
||||
switch recordType {
|
||||
case "A":
|
||||
addr, err := netip.ParseAddr(content)
|
||||
if err != nil || !addr.Is4() {
|
||||
return fmt.Errorf("A record content must be an IPv4 address")
|
||||
}
|
||||
case "AAAA":
|
||||
addr, err := netip.ParseAddr(content)
|
||||
if err != nil || !addr.Is6() {
|
||||
return fmt.Errorf("AAAA record content must be an IPv6 address")
|
||||
}
|
||||
case "CAA":
|
||||
return validateCAA(content)
|
||||
case "CNAME", "NS":
|
||||
if !IsFQDN(content) {
|
||||
return fmt.Errorf("%s record content must be a fully qualified domain name", recordType)
|
||||
}
|
||||
case "MX":
|
||||
return validateMX(content)
|
||||
case "SOA":
|
||||
return validateSOA(content)
|
||||
case "SRV":
|
||||
return validateSRV(content)
|
||||
case "TXT":
|
||||
if !txtValuePattern.MatchString(content) {
|
||||
return fmt.Errorf("TXT record content must be one or more quoted strings")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateMX(content string) error {
|
||||
fields := strings.Fields(content)
|
||||
if len(fields) != 2 {
|
||||
return fmt.Errorf("MX record content must be: priority target")
|
||||
}
|
||||
if _, err := parseUint(fields[0], 16); err != nil {
|
||||
return fmt.Errorf("MX priority must be between 0 and 65535")
|
||||
}
|
||||
if !IsFQDN(fields[1]) {
|
||||
return fmt.Errorf("MX target must be a fully qualified domain name")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateSRV(content string) error {
|
||||
fields := strings.Fields(content)
|
||||
if len(fields) != 4 {
|
||||
return fmt.Errorf("SRV record content must be: priority weight port target")
|
||||
}
|
||||
for i, label := range []string{"priority", "weight", "port"} {
|
||||
if _, err := parseUint(fields[i], 16); err != nil {
|
||||
return fmt.Errorf("SRV %s must be between 0 and 65535", label)
|
||||
}
|
||||
}
|
||||
target := strings.TrimSpace(fields[3])
|
||||
if target != "." && !IsFQDN(target) {
|
||||
return fmt.Errorf("SRV target must be a fully qualified domain name or .")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateSOA(content string) error {
|
||||
fields := strings.Fields(content)
|
||||
if len(fields) != 7 {
|
||||
return fmt.Errorf("SOA record content must be: primary hostmaster serial refresh retry expire minimum")
|
||||
}
|
||||
if !IsFQDN(fields[0]) {
|
||||
return fmt.Errorf("SOA primary nameserver must be a fully qualified domain name")
|
||||
}
|
||||
if !IsFQDN(fields[1]) {
|
||||
return fmt.Errorf("SOA hostmaster must be a fully qualified domain name")
|
||||
}
|
||||
for i, label := range []string{"serial", "refresh", "retry", "expire", "minimum"} {
|
||||
if _, err := parseUint(fields[i+2], 32); err != nil {
|
||||
return fmt.Errorf("SOA %s must be between 0 and 4294967295", label)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateCAA(content string) error {
|
||||
fields := strings.Fields(content)
|
||||
if len(fields) < 3 {
|
||||
return fmt.Errorf("CAA record content must be: flags tag \"value\"")
|
||||
}
|
||||
flags, err := parseUint(fields[0], 8)
|
||||
if err != nil || flags > 255 {
|
||||
return fmt.Errorf("CAA flags must be between 0 and 255")
|
||||
}
|
||||
if !caaTagPattern.MatchString(fields[1]) {
|
||||
return fmt.Errorf("CAA tag must start with a letter and contain only letters, numbers, and hyphens")
|
||||
}
|
||||
value := strings.Join(fields[2:], " ")
|
||||
if !txtValuePattern.MatchString(value) {
|
||||
return fmt.Errorf("CAA value must be quoted")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseUint(value string, bits int) (uint64, error) {
|
||||
return strconv.ParseUint(value, 10, bits)
|
||||
}
|
||||
|
||||
func supported(recordType string) bool {
|
||||
for _, candidate := range SupportedTypes() {
|
||||
if candidate == recordType {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
75
internal/dnsrecord/validator_test.go
Normal file
75
internal/dnsrecord/validator_test.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package dnsrecord
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestValidateRRSetAcceptsSupportedRecords(t *testing.T) {
|
||||
v := newTestValidator(t)
|
||||
|
||||
cases := []struct {
|
||||
recordType string
|
||||
content string
|
||||
}{
|
||||
{"A", "192.0.2.10"},
|
||||
{"AAAA", "2001:db8::1"},
|
||||
{"CAA", `0 issue "letsencrypt.org"`},
|
||||
{"CNAME", "target.example.org."},
|
||||
{"MX", "10 mail.example.org."},
|
||||
{"NS", "ns1.example.org."},
|
||||
{"SOA", "ns1.example.org. hostmaster.example.org. 2026010101 3600 600 604800 300"},
|
||||
{"SRV", "10 20 443 service.example.org."},
|
||||
{"TXT", `"v=spf1 -all"`},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.recordType, func(t *testing.T) {
|
||||
if _, err := v.ValidateRRSet("www.example.org", tc.recordType, 300, []string{tc.content}); err != nil {
|
||||
t.Fatalf("ValidateRRSet returned error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateRRSetRejectsInvalidRecords(t *testing.T) {
|
||||
v := newTestValidator(t)
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
recordType string
|
||||
content string
|
||||
}{
|
||||
{"bad name", "A", "192.0.2.10"},
|
||||
{"www.example.org.", "A", "2001:db8::1"},
|
||||
{"www.example.org.", "AAAA", "192.0.2.10"},
|
||||
{"www.example.org.", "MX", "mail.example.org."},
|
||||
{"www.example.org.", "CNAME", "target.example.org"},
|
||||
{"www.example.org.", "TXT", "not quoted"},
|
||||
{"www.example.org.", "SOA", "ns1.example.org. hostmaster.example.org."},
|
||||
{"www.example.org.", "UNSUPPORTED", "value"},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.recordType+" "+tc.content, func(t *testing.T) {
|
||||
if _, err := v.ValidateRRSet(tc.name, tc.recordType, 300, []string{tc.content}); err == nil {
|
||||
t.Fatal("expected validation error")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateRRSetRequiresTTL(t *testing.T) {
|
||||
v := newTestValidator(t)
|
||||
|
||||
if _, err := v.ValidateRRSet("www.example.org.", "A", 0, []string{"192.0.2.10"}); err == nil {
|
||||
t.Fatal("expected ttl validation error")
|
||||
}
|
||||
}
|
||||
|
||||
func newTestValidator(t *testing.T) *Validator {
|
||||
t.Helper()
|
||||
|
||||
v, err := NewValidator()
|
||||
if err != nil {
|
||||
t.Fatalf("NewValidator returned error: %v", err)
|
||||
}
|
||||
return v
|
||||
}
|
||||
Reference in New Issue
Block a user