Skip to content

Commit 6e8299e

Browse files
committed
Block ability to use '*'. replace origins string to a list to config. enable cors for specific endpoints
1 parent 68ce697 commit 6e8299e

File tree

3 files changed

+70
-9
lines changed

3 files changed

+70
-9
lines changed

config-example.yaml

+8-5
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,15 @@ grpc_listen_addr: 127.0.0.1:50443
4040
# are doing.
4141
grpc_allow_insecure: false
4242

43-
# The Access-Control-Allow-Origin header specifies which origins are allowed to access resources.
43+
# The allow_origins list will allow you to set the Access-Control-Allow-Origin header to the origin in the list.
44+
# This will allow you to enable cors and set headscale without a reverse proxy.
45+
# Multiple origins can be set in the allow_origins list.
4446
# Options:
45-
# - "*" to allow access from any origin (not recommended for sensitive data).
46-
# - "http://example.com" to only allow access from a specific origin.
47-
# - "" to disable Cross-Origin Resource Sharing (CORS).
48-
access_control_allow_origin: ""
47+
# - "*" is disabled (due to security risks).
48+
# - "https://example.com" to only allow access from a specific origin.
49+
# - "https://example.com:1234" to allow access from a specific origin with a port.
50+
cors:
51+
allow_origins: []
4952

5053
# The Noise section includes specific configuration for the
5154
# TS2021 Noise protocol

hscontrol/app.go

+48-2
Original file line numberDiff line numberDiff line change
@@ -455,18 +455,64 @@ func (h *Headscale) ensureUnixSocketIsAbsent() error {
455455
return os.Remove(h.cfg.UnixSocket)
456456
}
457457

458+
// corsHeaderMiddleware will add an "Access-Control-Allow-Origin" to enable CORS
458459
func (h *Headscale) corsHeadersMiddleware(next http.Handler) http.Handler {
459460
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
460-
w.Header().Set("Access-Control-Allow-Origin", h.cfg.AccessControlAllowOrigins)
461+
462+
// skip disabled CORS endpoints
463+
if !h.enabledCorsRoutes(r.URL.Path) {
464+
next.ServeHTTP(w, r)
465+
return
466+
}
467+
468+
origin := r.Header.Get("Origin")
469+
// we compare origin from the allowed Origins list. Then add the header with origin
470+
for _, allowedOrigin := range h.cfg.AllowedOrigins.Origins {
471+
if allowedOrigin == origin {
472+
w.Header().Set("Vary", "Origin")
473+
w.Header().Set("Access-Control-Allow-Origin", allowedOrigin)
474+
break
475+
}
476+
}
461477
next.ServeHTTP(w, r)
462478
})
463479
}
464480

481+
func (h *Headscale) enabledCorsRoutes(routerPath string) bool {
482+
483+
// enable all api endpoints
484+
if strings.HasPrefix(routerPath, "/api/") {
485+
return true
486+
}
487+
488+
// A list of enabled CORS endpoints
489+
var enabledRoutes = []string{
490+
"/health",
491+
"/key",
492+
"/register/{registration_id}",
493+
"/oidc/callback",
494+
"/verify",
495+
"/derp",
496+
"/derp/probe",
497+
"/derp/latency-check",
498+
"/bootstrap-dns",
499+
"/machine/register",
500+
"/machine/map",
501+
}
502+
503+
for _, routes := range enabledRoutes {
504+
if routes == routerPath {
505+
return true
506+
}
507+
}
508+
return false
509+
}
510+
465511
func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
466512
router := mux.NewRouter()
467513
router.Use(prometheusMiddleware)
468514

469-
if h.cfg.AccessControlAllowOrigins != "" {
515+
if len(h.cfg.AllowedOrigins.Origins) != 0 {
470516
router.Use(h.corsHeadersMiddleware)
471517
}
472518

hscontrol/types/config.go

+14-2
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ type Config struct {
6666
Log LogConfig
6767
DisableUpdateCheck bool
6868

69-
AccessControlAllowOrigins string
69+
AllowedOrigins CorsConfig
7070

7171
Database DatabaseConfig
7272

@@ -210,6 +210,10 @@ type LogTailConfig struct {
210210
Enabled bool
211211
}
212212

213+
type CorsConfig struct {
214+
Origins []string;
215+
}
216+
213217
type CLIConfig struct {
214218
Address string
215219
APIKey string
@@ -534,6 +538,14 @@ func logtailConfig() LogTailConfig {
534538
}
535539
}
536540

541+
func corsConfig() CorsConfig {
542+
allowedOrigins := viper.GetStringSlice("cors.allowed_origins")
543+
544+
return CorsConfig{
545+
Origins: allowedOrigins,
546+
}
547+
}
548+
537549
func policyConfig() PolicyConfig {
538550
policyPath := viper.GetString("policy.path")
539551
policyMode := viper.GetString("policy.mode")
@@ -907,7 +919,7 @@ func LoadServerConfig() (*Config, error) {
907919
GRPCAllowInsecure: viper.GetBool("grpc_allow_insecure"),
908920
DisableUpdateCheck: false,
909921

910-
AccessControlAllowOrigins: viper.GetString("access_control_allow_origin"),
922+
AllowedOrigins: corsConfig(),
911923

912924
PrefixV4: prefix4,
913925
PrefixV6: prefix6,

0 commit comments

Comments
 (0)