Skip to content

Commit c63a36b

Browse files
committed
Address review
1 parent 6215aac commit c63a36b

File tree

7 files changed

+50
-99
lines changed

7 files changed

+50
-99
lines changed

client/firewall/iptables/manager_linux.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,15 +260,15 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
260260
return m.router.UpdateSet(set, prefixes)
261261
}
262262

263-
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services
263+
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
264264
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
265265
m.mutex.Lock()
266266
defer m.mutex.Unlock()
267267

268268
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
269269
}
270270

271-
// RemoveInboundDNAT removes inbound DNAT rule
271+
// RemoveInboundDNAT removes an inbound DNAT rule.
272272
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
273273
m.mutex.Lock()
274274
defer m.mutex.Unlock()

client/firewall/iptables/router_linux.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -880,7 +880,7 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
880880
return nberrors.FormatErrorOrNil(merr)
881881
}
882882

883-
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services
883+
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
884884
func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
885885
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
886886

@@ -913,7 +913,7 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol
913913
return nil
914914
}
915915

916-
// RemoveInboundDNAT removes inbound DNAT rule
916+
// RemoveInboundDNAT removes an inbound DNAT rule.
917917
func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
918918
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
919919

client/firewall/nftables/manager_linux.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -376,15 +376,15 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
376376
return m.router.UpdateSet(set, prefixes)
377377
}
378378

379-
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services
379+
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
380380
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
381381
m.mutex.Lock()
382382
defer m.mutex.Unlock()
383383

384384
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
385385
}
386386

387-
// RemoveInboundDNAT removes inbound DNAT rule
387+
// RemoveInboundDNAT removes an inbound DNAT rule.
388388
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
389389
m.mutex.Lock()
390390
defer m.mutex.Unlock()

client/firewall/nftables/router_linux.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1350,7 +1350,7 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
13501350
return nil
13511351
}
13521352

1353-
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services
1353+
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
13541354
func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
13551355
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
13561356

@@ -1426,7 +1426,7 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol
14261426
return nil
14271427
}
14281428

