fixa aumento no serial do soa

This commit is contained in:
2026-06-19 18:47:34 -03:00
parent 968f4ef5d9
commit 1901055e25
9 changed files with 353 additions and 231 deletions

View File

@@ -1,9 +1,11 @@
package main package main
import ( import (
"context"
"log" "log"
"net/http" "net/http"
"os" "os"
"time"
"pdns_admin/internal/auth" "pdns_admin/internal/auth"
"pdns_admin/internal/config" "pdns_admin/internal/config"
@@ -20,6 +22,12 @@ func main() {
} }
pdnsClient := pdns.NewClient(cfg.PDNSAPIURL, cfg.PDNSAPIKey, cfg.PDNSServerID, http.DefaultClient) pdnsClient := pdns.NewClient(cfg.PDNSAPIURL, cfg.PDNSAPIKey, cfg.PDNSServerID, http.DefaultClient)
metadataCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := pdnsClient.EnsureAllZonesSOAEditAPI(metadataCtx); err != nil {
logger.Fatalf("failed to ensure SOA-EDIT-API setting: %v", err)
}
var authenticator server.Authenticator var authenticator server.Authenticator
if !cfg.Auth.Disabled { if !cfg.Auth.Disabled {
authenticator, err = auth.NewLDAPAuthenticator(auth.LDAPConfig{ authenticator, err = auth.NewLDAPAuthenticator(auth.LDAPConfig{

View File

@@ -13,6 +13,8 @@ import (
"time" "time"
) )
const soaEditAPIIncrease = "INCREASE"
type HTTPClient interface { type HTTPClient interface {
Do(*http.Request) (*http.Response, error) Do(*http.Request) (*http.Response, error)
} }
@@ -26,25 +28,20 @@ type Client struct {
type Server struct { type Server struct {
ID string `json:"id"` ID string `json:"id"`
Type string `json:"type"`
DaemonType string `json:"daemon_type"` DaemonType string `json:"daemon_type"`
Version string `json:"version"` Version string `json:"version"`
URL string `json:"url"`
ConfigURL string `json:"config_url"`
ZonesURL string `json:"zones_url"`
} }
type Zone struct { type Zone struct {
ID string `json:"id,omitempty"` ID string `json:"id,omitempty"`
Name string `json:"name"` Name string `json:"name,omitempty"`
Type string `json:"type,omitempty"` Type string `json:"type,omitempty"`
Kind string `json:"kind,omitempty"` Kind string `json:"kind,omitempty"`
Serial uint64 `json:"serial,omitempty"` Serial uint64 `json:"serial,omitempty"`
EditedSerial uint64 `json:"edited_serial,omitempty"` SOAEditAPI string `json:"soa_edit_api,omitempty"`
SOAEditAPI string `json:"soa_edit_api,omitempty"` Nameservers []string `json:"nameservers,omitempty"`
Nameservers []string `json:"nameservers,omitempty"` Masters []string `json:"masters,omitempty"`
Masters []string `json:"masters,omitempty"` RRSets []RRSet `json:"rrsets,omitempty"`
RRSets []RRSet `json:"rrsets,omitempty"`
} }
type RRSet struct { type RRSet struct {
@@ -100,9 +97,50 @@ func (c *Client) ListZones(ctx context.Context) ([]Zone, error) {
} }
func (c *Client) CreateZone(ctx context.Context, zone Zone) (Zone, error) { func (c *Client) CreateZone(ctx context.Context, zone Zone) (Zone, error) {
zone.SOAEditAPI = soaEditAPIIncrease
var created Zone var created Zone
err := c.do(ctx, http.MethodPost, c.path("/zones"), zone, &created) if err := c.do(ctx, http.MethodPost, c.path("/zones"), zone, &created); err != nil {
return created, err return Zone{}, err
}
zoneID := created.ID
if zoneID == "" {
zoneID = created.Name
}
if zoneID == "" {
zoneID = zone.Name
}
if err := c.SetZoneSOAEditAPI(ctx, zoneID); err != nil {
return Zone{}, fmt.Errorf("set SOA-EDIT-API for new zone %s: %w", zoneID, err)
}
return created, nil
}
func (c *Client) EnsureAllZonesSOAEditAPI(ctx context.Context) error {
zones, err := c.ListZones(ctx)
if err != nil {
return err
}
for _, zone := range zones {
zoneID := zone.ID
if zoneID == "" {
zoneID = zone.Name
}
if zoneID == "" {
return fmt.Errorf("zone without id or name cannot be updated")
}
if err := c.SetZoneSOAEditAPI(ctx, zoneID); err != nil {
return fmt.Errorf("set SOA-EDIT-API for %s: %w", zoneID, err)
}
}
return nil
}
func (c *Client) SetZoneSOAEditAPI(ctx context.Context, zoneID string) error {
body := Zone{SOAEditAPI: soaEditAPIIncrease}
return c.do(ctx, http.MethodPut, c.path("/zones/"+url.PathEscape(zoneID)), body, nil)
} }
func (c *Client) DeleteZone(ctx context.Context, zoneID string) error { func (c *Client) DeleteZone(ctx context.Context, zoneID string) error {
@@ -116,19 +154,17 @@ func (c *Client) GetZone(ctx context.Context, zoneID string) (Zone, error) {
} }
func (c *Client) CreateRRSet(ctx context.Context, zoneID string, rrset RRSet) error { func (c *Client) CreateRRSet(ctx context.Context, zoneID string, rrset RRSet) error {
var before *Zone
if allowsMultipleRecords(rrset.Type) { if allowsMultipleRecords(rrset.Type) {
zone, err := c.GetZone(ctx, zoneID) zone, err := c.GetZone(ctx, zoneID)
if err != nil { if err != nil {
return fmt.Errorf("read zone before merging records: %w", err) return fmt.Errorf("read zone before merging records: %w", err)
} }
before = &zone
if existing, ok := findRRSet(zone, rrset.Name, rrset.Type); ok { if existing, ok := findRRSet(zone, rrset.Name, rrset.Type); ok {
rrset.Records = mergeRecords(existing.Records, rrset.Records) rrset.Records = mergeRecords(existing.Records, rrset.Records)
} }
} }
return c.patchZoneWithSerialBump(ctx, zoneID, before, []changeRRSet{{ return c.patchZone(ctx, zoneID, []changeRRSet{{
Name: rrset.Name, Name: rrset.Name,
Type: rrset.Type, Type: rrset.Type,
TTL: rrset.TTL, TTL: rrset.TTL,
@@ -175,7 +211,7 @@ func recordKey(record Record) string {
} }
func (c *Client) DeleteRRSet(ctx context.Context, zoneID, name, recordType string) error { func (c *Client) DeleteRRSet(ctx context.Context, zoneID, name, recordType string) error {
return c.patchZoneWithSerialBump(ctx, zoneID, nil, []changeRRSet{{ return c.patchZone(ctx, zoneID, []changeRRSet{{
Name: name, Name: name,
Type: recordType, Type: recordType,
ChangeType: "DELETE", ChangeType: "DELETE",
@@ -189,40 +225,6 @@ func (c *Client) patchZone(ctx context.Context, zoneID string, rrsets []changeRR
return c.do(ctx, http.MethodPatch, c.path("/zones/"+url.PathEscape(zoneID)), body, nil) return c.do(ctx, http.MethodPatch, c.path("/zones/"+url.PathEscape(zoneID)), body, nil)
} }
func (c *Client) patchZoneWithSerialBump(ctx context.Context, zoneID string, before *Zone, rrsets []changeRRSet) error {
if before == nil {
zone, err := c.GetZone(ctx, zoneID)
if err != nil {
return fmt.Errorf("read zone before change: %w", err)
}
before = &zone
}
if err := c.patchZone(ctx, zoneID, rrsets); err != nil {
return err
}
after, err := c.GetZone(ctx, zoneID)
if err != nil {
return fmt.Errorf("read zone after change: %w", err)
}
if serialIncreased(*before, after) {
return nil
}
soa, err := bumpedSOA(after, before.Serial+1)
if err != nil {
return fmt.Errorf("bump SOA serial: %w", err)
}
return c.patchZone(ctx, zoneID, []changeRRSet{{
Name: soa.Name,
Type: soa.Type,
TTL: soa.TTL,
ChangeType: "REPLACE",
Records: soa.Records,
}})
}
func (c *Client) path(suffix string) string { func (c *Client) path(suffix string) string {
return "/api/v1/servers/" + url.PathEscape(c.serverID) + suffix return "/api/v1/servers/" + url.PathEscape(c.serverID) + suffix
} }
@@ -275,53 +277,3 @@ func (c *Client) do(ctx context.Context, method, apiPath string, in any, out any
return nil return nil
} }
func serialIncreased(before, after Zone) bool {
if after.Serial > before.Serial {
return true
}
if (before.EditedSerial != 0 || after.EditedSerial != 0) && after.EditedSerial > before.EditedSerial {
return true
}
return false
}
func bumpedSOA(zone Zone, minimum uint64) (RRSet, error) {
for _, rrset := range zone.RRSets {
if !strings.EqualFold(rrset.Type, "SOA") {
continue
}
if len(rrset.Records) != 1 {
return RRSet{}, fmt.Errorf("expected exactly one SOA record, got %d", len(rrset.Records))
}
record := rrset.Records[0]
fields := strings.Fields(record.Content)
if len(fields) != 7 {
return RRSet{}, fmt.Errorf("SOA record must have 7 fields")
}
current, err := strconv.ParseUint(fields[2], 10, 32)
if err != nil {
return RRSet{}, fmt.Errorf("parse SOA serial: %w", err)
}
if current == 1<<32-1 {
return RRSet{}, fmt.Errorf("SOA serial is already at maximum uint32 value")
}
next := current + 1
if next < minimum {
next = minimum
}
if next > 1<<32-1 {
return RRSet{}, fmt.Errorf("next SOA serial exceeds maximum uint32 value")
}
fields[2] = strconv.FormatUint(next, 10)
record.Content = strings.Join(fields, " ")
rrset.Records[0] = record
return rrset, nil
}
return RRSet{}, fmt.Errorf("SOA record not found")
}

View File

@@ -34,7 +34,6 @@ func TestListZonesSendsAPIKey(t *testing.T) {
func TestCreateRRSetPatchesZone(t *testing.T) { func TestCreateRRSetPatchesZone(t *testing.T) {
var patchCount int var patchCount int
var getCount int
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !strings.HasPrefix(r.URL.Path, "/api/v1/servers/localhost/zones/example.org.") { if !strings.HasPrefix(r.URL.Path, "/api/v1/servers/localhost/zones/example.org.") {
@@ -43,12 +42,7 @@ func TestCreateRRSetPatchesZone(t *testing.T) {
switch r.Method { switch r.Method {
case http.MethodGet: case http.MethodGet:
getCount++ writeZoneWithRRSets(t, w, 10, nil)
serial := uint64(10)
if getCount == 2 {
serial = 11
}
writeZone(t, w, serial)
case http.MethodPatch: case http.MethodPatch:
patchCount++ patchCount++
var payload struct { var payload struct {
@@ -80,62 +74,7 @@ func TestCreateRRSetPatchesZone(t *testing.T) {
t.Fatalf("CreateRRSet returned error: %v", err) t.Fatalf("CreateRRSet returned error: %v", err)
} }
if patchCount != 1 { if patchCount != 1 {
t.Fatalf("expected one patch when serial increases, got %d", patchCount) t.Fatalf("expected one patch, got %d", patchCount)
}
}
func TestCreateRRSetBumpsSOAWhenSerialDoesNotIncrease(t *testing.T) {
var patchCount int
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !strings.HasPrefix(r.URL.Path, "/api/v1/servers/localhost/zones/example.org.") {
t.Fatalf("unexpected path: %s", r.URL.Path)
}
switch r.Method {
case http.MethodGet:
writeZone(t, w, 10)
case http.MethodPatch:
patchCount++
var payload struct {
RRSets []changeRRSet `json:"rrsets"`
}
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
t.Fatalf("decode request: %v", err)
}
if len(payload.RRSets) != 1 {
t.Fatalf("unexpected payload: %#v", payload)
}
if patchCount == 2 {
rrset := payload.RRSets[0]
if rrset.Type != "SOA" {
t.Fatalf("expected SOA fallback patch, got %#v", rrset)
}
if got := rrset.Records[0].Content; !strings.Contains(got, " 11 ") {
t.Fatalf("expected bumped SOA serial, got %q", got)
}
}
w.WriteHeader(http.StatusNoContent)
default:
t.Fatalf("unexpected method: %s", r.Method)
}
}))
defer server.Close()
client := NewClient(server.URL, "secret", "localhost", server.Client())
err := client.CreateRRSet(context.Background(), "example.org.", RRSet{
Name: "www.example.org.",
Type: "A",
TTL: 300,
Records: []Record{{
Content: "192.0.2.10",
}},
})
if err != nil {
t.Fatalf("CreateRRSet returned error: %v", err)
}
if patchCount != 2 {
t.Fatalf("expected requested patch plus SOA fallback patch, got %d", patchCount)
} }
} }
@@ -281,23 +220,45 @@ func TestGetServerUsesConfiguredServerID(t *testing.T) {
} }
func TestCreateZonePostsZone(t *testing.T) { func TestCreateZonePostsZone(t *testing.T) {
var posted bool
var put bool
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost { switch r.Method {
case http.MethodPost:
posted = true
if r.URL.Path != "/api/v1/servers/localhost/zones" {
t.Fatalf("unexpected path: %s", r.URL.Path)
}
var payload Zone
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
t.Fatalf("decode request: %v", err)
}
if payload.Name != "example.org." || payload.Kind != "Native" {
t.Fatalf("unexpected payload: %#v", payload)
}
if payload.SOAEditAPI != soaEditAPIIncrease {
t.Fatalf("unexpected soa_edit_api: %q", payload.SOAEditAPI)
}
w.WriteHeader(http.StatusCreated)
_ = json.NewEncoder(w).Encode(payload)
case http.MethodPut:
put = true
if r.URL.Path != "/api/v1/servers/localhost/zones/example.org." {
t.Fatalf("unexpected path: %s", r.URL.Path)
}
var payload Zone
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
t.Fatalf("decode request: %v", err)
}
if payload.SOAEditAPI != soaEditAPIIncrease {
t.Fatalf("unexpected soa_edit_api: %q", payload.SOAEditAPI)
}
w.WriteHeader(http.StatusNoContent)
default:
t.Fatalf("unexpected method: %s", r.Method) t.Fatalf("unexpected method: %s", r.Method)
} }
if r.URL.Path != "/api/v1/servers/localhost/zones" {
t.Fatalf("unexpected path: %s", r.URL.Path)
}
var payload Zone
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
t.Fatalf("decode request: %v", err)
}
if payload.Name != "example.org." || payload.Kind != "Native" {
t.Fatalf("unexpected payload: %#v", payload)
}
w.WriteHeader(http.StatusCreated)
_ = json.NewEncoder(w).Encode(payload)
})) }))
defer server.Close() defer server.Close()
@@ -309,6 +270,78 @@ func TestCreateZonePostsZone(t *testing.T) {
if created.Name != "example.org." { if created.Name != "example.org." {
t.Fatalf("unexpected zone: %#v", created) t.Fatalf("unexpected zone: %#v", created)
} }
if !posted || !put {
t.Fatalf("expected POST and follow-up PUT, posted=%v put=%v", posted, put)
}
}
func TestEnsureAllZonesSOAEditAPIUpdatesAllZones(t *testing.T) {
var updated []string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
if r.URL.Path != "/api/v1/servers/localhost/zones" {
t.Fatalf("unexpected path: %s", r.URL.Path)
}
_ = json.NewEncoder(w).Encode([]Zone{
{ID: "already.example.org.", Name: "already.example.org.", SOAEditAPI: soaEditAPIIncrease},
{ID: "missing.example.org.", Name: "missing.example.org."},
{ID: "", Name: "fallback.example.org.", SOAEditAPI: "DEFAULT"},
})
case http.MethodPut:
var payload Zone
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
t.Fatalf("decode request: %v", err)
}
if payload.SOAEditAPI != soaEditAPIIncrease {
t.Fatalf("unexpected soa_edit_api: %q", payload.SOAEditAPI)
}
updated = append(updated, strings.TrimPrefix(r.URL.Path, "/api/v1/servers/localhost/zones/"))
w.WriteHeader(http.StatusNoContent)
default:
t.Fatalf("unexpected method: %s", r.Method)
}
}))
defer server.Close()
client := NewClient(server.URL, "secret", "localhost", server.Client())
if err := client.EnsureAllZonesSOAEditAPI(context.Background()); err != nil {
t.Fatalf("EnsureAllZonesSOAEditAPI returned error: %v", err)
}
if len(updated) != 3 {
t.Fatalf("expected three updates, got %#v", updated)
}
if updated[0] != "already.example.org." || updated[1] != "missing.example.org." || updated[2] != "fallback.example.org." {
t.Fatalf("unexpected updated zones: %#v", updated)
}
}
func TestSetZoneSOAEditAPIUsesPUT(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPut {
t.Fatalf("unexpected method: %s", r.Method)
}
if r.URL.Path != "/api/v1/servers/localhost/zones/example.org." {
t.Fatalf("unexpected path: %s", r.URL.Path)
}
var payload Zone
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
t.Fatalf("decode request: %v", err)
}
if payload.SOAEditAPI != soaEditAPIIncrease {
t.Fatalf("unexpected soa_edit_api: %q", payload.SOAEditAPI)
}
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
client := NewClient(server.URL, "secret", "localhost", server.Client())
if err := client.SetZoneSOAEditAPI(context.Background(), "example.org."); err != nil {
t.Fatalf("SetZoneSOAEditAPI returned error: %v", err)
}
} }
func TestDeleteZoneDeletesZone(t *testing.T) { func TestDeleteZoneDeletesZone(t *testing.T) {
@@ -329,12 +362,6 @@ func TestDeleteZoneDeletesZone(t *testing.T) {
} }
} }
func writeZone(t *testing.T, w http.ResponseWriter, serial uint64) {
t.Helper()
writeZoneWithRRSets(t, w, serial, nil)
}
func writeZoneWithRRSets(t *testing.T, w http.ResponseWriter, serial uint64, rrsets []RRSet) { func writeZoneWithRRSets(t *testing.T, w http.ResponseWriter, serial uint64, rrsets []RRSet) {
t.Helper() t.Helper()

View File

@@ -3,6 +3,7 @@ package server
import ( import (
"context" "context"
"crypto/rand" "crypto/rand"
"crypto/subtle"
"embed" "embed"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
@@ -23,7 +24,8 @@ import (
var assets embed.FS var assets embed.FS
const ( const (
sessionCookieName = "pdns_admin_session" csrfFieldName = "csrf_token"
sessionCookieName = "__Host-pdns_admin_session"
sessionTTL = 12 * time.Hour sessionTTL = 12 * time.Hour
) )
@@ -62,6 +64,7 @@ type pageData struct {
Error string Error string
AuthEnabled bool AuthEnabled bool
CurrentUser string CurrentUser string
CSRFToken string
Next string Next string
Server pdns.Server Server pdns.Server
ZoneID string ZoneID string
@@ -72,21 +75,19 @@ type pageData struct {
} }
type session struct { type session struct {
Username string Username string
Expires time.Time CSRFToken string
Expires time.Time
} }
type recordForm struct { type recordForm struct {
Name string Name string
Type string Type string
TTL uint32 TTL uint32
Records string Records string
OriginalName string IsEdit bool
OriginalType string Title string
IsEdit bool SubmitLabel string
IsSOA bool
Title string
SubmitLabel string
} }
func New(cfg Config, client PDNSClient, logger *log.Logger) (*Server, error) { func New(cfg Config, client PDNSClient, logger *log.Logger) (*Server, error) {
@@ -106,9 +107,8 @@ func New(cfg Config, client PDNSClient, logger *log.Logger) (*Server, error) {
templates := make(map[string]*template.Template) templates := make(map[string]*template.Template)
funcs := template.FuncMap{ funcs := template.FuncMap{
"isSOA": isSOA, "isSOA": isSOA,
"recordValues": recordValues, "urlQuery": url.QueryEscape,
"urlQuery": url.QueryEscape,
} }
for _, page := range []string{"dashboard.html", "login.html", "zones.html", "zone.html", "record_form.html"} { for _, page := range []string{"dashboard.html", "login.html", "zones.html", "zone.html", "record_form.html"} {
tmpl, err := template.New("base.html").Funcs(funcs).ParseFS(assets, "templates/base.html", "templates/"+page) tmpl, err := template.New("base.html").Funcs(funcs).ParseFS(assets, "templates/base.html", "templates/"+page)
@@ -139,7 +139,7 @@ func (s *Server) routes() http.Handler {
mux.HandleFunc("GET /healthz", s.healthz) mux.HandleFunc("GET /healthz", s.healthz)
mux.HandleFunc("GET /login", s.login) mux.HandleFunc("GET /login", s.login)
mux.HandleFunc("POST /login", s.loginPost) mux.HandleFunc("POST /login", s.loginPost)
mux.HandleFunc("GET /logout", s.logout) mux.HandleFunc("POST /logout", s.logout)
mux.HandleFunc("GET /", s.dashboard) mux.HandleFunc("GET /", s.dashboard)
mux.HandleFunc("GET /zones", s.listZones) mux.HandleFunc("GET /zones", s.listZones)
mux.HandleFunc("POST /zones", s.createZone) mux.HandleFunc("POST /zones", s.createZone)
@@ -150,7 +150,7 @@ func (s *Server) routes() http.Handler {
mux.HandleFunc("POST /zones/{zoneID}/rrsets", s.saveRRSet) mux.HandleFunc("POST /zones/{zoneID}/rrsets", s.saveRRSet)
mux.HandleFunc("POST /zones/{zoneID}/rrsets/edit", s.saveEditedRRSet) mux.HandleFunc("POST /zones/{zoneID}/rrsets/edit", s.saveEditedRRSet)
mux.HandleFunc("POST /zones/{zoneID}/rrsets/delete", s.deleteRRSet) mux.HandleFunc("POST /zones/{zoneID}/rrsets/delete", s.deleteRRSet)
return s.withLogging(s.withSessionAuth(mux)) return s.withLogging(s.withSecurityHeaders(s.withSessionAuth(mux)))
} }
func (s *Server) healthz(w http.ResponseWriter, _ *http.Request) { func (s *Server) healthz(w http.ResponseWriter, _ *http.Request) {
@@ -226,8 +226,10 @@ func (s *Server) loginPost(w http.ResponseWriter, r *http.Request) {
Value: token, Value: token,
Path: "/", Path: "/",
Expires: time.Now().Add(sessionTTL), Expires: time.Now().Add(sessionTTL),
MaxAge: int(sessionTTL.Seconds()),
HttpOnly: true, HttpOnly: true,
SameSite: http.SameSiteLaxMode, Secure: true,
SameSite: http.SameSiteStrictMode,
}) })
http.Redirect(w, r, safeRedirectPath(r.FormValue("next")), http.StatusSeeOther) http.Redirect(w, r, safeRedirectPath(r.FormValue("next")), http.StatusSeeOther)
@@ -244,7 +246,8 @@ func (s *Server) logout(w http.ResponseWriter, r *http.Request) {
Expires: time.Unix(0, 0), Expires: time.Unix(0, 0),
MaxAge: -1, MaxAge: -1,
HttpOnly: true, HttpOnly: true,
SameSite: http.SameSiteLaxMode, Secure: true,
SameSite: http.SameSiteStrictMode,
}) })
http.Redirect(w, r, "/login", http.StatusSeeOther) http.Redirect(w, r, "/login", http.StatusSeeOther)
} }
@@ -363,7 +366,6 @@ func (s *Server) editRRSet(w http.ResponseWriter, r *http.Request) {
TTL: rrset.TTL, TTL: rrset.TTL,
Records: recordValues(rrset), Records: recordValues(rrset),
IsEdit: true, IsEdit: true,
IsSOA: isSOA(rrset.Type),
Title: "Edit record", Title: "Edit record",
SubmitLabel: "Save record", SubmitLabel: "Save record",
} }
@@ -445,7 +447,7 @@ func (s *Server) deleteRRSet(w http.ResponseWriter, r *http.Request) {
return return
} }
name := ensureTrailingDot(r.FormValue("name")) name := dnsrecord.EnsureTrailingDot(r.FormValue("name"))
recordType := strings.ToUpper(strings.TrimSpace(r.FormValue("type"))) recordType := strings.ToUpper(strings.TrimSpace(r.FormValue("type")))
if name == "." || recordType == "" { if name == "." || recordType == "" {
s.redirectZoneError(w, r, zoneID, "record name and type are required") s.redirectZoneError(w, r, zoneID, "record name and type are required")
@@ -466,8 +468,9 @@ func (s *Server) deleteRRSet(w http.ResponseWriter, r *http.Request) {
func (s *Server) render(w http.ResponseWriter, r *http.Request, name string, data pageData) { func (s *Server) render(w http.ResponseWriter, r *http.Request, name string, data pageData) {
data.AuthEnabled = s.auth != nil data.AuthEnabled = s.auth != nil
if user, ok := s.currentUser(r); ok { if sess, ok := s.currentSession(r); ok {
data.CurrentUser = user data.CurrentUser = sess.Username
data.CSRFToken = sess.CSRFToken
} }
w.Header().Set("Content-Type", "text/html; charset=utf-8") w.Header().Set("Content-Type", "text/html; charset=utf-8")
tmpl, ok := s.templates[name] tmpl, ok := s.templates[name]
@@ -498,46 +501,64 @@ func (s *Server) withSessionAuth(next http.Handler) http.Handler {
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
return return
} }
if _, ok := s.currentUser(r); !ok { sess, ok := s.currentSession(r)
if !ok {
http.Redirect(w, r, "/login?next="+url.QueryEscape(r.URL.RequestURI()), http.StatusSeeOther) http.Redirect(w, r, "/login?next="+url.QueryEscape(r.URL.RequestURI()), http.StatusSeeOther)
return return
} }
if isUnsafeMethod(r.Method) && !validCSRFToken(r, sess.CSRFToken) {
http.Error(w, "invalid CSRF token", http.StatusForbidden)
return
}
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
}) })
} }
func (s *Server) currentUser(r *http.Request) (string, bool) { func (s *Server) currentUser(r *http.Request) (string, bool) {
cookie, err := r.Cookie(sessionCookieName) sess, ok := s.currentSession(r)
if err != nil || cookie.Value == "" { if !ok {
return "", false return "", false
} }
return sess.Username, true
}
func (s *Server) currentSession(r *http.Request) (session, bool) {
cookie, err := r.Cookie(sessionCookieName)
if err != nil || cookie.Value == "" || len(cookie.Value) > 128 {
return session{}, false
}
s.sessionsM.Lock() s.sessionsM.Lock()
defer s.sessionsM.Unlock() defer s.sessionsM.Unlock()
sess, ok := s.sessions[cookie.Value] sess, ok := s.sessions[cookie.Value]
if !ok { if !ok {
return "", false return session{}, false
} }
if time.Now().After(sess.Expires) { if time.Now().After(sess.Expires) {
delete(s.sessions, cookie.Value) delete(s.sessions, cookie.Value)
return "", false return session{}, false
} }
return sess.Username, true return sess, true
} }
func (s *Server) createSession(username string) (string, error) { func (s *Server) createSession(username string) (string, error) {
tokenBytes := make([]byte, 32) token, err := randomToken()
if _, err := rand.Read(tokenBytes); err != nil { if err != nil {
return "", err
}
csrfToken, err := randomToken()
if err != nil {
return "", err return "", err
} }
token := base64.RawURLEncoding.EncodeToString(tokenBytes)
s.sessionsM.Lock() s.sessionsM.Lock()
defer s.sessionsM.Unlock() defer s.sessionsM.Unlock()
s.pruneExpiredSessionsLocked(time.Now())
s.sessions[token] = session{ s.sessions[token] = session{
Username: username, Username: username,
Expires: time.Now().Add(sessionTTL), CSRFToken: csrfToken,
Expires: time.Now().Add(sessionTTL),
} }
return token, nil return token, nil
} }
@@ -548,8 +569,16 @@ func (s *Server) deleteSession(token string) {
delete(s.sessions, token) delete(s.sessions, token)
} }
func (s *Server) pruneExpiredSessionsLocked(now time.Time) {
for token, sess := range s.sessions {
if now.After(sess.Expires) {
delete(s.sessions, token)
}
}
}
func isPublicPath(path string) bool { func isPublicPath(path string) bool {
return path == "/login" || path == "/logout" || path == "/healthz" || strings.HasPrefix(path, "/static/") return path == "/login" || path == "/healthz" || strings.HasPrefix(path, "/static/")
} }
func safeRedirectPath(value string) string { func safeRedirectPath(value string) string {
@@ -563,6 +592,52 @@ func safeRedirectPath(value string) string {
return parsed.RequestURI() return parsed.RequestURI()
} }
func isUnsafeMethod(method string) bool {
switch method {
case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace:
return false
default:
return true
}
}
func validCSRFToken(r *http.Request, expected string) bool {
if expected == "" {
return false
}
if err := r.ParseForm(); err != nil {
return false
}
token := r.FormValue(csrfFieldName)
if token == "" {
token = r.Header.Get("X-CSRF-Token")
}
return subtle.ConstantTimeCompare([]byte(token), []byte(expected)) == 1
}
func randomToken() (string, error) {
tokenBytes := make([]byte, 32)
if _, err := rand.Read(tokenBytes); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(tokenBytes), nil
}
func (s *Server) withSecurityHeaders(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Security-Policy", "default-src 'self'; script-src 'self'; style-src 'self'; img-src 'self' data:; font-src 'self' data:; base-uri 'self'; form-action 'self'; frame-ancestors 'none'")
w.Header().Set("Permissions-Policy", "camera=(), microphone=(), geolocation=()")
w.Header().Set("Referrer-Policy", "same-origin")
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-Frame-Options", "DENY")
if !strings.HasPrefix(r.URL.Path, "/static/") {
w.Header().Set("Cache-Control", "no-store")
}
next.ServeHTTP(w, r)
})
}
func (s *Server) withLogging(next http.Handler) http.Handler { func (s *Server) withLogging(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now() start := time.Now()
@@ -584,10 +659,6 @@ func parseLines(raw string) []string {
return values return values
} }
func ensureTrailingDot(value string) string {
return dnsrecord.EnsureTrailingDot(value)
}
func parseFQDNLines(raw, label string) ([]string, error) { func parseFQDNLines(raw, label string) ([]string, error) {
values := parseLines(raw) values := parseLines(raw)
for i, value := range values { for i, value := range values {

View File

@@ -154,6 +154,15 @@ func TestLoginCreatesSession(t *testing.T) {
if len(cookies) == 0 || cookies[0].Name != sessionCookieName { if len(cookies) == 0 || cookies[0].Name != sessionCookieName {
t.Fatalf("expected session cookie, got %#v", cookies) t.Fatalf("expected session cookie, got %#v", cookies)
} }
if !cookies[0].HttpOnly {
t.Fatal("session cookie must be HttpOnly")
}
if !cookies[0].Secure {
t.Fatal("session cookie must be Secure")
}
if cookies[0].SameSite != http.SameSiteStrictMode {
t.Fatalf("unexpected SameSite policy: %v", cookies[0].SameSite)
}
req = httptest.NewRequest(http.MethodGet, "/zones", nil) req = httptest.NewRequest(http.MethodGet, "/zones", nil)
req.AddCookie(cookies[0]) req.AddCookie(cookies[0])
@@ -196,9 +205,12 @@ func TestLogoutClearsSession(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("createSession returned error: %v", err) t.Fatalf("createSession returned error: %v", err)
} }
csrfToken := srv.sessions[token].CSRFToken
req := httptest.NewRequest(http.MethodGet, "/logout", nil) body := strings.NewReader("csrf_token=" + csrfToken)
req := httptest.NewRequest(http.MethodPost, "/logout", body)
req.AddCookie(&http.Cookie{Name: sessionCookieName, Value: token}) req.AddCookie(&http.Cookie{Name: sessionCookieName, Value: token})
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
srv.routes().ServeHTTP(rec, req) srv.routes().ServeHTTP(rec, req)
@@ -211,6 +223,51 @@ func TestLogoutClearsSession(t *testing.T) {
} }
} }
func TestLogoutRequiresCSRF(t *testing.T) {
srv, err := New(Config{Authenticator: &fakeAuth{allowed: true}}, &fakeClient{}, nil)
if err != nil {
t.Fatalf("New returned error: %v", err)
}
token, err := srv.createSession("alice")
if err != nil {
t.Fatalf("createSession returned error: %v", err)
}
req := httptest.NewRequest(http.MethodPost, "/logout", nil)
req.AddCookie(&http.Cookie{Name: sessionCookieName, Value: token})
rec := httptest.NewRecorder()
srv.routes().ServeHTTP(rec, req)
if rec.Code != http.StatusForbidden {
t.Fatalf("unexpected status: %d", rec.Code)
}
}
func TestSecurityHeadersAreSet(t *testing.T) {
srv, err := New(Config{}, &fakeClient{}, nil)
if err != nil {
t.Fatalf("New returned error: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "/healthz", nil)
rec := httptest.NewRecorder()
srv.routes().ServeHTTP(rec, req)
for _, header := range []string{
"Content-Security-Policy",
"Referrer-Policy",
"Strict-Transport-Security",
"X-Content-Type-Options",
"X-Frame-Options",
} {
if rec.Header().Get(header) == "" {
t.Fatalf("expected %s header", header)
}
}
}
func TestParseLinesSkipsBlankLines(t *testing.T) { func TestParseLinesSkipsBlankLines(t *testing.T) {
values := parseLines("192.0.2.1\n\n192.0.2.2\n") values := parseLines("192.0.2.1\n\n192.0.2.2\n")
if len(values) != 2 { if len(values) != 2 {

View File

@@ -31,9 +31,12 @@
{{ if .CurrentUser }} {{ if .CurrentUser }}
<div class="navbar-nav ms-auto"> <div class="navbar-nav ms-auto">
<span class="nav-link text-secondary">{{ .CurrentUser }}</span> <span class="nav-link text-secondary">{{ .CurrentUser }}</span>
<a class="nav-link" href="/logout"> <form method="post" action="/logout" class="nav-item">
<span class="nav-link-title">Logout</span> <input type="hidden" name="csrf_token" value="{{ .CSRFToken }}">
</a> <button class="nav-link btn btn-link px-0" type="submit">
<span class="nav-link-title">Logout</span>
</button>
</form>
</div> </div>
{{ end }} {{ end }}
{{ end }} {{ end }}

View File

@@ -29,6 +29,7 @@
{{ end }} {{ end }}
<form method="post" action="{{ if .RecordForm.IsEdit }}/zones/{{ .ZoneID }}/rrsets/edit?name={{ urlQuery .RecordForm.Name }}&type={{ urlQuery .RecordForm.Type }}{{ else }}/zones/{{ .ZoneID }}/rrsets{{ end }}"> <form method="post" action="{{ if .RecordForm.IsEdit }}/zones/{{ .ZoneID }}/rrsets/edit?name={{ urlQuery .RecordForm.Name }}&type={{ urlQuery .RecordForm.Type }}{{ else }}/zones/{{ .ZoneID }}/rrsets{{ end }}">
<input type="hidden" name="csrf_token" value="{{ .CSRFToken }}">
{{ if not .RecordForm.IsEdit }} {{ if not .RecordForm.IsEdit }}
<div class="mb-3"> <div class="mb-3">
<label class="form-label">Name</label> <label class="form-label">Name</label>

View File

@@ -49,6 +49,7 @@
<a class="btn btn-outline-primary btn-sm" href="/zones/{{ $.ZoneID }}/rrsets/edit?name={{ urlQuery .Name }}&type={{ urlQuery .Type }}">Edit</a> <a class="btn btn-outline-primary btn-sm" href="/zones/{{ $.ZoneID }}/rrsets/edit?name={{ urlQuery .Name }}&type={{ urlQuery .Type }}">Edit</a>
{{ if not (isSOA .Type) }} {{ if not (isSOA .Type) }}
<form method="post" action="/zones/{{ $.ZoneID }}/rrsets/delete"> <form method="post" action="/zones/{{ $.ZoneID }}/rrsets/delete">
<input type="hidden" name="csrf_token" value="{{ $.CSRFToken }}">
<input type="hidden" name="name" value="{{ .Name }}"> <input type="hidden" name="name" value="{{ .Name }}">
<input type="hidden" name="type" value="{{ .Type }}"> <input type="hidden" name="type" value="{{ .Type }}">
<button class="btn btn-outline-danger btn-sm" type="submit">Delete</button> <button class="btn btn-outline-danger btn-sm" type="submit">Delete</button>

View File

@@ -34,6 +34,7 @@
<td>{{ .Serial }}</td> <td>{{ .Serial }}</td>
<td> <td>
<form method="post" action="/zones/{{ .ID }}/delete"> <form method="post" action="/zones/{{ .ID }}/delete">
<input type="hidden" name="csrf_token" value="{{ $.CSRFToken }}">
<button class="btn btn-outline-danger btn-sm" type="submit">Delete</button> <button class="btn btn-outline-danger btn-sm" type="submit">Delete</button>
</form> </form>
</td> </td>
@@ -55,6 +56,7 @@
</div> </div>
<div class="card-body"> <div class="card-body">
<form method="post" action="/zones"> <form method="post" action="/zones">
<input type="hidden" name="csrf_token" value="{{ .CSRFToken }}">
<div class="mb-3"> <div class="mb-3">
<label class="form-label">Zone name</label> <label class="form-label">Zone name</label>
<input class="form-control" name="name" placeholder="example.org." required> <input class="form-control" name="name" placeholder="example.org." required>