Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions cert/helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ import (
"golang.org/x/crypto/ed25519"
)

// testCertNow is the reference "now" used to derive default before/after times
// in NewTestCaCert and NewTestCert. Holding it fixed for the lifetime of the
// test binary keeps CA and leaf defaults aligned at the same second, so a leaf
// signed with default times can never expire after its CA on a rounding race.
var testCertNow = time.Now().Round(time.Second)

// NewTestCaCert will create a new ca certificate
func NewTestCaCert(version Version, curve Curve, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (Certificate, []byte, []byte, []byte) {
var err error
Expand All @@ -34,10 +40,10 @@ func NewTestCaCert(version Version, curve Curve, before, after time.Time, networ
}

if before.IsZero() {
before = time.Now().Add(time.Second * -60).Round(time.Second)
before = testCertNow.Add(time.Second * -60)
}
if after.IsZero() {
after = time.Now().Add(time.Second * 60).Round(time.Second)
after = testCertNow.Add(time.Second * 60)
}

t := &TBSCertificate{
Expand Down Expand Up @@ -70,11 +76,11 @@ func NewTestCaCert(version Version, curve Curve, before, after time.Time, networ
// Expiry times are defaulted if you do not pass them in
func NewTestCert(v Version, curve Curve, ca Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (Certificate, []byte, []byte, []byte) {
if before.IsZero() {
before = time.Now().Add(time.Second * -60).Round(time.Second)
before = testCertNow.Add(time.Second * -60)
}

if after.IsZero() {
after = time.Now().Add(time.Second * 60).Round(time.Second)
after = testCertNow.Add(time.Second * 60)
}

if len(networks) == 0 {
Expand Down
14 changes: 10 additions & 4 deletions cert_test/cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ import (
"golang.org/x/crypto/ed25519"
)

// testCertNow is the reference "now" used to derive default before/after times
// in NewTestCaCert and NewTestCert. Holding it fixed for the lifetime of the
// test binary keeps CA and leaf defaults aligned at the same second, so a leaf
// signed with default times can never expire after its CA on a rounding race.
var testCertNow = time.Now().Round(time.Second)

// NewTestCaCert will create a new ca certificate
func NewTestCaCert(version cert.Version, curve cert.Curve, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) {
var err error
Expand All @@ -35,10 +41,10 @@ func NewTestCaCert(version cert.Version, curve cert.Curve, before, after time.Ti
}

if before.IsZero() {
before = time.Now().Add(time.Second * -60).Round(time.Second)
before = testCertNow.Add(time.Second * -60)
}
if after.IsZero() {
after = time.Now().Add(time.Second * 60).Round(time.Second)
after = testCertNow.Add(time.Second * 60)
}

t := &cert.TBSCertificate{
Expand Down Expand Up @@ -71,11 +77,11 @@ func NewTestCaCert(version cert.Version, curve cert.Curve, before, after time.Ti
// Expiry times are defaulted if you do not pass them in
func NewTestCert(v cert.Version, curve cert.Curve, ca cert.Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) {
if before.IsZero() {
before = time.Now().Add(time.Second * -60).Round(time.Second)
before = testCertNow.Add(time.Second * -60)
}

if after.IsZero() {
after = time.Now().Add(time.Second * 60).Round(time.Second)
after = testCertNow.Add(time.Second * 60)
}

var pub, priv []byte
Expand Down
69 changes: 43 additions & 26 deletions firewall.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,9 @@ type Firewall struct {
routableNetworks *bart.Lite

// assignedNetworks is a list of vpn networks assigned to us in the certificate.
assignedNetworks []netip.Prefix
hasUnsafeNetworks bool
assignedNetworks []netip.Prefix
// unsafeNetworks is the list of unsafe networks issued to us in the certificate
unsafeNetworks []netip.Prefix

rules string
rulesVersion uint16
Expand Down Expand Up @@ -158,26 +159,25 @@ func NewFirewall(l *slog.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Dur
assignedNetworks = append(assignedNetworks, network)
}

hasUnsafeNetworks := false
for _, n := range c.UnsafeNetworks() {
unsafeNetworks := c.UnsafeNetworks()
for _, n := range unsafeNetworks {
routableNetworks.Insert(n)
hasUnsafeNetworks = true
}

return &Firewall{
Conntrack: &FirewallConntrack{
Conns: make(map[firewall.Packet]*conn),
TimerWheel: NewTimerWheel[firewall.Packet](tmin, tmax),
},
InRules: newFirewallTable(),
OutRules: newFirewallTable(),
TCPTimeout: tcpTimeout,
UDPTimeout: UDPTimeout,
DefaultTimeout: defaultTimeout,
routableNetworks: routableNetworks,
assignedNetworks: assignedNetworks,
hasUnsafeNetworks: hasUnsafeNetworks,
l: l,
InRules: newFirewallTable(),
OutRules: newFirewallTable(),
TCPTimeout: tcpTimeout,
UDPTimeout: UDPTimeout,
DefaultTimeout: defaultTimeout,
routableNetworks: routableNetworks,
assignedNetworks: assignedNetworks,
unsafeNetworks: unsafeNetworks,
l: l,

incomingMetrics: firewallMetrics{
droppedLocalAddr: metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_addr", nil),
Expand Down Expand Up @@ -897,7 +897,7 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localCidr string) error {
}

if localCidr == "" {
if !f.hasUnsafeNetworks || f.defaultLocalCIDRAny {
if len(f.unsafeNetworks) == 0 || f.defaultLocalCIDRAny {
flc.Any = true
return nil
}
Expand Down Expand Up @@ -1055,7 +1055,6 @@ func (r *rule) sanity() error {
}

func parsePort(s string) (int32, int32, error) {
var err error
const notAPort int32 = -2
if s == "any" {
return firewall.PortAny, firewall.PortAny, nil
Expand All @@ -1064,11 +1063,11 @@ func parsePort(s string) (int32, int32, error) {
return firewall.PortFragment, firewall.PortFragment, nil
}
if !strings.Contains(s, `-`) {
rPort, err := strconv.Atoi(s)
rPort, err := parsePortValue("", s)
if err != nil {
return notAPort, notAPort, fmt.Errorf("was not a number; `%s`", s)
return notAPort, notAPort, err
}
return int32(rPort), int32(rPort), nil
return rPort, rPort, nil
}

sPorts := strings.SplitN(s, `-`, 2)
Expand All @@ -1079,22 +1078,40 @@ func parsePort(s string) (int32, int32, error) {
return notAPort, notAPort, fmt.Errorf("appears to be a range but could not be parsed; `%s`", s)
}

rStartPort, err := strconv.Atoi(sPorts[0])
startPort, err := parsePortValue("beginning range ", sPorts[0])
if err != nil {
return notAPort, notAPort, fmt.Errorf("beginning range was not a number; `%s`", sPorts[0])
return notAPort, notAPort, err
}

rEndPort, err := strconv.Atoi(sPorts[1])
endPort, err := parsePortValue("ending range ", sPorts[1])
if err != nil {
return notAPort, notAPort, fmt.Errorf("ending range was not a number; `%s`", sPorts[1])
return notAPort, notAPort, err
}

startPort := int32(rStartPort)
endPort := int32(rEndPort)

if startPort == firewall.PortAny {
endPort = firewall.PortAny
}

return startPort, endPort, nil
}

// parsePortValue accepts a base-10 decimal in [0, 65535] and returns it
// widened to int32. Using strconv.ParseUint with bitSize 16 rejects
// negative input, out-of-range input (>65535), and any non-decimal byte
// by construction, so the int32 widening that follows is provably safe
// and cannot collide with firewall.PortAny (0) or firewall.PortFragment
// (-1) via integer truncation.
//
// prefix is prepended to both error messages so callers can disambiguate
// the single-port path (prefix="") from the range bounds (prefix="beginning
// range " / "ending range "), preserving the historical error strings.
func parsePortValue(prefix, s string) (int32, error) {
n, err := strconv.ParseUint(s, 10, 16)
if err == nil {
return int32(n), nil
}
if errors.Is(err, strconv.ErrRange) {
return 0, fmt.Errorf("%sout of range [0,65535]; `%s`", prefix, s)
}
return 0, fmt.Errorf("%swas not a number; `%s`", prefix, s)
}
69 changes: 69 additions & 0 deletions firewall_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1029,6 +1029,75 @@ func Test_parsePort(t *testing.T) {
require.NoError(t, err)
}

// Test_parsePort_invalid covers inputs that must error. The named bug is
// that int32(strconv.Atoi("4294967296")) truncates to 0 == firewall.PortAny,
// silently turning a typo into a match-all-ports rule; the rest are
// representative syntax/range probes.
func Test_parsePort_invalid(t *testing.T) {
tests := []struct {
name string
input string
wantErrContains string
}{
// Numeric overflow (the named bug + boundary).
{"named bug: 2^32 truncates to PortAny", "4294967296", "out of range"},
{"just above max real port", "65536", "out of range"},

// Negatives route through the range branch and hit the empty-half
// guard; included as defense in depth so a future refactor cannot
// accidentally reach the int32 cast.
{"negative", "-1", "could not be parsed"},

// Syntax probes.
{"NUL between digits", "4\x002", "was not a number"},
{"hex notation", "0x10", "was not a number"},
{"scientific notation", "1e3", "was not a number"},
{"leading whitespace", " 42", "was not a number"},
{"fullwidth digits", "42", "was not a number"},

// Range branch.
{"range upper out of range", "1-65536", "ending range out of range"},
{"range lower out of range", "65536-65537", "beginning range out of range"},
{"range with negative upper", "1--1", "ending range was not a number"},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
_, _, err := parsePort(tc.input)
require.Error(t, err, "input %q must error", tc.input)
require.ErrorContains(t, err, tc.wantErrContains)
})
}
}

// Test_parsePort_valid_boundaries locks in success cases at 0, 1, and 65535
// so a future refactor cannot regress the boundaries.
func Test_parsePort_valid_boundaries(t *testing.T) {
tests := []struct {
name string
input string
wantStart int32
wantEnd int32
}{
{"zero is PortAny", "0", 0, 0},
{"min real port", "1", 1, 1},
{"max real port", "65535", 65535, 65535},
{"range zero to max forces end to zero", "0-65535", 0, 0},
{"range max to max", "65535-65535", 65535, 65535},
{"range one to max", "1-65535", 1, 65535},
{"range with whitespace inside", " 1 - 2 ", 1, 2},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
s, e, err := parsePort(tc.input)
require.NoError(t, err)
assert.Equal(t, tc.wantStart, s, "start port")
assert.Equal(t, tc.wantEnd, e, "end port")
})
}
}

func TestNewFirewallFromConfig(t *testing.T) {
l := test.NewLogger()
// Test a bad rule definition
Expand Down
17 changes: 14 additions & 3 deletions interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@ import (
"io"
"log/slog"
"net/netip"
"slices"
"sync"
"sync/atomic"
"time"

"github.com/gaissmai/bart"
"github.com/rcrowley/go-metrics"

"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
Expand Down Expand Up @@ -375,13 +377,22 @@ func (f *Interface) reloadDisconnectInvalid(c *config.C) {
}

func (f *Interface) reloadFirewall(c *config.C) {
//TODO: need to trigger/detect if the certificate changed too
if c.HasChanged("firewall") == false {
cs := f.pki.getCertState()
curCert := cs.getCertificate(cert.Version2)
if curCert == nil {
curCert = cs.getCertificate(cert.Version1)
}

// The firewall builds its routableNetworks set from the certificate's UnsafeNetworks at construction.
// Check to see if that set has changed, and if so, rebuild the firewall.
certUnsafeChanged := curCert != nil && !slices.Equal(curCert.UnsafeNetworks(), f.firewall.unsafeNetworks)

if !c.HasChanged("firewall") && !certUnsafeChanged {
f.l.Debug("No firewall config change detected")
return
}

fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c)
fw, err := NewFirewallFromConfig(f.l, cs, c)
if err != nil {
f.l.Error("Error while creating firewall during reload", "error", err)
return
Expand Down
Loading
Loading