diff --git a/cmd/pdns-admin/main.go b/cmd/pdns-admin/main.go index 612d7db..c278c4a 100644 --- a/cmd/pdns-admin/main.go +++ b/cmd/pdns-admin/main.go @@ -1,9 +1,11 @@ package main import ( + "context" "log" "net/http" "os" + "time" "pdns_admin/internal/auth" "pdns_admin/internal/config" @@ -20,6 +22,12 @@ func main() { } pdnsClient := pdns.NewClient(cfg.PDNSAPIURL, cfg.PDNSAPIKey, cfg.PDNSServerID, http.DefaultClient) + metadataCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + if err := pdnsClient.EnsureAllZonesSOAEditAPI(metadataCtx); err != nil { + logger.Fatalf("failed to ensure SOA-EDIT-API setting: %v", err) + } + var authenticator server.Authenticator if !cfg.Auth.Disabled { authenticator, err = auth.NewLDAPAuthenticator(auth.LDAPConfig{ diff --git a/internal/pdns/client.go b/internal/pdns/client.go index 49ecbd2..10d1b0b 100644 --- a/internal/pdns/client.go +++ b/internal/pdns/client.go @@ -13,6 +13,8 @@ import ( "time" ) +const soaEditAPIIncrease = "INCREASE" + type HTTPClient interface { Do(*http.Request) (*http.Response, error) } @@ -26,25 +28,20 @@ type Client struct { type Server struct { ID string `json:"id"` - Type string `json:"type"` DaemonType string `json:"daemon_type"` Version string `json:"version"` - URL string `json:"url"` - ConfigURL string `json:"config_url"` - ZonesURL string `json:"zones_url"` } type Zone struct { - ID string `json:"id,omitempty"` - Name string `json:"name"` - Type string `json:"type,omitempty"` - Kind string `json:"kind,omitempty"` - Serial uint64 `json:"serial,omitempty"` - EditedSerial uint64 `json:"edited_serial,omitempty"` - SOAEditAPI string `json:"soa_edit_api,omitempty"` - Nameservers []string `json:"nameservers,omitempty"` - Masters []string `json:"masters,omitempty"` - RRSets []RRSet `json:"rrsets,omitempty"` + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Type string `json:"type,omitempty"` + Kind string `json:"kind,omitempty"` + Serial uint64 `json:"serial,omitempty"` + SOAEditAPI string `json:"soa_edit_api,omitempty"` + Nameservers []string `json:"nameservers,omitempty"` + Masters []string `json:"masters,omitempty"` + RRSets []RRSet `json:"rrsets,omitempty"` } type RRSet struct { @@ -100,9 +97,50 @@ func (c *Client) ListZones(ctx context.Context) ([]Zone, error) { } func (c *Client) CreateZone(ctx context.Context, zone Zone) (Zone, error) { + zone.SOAEditAPI = soaEditAPIIncrease var created Zone - err := c.do(ctx, http.MethodPost, c.path("/zones"), zone, &created) - return created, err + if err := c.do(ctx, http.MethodPost, c.path("/zones"), zone, &created); err != nil { + return Zone{}, err + } + + zoneID := created.ID + if zoneID == "" { + zoneID = created.Name + } + if zoneID == "" { + zoneID = zone.Name + } + if err := c.SetZoneSOAEditAPI(ctx, zoneID); err != nil { + return Zone{}, fmt.Errorf("set SOA-EDIT-API for new zone %s: %w", zoneID, err) + } + + return created, nil +} + +func (c *Client) EnsureAllZonesSOAEditAPI(ctx context.Context) error { + zones, err := c.ListZones(ctx) + if err != nil { + return err + } + + for _, zone := range zones { + zoneID := zone.ID + if zoneID == "" { + zoneID = zone.Name + } + if zoneID == "" { + return fmt.Errorf("zone without id or name cannot be updated") + } + if err := c.SetZoneSOAEditAPI(ctx, zoneID); err != nil { + return fmt.Errorf("set SOA-EDIT-API for %s: %w", zoneID, err) + } + } + return nil +} + +func (c *Client) SetZoneSOAEditAPI(ctx context.Context, zoneID string) error { + body := Zone{SOAEditAPI: soaEditAPIIncrease} + return c.do(ctx, http.MethodPut, c.path("/zones/"+url.PathEscape(zoneID)), body, nil) } func (c *Client) DeleteZone(ctx context.Context, zoneID string) error { @@ -116,19 +154,17 @@ func (c *Client) GetZone(ctx context.Context, zoneID string) (Zone, error) { } func (c *Client) CreateRRSet(ctx context.Context, zoneID string, rrset RRSet) error { - var before *Zone if allowsMultipleRecords(rrset.Type) { zone, err := c.GetZone(ctx, zoneID) if err != nil { return fmt.Errorf("read zone before merging records: %w", err) } - before = &zone if existing, ok := findRRSet(zone, rrset.Name, rrset.Type); ok { rrset.Records = mergeRecords(existing.Records, rrset.Records) } } - return c.patchZoneWithSerialBump(ctx, zoneID, before, []changeRRSet{{ + return c.patchZone(ctx, zoneID, []changeRRSet{{ Name: rrset.Name, Type: rrset.Type, TTL: rrset.TTL, @@ -175,7 +211,7 @@ func recordKey(record Record) string { } func (c *Client) DeleteRRSet(ctx context.Context, zoneID, name, recordType string) error { - return c.patchZoneWithSerialBump(ctx, zoneID, nil, []changeRRSet{{ + return c.patchZone(ctx, zoneID, []changeRRSet{{ Name: name, Type: recordType, ChangeType: "DELETE", @@ -189,40 +225,6 @@ func (c *Client) patchZone(ctx context.Context, zoneID string, rrsets []changeRR return c.do(ctx, http.MethodPatch, c.path("/zones/"+url.PathEscape(zoneID)), body, nil) } -func (c *Client) patchZoneWithSerialBump(ctx context.Context, zoneID string, before *Zone, rrsets []changeRRSet) error { - if before == nil { - zone, err := c.GetZone(ctx, zoneID) - if err != nil { - return fmt.Errorf("read zone before change: %w", err) - } - before = &zone - } - - if err := c.patchZone(ctx, zoneID, rrsets); err != nil { - return err - } - - after, err := c.GetZone(ctx, zoneID) - if err != nil { - return fmt.Errorf("read zone after change: %w", err) - } - if serialIncreased(*before, after) { - return nil - } - - soa, err := bumpedSOA(after, before.Serial+1) - if err != nil { - return fmt.Errorf("bump SOA serial: %w", err) - } - return c.patchZone(ctx, zoneID, []changeRRSet{{ - Name: soa.Name, - Type: soa.Type, - TTL: soa.TTL, - ChangeType: "REPLACE", - Records: soa.Records, - }}) -} - func (c *Client) path(suffix string) string { return "/api/v1/servers/" + url.PathEscape(c.serverID) + suffix } @@ -275,53 +277,3 @@ func (c *Client) do(ctx context.Context, method, apiPath string, in any, out any return nil } - -func serialIncreased(before, after Zone) bool { - if after.Serial > before.Serial { - return true - } - if (before.EditedSerial != 0 || after.EditedSerial != 0) && after.EditedSerial > before.EditedSerial { - return true - } - return false -} - -func bumpedSOA(zone Zone, minimum uint64) (RRSet, error) { - for _, rrset := range zone.RRSets { - if !strings.EqualFold(rrset.Type, "SOA") { - continue - } - if len(rrset.Records) != 1 { - return RRSet{}, fmt.Errorf("expected exactly one SOA record, got %d", len(rrset.Records)) - } - - record := rrset.Records[0] - fields := strings.Fields(record.Content) - if len(fields) != 7 { - return RRSet{}, fmt.Errorf("SOA record must have 7 fields") - } - - current, err := strconv.ParseUint(fields[2], 10, 32) - if err != nil { - return RRSet{}, fmt.Errorf("parse SOA serial: %w", err) - } - if current == 1<<32-1 { - return RRSet{}, fmt.Errorf("SOA serial is already at maximum uint32 value") - } - - next := current + 1 - if next < minimum { - next = minimum - } - if next > 1<<32-1 { - return RRSet{}, fmt.Errorf("next SOA serial exceeds maximum uint32 value") - } - - fields[2] = strconv.FormatUint(next, 10) - record.Content = strings.Join(fields, " ") - rrset.Records[0] = record - return rrset, nil - } - - return RRSet{}, fmt.Errorf("SOA record not found") -} diff --git a/internal/pdns/client_test.go b/internal/pdns/client_test.go index 2f00334..8ca5043 100644 --- a/internal/pdns/client_test.go +++ b/internal/pdns/client_test.go @@ -34,7 +34,6 @@ func TestListZonesSendsAPIKey(t *testing.T) { func TestCreateRRSetPatchesZone(t *testing.T) { var patchCount int - var getCount int server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if !strings.HasPrefix(r.URL.Path, "/api/v1/servers/localhost/zones/example.org.") { @@ -43,12 +42,7 @@ func TestCreateRRSetPatchesZone(t *testing.T) { switch r.Method { case http.MethodGet: - getCount++ - serial := uint64(10) - if getCount == 2 { - serial = 11 - } - writeZone(t, w, serial) + writeZoneWithRRSets(t, w, 10, nil) case http.MethodPatch: patchCount++ var payload struct { @@ -80,62 +74,7 @@ func TestCreateRRSetPatchesZone(t *testing.T) { t.Fatalf("CreateRRSet returned error: %v", err) } if patchCount != 1 { - t.Fatalf("expected one patch when serial increases, got %d", patchCount) - } -} - -func TestCreateRRSetBumpsSOAWhenSerialDoesNotIncrease(t *testing.T) { - var patchCount int - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if !strings.HasPrefix(r.URL.Path, "/api/v1/servers/localhost/zones/example.org.") { - t.Fatalf("unexpected path: %s", r.URL.Path) - } - - switch r.Method { - case http.MethodGet: - writeZone(t, w, 10) - case http.MethodPatch: - patchCount++ - var payload struct { - RRSets []changeRRSet `json:"rrsets"` - } - if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { - t.Fatalf("decode request: %v", err) - } - if len(payload.RRSets) != 1 { - t.Fatalf("unexpected payload: %#v", payload) - } - if patchCount == 2 { - rrset := payload.RRSets[0] - if rrset.Type != "SOA" { - t.Fatalf("expected SOA fallback patch, got %#v", rrset) - } - if got := rrset.Records[0].Content; !strings.Contains(got, " 11 ") { - t.Fatalf("expected bumped SOA serial, got %q", got) - } - } - w.WriteHeader(http.StatusNoContent) - default: - t.Fatalf("unexpected method: %s", r.Method) - } - })) - defer server.Close() - - client := NewClient(server.URL, "secret", "localhost", server.Client()) - err := client.CreateRRSet(context.Background(), "example.org.", RRSet{ - Name: "www.example.org.", - Type: "A", - TTL: 300, - Records: []Record{{ - Content: "192.0.2.10", - }}, - }) - if err != nil { - t.Fatalf("CreateRRSet returned error: %v", err) - } - if patchCount != 2 { - t.Fatalf("expected requested patch plus SOA fallback patch, got %d", patchCount) + t.Fatalf("expected one patch, got %d", patchCount) } } @@ -281,23 +220,45 @@ func TestGetServerUsesConfiguredServerID(t *testing.T) { } func TestCreateZonePostsZone(t *testing.T) { + var posted bool + var put bool + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost { + switch r.Method { + case http.MethodPost: + posted = true + if r.URL.Path != "/api/v1/servers/localhost/zones" { + t.Fatalf("unexpected path: %s", r.URL.Path) + } + + var payload Zone + if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { + t.Fatalf("decode request: %v", err) + } + if payload.Name != "example.org." || payload.Kind != "Native" { + t.Fatalf("unexpected payload: %#v", payload) + } + if payload.SOAEditAPI != soaEditAPIIncrease { + t.Fatalf("unexpected soa_edit_api: %q", payload.SOAEditAPI) + } + w.WriteHeader(http.StatusCreated) + _ = json.NewEncoder(w).Encode(payload) + case http.MethodPut: + put = true + if r.URL.Path != "/api/v1/servers/localhost/zones/example.org." { + t.Fatalf("unexpected path: %s", r.URL.Path) + } + var payload Zone + if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { + t.Fatalf("decode request: %v", err) + } + if payload.SOAEditAPI != soaEditAPIIncrease { + t.Fatalf("unexpected soa_edit_api: %q", payload.SOAEditAPI) + } + w.WriteHeader(http.StatusNoContent) + default: t.Fatalf("unexpected method: %s", r.Method) } - if r.URL.Path != "/api/v1/servers/localhost/zones" { - t.Fatalf("unexpected path: %s", r.URL.Path) - } - - var payload Zone - if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { - t.Fatalf("decode request: %v", err) - } - if payload.Name != "example.org." || payload.Kind != "Native" { - t.Fatalf("unexpected payload: %#v", payload) - } - w.WriteHeader(http.StatusCreated) - _ = json.NewEncoder(w).Encode(payload) })) defer server.Close() @@ -309,6 +270,78 @@ func TestCreateZonePostsZone(t *testing.T) { if created.Name != "example.org." { t.Fatalf("unexpected zone: %#v", created) } + if !posted || !put { + t.Fatalf("expected POST and follow-up PUT, posted=%v put=%v", posted, put) + } +} + +func TestEnsureAllZonesSOAEditAPIUpdatesAllZones(t *testing.T) { + var updated []string + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + if r.URL.Path != "/api/v1/servers/localhost/zones" { + t.Fatalf("unexpected path: %s", r.URL.Path) + } + _ = json.NewEncoder(w).Encode([]Zone{ + {ID: "already.example.org.", Name: "already.example.org.", SOAEditAPI: soaEditAPIIncrease}, + {ID: "missing.example.org.", Name: "missing.example.org."}, + {ID: "", Name: "fallback.example.org.", SOAEditAPI: "DEFAULT"}, + }) + case http.MethodPut: + var payload Zone + if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { + t.Fatalf("decode request: %v", err) + } + if payload.SOAEditAPI != soaEditAPIIncrease { + t.Fatalf("unexpected soa_edit_api: %q", payload.SOAEditAPI) + } + updated = append(updated, strings.TrimPrefix(r.URL.Path, "/api/v1/servers/localhost/zones/")) + w.WriteHeader(http.StatusNoContent) + default: + t.Fatalf("unexpected method: %s", r.Method) + } + })) + defer server.Close() + + client := NewClient(server.URL, "secret", "localhost", server.Client()) + if err := client.EnsureAllZonesSOAEditAPI(context.Background()); err != nil { + t.Fatalf("EnsureAllZonesSOAEditAPI returned error: %v", err) + } + + if len(updated) != 3 { + t.Fatalf("expected three updates, got %#v", updated) + } + if updated[0] != "already.example.org." || updated[1] != "missing.example.org." || updated[2] != "fallback.example.org." { + t.Fatalf("unexpected updated zones: %#v", updated) + } +} + +func TestSetZoneSOAEditAPIUsesPUT(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPut { + t.Fatalf("unexpected method: %s", r.Method) + } + if r.URL.Path != "/api/v1/servers/localhost/zones/example.org." { + t.Fatalf("unexpected path: %s", r.URL.Path) + } + + var payload Zone + if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { + t.Fatalf("decode request: %v", err) + } + if payload.SOAEditAPI != soaEditAPIIncrease { + t.Fatalf("unexpected soa_edit_api: %q", payload.SOAEditAPI) + } + w.WriteHeader(http.StatusNoContent) + })) + defer server.Close() + + client := NewClient(server.URL, "secret", "localhost", server.Client()) + if err := client.SetZoneSOAEditAPI(context.Background(), "example.org."); err != nil { + t.Fatalf("SetZoneSOAEditAPI returned error: %v", err) + } } func TestDeleteZoneDeletesZone(t *testing.T) { @@ -329,12 +362,6 @@ func TestDeleteZoneDeletesZone(t *testing.T) { } } -func writeZone(t *testing.T, w http.ResponseWriter, serial uint64) { - t.Helper() - - writeZoneWithRRSets(t, w, serial, nil) -} - func writeZoneWithRRSets(t *testing.T, w http.ResponseWriter, serial uint64, rrsets []RRSet) { t.Helper() diff --git a/internal/server/server.go b/internal/server/server.go index ba93f96..f0c99ca 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -3,6 +3,7 @@ package server import ( "context" "crypto/rand" + "crypto/subtle" "embed" "encoding/base64" "fmt" @@ -23,7 +24,8 @@ import ( var assets embed.FS const ( - sessionCookieName = "pdns_admin_session" + csrfFieldName = "csrf_token" + sessionCookieName = "__Host-pdns_admin_session" sessionTTL = 12 * time.Hour ) @@ -62,6 +64,7 @@ type pageData struct { Error string AuthEnabled bool CurrentUser string + CSRFToken string Next string Server pdns.Server ZoneID string @@ -72,21 +75,19 @@ type pageData struct { } type session struct { - Username string - Expires time.Time + Username string + CSRFToken string + Expires time.Time } type recordForm struct { - Name string - Type string - TTL uint32 - Records string - OriginalName string - OriginalType string - IsEdit bool - IsSOA bool - Title string - SubmitLabel string + Name string + Type string + TTL uint32 + Records string + IsEdit bool + Title string + SubmitLabel string } func New(cfg Config, client PDNSClient, logger *log.Logger) (*Server, error) { @@ -106,9 +107,8 @@ func New(cfg Config, client PDNSClient, logger *log.Logger) (*Server, error) { templates := make(map[string]*template.Template) funcs := template.FuncMap{ - "isSOA": isSOA, - "recordValues": recordValues, - "urlQuery": url.QueryEscape, + "isSOA": isSOA, + "urlQuery": url.QueryEscape, } for _, page := range []string{"dashboard.html", "login.html", "zones.html", "zone.html", "record_form.html"} { tmpl, err := template.New("base.html").Funcs(funcs).ParseFS(assets, "templates/base.html", "templates/"+page) @@ -139,7 +139,7 @@ func (s *Server) routes() http.Handler { mux.HandleFunc("GET /healthz", s.healthz) mux.HandleFunc("GET /login", s.login) mux.HandleFunc("POST /login", s.loginPost) - mux.HandleFunc("GET /logout", s.logout) + mux.HandleFunc("POST /logout", s.logout) mux.HandleFunc("GET /", s.dashboard) mux.HandleFunc("GET /zones", s.listZones) mux.HandleFunc("POST /zones", s.createZone) @@ -150,7 +150,7 @@ func (s *Server) routes() http.Handler { mux.HandleFunc("POST /zones/{zoneID}/rrsets", s.saveRRSet) mux.HandleFunc("POST /zones/{zoneID}/rrsets/edit", s.saveEditedRRSet) mux.HandleFunc("POST /zones/{zoneID}/rrsets/delete", s.deleteRRSet) - return s.withLogging(s.withSessionAuth(mux)) + return s.withLogging(s.withSecurityHeaders(s.withSessionAuth(mux))) } func (s *Server) healthz(w http.ResponseWriter, _ *http.Request) { @@ -226,8 +226,10 @@ func (s *Server) loginPost(w http.ResponseWriter, r *http.Request) { Value: token, Path: "/", Expires: time.Now().Add(sessionTTL), + MaxAge: int(sessionTTL.Seconds()), HttpOnly: true, - SameSite: http.SameSiteLaxMode, + Secure: true, + SameSite: http.SameSiteStrictMode, }) http.Redirect(w, r, safeRedirectPath(r.FormValue("next")), http.StatusSeeOther) @@ -244,7 +246,8 @@ func (s *Server) logout(w http.ResponseWriter, r *http.Request) { Expires: time.Unix(0, 0), MaxAge: -1, HttpOnly: true, - SameSite: http.SameSiteLaxMode, + Secure: true, + SameSite: http.SameSiteStrictMode, }) http.Redirect(w, r, "/login", http.StatusSeeOther) } @@ -363,7 +366,6 @@ func (s *Server) editRRSet(w http.ResponseWriter, r *http.Request) { TTL: rrset.TTL, Records: recordValues(rrset), IsEdit: true, - IsSOA: isSOA(rrset.Type), Title: "Edit record", SubmitLabel: "Save record", } @@ -445,7 +447,7 @@ func (s *Server) deleteRRSet(w http.ResponseWriter, r *http.Request) { return } - name := ensureTrailingDot(r.FormValue("name")) + name := dnsrecord.EnsureTrailingDot(r.FormValue("name")) recordType := strings.ToUpper(strings.TrimSpace(r.FormValue("type"))) if name == "." || recordType == "" { s.redirectZoneError(w, r, zoneID, "record name and type are required") @@ -466,8 +468,9 @@ func (s *Server) deleteRRSet(w http.ResponseWriter, r *http.Request) { func (s *Server) render(w http.ResponseWriter, r *http.Request, name string, data pageData) { data.AuthEnabled = s.auth != nil - if user, ok := s.currentUser(r); ok { - data.CurrentUser = user + if sess, ok := s.currentSession(r); ok { + data.CurrentUser = sess.Username + data.CSRFToken = sess.CSRFToken } w.Header().Set("Content-Type", "text/html; charset=utf-8") tmpl, ok := s.templates[name] @@ -498,46 +501,64 @@ func (s *Server) withSessionAuth(next http.Handler) http.Handler { next.ServeHTTP(w, r) return } - if _, ok := s.currentUser(r); !ok { + sess, ok := s.currentSession(r) + if !ok { http.Redirect(w, r, "/login?next="+url.QueryEscape(r.URL.RequestURI()), http.StatusSeeOther) return } + if isUnsafeMethod(r.Method) && !validCSRFToken(r, sess.CSRFToken) { + http.Error(w, "invalid CSRF token", http.StatusForbidden) + return + } next.ServeHTTP(w, r) }) } func (s *Server) currentUser(r *http.Request) (string, bool) { - cookie, err := r.Cookie(sessionCookieName) - if err != nil || cookie.Value == "" { + sess, ok := s.currentSession(r) + if !ok { return "", false } + return sess.Username, true +} + +func (s *Server) currentSession(r *http.Request) (session, bool) { + cookie, err := r.Cookie(sessionCookieName) + if err != nil || cookie.Value == "" || len(cookie.Value) > 128 { + return session{}, false + } s.sessionsM.Lock() defer s.sessionsM.Unlock() sess, ok := s.sessions[cookie.Value] if !ok { - return "", false + return session{}, false } if time.Now().After(sess.Expires) { delete(s.sessions, cookie.Value) - return "", false + return session{}, false } - return sess.Username, true + return sess, true } func (s *Server) createSession(username string) (string, error) { - tokenBytes := make([]byte, 32) - if _, err := rand.Read(tokenBytes); err != nil { + token, err := randomToken() + if err != nil { + return "", err + } + csrfToken, err := randomToken() + if err != nil { return "", err } - token := base64.RawURLEncoding.EncodeToString(tokenBytes) s.sessionsM.Lock() defer s.sessionsM.Unlock() + s.pruneExpiredSessionsLocked(time.Now()) s.sessions[token] = session{ - Username: username, - Expires: time.Now().Add(sessionTTL), + Username: username, + CSRFToken: csrfToken, + Expires: time.Now().Add(sessionTTL), } return token, nil } @@ -548,8 +569,16 @@ func (s *Server) deleteSession(token string) { delete(s.sessions, token) } +func (s *Server) pruneExpiredSessionsLocked(now time.Time) { + for token, sess := range s.sessions { + if now.After(sess.Expires) { + delete(s.sessions, token) + } + } +} + func isPublicPath(path string) bool { - return path == "/login" || path == "/logout" || path == "/healthz" || strings.HasPrefix(path, "/static/") + return path == "/login" || path == "/healthz" || strings.HasPrefix(path, "/static/") } func safeRedirectPath(value string) string { @@ -563,6 +592,52 @@ func safeRedirectPath(value string) string { return parsed.RequestURI() } +func isUnsafeMethod(method string) bool { + switch method { + case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace: + return false + default: + return true + } +} + +func validCSRFToken(r *http.Request, expected string) bool { + if expected == "" { + return false + } + if err := r.ParseForm(); err != nil { + return false + } + token := r.FormValue(csrfFieldName) + if token == "" { + token = r.Header.Get("X-CSRF-Token") + } + return subtle.ConstantTimeCompare([]byte(token), []byte(expected)) == 1 +} + +func randomToken() (string, error) { + tokenBytes := make([]byte, 32) + if _, err := rand.Read(tokenBytes); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(tokenBytes), nil +} + +func (s *Server) withSecurityHeaders(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Security-Policy", "default-src 'self'; script-src 'self'; style-src 'self'; img-src 'self' data:; font-src 'self' data:; base-uri 'self'; form-action 'self'; frame-ancestors 'none'") + w.Header().Set("Permissions-Policy", "camera=(), microphone=(), geolocation=()") + w.Header().Set("Referrer-Policy", "same-origin") + w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains") + w.Header().Set("X-Content-Type-Options", "nosniff") + w.Header().Set("X-Frame-Options", "DENY") + if !strings.HasPrefix(r.URL.Path, "/static/") { + w.Header().Set("Cache-Control", "no-store") + } + next.ServeHTTP(w, r) + }) +} + func (s *Server) withLogging(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() @@ -584,10 +659,6 @@ func parseLines(raw string) []string { return values } -func ensureTrailingDot(value string) string { - return dnsrecord.EnsureTrailingDot(value) -} - func parseFQDNLines(raw, label string) ([]string, error) { values := parseLines(raw) for i, value := range values { diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 28eb5fd..cffcdad 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -154,6 +154,15 @@ func TestLoginCreatesSession(t *testing.T) { if len(cookies) == 0 || cookies[0].Name != sessionCookieName { t.Fatalf("expected session cookie, got %#v", cookies) } + if !cookies[0].HttpOnly { + t.Fatal("session cookie must be HttpOnly") + } + if !cookies[0].Secure { + t.Fatal("session cookie must be Secure") + } + if cookies[0].SameSite != http.SameSiteStrictMode { + t.Fatalf("unexpected SameSite policy: %v", cookies[0].SameSite) + } req = httptest.NewRequest(http.MethodGet, "/zones", nil) req.AddCookie(cookies[0]) @@ -196,9 +205,12 @@ func TestLogoutClearsSession(t *testing.T) { if err != nil { t.Fatalf("createSession returned error: %v", err) } + csrfToken := srv.sessions[token].CSRFToken - req := httptest.NewRequest(http.MethodGet, "/logout", nil) + body := strings.NewReader("csrf_token=" + csrfToken) + req := httptest.NewRequest(http.MethodPost, "/logout", body) req.AddCookie(&http.Cookie{Name: sessionCookieName, Value: token}) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rec := httptest.NewRecorder() srv.routes().ServeHTTP(rec, req) @@ -211,6 +223,51 @@ func TestLogoutClearsSession(t *testing.T) { } } +func TestLogoutRequiresCSRF(t *testing.T) { + srv, err := New(Config{Authenticator: &fakeAuth{allowed: true}}, &fakeClient{}, nil) + if err != nil { + t.Fatalf("New returned error: %v", err) + } + token, err := srv.createSession("alice") + if err != nil { + t.Fatalf("createSession returned error: %v", err) + } + + req := httptest.NewRequest(http.MethodPost, "/logout", nil) + req.AddCookie(&http.Cookie{Name: sessionCookieName, Value: token}) + rec := httptest.NewRecorder() + + srv.routes().ServeHTTP(rec, req) + + if rec.Code != http.StatusForbidden { + t.Fatalf("unexpected status: %d", rec.Code) + } +} + +func TestSecurityHeadersAreSet(t *testing.T) { + srv, err := New(Config{}, &fakeClient{}, nil) + if err != nil { + t.Fatalf("New returned error: %v", err) + } + + req := httptest.NewRequest(http.MethodGet, "/healthz", nil) + rec := httptest.NewRecorder() + + srv.routes().ServeHTTP(rec, req) + + for _, header := range []string{ + "Content-Security-Policy", + "Referrer-Policy", + "Strict-Transport-Security", + "X-Content-Type-Options", + "X-Frame-Options", + } { + if rec.Header().Get(header) == "" { + t.Fatalf("expected %s header", header) + } + } +} + func TestParseLinesSkipsBlankLines(t *testing.T) { values := parseLines("192.0.2.1\n\n192.0.2.2\n") if len(values) != 2 { diff --git a/internal/server/templates/base.html b/internal/server/templates/base.html index d792dc5..37f95e8 100644 --- a/internal/server/templates/base.html +++ b/internal/server/templates/base.html @@ -31,9 +31,12 @@ {{ if .CurrentUser }} {{ end }} {{ end }} diff --git a/internal/server/templates/record_form.html b/internal/server/templates/record_form.html index 68caefb..3361f56 100644 --- a/internal/server/templates/record_form.html +++ b/internal/server/templates/record_form.html @@ -29,6 +29,7 @@ {{ end }}
+ {{ if not .RecordForm.IsEdit }}
diff --git a/internal/server/templates/zone.html b/internal/server/templates/zone.html index 047e59b..c20734c 100644 --- a/internal/server/templates/zone.html +++ b/internal/server/templates/zone.html @@ -49,6 +49,7 @@ Edit {{ if not (isSOA .Type) }} + diff --git a/internal/server/templates/zones.html b/internal/server/templates/zones.html index c8c09cd..a435812 100644 --- a/internal/server/templates/zones.html +++ b/internal/server/templates/zones.html @@ -34,6 +34,7 @@ {{ .Serial }} + @@ -55,6 +56,7 @@
+