Skip to content

Commit 641eb51

Browse files
authored
[client] Allow INPUT traffic on the compat iptables filter table for nftables (#4742)
1 parent 45c25dc commit 641eb51

File tree

6 files changed

+146
-210
lines changed

6 files changed

+146
-210
lines changed

client/firewall/manager/firewall.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,9 @@ type Manager interface {
100100
//
101101
// If comment argument is empty firewall manager should set
102102
// rule ID as comment for the rule
103+
//
104+
// Note: Callers should call Flush() after adding rules to ensure
105+
// they are applied to the kernel and rule handles are refreshed.
103106
AddPeerFiltering(
104107
id []byte,
105108
ip net.IP,

client/firewall/nftables/acl_linux.go

Lines changed: 14 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@ const (
2929
chainNameForwardFilter = "netbird-acl-forward-filter"
3030
chainNameManglePrerouting = "netbird-mangle-prerouting"
3131
chainNameManglePostrouting = "netbird-mangle-postrouting"
32-
33-
allowNetbirdInputRuleID = "allow Netbird incoming traffic"
3432
)
3533

3634
const flushError = "flush: %w"
@@ -195,25 +193,6 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
195193
// createDefaultAllowRules creates default allow rules for the input and output chains
196194
func (m *AclManager) createDefaultAllowRules() error {
197195
expIn := []expr.Any{
198-
&expr.Payload{
199-
DestRegister: 1,
200-
Base: expr.PayloadBaseNetworkHeader,
201-
Offset: 12,
202-
Len: 4,
203-
},
204-
// mask
205-
&expr.Bitwise{
206-
SourceRegister: 1,
207-
DestRegister: 1,
208-
Len: 4,
209-
Mask: []byte{0, 0, 0, 0},
210-
Xor: []byte{0, 0, 0, 0},
211-
},
212-
// net address
213-
&expr.Cmp{
214-
Register: 1,
215-
Data: []byte{0, 0, 0, 0},
216-
},
217196
&expr.Verdict{
218197
Kind: expr.VerdictAccept,
219198
},
@@ -258,7 +237,7 @@ func (m *AclManager) addIOFiltering(
258237
action firewall.Action,
259238
ipset *nftables.Set,
260239
) (*Rule, error) {
261-
ruleId := generatePeerRuleId(ip, sPort, dPort, action, ipset)
240+
ruleId := generatePeerRuleId(ip, proto, sPort, dPort, action, ipset)
262241
if r, ok := m.rules[ruleId]; ok {
263242
return &Rule{
264243
nftRule: r.nftRule,
@@ -357,11 +336,12 @@ func (m *AclManager) addIOFiltering(
357336
}
358337

359338
if err := m.rConn.Flush(); err != nil {
360-
return nil, fmt.Errorf(flushError, err)
339+
return nil, fmt.Errorf("flush input rule %s: %v", ruleId, err)
361340
}
362341

363342
ruleStruct := &Rule{
364-
nftRule: nftRule,
343+
nftRule: nftRule,
344+
// best effort mangle rule
365345
mangleRule: m.createPreroutingRule(expressions, userData),
366346
nftSet: ipset,
367347
ruleID: ruleId,
@@ -420,12 +400,19 @@ func (m *AclManager) createPreroutingRule(expressions []expr.Any, userData []byt
420400
},
421401
)
422402

423-
return m.rConn.AddRule(&nftables.Rule{
403+
nfRule := m.rConn.AddRule(&nftables.Rule{
424404
Table: m.workTable,
425405
Chain: m.chainPrerouting,
426406
Exprs: preroutingExprs,
427407
UserData: userData,
428408
})
409+
410+
if err := m.rConn.Flush(); err != nil {
411+
log.Errorf("failed to flush mangle rule %s: %v", string(userData), err)
412+
return nil
413+
}
414+
415+
return nfRule
429416
}
430417

431418
func (m *AclManager) createDefaultChains() (err error) {
@@ -697,8 +684,8 @@ func (m *AclManager) refreshRuleHandles(chain *nftables.Chain, mangle bool) erro
697684
return nil
698685
}
699686

700-
func generatePeerRuleId(ip net.IP, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string {
701-
rulesetID := ":"
687+
func generatePeerRuleId(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string {
688+
rulesetID := ":" + string(proto) + ":"
702689
if sPort != nil {
703690
rulesetID += sPort.String()
704691
}

client/firewall/nftables/manager_linux.go

Lines changed: 21 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
package nftables
22

33
import (
4-
"bytes"
54
"context"
65
"fmt"
76
"net"
87
"net/netip"
8+
"os"
99
"sync"
1010

1111
"github.com/google/nftables"
@@ -19,13 +19,22 @@ import (
1919
)
2020

2121
const (
22-
// tableNameNetbird is the name of the table that is used for filtering by the Netbird client
22+
// tableNameNetbird is the default name of the table that is used for filtering by the Netbird client
2323
tableNameNetbird = "netbird"
24+
// envTableName is the environment variable to override the table name
25+
envTableName = "NB_NFTABLES_TABLE"
2426

2527
tableNameFilter = "filter"
2628
chainNameInput = "INPUT"
2729
)
2830

31+
func getTableName() string {
32+
if name := os.Getenv(envTableName); name != "" {
33+
return name
34+
}
35+
return tableNameNetbird
36+
}
37+
2938
// iFaceMapper defines subset methods of interface required for manager
3039
type iFaceMapper interface {
3140
Name() string
@@ -50,7 +59,7 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
5059
wgIface: wgIface,
5160
}
5261

53-
workTable := &nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4}
62+
workTable := &nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4}
5463

5564
var err error
5665
m.router, err = newRouter(workTable, wgIface, mtu)
@@ -198,44 +207,11 @@ func (m *Manager) AllowNetbird() error {
198207
m.mutex.Lock()
199208
defer m.mutex.Unlock()
200209

201-
err := m.aclManager.createDefaultAllowRules()
202-
if err != nil {
203-
return fmt.Errorf("failed to create default allow rules: %v", err)
204-
}
205-
206-
chains, err := m.rConn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
207-
if err != nil {
208-
return fmt.Errorf("list of chains: %w", err)
209-
}
210-
211-
var chain *nftables.Chain
212-
for _, c := range chains {
213-
if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
214-
chain = c
215-
break
216-
}
210+
if err := m.aclManager.createDefaultAllowRules(); err != nil {
211+
return fmt.Errorf("create default allow rules: %w", err)
217212
}
218-
219-
if chain == nil {
220-
log.Debugf("chain INPUT not found. Skipping add allow netbird rule")
221-
return nil
222-
}
223-
224-
rules, err := m.rConn.GetRules(chain.Table, chain)
225-
if err != nil {
226-
return fmt.Errorf("failed to get rules for the INPUT chain: %v", err)
227-
}
228-
229-
if rule := m.detectAllowNetbirdRule(rules); rule != nil {
230-
log.Debugf("allow netbird rule already exists: %v", rule)
231-
return nil
232-
}
233-
234-
m.applyAllowNetbirdRules(chain)
235-
236-
err = m.rConn.Flush()
237-
if err != nil {
238-
return fmt.Errorf("failed to flush allow input netbird rules: %v", err)
213+
if err := m.rConn.Flush(); err != nil {
214+
return fmt.Errorf("flush allow input netbird rules: %w", err)
239215
}
240216

241217
return nil
@@ -251,10 +227,6 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
251227
m.mutex.Lock()
252228
defer m.mutex.Unlock()
253229

254-
if err := m.resetNetbirdInputRules(); err != nil {
255-
return fmt.Errorf("reset netbird input rules: %v", err)
256-
}
257-
258230
if err := m.router.Reset(); err != nil {
259231
return fmt.Errorf("reset router: %v", err)
260232
}
@@ -274,49 +246,15 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
274246
return nil
275247
}
276248

277-
func (m *Manager) resetNetbirdInputRules() error {
278-
chains, err := m.rConn.ListChains()
279-
if err != nil {
280-
return fmt.Errorf("list chains: %w", err)
281-
}
282-
283-
m.deleteNetbirdInputRules(chains)
284-
285-
return nil
286-
}
287-
288-
func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) {
289-
for _, c := range chains {
290-
if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
291-
rules, err := m.rConn.GetRules(c.Table, c)
292-
if err != nil {
293-
log.Errorf("get rules for chain %q: %v", c.Name, err)
294-
continue
295-
}
296-
297-
m.deleteMatchingRules(rules)
298-
}
299-
}
300-
}
301-
302-
func (m *Manager) deleteMatchingRules(rules []*nftables.Rule) {
303-
for _, r := range rules {
304-
if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) {
305-
if err := m.rConn.DelRule(r); err != nil {
306-
log.Errorf("delete rule: %v", err)
307-
}
308-
}
309-
}
310-
}
311-
312249
func (m *Manager) cleanupNetbirdTables() error {
313250
tables, err := m.rConn.ListTables()
314251
if err != nil {
315252
return fmt.Errorf("list tables: %w", err)
316253
}
317254

255+
tableName := getTableName()
318256
for _, t := range tables {
319-
if t.Name == tableNameNetbird {
257+
if t.Name == tableName {
320258
m.rConn.DelTable(t)
321259
}
322260
}
@@ -399,55 +337,18 @@ func (m *Manager) createWorkTable() (*nftables.Table, error) {
399337
return nil, fmt.Errorf("list of tables: %w", err)
400338
}
401339

340+
tableName := getTableName()
402341
for _, t := range tables {
403-
if t.Name == tableNameNetbird {
342+
if t.Name == tableName {
404343
m.rConn.DelTable(t)
405344
}
406345
}
407346

408-
table := m.rConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4})
347+
table := m.rConn.AddTable(&nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4})
409348
err = m.rConn.Flush()
410349
return table, err
411350
}
412351

413-
func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
414-
rule := &nftables.Rule{
415-
Table: chain.Table,
416-
Chain: chain,
417-
Exprs: []expr.Any{
418-
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
419-
&expr.Cmp{
420-
Op: expr.CmpOpEq,
421-
Register: 1,
422-
Data: ifname(m.wgIface.Name()),
423-
},
424-
&expr.Verdict{
425-
Kind: expr.VerdictAccept,
426-
},
427-
},
428-
UserData: []byte(allowNetbirdInputRuleID),
429-
}
430-
_ = m.rConn.InsertRule(rule)
431-
}
432-
433-
func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule {
434-
ifName := ifname(m.wgIface.Name())
435-
for _, rule := range existedRules {
436-
if rule.Table.Name == tableNameFilter && rule.Chain.Name == chainNameInput {
437-
if len(rule.Exprs) < 4 {
438-
if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME {
439-
continue
440-
}
441-
if e, ok := rule.Exprs[1].(*expr.Cmp); !ok || e.Op != expr.CmpOpEq || !bytes.Equal(e.Data, ifName) {
442-
continue
443-
}
444-
return rule
445-
}
446-
}
447-
}
448-
return nil
449-
}
450-
451352
func insertReturnTrafficRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain) {
452353
rule := &nftables.Rule{
453354
Table: table,

0 commit comments

Comments
 (0)