Skip to content

Commit 8c7d71b

Browse files
committed
Allow INPUT traffic on the compat iptables filter table for nftables
1 parent a2313a5 commit 8c7d71b

File tree

6 files changed

+136
-66
lines changed

6 files changed

+136
-66
lines changed

client/firewall/manager/firewall.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,9 @@ type Manager interface {
100100
//
101101
// If comment argument is empty firewall manager should set
102102
// rule ID as comment for the rule
103+
//
104+
// Note: Callers should call Flush() after adding rules to ensure
105+
// they are applied to the kernel and rule handles are refreshed.
103106
AddPeerFiltering(
104107
id []byte,
105108
ip net.IP,

client/firewall/nftables/acl_linux.go

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -357,11 +357,12 @@ func (m *AclManager) addIOFiltering(
357357
}
358358

359359
if err := m.rConn.Flush(); err != nil {
360-
return nil, fmt.Errorf(flushError, err)
360+
return nil, fmt.Errorf("flush input rule %s: %v", ruleId, err)
361361
}
362362

363363
ruleStruct := &Rule{
364-
nftRule: nftRule,
364+
nftRule: nftRule,
365+
// best effort mangle rule
365366
mangleRule: m.createPreroutingRule(expressions, userData),
366367
nftSet: ipset,
367368
ruleID: ruleId,
@@ -420,12 +421,19 @@ func (m *AclManager) createPreroutingRule(expressions []expr.Any, userData []byt
420421
},
421422
)
422423

423-
return m.rConn.AddRule(&nftables.Rule{
424+
nfRule := m.rConn.AddRule(&nftables.Rule{
424425
Table: m.workTable,
425426
Chain: m.chainPrerouting,
426427
Exprs: preroutingExprs,
427428
UserData: userData,
428429
})
430+
431+
if err := m.rConn.Flush(); err != nil {
432+
log.Errorf("failed to flush mangle rule %s: %v", string(userData), err)
433+
return nil
434+
}
435+
436+
return nfRule
429437
}
430438

431439
func (m *AclManager) createDefaultChains() (err error) {

client/firewall/nftables/manager_linux.go

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"fmt"
77
"net"
88
"net/netip"
9+
"os"
910
"sync"
1011

1112
"github.com/google/nftables"
@@ -19,13 +20,22 @@ import (
1920
)
2021

2122
const (
22-
// tableNameNetbird is the name of the table that is used for filtering by the Netbird client
23+
// tableNameNetbird is the default name of the table that is used for filtering by the Netbird client
2324
tableNameNetbird = "netbird"
25+
// envTableName is the environment variable to override the table name
26+
envTableName = "NB_NFTABLES_TABLE"
2427

2528
tableNameFilter = "filter"
2629
chainNameInput = "INPUT"
2730
)
2831

32+
func getTableName() string {
33+
if name := os.Getenv(envTableName); name != "" {
34+
return name
35+
}
36+
return tableNameNetbird
37+
}
38+
2939
// iFaceMapper defines subset methods of interface required for manager
3040
type iFaceMapper interface {
3141
Name() string
@@ -50,7 +60,7 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
5060
wgIface: wgIface,
5161
}
5262

53-
workTable := &nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4}
63+
workTable := &nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4}
5464

5565
var err error
5666
m.router, err = newRouter(workTable, wgIface)
@@ -314,8 +324,9 @@ func (m *Manager) cleanupNetbirdTables() error {
314324
return fmt.Errorf("list tables: %w", err)
315325
}
316326

327+
tableName := getTableName()
317328
for _, t := range tables {
318-
if t.Name == tableNameNetbird {
329+
if t.Name == tableName {
319330
m.rConn.DelTable(t)
320331
}
321332
}
@@ -398,13 +409,14 @@ func (m *Manager) createWorkTable() (*nftables.Table, error) {
398409
return nil, fmt.Errorf("list of tables: %w", err)
399410
}
400411

412+
tableName := getTableName()
401413
for _, t := range tables {
402-
if t.Name == tableNameNetbird {
414+
if t.Name == tableName {
403415
m.rConn.DelTable(t)
404416
}
405417
}
406418

407-
table := m.rConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4})
419+
table := m.rConn.AddTable(&nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4})
408420
err = m.rConn.Flush()
409421
return table, err
410422
}

