fixa aumento no serial do soa
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"pdns_admin/internal/auth"
|
||||
"pdns_admin/internal/config"
|
||||
@@ -20,6 +22,12 @@ func main() {
|
||||
}
|
||||
|
||||
pdnsClient := pdns.NewClient(cfg.PDNSAPIURL, cfg.PDNSAPIKey, cfg.PDNSServerID, http.DefaultClient)
|
||||
metadataCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
if err := pdnsClient.EnsureAllZonesSOAEditAPI(metadataCtx); err != nil {
|
||||
logger.Fatalf("failed to ensure SOA-EDIT-API setting: %v", err)
|
||||
}
|
||||
|
||||
var authenticator server.Authenticator
|
||||
if !cfg.Auth.Disabled {
|
||||
authenticator, err = auth.NewLDAPAuthenticator(auth.LDAPConfig{
|
||||
|
||||
@@ -13,6 +13,8 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
const soaEditAPIIncrease = "INCREASE"
|
||||
|
||||
type HTTPClient interface {
|
||||
Do(*http.Request) (*http.Response, error)
|
||||
}
|
||||
@@ -26,25 +28,20 @@ type Client struct {
|
||||
|
||||
type Server struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
DaemonType string `json:"daemon_type"`
|
||||
Version string `json:"version"`
|
||||
URL string `json:"url"`
|
||||
ConfigURL string `json:"config_url"`
|
||||
ZonesURL string `json:"zones_url"`
|
||||
}
|
||||
|
||||
type Zone struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Kind string `json:"kind,omitempty"`
|
||||
Serial uint64 `json:"serial,omitempty"`
|
||||
EditedSerial uint64 `json:"edited_serial,omitempty"`
|
||||
SOAEditAPI string `json:"soa_edit_api,omitempty"`
|
||||
Nameservers []string `json:"nameservers,omitempty"`
|
||||
Masters []string `json:"masters,omitempty"`
|
||||
RRSets []RRSet `json:"rrsets,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Kind string `json:"kind,omitempty"`
|
||||
Serial uint64 `json:"serial,omitempty"`
|
||||
SOAEditAPI string `json:"soa_edit_api,omitempty"`
|
||||
Nameservers []string `json:"nameservers,omitempty"`
|
||||
Masters []string `json:"masters,omitempty"`
|
||||
RRSets []RRSet `json:"rrsets,omitempty"`
|
||||
}
|
||||
|
||||
type RRSet struct {
|
||||
@@ -100,9 +97,50 @@ func (c *Client) ListZones(ctx context.Context) ([]Zone, error) {
|
||||
}
|
||||
|
||||
func (c *Client) CreateZone(ctx context.Context, zone Zone) (Zone, error) {
|
||||
zone.SOAEditAPI = soaEditAPIIncrease
|
||||
var created Zone
|
||||
err := c.do(ctx, http.MethodPost, c.path("/zones"), zone, &created)
|
||||
return created, err
|
||||
if err := c.do(ctx, http.MethodPost, c.path("/zones"), zone, &created); err != nil {
|
||||
return Zone{}, err
|
||||
}
|
||||
|
||||
zoneID := created.ID
|
||||
if zoneID == "" {
|
||||
zoneID = created.Name
|
||||
}
|
||||
if zoneID == "" {
|
||||
zoneID = zone.Name
|
||||
}
|
||||
if err := c.SetZoneSOAEditAPI(ctx, zoneID); err != nil {
|
||||
return Zone{}, fmt.Errorf("set SOA-EDIT-API for new zone %s: %w", zoneID, err)
|
||||
}
|
||||
|
||||
return created, nil
|
||||
}
|
||||
|
||||
func (c *Client) EnsureAllZonesSOAEditAPI(ctx context.Context) error {
|
||||
zones, err := c.ListZones(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, zone := range zones {
|
||||
zoneID := zone.ID
|
||||
if zoneID == "" {
|
||||
zoneID = zone.Name
|
||||
}
|
||||
if zoneID == "" {
|
||||
return fmt.Errorf("zone without id or name cannot be updated")
|
||||
}
|
||||
if err := c.SetZoneSOAEditAPI(ctx, zoneID); err != nil {
|
||||
return fmt.Errorf("set SOA-EDIT-API for %s: %w", zoneID, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) SetZoneSOAEditAPI(ctx context.Context, zoneID string) error {
|
||||
body := Zone{SOAEditAPI: soaEditAPIIncrease}
|
||||
return c.do(ctx, http.MethodPut, c.path("/zones/"+url.PathEscape(zoneID)), body, nil)
|
||||
}
|
||||
|
||||
func (c *Client) DeleteZone(ctx context.Context, zoneID string) error {
|
||||
@@ -116,19 +154,17 @@ func (c *Client) GetZone(ctx context.Context, zoneID string) (Zone, error) {
|
||||
}
|
||||
|
||||
func (c *Client) CreateRRSet(ctx context.Context, zoneID string, rrset RRSet) error {
|
||||
var before *Zone
|
||||
if allowsMultipleRecords(rrset.Type) {
|
||||
zone, err := c.GetZone(ctx, zoneID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read zone before merging records: %w", err)
|
||||
}
|
||||
before = &zone
|
||||
if existing, ok := findRRSet(zone, rrset.Name, rrset.Type); ok {
|
||||
rrset.Records = mergeRecords(existing.Records, rrset.Records)
|
||||
}
|
||||
}
|
||||
|
||||
return c.patchZoneWithSerialBump(ctx, zoneID, before, []changeRRSet{{
|
||||
return c.patchZone(ctx, zoneID, []changeRRSet{{
|
||||
Name: rrset.Name,
|
||||
Type: rrset.Type,
|
||||
TTL: rrset.TTL,
|
||||
@@ -175,7 +211,7 @@ func recordKey(record Record) string {
|
||||
}
|
||||
|
||||
func (c *Client) DeleteRRSet(ctx context.Context, zoneID, name, recordType string) error {
|
||||
return c.patchZoneWithSerialBump(ctx, zoneID, nil, []changeRRSet{{
|
||||
return c.patchZone(ctx, zoneID, []changeRRSet{{
|
||||
Name: name,
|
||||
Type: recordType,
|
||||
ChangeType: "DELETE",
|
||||
@@ -189,40 +225,6 @@ func (c *Client) patchZone(ctx context.Context, zoneID string, rrsets []changeRR
|
||||
return c.do(ctx, http.MethodPatch, c.path("/zones/"+url.PathEscape(zoneID)), body, nil)
|
||||
}
|
||||
|
||||
func (c *Client) patchZoneWithSerialBump(ctx context.Context, zoneID string, before *Zone, rrsets []changeRRSet) error {
|
||||
if before == nil {
|
||||
zone, err := c.GetZone(ctx, zoneID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read zone before change: %w", err)
|
||||
}
|
||||
before = &zone
|
||||
}
|
||||
|
||||
if err := c.patchZone(ctx, zoneID, rrsets); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
after, err := c.GetZone(ctx, zoneID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read zone after change: %w", err)
|
||||
}
|
||||
if serialIncreased(*before, after) {
|
||||
return nil
|
||||
}
|
||||
|
||||
soa, err := bumpedSOA(after, before.Serial+1)
|
||||
if err != nil {
|
||||
return fmt.Errorf("bump SOA serial: %w", err)
|
||||
}
|
||||
return c.patchZone(ctx, zoneID, []changeRRSet{{
|
||||
Name: soa.Name,
|
||||
Type: soa.Type,
|
||||
TTL: soa.TTL,
|
||||
ChangeType: "REPLACE",
|
||||
Records: soa.Records,
|
||||
}})
|
||||
}
|
||||
|
||||
func (c *Client) path(suffix string) string {
|
||||
return "/api/v1/servers/" + url.PathEscape(c.serverID) + suffix
|
||||
}
|
||||
@@ -275,53 +277,3 @@ func (c *Client) do(ctx context.Context, method, apiPath string, in any, out any
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func serialIncreased(before, after Zone) bool {
|
||||
if after.Serial > before.Serial {
|
||||
return true
|
||||
}
|
||||
if (before.EditedSerial != 0 || after.EditedSerial != 0) && after.EditedSerial > before.EditedSerial {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func bumpedSOA(zone Zone, minimum uint64) (RRSet, error) {
|
||||
for _, rrset := range zone.RRSets {
|
||||
if !strings.EqualFold(rrset.Type, "SOA") {
|
||||
continue
|
||||
}
|
||||
if len(rrset.Records) != 1 {
|
||||
return RRSet{}, fmt.Errorf("expected exactly one SOA record, got %d", len(rrset.Records))
|
||||
}
|
||||
|
||||
record := rrset.Records[0]
|
||||
fields := strings.Fields(record.Content)
|
||||
if len(fields) != 7 {
|
||||
return RRSet{}, fmt.Errorf("SOA record must have 7 fields")
|
||||
}
|
||||
|
||||
current, err := strconv.ParseUint(fields[2], 10, 32)
|
||||
if err != nil {
|
||||
return RRSet{}, fmt.Errorf("parse SOA serial: %w", err)
|
||||
}
|
||||
if current == 1<<32-1 {
|
||||
return RRSet{}, fmt.Errorf("SOA serial is already at maximum uint32 value")
|
||||
}
|
||||
|
||||
next := current + 1
|
||||
if next < minimum {
|
||||
next = minimum
|
||||
}
|
||||
if next > 1<<32-1 {
|
||||
return RRSet{}, fmt.Errorf("next SOA serial exceeds maximum uint32 value")
|
||||
}
|
||||
|
||||
fields[2] = strconv.FormatUint(next, 10)
|
||||
record.Content = strings.Join(fields, " ")
|
||||
rrset.Records[0] = record
|
||||
return rrset, nil
|
||||
}
|
||||
|
||||
return RRSet{}, fmt.Errorf("SOA record not found")
|
||||
}
|
||||
|
||||
@@ -34,7 +34,6 @@ func TestListZonesSendsAPIKey(t *testing.T) {
|
||||
|
||||
func TestCreateRRSetPatchesZone(t *testing.T) {
|
||||
var patchCount int
|
||||
var getCount int
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !strings.HasPrefix(r.URL.Path, "/api/v1/servers/localhost/zones/example.org.") {
|
||||
@@ -43,12 +42,7 @@ func TestCreateRRSetPatchesZone(t *testing.T) {
|
||||
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
getCount++
|
||||
serial := uint64(10)
|
||||
if getCount == 2 {
|
||||
serial = 11
|
||||
}
|
||||
writeZone(t, w, serial)
|
||||
writeZoneWithRRSets(t, w, 10, nil)
|
||||
case http.MethodPatch:
|
||||
patchCount++
|
||||
var payload struct {
|
||||
@@ -80,62 +74,7 @@ func TestCreateRRSetPatchesZone(t *testing.T) {
|
||||
t.Fatalf("CreateRRSet returned error: %v", err)
|
||||
}
|
||||
if patchCount != 1 {
|
||||
t.Fatalf("expected one patch when serial increases, got %d", patchCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateRRSetBumpsSOAWhenSerialDoesNotIncrease(t *testing.T) {
|
||||
var patchCount int
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !strings.HasPrefix(r.URL.Path, "/api/v1/servers/localhost/zones/example.org.") {
|
||||
t.Fatalf("unexpected path: %s", r.URL.Path)
|
||||
}
|
||||
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
writeZone(t, w, 10)
|
||||
case http.MethodPatch:
|
||||
patchCount++
|
||||
var payload struct {
|
||||
RRSets []changeRRSet `json:"rrsets"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
||||
t.Fatalf("decode request: %v", err)
|
||||
}
|
||||
if len(payload.RRSets) != 1 {
|
||||
t.Fatalf("unexpected payload: %#v", payload)
|
||||
}
|
||||
if patchCount == 2 {
|
||||
rrset := payload.RRSets[0]
|
||||
if rrset.Type != "SOA" {
|
||||
t.Fatalf("expected SOA fallback patch, got %#v", rrset)
|
||||
}
|
||||
if got := rrset.Records[0].Content; !strings.Contains(got, " 11 ") {
|
||||
t.Fatalf("expected bumped SOA serial, got %q", got)
|
||||
}
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
default:
|
||||
t.Fatalf("unexpected method: %s", r.Method)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "secret", "localhost", server.Client())
|
||||
err := client.CreateRRSet(context.Background(), "example.org.", RRSet{
|
||||
Name: "www.example.org.",
|
||||
Type: "A",
|
||||
TTL: 300,
|
||||
Records: []Record{{
|
||||
Content: "192.0.2.10",
|
||||
}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CreateRRSet returned error: %v", err)
|
||||
}
|
||||
if patchCount != 2 {
|
||||
t.Fatalf("expected requested patch plus SOA fallback patch, got %d", patchCount)
|
||||
t.Fatalf("expected one patch, got %d", patchCount)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -281,23 +220,45 @@ func TestGetServerUsesConfiguredServerID(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCreateZonePostsZone(t *testing.T) {
|
||||
var posted bool
|
||||
var put bool
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
switch r.Method {
|
||||
case http.MethodPost:
|
||||
posted = true
|
||||
if r.URL.Path != "/api/v1/servers/localhost/zones" {
|
||||
t.Fatalf("unexpected path: %s", r.URL.Path)
|
||||
}
|
||||
|
||||
var payload Zone
|
||||
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
||||
t.Fatalf("decode request: %v", err)
|
||||
}
|
||||
if payload.Name != "example.org." || payload.Kind != "Native" {
|
||||
t.Fatalf("unexpected payload: %#v", payload)
|
||||
}
|
||||
if payload.SOAEditAPI != soaEditAPIIncrease {
|
||||
t.Fatalf("unexpected soa_edit_api: %q", payload.SOAEditAPI)
|
||||
}
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
_ = json.NewEncoder(w).Encode(payload)
|
||||
case http.MethodPut:
|
||||
put = true
|
||||
if r.URL.Path != "/api/v1/servers/localhost/zones/example.org." {
|
||||
t.Fatalf("unexpected path: %s", r.URL.Path)
|
||||
}
|
||||
var payload Zone
|
||||
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
||||
t.Fatalf("decode request: %v", err)
|
||||
}
|
||||
if payload.SOAEditAPI != soaEditAPIIncrease {
|
||||
t.Fatalf("unexpected soa_edit_api: %q", payload.SOAEditAPI)
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
default:
|
||||
t.Fatalf("unexpected method: %s", r.Method)
|
||||
}
|
||||
if r.URL.Path != "/api/v1/servers/localhost/zones" {
|
||||
t.Fatalf("unexpected path: %s", r.URL.Path)
|
||||
}
|
||||
|
||||
var payload Zone
|
||||
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
||||
t.Fatalf("decode request: %v", err)
|
||||
}
|
||||
if payload.Name != "example.org." || payload.Kind != "Native" {
|
||||
t.Fatalf("unexpected payload: %#v", payload)
|
||||
}
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
_ = json.NewEncoder(w).Encode(payload)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
@@ -309,6 +270,78 @@ func TestCreateZonePostsZone(t *testing.T) {
|
||||
if created.Name != "example.org." {
|
||||
t.Fatalf("unexpected zone: %#v", created)
|
||||
}
|
||||
if !posted || !put {
|
||||
t.Fatalf("expected POST and follow-up PUT, posted=%v put=%v", posted, put)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureAllZonesSOAEditAPIUpdatesAllZones(t *testing.T) {
|
||||
var updated []string
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
if r.URL.Path != "/api/v1/servers/localhost/zones" {
|
||||
t.Fatalf("unexpected path: %s", r.URL.Path)
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode([]Zone{
|
||||
{ID: "already.example.org.", Name: "already.example.org.", SOAEditAPI: soaEditAPIIncrease},
|
||||
{ID: "missing.example.org.", Name: "missing.example.org."},
|
||||
{ID: "", Name: "fallback.example.org.", SOAEditAPI: "DEFAULT"},
|
||||
})
|
||||
case http.MethodPut:
|
||||
var payload Zone
|
||||
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
||||
t.Fatalf("decode request: %v", err)
|
||||
}
|
||||
if payload.SOAEditAPI != soaEditAPIIncrease {
|
||||
t.Fatalf("unexpected soa_edit_api: %q", payload.SOAEditAPI)
|
||||
}
|
||||
updated = append(updated, strings.TrimPrefix(r.URL.Path, "/api/v1/servers/localhost/zones/"))
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
default:
|
||||
t.Fatalf("unexpected method: %s", r.Method)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "secret", "localhost", server.Client())
|
||||
if err := client.EnsureAllZonesSOAEditAPI(context.Background()); err != nil {
|
||||
t.Fatalf("EnsureAllZonesSOAEditAPI returned error: %v", err)
|
||||
}
|
||||
|
||||
if len(updated) != 3 {
|
||||
t.Fatalf("expected three updates, got %#v", updated)
|
||||
}
|
||||
if updated[0] != "already.example.org." || updated[1] != "missing.example.org." || updated[2] != "fallback.example.org." {
|
||||
t.Fatalf("unexpected updated zones: %#v", updated)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetZoneSOAEditAPIUsesPUT(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPut {
|
||||
t.Fatalf("unexpected method: %s", r.Method)
|
||||
}
|
||||
if r.URL.Path != "/api/v1/servers/localhost/zones/example.org." {
|
||||
t.Fatalf("unexpected path: %s", r.URL.Path)
|
||||
}
|
||||
|
||||
var payload Zone
|
||||
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
||||
t.Fatalf("decode request: %v", err)
|
||||
}
|
||||
if payload.SOAEditAPI != soaEditAPIIncrease {
|
||||
t.Fatalf("unexpected soa_edit_api: %q", payload.SOAEditAPI)
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "secret", "localhost", server.Client())
|
||||
if err := client.SetZoneSOAEditAPI(context.Background(), "example.org."); err != nil {
|
||||
t.Fatalf("SetZoneSOAEditAPI returned error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteZoneDeletesZone(t *testing.T) {
|
||||
@@ -329,12 +362,6 @@ func TestDeleteZoneDeletesZone(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func writeZone(t *testing.T, w http.ResponseWriter, serial uint64) {
|
||||
t.Helper()
|
||||
|
||||
writeZoneWithRRSets(t, w, serial, nil)
|
||||
}
|
||||
|
||||
func writeZoneWithRRSets(t *testing.T, w http.ResponseWriter, serial uint64, rrsets []RRSet) {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"embed"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
@@ -23,7 +24,8 @@ import (
|
||||
var assets embed.FS
|
||||
|
||||
const (
|
||||
sessionCookieName = "pdns_admin_session"
|
||||
csrfFieldName = "csrf_token"
|
||||
sessionCookieName = "__Host-pdns_admin_session"
|
||||
sessionTTL = 12 * time.Hour
|
||||
)
|
||||
|
||||
@@ -62,6 +64,7 @@ type pageData struct {
|
||||
Error string
|
||||
AuthEnabled bool
|
||||
CurrentUser string
|
||||
CSRFToken string
|
||||
Next string
|
||||
Server pdns.Server
|
||||
ZoneID string
|
||||
@@ -72,21 +75,19 @@ type pageData struct {
|
||||
}
|
||||
|
||||
type session struct {
|
||||
Username string
|
||||
Expires time.Time
|
||||
Username string
|
||||
CSRFToken string
|
||||
Expires time.Time
|
||||
}
|
||||
|
||||
type recordForm struct {
|
||||
Name string
|
||||
Type string
|
||||
TTL uint32
|
||||
Records string
|
||||
OriginalName string
|
||||
OriginalType string
|
||||
IsEdit bool
|
||||
IsSOA bool
|
||||
Title string
|
||||
SubmitLabel string
|
||||
Name string
|
||||
Type string
|
||||
TTL uint32
|
||||
Records string
|
||||
IsEdit bool
|
||||
Title string
|
||||
SubmitLabel string
|
||||
}
|
||||
|
||||
func New(cfg Config, client PDNSClient, logger *log.Logger) (*Server, error) {
|
||||
@@ -106,9 +107,8 @@ func New(cfg Config, client PDNSClient, logger *log.Logger) (*Server, error) {
|
||||
|
||||
templates := make(map[string]*template.Template)
|
||||
funcs := template.FuncMap{
|
||||
"isSOA": isSOA,
|
||||
"recordValues": recordValues,
|
||||
"urlQuery": url.QueryEscape,
|
||||
"isSOA": isSOA,
|
||||
"urlQuery": url.QueryEscape,
|
||||
}
|
||||
for _, page := range []string{"dashboard.html", "login.html", "zones.html", "zone.html", "record_form.html"} {
|
||||
tmpl, err := template.New("base.html").Funcs(funcs).ParseFS(assets, "templates/base.html", "templates/"+page)
|
||||
@@ -139,7 +139,7 @@ func (s *Server) routes() http.Handler {
|
||||
mux.HandleFunc("GET /healthz", s.healthz)
|
||||
mux.HandleFunc("GET /login", s.login)
|
||||
mux.HandleFunc("POST /login", s.loginPost)
|
||||
mux.HandleFunc("GET /logout", s.logout)
|
||||
mux.HandleFunc("POST /logout", s.logout)
|
||||
mux.HandleFunc("GET /", s.dashboard)
|
||||
mux.HandleFunc("GET /zones", s.listZones)
|
||||
mux.HandleFunc("POST /zones", s.createZone)
|
||||
@@ -150,7 +150,7 @@ func (s *Server) routes() http.Handler {
|
||||
mux.HandleFunc("POST /zones/{zoneID}/rrsets", s.saveRRSet)
|
||||
mux.HandleFunc("POST /zones/{zoneID}/rrsets/edit", s.saveEditedRRSet)
|
||||
mux.HandleFunc("POST /zones/{zoneID}/rrsets/delete", s.deleteRRSet)
|
||||
return s.withLogging(s.withSessionAuth(mux))
|
||||
return s.withLogging(s.withSecurityHeaders(s.withSessionAuth(mux)))
|
||||
}
|
||||
|
||||
func (s *Server) healthz(w http.ResponseWriter, _ *http.Request) {
|
||||
@@ -226,8 +226,10 @@ func (s *Server) loginPost(w http.ResponseWriter, r *http.Request) {
|
||||
Value: token,
|
||||
Path: "/",
|
||||
Expires: time.Now().Add(sessionTTL),
|
||||
MaxAge: int(sessionTTL.Seconds()),
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
})
|
||||
|
||||
http.Redirect(w, r, safeRedirectPath(r.FormValue("next")), http.StatusSeeOther)
|
||||
@@ -244,7 +246,8 @@ func (s *Server) logout(w http.ResponseWriter, r *http.Request) {
|
||||
Expires: time.Unix(0, 0),
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
})
|
||||
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
||||
}
|
||||
@@ -363,7 +366,6 @@ func (s *Server) editRRSet(w http.ResponseWriter, r *http.Request) {
|
||||
TTL: rrset.TTL,
|
||||
Records: recordValues(rrset),
|
||||
IsEdit: true,
|
||||
IsSOA: isSOA(rrset.Type),
|
||||
Title: "Edit record",
|
||||
SubmitLabel: "Save record",
|
||||
}
|
||||
@@ -445,7 +447,7 @@ func (s *Server) deleteRRSet(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
name := ensureTrailingDot(r.FormValue("name"))
|
||||
name := dnsrecord.EnsureTrailingDot(r.FormValue("name"))
|
||||
recordType := strings.ToUpper(strings.TrimSpace(r.FormValue("type")))
|
||||
if name == "." || recordType == "" {
|
||||
s.redirectZoneError(w, r, zoneID, "record name and type are required")
|
||||
@@ -466,8 +468,9 @@ func (s *Server) deleteRRSet(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
func (s *Server) render(w http.ResponseWriter, r *http.Request, name string, data pageData) {
|
||||
data.AuthEnabled = s.auth != nil
|
||||
if user, ok := s.currentUser(r); ok {
|
||||
data.CurrentUser = user
|
||||
if sess, ok := s.currentSession(r); ok {
|
||||
data.CurrentUser = sess.Username
|
||||
data.CSRFToken = sess.CSRFToken
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
tmpl, ok := s.templates[name]
|
||||
@@ -498,46 +501,64 @@ func (s *Server) withSessionAuth(next http.Handler) http.Handler {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
if _, ok := s.currentUser(r); !ok {
|
||||
sess, ok := s.currentSession(r)
|
||||
if !ok {
|
||||
http.Redirect(w, r, "/login?next="+url.QueryEscape(r.URL.RequestURI()), http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
if isUnsafeMethod(r.Method) && !validCSRFToken(r, sess.CSRFToken) {
|
||||
http.Error(w, "invalid CSRF token", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) currentUser(r *http.Request) (string, bool) {
|
||||
cookie, err := r.Cookie(sessionCookieName)
|
||||
if err != nil || cookie.Value == "" {
|
||||
sess, ok := s.currentSession(r)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
return sess.Username, true
|
||||
}
|
||||
|
||||
func (s *Server) currentSession(r *http.Request) (session, bool) {
|
||||
cookie, err := r.Cookie(sessionCookieName)
|
||||
if err != nil || cookie.Value == "" || len(cookie.Value) > 128 {
|
||||
return session{}, false
|
||||
}
|
||||
|
||||
s.sessionsM.Lock()
|
||||
defer s.sessionsM.Unlock()
|
||||
|
||||
sess, ok := s.sessions[cookie.Value]
|
||||
if !ok {
|
||||
return "", false
|
||||
return session{}, false
|
||||
}
|
||||
if time.Now().After(sess.Expires) {
|
||||
delete(s.sessions, cookie.Value)
|
||||
return "", false
|
||||
return session{}, false
|
||||
}
|
||||
return sess.Username, true
|
||||
return sess, true
|
||||
}
|
||||
|
||||
func (s *Server) createSession(username string) (string, error) {
|
||||
tokenBytes := make([]byte, 32)
|
||||
if _, err := rand.Read(tokenBytes); err != nil {
|
||||
token, err := randomToken()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
csrfToken, err := randomToken()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
token := base64.RawURLEncoding.EncodeToString(tokenBytes)
|
||||
|
||||
s.sessionsM.Lock()
|
||||
defer s.sessionsM.Unlock()
|
||||
s.pruneExpiredSessionsLocked(time.Now())
|
||||
s.sessions[token] = session{
|
||||
Username: username,
|
||||
Expires: time.Now().Add(sessionTTL),
|
||||
Username: username,
|
||||
CSRFToken: csrfToken,
|
||||
Expires: time.Now().Add(sessionTTL),
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
@@ -548,8 +569,16 @@ func (s *Server) deleteSession(token string) {
|
||||
delete(s.sessions, token)
|
||||
}
|
||||
|
||||
func (s *Server) pruneExpiredSessionsLocked(now time.Time) {
|
||||
for token, sess := range s.sessions {
|
||||
if now.After(sess.Expires) {
|
||||
delete(s.sessions, token)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isPublicPath(path string) bool {
|
||||
return path == "/login" || path == "/logout" || path == "/healthz" || strings.HasPrefix(path, "/static/")
|
||||
return path == "/login" || path == "/healthz" || strings.HasPrefix(path, "/static/")
|
||||
}
|
||||
|
||||
func safeRedirectPath(value string) string {
|
||||
@@ -563,6 +592,52 @@ func safeRedirectPath(value string) string {
|
||||
return parsed.RequestURI()
|
||||
}
|
||||
|
||||
func isUnsafeMethod(method string) bool {
|
||||
switch method {
|
||||
case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace:
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func validCSRFToken(r *http.Request, expected string) bool {
|
||||
if expected == "" {
|
||||
return false
|
||||
}
|
||||
if err := r.ParseForm(); err != nil {
|
||||
return false
|
||||
}
|
||||
token := r.FormValue(csrfFieldName)
|
||||
if token == "" {
|
||||
token = r.Header.Get("X-CSRF-Token")
|
||||
}
|
||||
return subtle.ConstantTimeCompare([]byte(token), []byte(expected)) == 1
|
||||
}
|
||||
|
||||
func randomToken() (string, error) {
|
||||
tokenBytes := make([]byte, 32)
|
||||
if _, err := rand.Read(tokenBytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(tokenBytes), nil
|
||||
}
|
||||
|
||||
func (s *Server) withSecurityHeaders(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Security-Policy", "default-src 'self'; script-src 'self'; style-src 'self'; img-src 'self' data:; font-src 'self' data:; base-uri 'self'; form-action 'self'; frame-ancestors 'none'")
|
||||
w.Header().Set("Permissions-Policy", "camera=(), microphone=(), geolocation=()")
|
||||
w.Header().Set("Referrer-Policy", "same-origin")
|
||||
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
w.Header().Set("X-Frame-Options", "DENY")
|
||||
if !strings.HasPrefix(r.URL.Path, "/static/") {
|
||||
w.Header().Set("Cache-Control", "no-store")
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) withLogging(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
@@ -584,10 +659,6 @@ func parseLines(raw string) []string {
|
||||
return values
|
||||
}
|
||||
|
||||
func ensureTrailingDot(value string) string {
|
||||
return dnsrecord.EnsureTrailingDot(value)
|
||||
}
|
||||
|
||||
func parseFQDNLines(raw, label string) ([]string, error) {
|
||||
values := parseLines(raw)
|
||||
for i, value := range values {
|
||||
|
||||
@@ -154,6 +154,15 @@ func TestLoginCreatesSession(t *testing.T) {
|
||||
if len(cookies) == 0 || cookies[0].Name != sessionCookieName {
|
||||
t.Fatalf("expected session cookie, got %#v", cookies)
|
||||
}
|
||||
if !cookies[0].HttpOnly {
|
||||
t.Fatal("session cookie must be HttpOnly")
|
||||
}
|
||||
if !cookies[0].Secure {
|
||||
t.Fatal("session cookie must be Secure")
|
||||
}
|
||||
if cookies[0].SameSite != http.SameSiteStrictMode {
|
||||
t.Fatalf("unexpected SameSite policy: %v", cookies[0].SameSite)
|
||||
}
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/zones", nil)
|
||||
req.AddCookie(cookies[0])
|
||||
@@ -196,9 +205,12 @@ func TestLogoutClearsSession(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("createSession returned error: %v", err)
|
||||
}
|
||||
csrfToken := srv.sessions[token].CSRFToken
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/logout", nil)
|
||||
body := strings.NewReader("csrf_token=" + csrfToken)
|
||||
req := httptest.NewRequest(http.MethodPost, "/logout", body)
|
||||
req.AddCookie(&http.Cookie{Name: sessionCookieName, Value: token})
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
srv.routes().ServeHTTP(rec, req)
|
||||
@@ -211,6 +223,51 @@ func TestLogoutClearsSession(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogoutRequiresCSRF(t *testing.T) {
|
||||
srv, err := New(Config{Authenticator: &fakeAuth{allowed: true}}, &fakeClient{}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("New returned error: %v", err)
|
||||
}
|
||||
token, err := srv.createSession("alice")
|
||||
if err != nil {
|
||||
t.Fatalf("createSession returned error: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/logout", nil)
|
||||
req.AddCookie(&http.Cookie{Name: sessionCookieName, Value: token})
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
srv.routes().ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusForbidden {
|
||||
t.Fatalf("unexpected status: %d", rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityHeadersAreSet(t *testing.T) {
|
||||
srv, err := New(Config{}, &fakeClient{}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("New returned error: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/healthz", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
srv.routes().ServeHTTP(rec, req)
|
||||
|
||||
for _, header := range []string{
|
||||
"Content-Security-Policy",
|
||||
"Referrer-Policy",
|
||||
"Strict-Transport-Security",
|
||||
"X-Content-Type-Options",
|
||||
"X-Frame-Options",
|
||||
} {
|
||||
if rec.Header().Get(header) == "" {
|
||||
t.Fatalf("expected %s header", header)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseLinesSkipsBlankLines(t *testing.T) {
|
||||
values := parseLines("192.0.2.1\n\n192.0.2.2\n")
|
||||
if len(values) != 2 {
|
||||
|
||||
@@ -31,9 +31,12 @@
|
||||
{{ if .CurrentUser }}
|
||||
<div class="navbar-nav ms-auto">
|
||||
<span class="nav-link text-secondary">{{ .CurrentUser }}</span>
|
||||
<a class="nav-link" href="/logout">
|
||||
<span class="nav-link-title">Logout</span>
|
||||
</a>
|
||||
<form method="post" action="/logout" class="nav-item">
|
||||
<input type="hidden" name="csrf_token" value="{{ .CSRFToken }}">
|
||||
<button class="nav-link btn btn-link px-0" type="submit">
|
||||
<span class="nav-link-title">Logout</span>
|
||||
</button>
|
||||
</form>
|
||||
</div>
|
||||
{{ end }}
|
||||
{{ end }}
|
||||
|
||||
@@ -29,6 +29,7 @@
|
||||
{{ end }}
|
||||
|
||||
<form method="post" action="{{ if .RecordForm.IsEdit }}/zones/{{ .ZoneID }}/rrsets/edit?name={{ urlQuery .RecordForm.Name }}&type={{ urlQuery .RecordForm.Type }}{{ else }}/zones/{{ .ZoneID }}/rrsets{{ end }}">
|
||||
<input type="hidden" name="csrf_token" value="{{ .CSRFToken }}">
|
||||
{{ if not .RecordForm.IsEdit }}
|
||||
<div class="mb-3">
|
||||
<label class="form-label">Name</label>
|
||||
|
||||
@@ -49,6 +49,7 @@
|
||||
<a class="btn btn-outline-primary btn-sm" href="/zones/{{ $.ZoneID }}/rrsets/edit?name={{ urlQuery .Name }}&type={{ urlQuery .Type }}">Edit</a>
|
||||
{{ if not (isSOA .Type) }}
|
||||
<form method="post" action="/zones/{{ $.ZoneID }}/rrsets/delete">
|
||||
<input type="hidden" name="csrf_token" value="{{ $.CSRFToken }}">
|
||||
<input type="hidden" name="name" value="{{ .Name }}">
|
||||
<input type="hidden" name="type" value="{{ .Type }}">
|
||||
<button class="btn btn-outline-danger btn-sm" type="submit">Delete</button>
|
||||
|
||||
@@ -34,6 +34,7 @@
|
||||
<td>{{ .Serial }}</td>
|
||||
<td>
|
||||
<form method="post" action="/zones/{{ .ID }}/delete">
|
||||
<input type="hidden" name="csrf_token" value="{{ $.CSRFToken }}">
|
||||
<button class="btn btn-outline-danger btn-sm" type="submit">Delete</button>
|
||||
</form>
|
||||
</td>
|
||||
@@ -55,6 +56,7 @@
|
||||
</div>
|
||||
<div class="card-body">
|
||||
<form method="post" action="/zones">
|
||||
<input type="hidden" name="csrf_token" value="{{ .CSRFToken }}">
|
||||
<div class="mb-3">
|
||||
<label class="form-label">Zone name</label>
|
||||
<input class="form-control" name="name" placeholder="example.org." required>
|
||||
|
||||
Reference in New Issue
Block a user