1429-
// RemoveInboundDNAT removes inbound DNAT rule
1429+
// RemoveInboundDNAT removes an inbound DNAT rule.
14301430
func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
14311431
if err := r.refreshRulesMap(); err != nil {
14321432
return fmt.Errorf(refreshRulesMapError, err)

client/firewall/uspfilter/nat.go

Lines changed: 29 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ var (
2222

2323
const (
2424
errRewriteTCPDestinationPort = "rewrite TCP destination port: %v"
25+
26+
// Port offsets in TCP/UDP headers
27+
sourcePortOffset = 0
28+
destinationPortOffset = 2
2529
)
2630

2731
// ipv4Checksum calculates IPv4 header checksum.
@@ -748,8 +752,8 @@ func (m *Manager) applyPortDNATRule(packetData []byte, d *decoder, rule portDNAT
748752
return true
749753
}
750754

751-
// rewriteTCPDestinationPort rewrites the destination port in a TCP packet and updates checksum.
752-
func (m *Manager) rewriteTCPDestinationPort(packetData []byte, d *decoder, newPort uint16) error {
755+
// rewriteTCPPort rewrites a TCP port (source or destination) and updates checksum.
756+
func (m *Manager) rewriteTCPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error {
753757
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
754758
return ErrIPv4Only
755759
}
@@ -768,9 +772,9 @@ func (m *Manager) rewriteTCPDestinationPort(packetData []byte, d *decoder, newPo
768772
return fmt.Errorf("packet too short for TCP header")
769773
}
770774

771-
oldPort := binary.BigEndian.Uint16(packetData[tcpStart+2 : tcpStart+4])
772-
773-
binary.BigEndian.PutUint16(packetData[tcpStart+2:tcpStart+4], newPort)
775+
portStart := tcpStart + portOffset
776+
oldPort := binary.BigEndian.Uint16(packetData[portStart : portStart+2])
777+
binary.BigEndian.PutUint16(packetData[portStart:portStart+2], newPort)
774778

775779
if len(packetData) >= tcpStart+18 {
776780
checksumOffset := tcpStart + 16
@@ -787,45 +791,17 @@ func (m *Manager) rewriteTCPDestinationPort(packetData []byte, d *decoder, newPo
787791
return nil
788792
}
789793

794+
// rewriteTCPDestinationPort rewrites the destination port in a TCP packet and updates checksum.
795+
func (m *Manager) rewriteTCPDestinationPort(packetData []byte, d *decoder, newPort uint16) error {
796+
return m.rewriteTCPPort(packetData, d, newPort, destinationPortOffset)
797+
}
798+
790799
// rewriteTCPSourcePort rewrites the source port in a TCP packet and updates checksum.
791800
func (m *Manager) rewriteTCPSourcePort(packetData []byte, d *decoder, newPort uint16) error {
792-
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
793-
return ErrIPv4Only
794-
}
795-
796-
if len(d.decoded) < 2 || d.decoded[1] != layers.LayerTypeTCP {
797-
return fmt.Errorf("not a TCP packet")
798-
}
799-
800-
ipHeaderLen := int(d.ip4.IHL) * 4
801-
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
802-
return errInvalidIPHeaderLength
803-
}
804-
805-
tcpStart := ipHeaderLen
806-
if len(packetData) < tcpStart+4 {
807-
return fmt.Errorf("packet too short for TCP header")
808-
}
809-
810-
oldPort := binary.BigEndian.Uint16(packetData[tcpStart : tcpStart+2])
811-
812-
binary.BigEndian.PutUint16(packetData[tcpStart:tcpStart+2], newPort)
813-
814-
if len(packetData) >= tcpStart+18 {
815-
checksumOffset := tcpStart + 16
816-
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
817-
818-
var oldPortBytes, newPortBytes [2]byte
819-
binary.BigEndian.PutUint16(oldPortBytes[:], oldPort)
820-
binary.BigEndian.PutUint16(newPortBytes[:], newPort)
821-
822-
newChecksum := incrementalUpdate(oldChecksum, oldPortBytes[:], newPortBytes[:])
823-
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
824-
}
825-
826-
return nil
801+
return m.rewriteTCPPort(packetData, d, newPort, sourcePortOffset)
827802
}
828803

804+
// applyInboundUDPPortDNAT applies port DNAT to inbound UDP packets.
829805
func (m *Manager) applyInboundUDPPortDNAT(packetData []byte, d *decoder, dstIP netip.Addr, dstPort uint16) bool {
830806
m.portDNATMutex.RLock()
831807
defer m.portDNATMutex.RUnlock()
@@ -849,7 +825,8 @@ func (m *Manager) applyInboundUDPPortDNAT(packetData []byte, d *decoder, dstIP n
849825
return false
850826
}
851827

852-
func (m *Manager) rewriteUDPDestinationPort(packetData []byte, d *decoder, newPort uint16) error {
828+
// rewriteUDPPort rewrites a UDP port (source or destination) and updates checksum.
829+
func (m *Manager) rewriteUDPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error {
853830
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
854831
return ErrIPv4Only
855832
}
@@ -868,8 +845,9 @@ func (m *Manager) rewriteUDPDestinationPort(packetData []byte, d *decoder, newPo
868845
return fmt.Errorf("packet too short for UDP header")
869846
}
870847

871-
oldPort := binary.BigEndian.Uint16(packetData[udpStart+2 : udpStart+4])
872-
binary.BigEndian.PutUint16(packetData[udpStart+2:udpStart+4], newPort)
848+
portStart := udpStart + portOffset
849+
oldPort := binary.BigEndian.Uint16(packetData[portStart : portStart+2])
850+
binary.BigEndian.PutUint16(packetData[portStart:portStart+2], newPort)
873851

874852
checksumOffset := udpStart + 6
875853
if len(packetData) >= udpStart+8 {
@@ -887,6 +865,11 @@ func (m *Manager) rewriteUDPDestinationPort(packetData []byte, d *decoder, newPo
887865
return nil
888866
}
889867

868+
// rewriteUDPDestinationPort rewrites the destination port in a UDP packet and updates checksum.
869+
func (m *Manager) rewriteUDPDestinationPort(packetData []byte, d *decoder, newPort uint16) error {
870+
return m.rewriteUDPPort(packetData, d, newPort, destinationPortOffset)
871+
}
872+
890873
// translateOutboundPortReverse applies reverse port DNAT to outbound return traffic.
891874
func (m *Manager) translateOutboundPortReverse(packetData []byte, d *decoder) bool {
892875
if !m.portDNATEnabled.Load() {
@@ -932,6 +915,7 @@ func (m *Manager) translateOutboundPortReverse(packetData []byte, d *decoder) bo
932915
return m.applyOutboundUDPPortReverse(packetData, d, srcIP, srcPort)
933916
}
934917

918+
// applyOutboundUDPPortReverse applies reverse port DNAT to outbound UDP packets.
935919
func (m *Manager) applyOutboundUDPPortReverse(packetData []byte, d *decoder, srcIP netip.Addr, srcPort uint16) bool {
936920
m.portDNATMutex.RLock()
937921
defer m.portDNATMutex.RUnlock()
@@ -955,40 +939,7 @@ func (m *Manager) applyOutboundUDPPortReverse(packetData []byte, d *decoder, src
955939
return false
956940
}
957941

942+
// rewriteUDPSourcePort rewrites the source port in a UDP packet and updates checksum.
958943
func (m *Manager) rewriteUDPSourcePort(packetData []byte, d *decoder, newPort uint16) error {
959-
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
960-
return ErrIPv4Only
961-
}
962-
963-
if len(d.decoded) < 2 || d.decoded[1] != layers.LayerTypeUDP {
964-
return fmt.Errorf("not a UDP packet")
965-
}
966-
967-
ipHeaderLen := int(d.ip4.IHL) * 4
968-
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
969-
return errInvalidIPHeaderLength
970-
}
971-
972-
udpStart := ipHeaderLen
973-
if len(packetData) < udpStart+8 {
974-
return fmt.Errorf("packet too short for UDP header")
975-
}
976-
977-
oldPort := binary.BigEndian.Uint16(packetData[udpStart : udpStart+2])
978-
binary.BigEndian.PutUint16(packetData[udpStart:udpStart+2], newPort)
979-
980-
checksumOffset := udpStart + 6
981-
if len(packetData) >= udpStart+8 {
982-
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
983-
if oldChecksum != 0 {
984-
var oldPortBytes, newPortBytes [2]byte
985-
binary.BigEndian.PutUint16(oldPortBytes[:], oldPort)
986-
binary.BigEndian.PutUint16(newPortBytes[:], newPort)
987-
988-
newChecksum := incrementalUpdate(oldChecksum, oldPortBytes[:], newPortBytes[:])
989-
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
990-
}
991-
}
992-
993-
return nil
944+
return m.rewriteUDPPort(packetData, d, newPort, sourcePortOffset)
994945
}

management/server/dns_test.go

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -394,15 +394,15 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) {
394394

395395
b.ResetTimer()
396396
for i := 0; i < b.N; i++ {
397-
toProtocolDNSConfig(testData, cache, dnsForwarderPort)
397+
toProtocolDNSConfig(testData, cache, int64(dnsForwarderPort))
398398
}
399399
})
400400

401401
b.Run(fmt.Sprintf("WithoutCache-Size%d", size), func(b *testing.B) {
402402
b.ResetTimer()
403403
for i := 0; i < b.N; i++ {
404404
cache := &DNSConfigCache{}
405-
toProtocolDNSConfig(testData, cache, dnsForwarderPort)
405+
toProtocolDNSConfig(testData, cache, int64(dnsForwarderPort))
406406
}
407407
})
408408
}
@@ -455,13 +455,13 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) {
455455
}
456456

457457
// First run with config1
458-
result1 := toProtocolDNSConfig(config1, &cache, dnsForwarderPort)
458+
result1 := toProtocolDNSConfig(config1, &cache, int64(dnsForwarderPort))
459459

460460
// Second run with config2
461-
result2 := toProtocolDNSConfig(config2, &cache, dnsForwarderPort)
461+
result2 := toProtocolDNSConfig(config2, &cache, int64(dnsForwarderPort))
462462

463463
// Third run with config1 again
464-
result3 := toProtocolDNSConfig(config1, &cache, dnsForwarderPort)
464+
result3 := toProtocolDNSConfig(config1, &cache, int64(dnsForwarderPort))
465465

466466
// Verify that result1 and result3 are identical
467467
if !reflect.DeepEqual(result1, result3) {
@@ -486,7 +486,7 @@ func TestComputeForwarderPort(t *testing.T) {
486486
// Test with empty peers list
487487
peers := []*nbpeer.Peer{}
488488
result := computeForwarderPort(peers, "v0.59.0")
489-
if result != oldForwarderPort {
489+
if result != int64(oldForwarderPort) {
490490
t.Errorf("Expected %d for empty peers list, got %d", oldForwarderPort, result)
491491
}
492492

@@ -504,7 +504,7 @@ func TestComputeForwarderPort(t *testing.T) {
504504
},
505505
}
506506
result = computeForwarderPort(peers, "v0.59.0")
507-
if result != oldForwarderPort {
507+
if result != int64(oldForwarderPort) {
508508
t.Errorf("Expected %d for peers with old versions, got %d", oldForwarderPort, result)
509509
}
510510

@@ -522,7 +522,7 @@ func TestComputeForwarderPort(t *testing.T) {
522522
},
523523
}
524524
result = computeForwarderPort(peers, "v0.59.0")
525-
if result != dnsForwarderPort {
525+
if result != int64(dnsForwarderPort) {
526526
t.Errorf("Expected %d for peers with new versions, got %d", dnsForwarderPort, result)
527527
}
528528

@@ -540,7 +540,7 @@ func TestComputeForwarderPort(t *testing.T) {
540540
},
541541
}
542542
result = computeForwarderPort(peers, "v0.59.0")
543-
if result != oldForwarderPort {
543+
if result != int64(oldForwarderPort) {
544544
t.Errorf("Expected %d for peers with mixed versions, got %d", oldForwarderPort, result)
545545
}
546546

@@ -553,7 +553,7 @@ func TestComputeForwarderPort(t *testing.T) {
553553
},
554554
}
555555
result = computeForwarderPort(peers, "v0.59.0")
556-
if result != oldForwarderPort {
556+
if result != int64(oldForwarderPort) {
557557
t.Errorf("Expected %d for peers with empty version, got %d", oldForwarderPort, result)
558558
}
559559

@@ -565,7 +565,7 @@ func TestComputeForwarderPort(t *testing.T) {
565565
},
566566
}
567567
result = computeForwarderPort(peers, "v0.59.0")
568-
if result == oldForwarderPort {
568+
if result == int64(oldForwarderPort) {
569569
t.Errorf("Expected %d for peers with dev version, got %d", dnsForwarderPort, result)
570570
}
571571

@@ -578,7 +578,7 @@ func TestComputeForwarderPort(t *testing.T) {
578578
},
579579
}
580580
result = computeForwarderPort(peers, "v0.59.0")
581-
if result != oldForwarderPort {
581+
if result != int64(oldForwarderPort) {
582582
t.Errorf("Expected %d for peers with unknown version, got %d", oldForwarderPort, result)
583583
}
584584
}

management/server/peer_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1161,7 +1161,7 @@ func TestToSyncResponse(t *testing.T) {
11611161
}
11621162
dnsCache := &DNSConfigCache{}
11631163
accountSettings := &types.Settings{RoutingPeerDNSResolutionEnabled: true}
1164-
response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}, dnsForwarderPort)
1164+
response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}, int64(dnsForwarderPort))
11651165

11661166
assert.NotNil(t, response)
11671167
// assert peer config

0 commit comments

Comments
 (0)