client/firewall/nftables/router_linux.go

Lines changed: 69 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ const (
3535

3636
userDataAcceptForwardRuleIif = "frwacceptiif"
3737
userDataAcceptForwardRuleOif = "frwacceptoif"
38+
userDataAcceptInputRule = "inputaccept"
3839

3940
dnatSuffix = "_dnat"
4041
snatSuffix = "_snat"
@@ -96,8 +97,8 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error)
9697
func (r *router) init(workTable *nftables.Table) error {
9798
r.workTable = workTable
9899

99-
if err := r.removeAcceptForwardRules(); err != nil {
100-
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
100+
if err := r.removeAcceptFilterRules(); err != nil {
101+
log.Errorf("failed to clean up rules from filter table: %s", err)
101102
}
102103

103104
if err := r.createContainers(); err != nil {
@@ -111,15 +112,15 @@ func (r *router) init(workTable *nftables.Table) error {
111112
return nil
112113
}
113114

114-
// Reset cleans existing nftables default forward rules from the system
115+
// Reset cleans existing nftables filter table rules from the system
115116
func (r *router) Reset() error {
116117
// clear without deleting the ipsets, the nf table will be deleted by the caller
117118
r.ipsetCounter.Clear()
118119

119120
var merr *multierror.Error
120121

121-
if err := r.removeAcceptForwardRules(); err != nil {
122-
merr = multierror.Append(merr, fmt.Errorf("remove accept forward rules: %w", err))
122+
if err := r.removeAcceptFilterRules(); err != nil {
123+
merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err))
123124
}
124125

125126
if err := r.removeNatPreroutingRules(); err != nil {
@@ -840,6 +841,7 @@ func (r *router) RemoveAllLegacyRouteRules() error {
840841
// that our traffic is not dropped by existing rules there.
841842
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
842843
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
844+
// This method also adds INPUT chain rules to allow traffic to the local interface.
843845
func (r *router) acceptForwardRules() error {
844846
if r.filterTable == nil {
845847
log.Debugf("table 'filter' not found for forward rules, skipping accept rules")
@@ -849,7 +851,7 @@ func (r *router) acceptForwardRules() error {
849851
fw := "iptables"
850852

851853
defer func() {
852-
log.Debugf("Used %s to add accept forward rules", fw)
854+
log.Debugf("Used %s to add accept forward and input rules", fw)
853855
}()
854856

855857
// Try iptables first and fallback to nftables if iptables is not available
@@ -859,22 +861,30 @@ func (r *router) acceptForwardRules() error {
859861
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
860862

861863
fw = "nftables"
862-
return r.acceptForwardRulesNftables()
864+
return r.acceptFilterRulesNftables()
863865
}
864866

865-
return r.acceptForwardRulesIptables(ipt)
867+
return r.acceptFilterRulesIptables(ipt)
866868
}
867869

868-
func (r *router) acceptForwardRulesIptables(ipt *iptables.IPTables) error {
870+
func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
869871
var merr *multierror.Error
872+
870873
for _, rule := range r.getAcceptForwardRules() {
871874
if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil {
872-
merr = multierror.Append(err, fmt.Errorf("add iptables rule: %v", err))
875+
merr = multierror.Append(err, fmt.Errorf("add iptables forward rule: %v", err))
873876
} else {
874-
log.Debugf("added iptables rule: %v", rule)
877+
log.Debugf("added iptables forward rule: %v", rule)
875878
}
876879
}
877880

881+
inputRule := r.getAcceptInputRule()
882+
if err := ipt.Insert("filter", chainNameInput, 1, inputRule...); err != nil {
883+
merr = multierror.Append(err, fmt.Errorf("add iptables input rule: %v", err))
884+
} else {
885+
log.Debugf("added iptables input rule: %v", inputRule)
886+
}
887+
878888
return nberrors.FormatErrorOrNil(merr)
879889
}
880890

@@ -886,10 +896,13 @@ func (r *router) getAcceptForwardRules() [][]string {
886896
}
887897
}
888898

889-
func (r *router) acceptForwardRulesNftables() error {
899+
func (r *router) getAcceptInputRule() []string {
900+
return []string{"-i", r.wgIface.Name(), "-j", "ACCEPT"}
901+
}
902+
903+
func (r *router) acceptFilterRulesNftables() error {
890904
intf := ifname(r.wgIface.Name())
891905

892-
// Rule for incoming interface (iif) with counter
893906
iifRule := &nftables.Rule{
894907
Table: r.filterTable,
895908
Chain: &nftables.Chain{
@@ -922,11 +935,10 @@ func (r *router) acceptForwardRulesNftables() error {
922935
},
923936
}
924937

925-
// Rule for outgoing interface (oif) with counter
926938
oifRule := &nftables.Rule{
927939
Table: r.filterTable,
928940
Chain: &nftables.Chain{
929-
Name: "FORWARD",
941+
Name: chainNameForward,
930942
Table: r.filterTable,
931943
Type: nftables.ChainTypeFilter,
932944
Hooknum: nftables.ChainHookForward,
@@ -935,35 +947,60 @@ func (r *router) acceptForwardRulesNftables() error {
935947
Exprs: append(oifExprs, getEstablishedExprs(2)...),
936948
UserData: []byte(userDataAcceptForwardRuleOif),
937949
}
938-
939950
r.conn.InsertRule(oifRule)
940951

952+
inputRule := &nftables.Rule{
953+
Table: r.filterTable,
954+
Chain: &nftables.Chain{
955+
Name: chainNameInput,
956+
Table: r.filterTable,
957+
Type: nftables.ChainTypeFilter,
958+
Hooknum: nftables.ChainHookInput,
959+
Priority: nftables.ChainPriorityFilter,
960+
},
961+
Exprs: []expr.Any{
962+
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
963+
&expr.Cmp{
964+
Op: expr.CmpOpEq,
965+
Register: 1,
966+
Data: intf,
967+
},
968+
&expr.Counter{},
969+
&expr.Verdict{Kind: expr.VerdictAccept},
970+
},
971+
UserData: []byte(userDataAcceptInputRule),
972+
}
973+
r.conn.InsertRule(inputRule)
974+
941975
return nil
942976
}
943977

944-
func (r *router) removeAcceptForwardRules() error {
978+
func (r *router) removeAcceptFilterRules() error {
945979
if r.filterTable == nil {
946980
return nil
947981
}
948982

949-
// Try iptables first and fallback to nftables if iptables is not available
950983
ipt, err := iptables.New()
951984
if err != nil {
952985
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
953-
return r.removeAcceptForwardRulesNftables()
986+
return r.removeAcceptFilterRulesNftables()
954987
}
955988

956-
return r.removeAcceptForwardRulesIptables(ipt)
989+
return r.removeAcceptFilterRulesIptables(ipt)
957990
}
958991

959-
func (r *router) removeAcceptForwardRulesNftables() error {
992+
func (r *router) removeAcceptFilterRulesNftables() error {
960993
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
961994
if err != nil {
962995
return fmt.Errorf("list chains: %v", err)
963996
}
964997

965998
for _, chain := range chains {
966-
if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward {
999+
if chain.Table.Name != r.filterTable.Name {
1000+
continue
1001+
}
1002+
1003+
if chain.Name != chainNameForward && chain.Name != chainNameInput {
9671004
continue
9681005
}
9691006

@@ -974,7 +1011,8 @@ func (r *router) removeAcceptForwardRulesNftables() error {
9741011

9751012
for _, rule := range rules {
9761013
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
977-
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) {
1014+
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) ||
1015+
bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) {
9781016
if err := r.conn.DelRule(rule); err != nil {
9791017
return fmt.Errorf("delete rule: %v", err)
9801018
}
@@ -989,14 +1027,20 @@ func (r *router) removeAcceptForwardRulesNftables() error {
9891027
return nil
9901028
}
9911029

992-
func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error {
1030+
func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error {
9931031
var merr *multierror.Error
1032+
9941033
for _, rule := range r.getAcceptForwardRules() {
9951034
if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil {
996-
merr = multierror.Append(err, fmt.Errorf("remove iptables rule: %v", err))
1035+
merr = multierror.Append(err, fmt.Errorf("remove iptables forward rule: %v", err))
9971036
}
9981037
}
9991038

1039+
inputRule := r.getAcceptInputRule()
1040+
if err := ipt.DeleteIfExists("filter", chainNameInput, inputRule...); err != nil {
1041+
merr = multierror.Append(err, fmt.Errorf("remove iptables input rule: %v", err))
1042+
}
1043+
10001044
return nberrors.FormatErrorOrNil(merr)
10011045
}
10021046

client/internal/dnsfwd/manager.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
nberrors "github.com/netbirdio/netbird/client/errors"
1616
firewall "github.com/netbirdio/netbird/client/firewall/manager"
1717
"github.com/netbirdio/netbird/client/iface/wgaddr"
18+
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
1819
"github.com/netbirdio/netbird/client/internal/peer"
1920
nbdns "github.com/netbirdio/netbird/dns"
2021
"github.com/netbirdio/netbird/route"
@@ -134,6 +135,8 @@ func (m *Manager) Stop(ctx context.Context) error {
134135
}
135136
}
136137

138+
m.unregisterNetstackServices()
139+
137140
if err := m.dropDNSFirewall(); err != nil {
138141
mErr = multierror.Append(mErr, err)
139142
}
@@ -170,9 +173,40 @@ func (m *Manager) allowDNSFirewall() error {
170173
}
171174
m.tcpRules = tcpRules
172175

176+
if err := m.firewall.Flush(); err != nil {
177+
log.Errorf("failed to flush DNS firewall rules: %v", err)
178+
return err
179+
}
180+
181+
m.registerNetstackServices()
182+
173183
return nil
174184
}
175185

186+
func (m *Manager) registerNetstackServices() {
187+
if netstackNet := m.wgIface.GetNet(); netstackNet != nil {
188+
if registrar, ok := m.firewall.(interface {
189+
RegisterNetstackService(protocol nftypes.Protocol, port uint16)
190+
}); ok {
191+
registrar.RegisterNetstackService(nftypes.TCP, m.serverPort)
192+
registrar.RegisterNetstackService(nftypes.UDP, m.serverPort)
193+
log.Debugf("registered DNS forwarder service with netstack for UDP/TCP:%d", m.serverPort)
194+
}
195+
}
196+
}
197+
198+
func (m *Manager) unregisterNetstackServices() {
199+
if netstackNet := m.wgIface.GetNet(); netstackNet != nil {
200+
if registrar, ok := m.firewall.(interface {
201+
UnregisterNetstackService(protocol nftypes.Protocol, port uint16)
202+
}); ok {
203+
registrar.UnregisterNetstackService(nftypes.TCP, m.serverPort)
204+
registrar.UnregisterNetstackService(nftypes.UDP, m.serverPort)
205+
log.Debugf("unregistered DNS forwarder service with netstack for UDP/TCP:%d", m.serverPort)
206+
}
207+
}
208+
}
209+
176210
func (m *Manager) dropDNSFirewall() error {
177211
var mErr *multierror.Error
178212
for _, rule := range m.fwRules {

0 commit comments

Comments
 (0)