Files
pdns-admin/internal/server/server.go

720 lines
19 KiB
Go

package server
import (
"context"
"crypto/rand"
"crypto/subtle"
"embed"
"encoding/base64"
"fmt"
"html/template"
"log"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
"time"
"pdns_admin/internal/dnsrecord"
"pdns_admin/internal/pdns"
)
//go:embed templates/*.html static
var assets embed.FS
const (
csrfFieldName = "csrf_token"
sessionCookieName = "__Host-pdns_admin_session"
sessionTTL = 12 * time.Hour
)
type PDNSClient interface {
GetServer(context.Context) (pdns.Server, error)
ListZones(context.Context) ([]pdns.Zone, error)
CreateZone(context.Context, pdns.Zone) (pdns.Zone, error)
DeleteZone(context.Context, string) error
GetZone(context.Context, string) (pdns.Zone, error)
CreateRRSet(context.Context, string, pdns.RRSet) error
DeleteRRSet(context.Context, string, string, string) error
}
type Authenticator interface {
Authenticate(context.Context, string, string) (bool, error)
}
type Config struct {
Addr string
Authenticator Authenticator
}
type Server struct {
addr string
client PDNSClient
logger *log.Logger
templates map[string]*template.Template
validator *dnsrecord.Validator
auth Authenticator
sessions map[string]session
sessionsM sync.Mutex
}
type pageData struct {
Title string
Error string
AuthEnabled bool
CurrentUser string
CSRFToken string
Next string
Server pdns.Server
ZoneID string
Zones []pdns.Zone
Zone pdns.Zone
RecordForm recordForm
RecordTypes []string
}
type session struct {
Username string
CSRFToken string
Expires time.Time
}
type recordForm struct {
Name string
Type string
TTL uint32
Records string
IsEdit bool
Title string
SubmitLabel string
}
func New(cfg Config, client PDNSClient, logger *log.Logger) (*Server, error) {
if client == nil {
return nil, fmt.Errorf("pdns client is required")
}
if logger == nil {
logger = log.Default()
}
if cfg.Addr == "" {
cfg.Addr = ":8080"
}
recordValidator, err := dnsrecord.NewValidator()
if err != nil {
return nil, fmt.Errorf("create record validator: %w", err)
}
templates := make(map[string]*template.Template)
funcs := template.FuncMap{
"isSOA": isSOA,
"urlQuery": url.QueryEscape,
}
for _, page := range []string{"dashboard.html", "login.html", "zones.html", "zone.html", "record_form.html"} {
tmpl, err := template.New("base.html").Funcs(funcs).ParseFS(assets, "templates/base.html", "templates/"+page)
if err != nil {
return nil, fmt.Errorf("parse template %s: %w", page, err)
}
templates[page] = tmpl
}
return &Server{
addr: cfg.Addr,
client: client,
logger: logger,
templates: templates,
validator: recordValidator,
auth: cfg.Authenticator,
sessions: make(map[string]session),
}, nil
}
func (s *Server) ListenAndServe() error {
return http.ListenAndServe(s.addr, s.routes())
}
func (s *Server) routes() http.Handler {
mux := http.NewServeMux()
mux.Handle("GET /static/", http.FileServerFS(assets))
mux.HandleFunc("GET /healthz", s.healthz)
mux.HandleFunc("GET /login", s.login)
mux.HandleFunc("POST /login", s.loginPost)
mux.HandleFunc("POST /logout", s.logout)
mux.HandleFunc("GET /", s.dashboard)
mux.HandleFunc("GET /zones", s.listZones)
mux.HandleFunc("POST /zones", s.createZone)
mux.HandleFunc("GET /zones/{zoneID}", s.showZone)
mux.HandleFunc("POST /zones/{zoneID}/delete", s.deleteZone)
mux.HandleFunc("GET /zones/{zoneID}/rrsets/new", s.newRRSet)
mux.HandleFunc("GET /zones/{zoneID}/rrsets/edit", s.editRRSet)
mux.HandleFunc("POST /zones/{zoneID}/rrsets", s.saveRRSet)
mux.HandleFunc("POST /zones/{zoneID}/rrsets/edit", s.saveEditedRRSet)
mux.HandleFunc("POST /zones/{zoneID}/rrsets/delete", s.deleteRRSet)
return s.withLogging(s.withSecurityHeaders(s.withSessionAuth(mux)))
}
func (s *Server) healthz(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusNoContent)
}
func (s *Server) dashboard(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/" {
http.NotFound(w, r)
return
}
server, serverErr := s.client.GetServer(r.Context())
zones, zonesErr := s.client.ListZones(r.Context())
data := pageData{
Title: "Dashboard",
Server: server,
Zones: zones,
Error: firstNonEmpty(errorText(serverErr), errorText(zonesErr)),
}
s.render(w, r, "dashboard.html", data)
}
func (s *Server) login(w http.ResponseWriter, r *http.Request) {
if s.auth == nil {
http.Redirect(w, r, "/", http.StatusSeeOther)
return
}
if _, ok := s.currentUser(r); ok {
http.Redirect(w, r, "/", http.StatusSeeOther)
return
}
data := pageData{
Title: "Login",
Error: r.URL.Query().Get("error"),
Next: safeRedirectPath(r.URL.Query().Get("next")),
}
s.render(w, r, "login.html", data)
}
func (s *Server) loginPost(w http.ResponseWriter, r *http.Request) {
if s.auth == nil {
http.Redirect(w, r, "/", http.StatusSeeOther)
return
}
if err := r.ParseForm(); err != nil {
http.Redirect(w, r, "/login?error="+url.QueryEscape("invalid form data"), http.StatusSeeOther)
return
}
username := strings.TrimSpace(r.FormValue("username"))
password := r.FormValue("password")
allowed, err := s.auth.Authenticate(r.Context(), username, password)
if err != nil {
s.logger.Printf("authentication failed for %q: %v", username, err)
http.Redirect(w, r, "/login?error="+url.QueryEscape("authentication backend failed"), http.StatusSeeOther)
return
}
if !allowed {
http.Redirect(w, r, "/login?error="+url.QueryEscape("invalid username or password"), http.StatusSeeOther)
return
}
token, err := s.createSession(username)
if err != nil {
s.logger.Printf("create session: %v", err)
http.Error(w, "session creation failed", http.StatusInternalServerError)
return
}
http.SetCookie(w, &http.Cookie{
Name: sessionCookieName,
Value: token,
Path: "/",
Expires: time.Now().Add(sessionTTL),
MaxAge: int(sessionTTL.Seconds()),
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteStrictMode,
})
http.Redirect(w, r, safeRedirectPath(r.FormValue("next")), http.StatusSeeOther)
}
func (s *Server) logout(w http.ResponseWriter, r *http.Request) {
if cookie, err := r.Cookie(sessionCookieName); err == nil {
s.deleteSession(cookie.Value)
}
http.SetCookie(w, &http.Cookie{
Name: sessionCookieName,
Value: "",
Path: "/",
Expires: time.Unix(0, 0),
MaxAge: -1,
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteStrictMode,
})
http.Redirect(w, r, "/login", http.StatusSeeOther)
}
func (s *Server) listZones(w http.ResponseWriter, r *http.Request) {
zones, err := s.client.ListZones(r.Context())
data := pageData{
Title: "Zones",
Zones: zones,
Error: firstNonEmpty(r.URL.Query().Get("error"), errorText(err)),
}
s.render(w, r, "zones.html", data)
}
func (s *Server) createZone(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
s.redirectZonesError(w, r, "invalid form data")
return
}
zoneName := dnsrecord.EnsureTrailingDot(r.FormValue("name"))
if !dnsrecord.IsFQDN(zoneName) {
s.redirectZonesError(w, r, "zone name must be a fully qualified domain name")
return
}
kind := strings.TrimSpace(r.FormValue("kind"))
if !validZoneKind(kind) {
s.redirectZonesError(w, r, "zone kind must be Native, Master, or Slave")
return
}
nameservers, err := parseFQDNLines(r.FormValue("nameservers"), "nameserver")
if err != nil {
s.redirectZonesError(w, r, err.Error())
return
}
masters := parseLines(r.FormValue("masters"))
if kind == "Slave" && len(masters) == 0 {
s.redirectZonesError(w, r, "slave zones require at least one master address")
return
}
if _, err := s.client.CreateZone(r.Context(), pdns.Zone{
Name: zoneName,
Kind: kind,
Nameservers: nameservers,
Masters: masters,
}); err != nil {
s.redirectZonesError(w, r, err.Error())
return
}
http.Redirect(w, r, "/zones/"+url.PathEscape(zoneName), http.StatusSeeOther)
}
func (s *Server) deleteZone(w http.ResponseWriter, r *http.Request) {
zoneID := r.PathValue("zoneID")
if err := s.client.DeleteZone(r.Context(), zoneID); err != nil {
s.redirectZoneError(w, r, zoneID, err.Error())
return
}
http.Redirect(w, r, "/zones", http.StatusSeeOther)
}
func (s *Server) showZone(w http.ResponseWriter, r *http.Request) {
zoneID := r.PathValue("zoneID")
zone, err := s.client.GetZone(r.Context(), zoneID)
data := pageData{
Title: "Zone " + zoneID,
ZoneID: zoneID,
Zone: zone,
Error: firstNonEmpty(r.URL.Query().Get("error"), errorText(err)),
}
s.render(w, r, "zone.html", data)
}
func (s *Server) newRRSet(w http.ResponseWriter, r *http.Request) {
zoneID := r.PathValue("zoneID")
zone, err := s.client.GetZone(r.Context(), zoneID)
data := pageData{
Title: "Add record",
ZoneID: zoneID,
Zone: zone,
Error: firstNonEmpty(r.URL.Query().Get("error"), errorText(err)),
RecordTypes: dnsrecord.SupportedTypes(),
RecordForm: recordForm{
Type: "A",
TTL: 300,
Title: "Add record",
SubmitLabel: "Create record",
},
}
s.render(w, r, "record_form.html", data)
}
func (s *Server) editRRSet(w http.ResponseWriter, r *http.Request) {
zoneID := r.PathValue("zoneID")
name := r.URL.Query().Get("name")
recordType := strings.ToUpper(strings.TrimSpace(r.URL.Query().Get("type")))
zone, err := s.client.GetZone(r.Context(), zoneID)
if err != nil {
s.redirectZoneError(w, r, zoneID, err.Error())
return
}
rrset, ok := findRRSet(zone, name, recordType)
if !ok {
s.redirectZoneError(w, r, zoneID, "record not found")
return
}
form := recordForm{
Name: rrset.Name,
Type: rrset.Type,
TTL: rrset.TTL,
Records: recordValues(rrset),
IsEdit: true,
Title: "Edit record",
SubmitLabel: "Save record",
}
data := pageData{
Title: "Edit record",
ZoneID: zoneID,
Zone: zone,
Error: r.URL.Query().Get("error"),
RecordTypes: dnsrecord.SupportedTypes(),
RecordForm: form,
}
s.render(w, r, "record_form.html", data)
}
func (s *Server) saveRRSet(w http.ResponseWriter, r *http.Request) {
zoneID := r.PathValue("zoneID")
if err := r.ParseForm(); err != nil {
s.redirectZoneError(w, r, zoneID, "invalid form data")
return
}
ttl, err := strconv.ParseUint(strings.TrimSpace(r.FormValue("ttl")), 10, 32)
if err != nil {
s.redirectZoneError(w, r, zoneID, "ttl must be a positive integer")
return
}
rrset, err := s.validator.ValidateRRSet(r.FormValue("name"), r.FormValue("type"), ttl, parseLines(r.FormValue("records")))
if err != nil {
s.redirectZoneError(w, r, zoneID, err.Error())
return
}
if err := s.client.CreateRRSet(r.Context(), zoneID, rrset); err != nil {
s.redirectZoneError(w, r, zoneID, err.Error())
return
}
http.Redirect(w, r, "/zones/"+zoneID, http.StatusSeeOther)
}
func (s *Server) saveEditedRRSet(w http.ResponseWriter, r *http.Request) {
zoneID := r.PathValue("zoneID")
name := r.URL.Query().Get("name")
recordType := strings.ToUpper(strings.TrimSpace(r.URL.Query().Get("type")))
if name == "" || recordType == "" {
s.redirectZoneError(w, r, zoneID, "record identity is required")
return
}
if err := r.ParseForm(); err != nil {
s.redirectZoneError(w, r, zoneID, "invalid form data")
return
}
ttl, err := strconv.ParseUint(strings.TrimSpace(r.FormValue("ttl")), 10, 32)
if err != nil {
s.redirectZoneError(w, r, zoneID, "ttl must be a positive integer")
return
}
rrset, err := s.validator.ValidateRRSet(name, recordType, ttl, parseLines(r.FormValue("records")))
if err != nil {
s.redirectZoneError(w, r, zoneID, err.Error())
return
}
if err := s.client.CreateRRSet(r.Context(), zoneID, rrset); err != nil {
s.redirectZoneError(w, r, zoneID, err.Error())
return
}
http.Redirect(w, r, "/zones/"+zoneID, http.StatusSeeOther)
}
func (s *Server) deleteRRSet(w http.ResponseWriter, r *http.Request) {
zoneID := r.PathValue("zoneID")
if err := r.ParseForm(); err != nil {
s.redirectZoneError(w, r, zoneID, "invalid form data")
return
}
name := dnsrecord.EnsureTrailingDot(r.FormValue("name"))
recordType := strings.ToUpper(strings.TrimSpace(r.FormValue("type")))
if name == "." || recordType == "" {
s.redirectZoneError(w, r, zoneID, "record name and type are required")
return
}
if isSOA(recordType) {
s.redirectZoneError(w, r, zoneID, "SOA records are required for zones and cannot be deleted")
return
}
if err := s.client.DeleteRRSet(r.Context(), zoneID, name, recordType); err != nil {
s.redirectZoneError(w, r, zoneID, err.Error())
return
}
http.Redirect(w, r, "/zones/"+zoneID, http.StatusSeeOther)
}
func (s *Server) render(w http.ResponseWriter, r *http.Request, name string, data pageData) {
data.AuthEnabled = s.auth != nil
if sess, ok := s.currentSession(r); ok {
data.CurrentUser = sess.Username
data.CSRFToken = sess.CSRFToken
}
w.Header().Set("Content-Type", "text/html; charset=utf-8")
tmpl, ok := s.templates[name]
if !ok {
http.Error(w, "template not found", http.StatusInternalServerError)
return
}
if err := tmpl.ExecuteTemplate(w, name, data); err != nil {
s.logger.Printf("render %s: %v", name, err)
}
}
func (s *Server) redirectZoneError(w http.ResponseWriter, r *http.Request, zoneID, message string) {
http.Redirect(w, r, "/zones/"+url.PathEscape(zoneID)+"?error="+url.QueryEscape(message), http.StatusSeeOther)
}
func (s *Server) redirectZonesError(w http.ResponseWriter, r *http.Request, message string) {
http.Redirect(w, r, "/zones?error="+url.QueryEscape(message), http.StatusSeeOther)
}
func (s *Server) withSessionAuth(next http.Handler) http.Handler {
if s.auth == nil {
return next
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if isPublicPath(r.URL.Path) {
next.ServeHTTP(w, r)
return
}
sess, ok := s.currentSession(r)
if !ok {
http.Redirect(w, r, "/login?next="+url.QueryEscape(r.URL.RequestURI()), http.StatusSeeOther)
return
}
if isUnsafeMethod(r.Method) && !validCSRFToken(r, sess.CSRFToken) {
http.Error(w, "invalid CSRF token", http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
})
}
func (s *Server) currentUser(r *http.Request) (string, bool) {
sess, ok := s.currentSession(r)
if !ok {
return "", false
}
return sess.Username, true
}
func (s *Server) currentSession(r *http.Request) (session, bool) {
cookie, err := r.Cookie(sessionCookieName)
if err != nil || cookie.Value == "" || len(cookie.Value) > 128 {
return session{}, false
}
s.sessionsM.Lock()
defer s.sessionsM.Unlock()
sess, ok := s.sessions[cookie.Value]
if !ok {
return session{}, false
}
if time.Now().After(sess.Expires) {
delete(s.sessions, cookie.Value)
return session{}, false
}
return sess, true
}
func (s *Server) createSession(username string) (string, error) {
token, err := randomToken()
if err != nil {
return "", err
}
csrfToken, err := randomToken()
if err != nil {
return "", err
}
s.sessionsM.Lock()
defer s.sessionsM.Unlock()
s.pruneExpiredSessionsLocked(time.Now())
s.sessions[token] = session{
Username: username,
CSRFToken: csrfToken,
Expires: time.Now().Add(sessionTTL),
}
return token, nil
}
func (s *Server) deleteSession(token string) {
s.sessionsM.Lock()
defer s.sessionsM.Unlock()
delete(s.sessions, token)
}
func (s *Server) pruneExpiredSessionsLocked(now time.Time) {
for token, sess := range s.sessions {
if now.After(sess.Expires) {
delete(s.sessions, token)
}
}
}
func isPublicPath(path string) bool {
return path == "/login" || path == "/healthz" || strings.HasPrefix(path, "/static/")
}
func safeRedirectPath(value string) string {
if value == "" {
return "/"
}
parsed, err := url.Parse(value)
if err != nil || parsed.IsAbs() || !strings.HasPrefix(parsed.Path, "/") || strings.HasPrefix(parsed.Path, "//") {
return "/"
}
return parsed.RequestURI()
}
func isUnsafeMethod(method string) bool {
switch method {
case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace:
return false
default:
return true
}
}
func validCSRFToken(r *http.Request, expected string) bool {
if expected == "" {
return false
}
if err := r.ParseForm(); err != nil {
return false
}
token := r.FormValue(csrfFieldName)
if token == "" {
token = r.Header.Get("X-CSRF-Token")
}
return subtle.ConstantTimeCompare([]byte(token), []byte(expected)) == 1
}
func randomToken() (string, error) {
tokenBytes := make([]byte, 32)
if _, err := rand.Read(tokenBytes); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(tokenBytes), nil
}
func (s *Server) withSecurityHeaders(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Security-Policy", "default-src 'self'; script-src 'self'; style-src 'self'; img-src 'self' data:; font-src 'self' data:; base-uri 'self'; form-action 'self'; frame-ancestors 'none'")
w.Header().Set("Permissions-Policy", "camera=(), microphone=(), geolocation=()")
w.Header().Set("Referrer-Policy", "same-origin")
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-Frame-Options", "DENY")
if !strings.HasPrefix(r.URL.Path, "/static/") {
w.Header().Set("Cache-Control", "no-store")
}
next.ServeHTTP(w, r)
})
}
func (s *Server) withLogging(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
next.ServeHTTP(w, r)
s.logger.Printf("%s %s %s", r.Method, r.URL.Path, time.Since(start).Round(time.Millisecond))
})
}
func parseLines(raw string) []string {
lines := strings.Split(raw, "\n")
values := make([]string, 0, len(lines))
for _, line := range lines {
value := strings.TrimSpace(line)
if value == "" {
continue
}
values = append(values, value)
}
return values
}
func parseFQDNLines(raw, label string) ([]string, error) {
values := parseLines(raw)
for i, value := range values {
values[i] = dnsrecord.EnsureTrailingDot(value)
if !dnsrecord.IsFQDN(values[i]) {
return nil, fmt.Errorf("%s %q must be a fully qualified domain name", label, value)
}
}
return values, nil
}
func validZoneKind(kind string) bool {
switch kind {
case "Native", "Master", "Slave":
return true
default:
return false
}
}
func findRRSet(zone pdns.Zone, name, recordType string) (pdns.RRSet, bool) {
name = dnsrecord.EnsureTrailingDot(name)
recordType = strings.ToUpper(strings.TrimSpace(recordType))
for _, rrset := range zone.RRSets {
if rrset.Name == name && rrset.Type == recordType {
return rrset, true
}
}
return pdns.RRSet{}, false
}
func isSOA(recordType string) bool {
return strings.EqualFold(recordType, "SOA")
}
func recordValues(rrset pdns.RRSet) string {
values := make([]string, 0, len(rrset.Records))
for _, record := range rrset.Records {
values = append(values, record.Content)
}
return strings.Join(values, "\n")
}
func errorText(err error) string {
if err == nil {
return ""
}
return err.Error()
}
func firstNonEmpty(values ...string) string {
for _, value := range values {
if strings.TrimSpace(value) != "" {
return value
}
}
return ""
}