fixa aumento no serial do soa

This commit is contained in:
2026-06-19 18:47:34 -03:00
parent 968f4ef5d9
commit 1901055e25
9 changed files with 353 additions and 231 deletions

View File

@@ -1,9 +1,11 @@
package main package main
import ( import (
"context"
"log" "log"
"net/http" "net/http"
"os" "os"
"time"
"pdns_admin/internal/auth" "pdns_admin/internal/auth"
"pdns_admin/internal/config" "pdns_admin/internal/config"
@@ -20,6 +22,12 @@ func main() {
} }
pdnsClient := pdns.NewClient(cfg.PDNSAPIURL, cfg.PDNSAPIKey, cfg.PDNSServerID, http.DefaultClient) pdnsClient := pdns.NewClient(cfg.PDNSAPIURL, cfg.PDNSAPIKey, cfg.PDNSServerID, http.DefaultClient)
metadataCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := pdnsClient.EnsureAllZonesSOAEditAPI(metadataCtx); err != nil {
logger.Fatalf("failed to ensure SOA-EDIT-API setting: %v", err)
}
var authenticator server.Authenticator var authenticator server.Authenticator
if !cfg.Auth.Disabled { if !cfg.Auth.Disabled {
authenticator, err = auth.NewLDAPAuthenticator(auth.LDAPConfig{ authenticator, err = auth.NewLDAPAuthenticator(auth.LDAPConfig{

View File

@@ -13,6 +13,8 @@ import (
"time" "time"
) )
const soaEditAPIIncrease = "INCREASE"
type HTTPClient interface { type HTTPClient interface {
Do(*http.Request) (*http.Response, error) Do(*http.Request) (*http.Response, error)
} }
@@ -26,21 +28,16 @@ type Client struct {
type Server struct { type Server struct {
ID string `json:"id"` ID string `json:"id"`
Type string `json:"type"`
DaemonType string `json:"daemon_type"` DaemonType string `json:"daemon_type"`
Version string `json:"version"` Version string `json:"version"`
URL string `json:"url"`
ConfigURL string `json:"config_url"`
ZonesURL string `json:"zones_url"`
} }
type Zone struct { type Zone struct {
ID string `json:"id,omitempty"` ID string `json:"id,omitempty"`
Name string `json:"name"` Name string `json:"name,omitempty"`
Type string `json:"type,omitempty"` Type string `json:"type,omitempty"`
Kind string `json:"kind,omitempty"` Kind string `json:"kind,omitempty"`
Serial uint64 `json:"serial,omitempty"` Serial uint64 `json:"serial,omitempty"`
EditedSerial uint64 `json:"edited_serial,omitempty"`
SOAEditAPI string `json:"soa_edit_api,omitempty"` SOAEditAPI string `json:"soa_edit_api,omitempty"`
Nameservers []string `json:"nameservers,omitempty"` Nameservers []string `json:"nameservers,omitempty"`
Masters []string `json:"masters,omitempty"` Masters []string `json:"masters,omitempty"`
@@ -100,9 +97,50 @@ func (c *Client) ListZones(ctx context.Context) ([]Zone, error) {
} }
func (c *Client) CreateZone(ctx context.Context, zone Zone) (Zone, error) { func (c *Client) CreateZone(ctx context.Context, zone Zone) (Zone, error) {
zone.SOAEditAPI = soaEditAPIIncrease
var created Zone var created Zone
err := c.do(ctx, http.MethodPost, c.path("/zones"), zone, &created) if err := c.do(ctx, http.MethodPost, c.path("/zones"), zone, &created); err != nil {
return created, err return Zone{}, err
}
zoneID := created.ID
if zoneID == "" {
zoneID = created.Name
}
if zoneID == "" {
zoneID = zone.Name
}
if err := c.SetZoneSOAEditAPI(ctx, zoneID); err != nil {
return Zone{}, fmt.Errorf("set SOA-EDIT-API for new zone %s: %w", zoneID, err)
}
return created, nil
}
func (c *Client) EnsureAllZonesSOAEditAPI(ctx context.Context) error {
zones, err := c.ListZones(ctx)
if err != nil {
return err
}
for _, zone := range zones {
zoneID := zone.ID
if zoneID == "" {
zoneID = zone.Name
}
if zoneID == "" {
return fmt.Errorf("zone without id or name cannot be updated")
}
if err := c.SetZoneSOAEditAPI(ctx, zoneID); err != nil {
return fmt.Errorf("set SOA-EDIT-API for %s: %w", zoneID, err)
}
}
return nil
}
func (c *Client) SetZoneSOAEditAPI(ctx context.Context, zoneID string) error {
body := Zone{SOAEditAPI: soaEditAPIIncrease}
return c.do(ctx, http.MethodPut, c.path("/zones/"+url.PathEscape(zoneID)), body, nil)
} }
func (c *Client) DeleteZone(ctx context.Context, zoneID string) error { func (c *Client) DeleteZone(ctx context.Context, zoneID string) error {
@@ -116,19 +154,17 @@ func (c *Client) GetZone(ctx context.Context, zoneID string) (Zone, error) {
} }
func (c *Client) CreateRRSet(ctx context.Context, zoneID string, rrset RRSet) error { func (c *Client) CreateRRSet(ctx context.Context, zoneID string, rrset RRSet) error {
var before *Zone
if allowsMultipleRecords(rrset.Type) { if allowsMultipleRecords(rrset.Type) {
zone, err := c.GetZone(ctx, zoneID) zone, err := c.GetZone(ctx, zoneID)
if err != nil { if err != nil {
return fmt.Errorf("read zone before merging records: %w", err) return fmt.Errorf("read zone before merging records: %w", err)
} }
before = &zone
if existing, ok := findRRSet(zone, rrset.Name, rrset.Type); ok { if existing, ok := findRRSet(zone, rrset.Name, rrset.Type); ok {
rrset.Records = mergeRecords(existing.Records, rrset.Records) rrset.Records = mergeRecords(existing.Records, rrset.Records)
} }
} }
return c.patchZoneWithSerialBump(ctx, zoneID, before, []changeRRSet{{ return c.patchZone(ctx, zoneID, []changeRRSet{{
Name: rrset.Name, Name: rrset.Name,
Type: rrset.Type, Type: rrset.Type,
TTL: rrset.TTL, TTL: rrset.TTL,
@@ -175,7 +211,7 @@ func recordKey(record Record) string {
} }
func (c *Client) DeleteRRSet(ctx context.Context, zoneID, name, recordType string) error { func (c *Client) DeleteRRSet(ctx context.Context, zoneID, name, recordType string) error {
return c.patchZoneWithSerialBump(ctx, zoneID, nil, []changeRRSet{{ return c.patchZone(ctx, zoneID, []changeRRSet{{
Name: name, Name: name,
Type: recordType, Type: recordType,
ChangeType: "DELETE", ChangeType: "DELETE",
@@ -189,40 +225,6 @@ func (c *Client) patchZone(ctx context.Context, zoneID string, rrsets []changeRR
return c.do(ctx, http.MethodPatch, c.path("/zones/"+url.PathEscape(zoneID)), body, nil) return c.do(ctx, http.MethodPatch, c.path("/zones/"+url.PathEscape(zoneID)), body, nil)
} }
func (c *Client) patchZoneWithSerialBump(ctx context.Context, zoneID string, before *Zone, rrsets []changeRRSet) error {
if before == nil {
zone, err := c.GetZone(ctx, zoneID)
if err != nil {
return fmt.Errorf("read zone before change: %w", err)
}
before = &zone
}
if err := c.patchZone(ctx, zoneID, rrsets); err != nil {
return err
}
after, err := c.GetZone(ctx, zoneID)
if err != nil {
return fmt.Errorf("read zone after change: %w", err)
}
if serialIncreased(*before, after) {
return nil
}
soa, err := bumpedSOA(after, before.Serial+1)
if err != nil {
return fmt.Errorf("bump SOA serial: %w", err)
}
return c.patchZone(ctx, zoneID, []changeRRSet{{
Name: soa.Name,
Type: soa.Type,
TTL: soa.TTL,
ChangeType: "REPLACE",
Records: soa.Records,
}})
}
func (c *Client) path(suffix string) string { func (c *Client) path(suffix string) string {
return "/api/v1/servers/" + url.PathEscape(c.serverID) + suffix return "/api/v1/servers/" + url.PathEscape(c.serverID) + suffix
} }
@@ -275,53 +277,3 @@ func (c *Client) do(ctx context.Context, method, apiPath string, in any, out any
return nil return nil
} }
func serialIncreased(before, after Zone) bool {
if after.Serial > before.Serial {
return true
}
if (before.EditedSerial != 0 || after.EditedSerial != 0) && after.EditedSerial > before.EditedSerial {
return true
}
return false
}
func bumpedSOA(zone Zone, minimum uint64) (RRSet, error) {
for _, rrset := range zone.RRSets {
if !strings.EqualFold(rrset.Type, "SOA") {
continue
}
if len(rrset.Records) != 1 {
return RRSet{}, fmt.Errorf("expected exactly one SOA record, got %d", len(rrset.Records))
}
record := rrset.Records[0]
fields := strings.Fields(record.Content)
if len(fields) != 7 {
return RRSet{}, fmt.Errorf("SOA record must have 7 fields")
}
current, err := strconv.ParseUint(fields[2], 10, 32)
if err != nil {
return RRSet{}, fmt.Errorf("parse SOA serial: %w", err)
}
if current == 1<<32-1 {
return RRSet{}, fmt.Errorf("SOA serial is already at maximum uint32 value")
}
next := current + 1
if next < minimum {
next = minimum
}
if next > 1<<32-1 {
return RRSet{}, fmt.Errorf("next SOA serial exceeds maximum uint32 value")
}
fields[2] = strconv.FormatUint(next, 10)
record.Content = strings.Join(fields, " ")
rrset.Records[0] = record
return rrset, nil
}
return RRSet{}, fmt.Errorf("SOA record not found")
}

View File

@@ -34,7 +34,6 @@ func TestListZonesSendsAPIKey(t *testing.T) {
func TestCreateRRSetPatchesZone(t *testing.T) { func TestCreateRRSetPatchesZone(t *testing.T) {
var patchCount int var patchCount int
var getCount int
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !strings.HasPrefix(r.URL.Path, "/api/v1/servers/localhost/zones/example.org.") { if !strings.HasPrefix(r.URL.Path, "/api/v1/servers/localhost/zones/example.org.") {
@@ -43,12 +42,7 @@ func TestCreateRRSetPatchesZone(t *testing.T) {
switch r.Method { switch r.Method {
case http.MethodGet: case http.MethodGet:
getCount++ writeZoneWithRRSets(t, w, 10, nil)
serial := uint64(10)
if getCount == 2 {
serial = 11
}
writeZone(t, w, serial)
case http.MethodPatch: case http.MethodPatch:
patchCount++ patchCount++
var payload struct { var payload struct {
@@ -80,62 +74,7 @@ func TestCreateRRSetPatchesZone(t *testing.T) {
t.Fatalf("CreateRRSet returned error: %v", err) t.Fatalf("CreateRRSet returned error: %v", err)
} }
if patchCount != 1 { if patchCount != 1 {
t.Fatalf("expected one patch when serial increases, got %d", patchCount) t.Fatalf("expected one patch, got %d", patchCount)
}
}
func TestCreateRRSetBumpsSOAWhenSerialDoesNotIncrease(t *testing.T) {
var patchCount int
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !strings.HasPrefix(r.URL.Path, "/api/v1/servers/localhost/zones/example.org.") {
t.Fatalf("unexpected path: %s", r.URL.Path)
}
switch r.Method {
case http.MethodGet:
writeZone(t, w, 10)
case http.MethodPatch:
patchCount++
var payload struct {
RRSets []changeRRSet `json:"rrsets"`
}
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
t.Fatalf("decode request: %v", err)
}
if len(payload.RRSets) != 1 {
t.Fatalf("unexpected payload: %#v", payload)
}
if patchCount == 2 {
rrset := payload.RRSets[0]
if rrset.Type != "SOA" {
t.Fatalf("expected SOA fallback patch, got %#v", rrset)
}
if got := rrset.Records[0].Content; !strings.Contains(got, " 11 ") {
t.Fatalf("expected bumped SOA serial, got %q", got)
}
}
w.WriteHeader(http.StatusNoContent)
default:
t.Fatalf("unexpected method: %s", r.Method)
}
}))
defer server.Close()
client := NewClient(server.URL, "secret", "localhost", server.Client())
err := client.CreateRRSet(context.Background(), "example.org.", RRSet{
Name: "www.example.org.",
Type: "A",
TTL: 300,
Records: []Record{{
Content: "192.0.2.10",
}},
})
if err != nil {
t.Fatalf("CreateRRSet returned error: %v", err)
}
if patchCount != 2 {
t.Fatalf("expected requested patch plus SOA fallback patch, got %d", patchCount)
} }
} }
@@ -281,10 +220,13 @@ func TestGetServerUsesConfiguredServerID(t *testing.T) {
} }
func TestCreateZonePostsZone(t *testing.T) { func TestCreateZonePostsZone(t *testing.T) {
var posted bool
var put bool
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost { switch r.Method {
t.Fatalf("unexpected method: %s", r.Method) case http.MethodPost:
} posted = true
if r.URL.Path != "/api/v1/servers/localhost/zones" { if r.URL.Path != "/api/v1/servers/localhost/zones" {
t.Fatalf("unexpected path: %s", r.URL.Path) t.Fatalf("unexpected path: %s", r.URL.Path)
} }
@@ -296,8 +238,27 @@ func TestCreateZonePostsZone(t *testing.T) {
if payload.Name != "example.org." || payload.Kind != "Native" { if payload.Name != "example.org." || payload.Kind != "Native" {
t.Fatalf("unexpected payload: %#v", payload) t.Fatalf("unexpected payload: %#v", payload)
} }
if payload.SOAEditAPI != soaEditAPIIncrease {
t.Fatalf("unexpected soa_edit_api: %q", payload.SOAEditAPI)
}
w.WriteHeader(http.StatusCreated) w.WriteHeader(http.StatusCreated)
_ = json.NewEncoder(w).Encode(payload) _ = json.NewEncoder(w).Encode(payload)
case http.MethodPut:
put = true
if r.URL.Path != "/api/v1/servers/localhost/zones/example.org." {
t.Fatalf("unexpected path: %s", r.URL.Path)
}
var payload Zone
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
t.Fatalf("decode request: %v", err)
}
if payload.SOAEditAPI != soaEditAPIIncrease {
t.Fatalf("unexpected soa_edit_api: %q", payload.SOAEditAPI)
}
w.WriteHeader(http.StatusNoContent)
default:
t.Fatalf("unexpected method: %s", r.Method)
}
})) }))
defer server.Close() defer server.Close()
@@ -309,6 +270,78 @@ func TestCreateZonePostsZone(t *testing.T) {
if created.Name != "example.org." { if created.Name != "example.org." {
t.Fatalf("unexpected zone: %#v", created) t.Fatalf("unexpected zone: %#v", created)
} }
if !posted || !put {
t.Fatalf("expected POST and follow-up PUT, posted=%v put=%v", posted, put)
}
}
func TestEnsureAllZonesSOAEditAPIUpdatesAllZones(t *testing.T) {
var updated []string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
if r.URL.Path != "/api/v1/servers/localhost/zones" {
t.Fatalf("unexpected path: %s", r.URL.Path)
}
_ = json.NewEncoder(w).Encode([]Zone{
{ID: "already.example.org.", Name: "already.example.org.", SOAEditAPI: soaEditAPIIncrease},
{ID: "missing.example.org.", Name: "missing.example.org."},
{ID: "", Name: "fallback.example.org.", SOAEditAPI: "DEFAULT"},
})
case http.MethodPut:
var payload Zone
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
t.Fatalf("decode request: %v", err)
}
if payload.SOAEditAPI != soaEditAPIIncrease {
t.Fatalf("unexpected soa_edit_api: %q", payload.SOAEditAPI)
}
updated = append(updated, strings.TrimPrefix(r.URL.Path, "/api/v1/servers/localhost/zones/"))
w.WriteHeader(http.StatusNoContent)
default:
t.Fatalf("unexpected method: %s", r.Method)
}
}))
defer server.Close()
client := NewClient(server.URL, "secret", "localhost", server.Client())
if err := client.EnsureAllZonesSOAEditAPI(context.Background()); err != nil {
t.Fatalf("EnsureAllZonesSOAEditAPI returned error: %v", err)
}
if len(updated) != 3 {
t.Fatalf("expected three updates, got %#v", updated)
}
if updated[0] != "already.example.org." || updated[1] != "missing.example.org." || updated[2] != "fallback.example.org." {
t.Fatalf("unexpected updated zones: %#v", updated)
}
}
func TestSetZoneSOAEditAPIUsesPUT(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPut {
t.Fatalf("unexpected method: %s", r.Method)
}
if r.URL.Path != "/api/v1/servers/localhost/zones/example.org." {
t.Fatalf("unexpected path: %s", r.URL.Path)
}
var payload Zone
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
t.Fatalf("decode request: %v", err)
}
if payload.SOAEditAPI != soaEditAPIIncrease {
t.Fatalf("unexpected soa_edit_api: %q", payload.SOAEditAPI)
}
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
client := NewClient(server.URL, "secret", "localhost", server.Client())
if err := client.SetZoneSOAEditAPI(context.Background(), "example.org."); err != nil {
t.Fatalf("SetZoneSOAEditAPI returned error: %v", err)
}
} }
func TestDeleteZoneDeletesZone(t *testing.T) { func TestDeleteZoneDeletesZone(t *testing.T) {
@@ -329,12 +362,6 @@ func TestDeleteZoneDeletesZone(t *testing.T) {
} }
} }
func writeZone(t *testing.T, w http.ResponseWriter, serial uint64) {
t.Helper()
writeZoneWithRRSets(t, w, serial, nil)
}
func writeZoneWithRRSets(t *testing.T, w http.ResponseWriter, serial uint64, rrsets []RRSet) { func writeZoneWithRRSets(t *testing.T, w http.ResponseWriter, serial uint64, rrsets []RRSet) {
t.Helper() t.Helper()

View File

@@ -3,6 +3,7 @@ package server
import ( import (
"context" "context"
"crypto/rand" "crypto/rand"
"crypto/subtle"
"embed" "embed"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
@@ -23,7 +24,8 @@ import (
var assets embed.FS var assets embed.FS
const ( const (
sessionCookieName = "pdns_admin_session" csrfFieldName = "csrf_token"
sessionCookieName = "__Host-pdns_admin_session"
sessionTTL = 12 * time.Hour sessionTTL = 12 * time.Hour
) )
@@ -62,6 +64,7 @@ type pageData struct {
Error string Error string
AuthEnabled bool AuthEnabled bool
CurrentUser string CurrentUser string
CSRFToken string
Next string Next string
Server pdns.Server Server pdns.Server
ZoneID string ZoneID string
@@ -73,6 +76,7 @@ type pageData struct {
type session struct { type session struct {
Username string Username string
CSRFToken string
Expires time.Time Expires time.Time
} }
@@ -81,10 +85,7 @@ type recordForm struct {
Type string Type string
TTL uint32 TTL uint32
Records string Records string
OriginalName string
OriginalType string
IsEdit bool IsEdit bool
IsSOA bool
Title string Title string
SubmitLabel string SubmitLabel string
} }
@@ -107,7 +108,6 @@ func New(cfg Config, client PDNSClient, logger *log.Logger) (*Server, error) {
templates := make(map[string]*template.Template) templates := make(map[string]*template.Template)
funcs := template.FuncMap{ funcs := template.FuncMap{
"isSOA": isSOA, "isSOA": isSOA,
"recordValues": recordValues,
"urlQuery": url.QueryEscape, "urlQuery": url.QueryEscape,
} }
for _, page := range []string{"dashboard.html", "login.html", "zones.html", "zone.html", "record_form.html"} { for _, page := range []string{"dashboard.html", "login.html", "zones.html", "zone.html", "record_form.html"} {
@@ -139,7 +139,7 @@ func (s *Server) routes() http.Handler {
mux.HandleFunc("GET /healthz", s.healthz) mux.HandleFunc("GET /healthz", s.healthz)
mux.HandleFunc("GET /login", s.login) mux.HandleFunc("GET /login", s.login)
mux.HandleFunc("POST /login", s.loginPost) mux.HandleFunc("POST /login", s.loginPost)
mux.HandleFunc("GET /logout", s.logout) mux.HandleFunc("POST /logout", s.logout)
mux.HandleFunc("GET /", s.dashboard) mux.HandleFunc("GET /", s.dashboard)
mux.HandleFunc("GET /zones", s.listZones) mux.HandleFunc("GET /zones", s.listZones)
mux.HandleFunc("POST /zones", s.createZone) mux.HandleFunc("POST /zones", s.createZone)
@@ -150,7 +150,7 @@ func (s *Server) routes() http.Handler {
mux.HandleFunc("POST /zones/{zoneID}/rrsets", s.saveRRSet) mux.HandleFunc("POST /zones/{zoneID}/rrsets", s.saveRRSet)
mux.HandleFunc("POST /zones/{zoneID}/rrsets/edit", s.saveEditedRRSet) mux.HandleFunc("POST /zones/{zoneID}/rrsets/edit", s.saveEditedRRSet)
mux.HandleFunc("POST /zones/{zoneID}/rrsets/delete", s.deleteRRSet) mux.HandleFunc("POST /zones/{zoneID}/rrsets/delete", s.deleteRRSet)
return s.withLogging(s.withSessionAuth(mux)) return s.withLogging(s.withSecurityHeaders(s.withSessionAuth(mux)))
} }
func (s *Server) healthz(w http.ResponseWriter, _ *http.Request) { func (s *Server) healthz(w http.ResponseWriter, _ *http.Request) {
@@ -226,8 +226,10 @@ func (s *Server) loginPost(w http.ResponseWriter, r *http.Request) {
Value: token, Value: token,
Path: "/", Path: "/",
Expires: time.Now().Add(sessionTTL), Expires: time.Now().Add(sessionTTL),
MaxAge: int(sessionTTL.Seconds()),
HttpOnly: true, HttpOnly: true,
SameSite: http.SameSiteLaxMode, Secure: true,
SameSite: http.SameSiteStrictMode,
}) })
http.Redirect(w, r, safeRedirectPath(r.FormValue("next")), http.StatusSeeOther) http.Redirect(w, r, safeRedirectPath(r.FormValue("next")), http.StatusSeeOther)
@@ -244,7 +246,8 @@ func (s *Server) logout(w http.ResponseWriter, r *http.Request) {
Expires: time.Unix(0, 0), Expires: time.Unix(0, 0),
MaxAge: -1, MaxAge: -1,
HttpOnly: true, HttpOnly: true,
SameSite: http.SameSiteLaxMode, Secure: true,
SameSite: http.SameSiteStrictMode,
}) })
http.Redirect(w, r, "/login", http.StatusSeeOther) http.Redirect(w, r, "/login", http.StatusSeeOther)
} }
@@ -363,7 +366,6 @@ func (s *Server) editRRSet(w http.ResponseWriter, r *http.Request) {
TTL: rrset.TTL, TTL: rrset.TTL,
Records: recordValues(rrset), Records: recordValues(rrset),
IsEdit: true, IsEdit: true,
IsSOA: isSOA(rrset.Type),
Title: "Edit record", Title: "Edit record",
SubmitLabel: "Save record", SubmitLabel: "Save record",
} }
@@ -445,7 +447,7 @@ func (s *Server) deleteRRSet(w http.ResponseWriter, r *http.Request) {
return return
} }
name := ensureTrailingDot(r.FormValue("name")) name := dnsrecord.EnsureTrailingDot(r.FormValue("name"))
recordType := strings.ToUpper(strings.TrimSpace(r.FormValue("type"))) recordType := strings.ToUpper(strings.TrimSpace(r.FormValue("type")))
if name == "." || recordType == "" { if name == "." || recordType == "" {
s.redirectZoneError(w, r, zoneID, "record name and type are required") s.redirectZoneError(w, r, zoneID, "record name and type are required")
@@ -466,8 +468,9 @@ func (s *Server) deleteRRSet(w http.ResponseWriter, r *http.Request) {
func (s *Server) render(w http.ResponseWriter, r *http.Request, name string, data pageData) { func (s *Server) render(w http.ResponseWriter, r *http.Request, name string, data pageData) {
data.AuthEnabled = s.auth != nil data.AuthEnabled = s.auth != nil
if user, ok := s.currentUser(r); ok { if sess, ok := s.currentSession(r); ok {
data.CurrentUser = user data.CurrentUser = sess.Username
data.CSRFToken = sess.CSRFToken
} }
w.Header().Set("Content-Type", "text/html; charset=utf-8") w.Header().Set("Content-Type", "text/html; charset=utf-8")
tmpl, ok := s.templates[name] tmpl, ok := s.templates[name]
@@ -498,45 +501,63 @@ func (s *Server) withSessionAuth(next http.Handler) http.Handler {
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
return return
} }
if _, ok := s.currentUser(r); !ok { sess, ok := s.currentSession(r)
if !ok {
http.Redirect(w, r, "/login?next="+url.QueryEscape(r.URL.RequestURI()), http.StatusSeeOther) http.Redirect(w, r, "/login?next="+url.QueryEscape(r.URL.RequestURI()), http.StatusSeeOther)
return return
} }
if isUnsafeMethod(r.Method) && !validCSRFToken(r, sess.CSRFToken) {
http.Error(w, "invalid CSRF token", http.StatusForbidden)
return
}
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
}) })
} }
func (s *Server) currentUser(r *http.Request) (string, bool) { func (s *Server) currentUser(r *http.Request) (string, bool) {
cookie, err := r.Cookie(sessionCookieName) sess, ok := s.currentSession(r)
if err != nil || cookie.Value == "" { if !ok {
return "", false return "", false
} }
return sess.Username, true
}
func (s *Server) currentSession(r *http.Request) (session, bool) {
cookie, err := r.Cookie(sessionCookieName)
if err != nil || cookie.Value == "" || len(cookie.Value) > 128 {
return session{}, false
}
s.sessionsM.Lock() s.sessionsM.Lock()
defer s.sessionsM.Unlock() defer s.sessionsM.Unlock()
sess, ok := s.sessions[cookie.Value] sess, ok := s.sessions[cookie.Value]
if !ok { if !ok {
return "", false return session{}, false
} }
if time.Now().After(sess.Expires) { if time.Now().After(sess.Expires) {
delete(s.sessions, cookie.Value) delete(s.sessions, cookie.Value)
return "", false return session{}, false
} }
return sess.Username, true return sess, true
} }
func (s *Server) createSession(username string) (string, error) { func (s *Server) createSession(username string) (string, error) {
tokenBytes := make([]byte, 32) token, err := randomToken()
if _, err := rand.Read(tokenBytes); err != nil { if err != nil {
return "", err
}
csrfToken, err := randomToken()
if err != nil {
return "", err return "", err
} }
token := base64.RawURLEncoding.EncodeToString(tokenBytes)
s.sessionsM.Lock() s.sessionsM.Lock()
defer s.sessionsM.Unlock() defer s.sessionsM.Unlock()
s.pruneExpiredSessionsLocked(time.Now())
s.sessions[token] = session{ s.sessions[token] = session{
Username: username, Username: username,
CSRFToken: csrfToken,
Expires: time.Now().Add(sessionTTL), Expires: time.Now().Add(sessionTTL),
} }
return token, nil return token, nil
@@ -548,8 +569,16 @@ func (s *Server) deleteSession(token string) {
delete(s.sessions, token) delete(s.sessions, token)
} }
func (s *Server) pruneExpiredSessionsLocked(now time.Time) {
for token, sess := range s.sessions {
if now.After(sess.Expires) {
delete(s.sessions, token)
}
}
}
func isPublicPath(path string) bool { func isPublicPath(path string) bool {
return path == "/login" || path == "/logout" || path == "/healthz" || strings.HasPrefix(path, "/static/") return path == "/login" || path == "/healthz" || strings.HasPrefix(path, "/static/")
} }
func safeRedirectPath(value string) string { func safeRedirectPath(value string) string {
@@ -563,6 +592,52 @@ func safeRedirectPath(value string) string {
return parsed.RequestURI() return parsed.RequestURI()
} }
func isUnsafeMethod(method string) bool {
switch method {
case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace:
return false
default:
return true
}
}
func validCSRFToken(r *http.Request, expected string) bool {
if expected == "" {
return false
}
if err := r.ParseForm(); err != nil {
return false
}
token := r.FormValue(csrfFieldName)
if token == "" {
token = r.Header.Get("X-CSRF-Token")
}
return subtle.ConstantTimeCompare([]byte(token), []byte(expected)) == 1
}
func randomToken() (string, error) {
tokenBytes := make([]byte, 32)
if _, err := rand.Read(tokenBytes); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(tokenBytes), nil
}
func (s *Server) withSecurityHeaders(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Security-Policy", "default-src 'self'; script-src 'self'; style-src 'self'; img-src 'self' data:; font-src 'self' data:; base-uri 'self'; form-action 'self'; frame-ancestors 'none'")
w.Header().Set("Permissions-Policy", "camera=(), microphone=(), geolocation=()")
w.Header().Set("Referrer-Policy", "same-origin")
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-Frame-Options", "DENY")
if !strings.HasPrefix(r.URL.Path, "/static/") {
w.Header().Set("Cache-Control", "no-store")
}
next.ServeHTTP(w, r)
})
}
func (s *Server) withLogging(next http.Handler) http.Handler { func (s *Server) withLogging(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now() start := time.Now()
@@ -584,10 +659,6 @@ func parseLines(raw string) []string {
return values return values
} }
func ensureTrailingDot(value string) string {
return dnsrecord.EnsureTrailingDot(value)
}
func parseFQDNLines(raw, label string) ([]string, error) { func parseFQDNLines(raw, label string) ([]string, error) {
values := parseLines(raw) values := parseLines(raw)
for i, value := range values { for i, value := range values {

View File

@@ -154,6 +154,15 @@ func TestLoginCreatesSession(t *testing.T) {
if len(cookies) == 0 || cookies[0].Name != sessionCookieName { if len(cookies) == 0 || cookies[0].Name != sessionCookieName {
t.Fatalf("expected session cookie, got %#v", cookies) t.Fatalf("expected session cookie, got %#v", cookies)
} }
if !cookies[0].HttpOnly {
t.Fatal("session cookie must be HttpOnly")
}
if !cookies[0].Secure {
t.Fatal("session cookie must be Secure")
}
if cookies[0].SameSite != http.SameSiteStrictMode {
t.Fatalf("unexpected SameSite policy: %v", cookies[0].SameSite)
}
req = httptest.NewRequest(http.MethodGet, "/zones", nil) req = httptest.NewRequest(http.MethodGet, "/zones", nil)
req.AddCookie(cookies[0]) req.AddCookie(cookies[0])
@@ -196,9 +205,12 @@ func TestLogoutClearsSession(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("createSession returned error: %v", err) t.Fatalf("createSession returned error: %v", err)
} }
csrfToken := srv.sessions[token].CSRFToken
req := httptest.NewRequest(http.MethodGet, "/logout", nil) body := strings.NewReader("csrf_token=" + csrfToken)
req := httptest.NewRequest(http.MethodPost, "/logout", body)
req.AddCookie(&http.Cookie{Name: sessionCookieName, Value: token}) req.AddCookie(&http.Cookie{Name: sessionCookieName, Value: token})
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
srv.routes().ServeHTTP(rec, req) srv.routes().ServeHTTP(rec, req)
@@ -211,6 +223,51 @@ func TestLogoutClearsSession(t *testing.T) {
} }
} }
func TestLogoutRequiresCSRF(t *testing.T) {
srv, err := New(Config{Authenticator: &fakeAuth{allowed: true}}, &fakeClient{}, nil)
if err != nil {
t.Fatalf("New returned error: %v", err)
}
token, err := srv.createSession("alice")
if err != nil {
t.Fatalf("createSession returned error: %v", err)
}
req := httptest.NewRequest(http.MethodPost, "/logout", nil)
req.AddCookie(&http.Cookie{Name: sessionCookieName, Value: token})
rec := httptest.NewRecorder()
srv.routes().ServeHTTP(rec, req)
if rec.Code != http.StatusForbidden {
t.Fatalf("unexpected status: %d", rec.Code)
}
}
func TestSecurityHeadersAreSet(t *testing.T) {
srv, err := New(Config{}, &fakeClient{}, nil)
if err != nil {
t.Fatalf("New returned error: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "/healthz", nil)
rec := httptest.NewRecorder()
srv.routes().ServeHTTP(rec, req)
for _, header := range []string{
"Content-Security-Policy",
"Referrer-Policy",
"Strict-Transport-Security",
"X-Content-Type-Options",
"X-Frame-Options",
} {
if rec.Header().Get(header) == "" {
t.Fatalf("expected %s header", header)
}
}
}
func TestParseLinesSkipsBlankLines(t *testing.T) { func TestParseLinesSkipsBlankLines(t *testing.T) {
values := parseLines("192.0.2.1\n\n192.0.2.2\n") values := parseLines("192.0.2.1\n\n192.0.2.2\n")
if len(values) != 2 { if len(values) != 2 {

View File

@@ -31,9 +31,12 @@
{{ if .CurrentUser }} {{ if .CurrentUser }}
<div class="navbar-nav ms-auto"> <div class="navbar-nav ms-auto">
<span class="nav-link text-secondary">{{ .CurrentUser }}</span> <span class="nav-link text-secondary">{{ .CurrentUser }}</span>
<a class="nav-link" href="/logout"> <form method="post" action="/logout" class="nav-item">
<input type="hidden" name="csrf_token" value="{{ .CSRFToken }}">
<button class="nav-link btn btn-link px-0" type="submit">
<span class="nav-link-title">Logout</span> <span class="nav-link-title">Logout</span>
</a> </button>
</form>
</div> </div>
{{ end }} {{ end }}
{{ end }} {{ end }}

View File

@@ -29,6 +29,7 @@
{{ end }} {{ end }}
<form method="post" action="{{ if .RecordForm.IsEdit }}/zones/{{ .ZoneID }}/rrsets/edit?name={{ urlQuery .RecordForm.Name }}&type={{ urlQuery .RecordForm.Type }}{{ else }}/zones/{{ .ZoneID }}/rrsets{{ end }}"> <form method="post" action="{{ if .RecordForm.IsEdit }}/zones/{{ .ZoneID }}/rrsets/edit?name={{ urlQuery .RecordForm.Name }}&type={{ urlQuery .RecordForm.Type }}{{ else }}/zones/{{ .ZoneID }}/rrsets{{ end }}">
<input type="hidden" name="csrf_token" value="{{ .CSRFToken }}">
{{ if not .RecordForm.IsEdit }} {{ if not .RecordForm.IsEdit }}
<div class="mb-3"> <div class="mb-3">
<label class="form-label">Name</label> <label class="form-label">Name</label>

View File

@@ -49,6 +49,7 @@
<a class="btn btn-outline-primary btn-sm" href="/zones/{{ $.ZoneID }}/rrsets/edit?name={{ urlQuery .Name }}&type={{ urlQuery .Type }}">Edit</a> <a class="btn btn-outline-primary btn-sm" href="/zones/{{ $.ZoneID }}/rrsets/edit?name={{ urlQuery .Name }}&type={{ urlQuery .Type }}">Edit</a>
{{ if not (isSOA .Type) }} {{ if not (isSOA .Type) }}
<form method="post" action="/zones/{{ $.ZoneID }}/rrsets/delete"> <form method="post" action="/zones/{{ $.ZoneID }}/rrsets/delete">
<input type="hidden" name="csrf_token" value="{{ $.CSRFToken }}">
<input type="hidden" name="name" value="{{ .Name }}"> <input type="hidden" name="name" value="{{ .Name }}">
<input type="hidden" name="type" value="{{ .Type }}"> <input type="hidden" name="type" value="{{ .Type }}">
<button class="btn btn-outline-danger btn-sm" type="submit">Delete</button> <button class="btn btn-outline-danger btn-sm" type="submit">Delete</button>

View File

@@ -34,6 +34,7 @@
<td>{{ .Serial }}</td> <td>{{ .Serial }}</td>
<td> <td>
<form method="post" action="/zones/{{ .ID }}/delete"> <form method="post" action="/zones/{{ .ID }}/delete">
<input type="hidden" name="csrf_token" value="{{ $.CSRFToken }}">
<button class="btn btn-outline-danger btn-sm" type="submit">Delete</button> <button class="btn btn-outline-danger btn-sm" type="submit">Delete</button>
</form> </form>
</td> </td>
@@ -55,6 +56,7 @@
</div> </div>
<div class="card-body"> <div class="card-body">
<form method="post" action="/zones"> <form method="post" action="/zones">
<input type="hidden" name="csrf_token" value="{{ .CSRFToken }}">
<div class="mb-3"> <div class="mb-3">
<label class="form-label">Zone name</label> <label class="form-label">Zone name</label>
<input class="form-control" name="name" placeholder="example.org." required> <input class="form-control" name="name" placeholder="example.org." required>