Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions client/firewall/manager/firewall.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ type Manager interface {
//
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
//
// Note: Callers should call Flush() after adding rules to ensure
// they are applied to the kernel and rule handles are refreshed.
AddPeerFiltering(
id []byte,
ip net.IP,
Expand Down
41 changes: 14 additions & 27 deletions client/firewall/nftables/acl_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ const (
chainNameForwardFilter = "netbird-acl-forward-filter"
chainNameManglePrerouting = "netbird-mangle-prerouting"
chainNameManglePostrouting = "netbird-mangle-postrouting"

allowNetbirdInputRuleID = "allow Netbird incoming traffic"
)

const flushError = "flush: %w"
Expand Down Expand Up @@ -195,25 +193,6 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
// createDefaultAllowRules creates default allow rules for the input and output chains
func (m *AclManager) createDefaultAllowRules() error {
expIn := []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
// mask
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: []byte{0, 0, 0, 0},
Xor: []byte{0, 0, 0, 0},
},
// net address
&expr.Cmp{
Register: 1,
Data: []byte{0, 0, 0, 0},
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
Expand Down Expand Up @@ -258,7 +237,7 @@ func (m *AclManager) addIOFiltering(
action firewall.Action,
ipset *nftables.Set,
) (*Rule, error) {
ruleId := generatePeerRuleId(ip, sPort, dPort, action, ipset)
ruleId := generatePeerRuleId(ip, proto, sPort, dPort, action, ipset)
if r, ok := m.rules[ruleId]; ok {
return &Rule{
nftRule: r.nftRule,
Expand Down Expand Up @@ -357,11 +336,12 @@ func (m *AclManager) addIOFiltering(
}

if err := m.rConn.Flush(); err != nil {
return nil, fmt.Errorf(flushError, err)
return nil, fmt.Errorf("flush input rule %s: %v", ruleId, err)
}

ruleStruct := &Rule{
nftRule: nftRule,
nftRule: nftRule,
// best effort mangle rule
mangleRule: m.createPreroutingRule(expressions, userData),
nftSet: ipset,
ruleID: ruleId,
Expand Down Expand Up @@ -420,12 +400,19 @@ func (m *AclManager) createPreroutingRule(expressions []expr.Any, userData []byt
},
)

return m.rConn.AddRule(&nftables.Rule{
nfRule := m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: m.chainPrerouting,
Exprs: preroutingExprs,
UserData: userData,
})

if err := m.rConn.Flush(); err != nil {
log.Errorf("failed to flush mangle rule %s: %v", string(userData), err)
return nil
}

return nfRule
}

func (m *AclManager) createDefaultChains() (err error) {
Expand Down Expand Up @@ -697,8 +684,8 @@ func (m *AclManager) refreshRuleHandles(chain *nftables.Chain, mangle bool) erro
return nil
}

func generatePeerRuleId(ip net.IP, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string {
rulesetID := ":"
func generatePeerRuleId(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string {
rulesetID := ":" + string(proto) + ":"
if sPort != nil {
rulesetID += sPort.String()
}
Expand Down
141 changes: 21 additions & 120 deletions client/firewall/nftables/manager_linux.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
package nftables

import (
"bytes"
"context"
"fmt"
"net"
"net/netip"
"os"
"sync"

"github.com/google/nftables"
Expand All @@ -19,13 +19,22 @@ import (
)

const (
// tableNameNetbird is the name of the table that is used for filtering by the Netbird client
// tableNameNetbird is the default name of the table that is used for filtering by the Netbird client
tableNameNetbird = "netbird"
// envTableName is the environment variable to override the table name
envTableName = "NB_NFTABLES_TABLE"

tableNameFilter = "filter"
chainNameInput = "INPUT"
)

func getTableName() string {
if name := os.Getenv(envTableName); name != "" {
return name
}
return tableNameNetbird
}

// iFaceMapper defines subset methods of interface required for manager
type iFaceMapper interface {
Name() string
Expand All @@ -50,7 +59,7 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
wgIface: wgIface,
}

workTable := &nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4}
workTable := &nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4}

var err error
m.router, err = newRouter(workTable, wgIface, mtu)
Expand Down Expand Up @@ -198,44 +207,11 @@ func (m *Manager) AllowNetbird() error {
m.mutex.Lock()
defer m.mutex.Unlock()

err := m.aclManager.createDefaultAllowRules()
if err != nil {
return fmt.Errorf("failed to create default allow rules: %v", err)
}

chains, err := m.rConn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
if err != nil {
return fmt.Errorf("list of chains: %w", err)
}

var chain *nftables.Chain
for _, c := range chains {
if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
chain = c
break
}
if err := m.aclManager.createDefaultAllowRules(); err != nil {
return fmt.Errorf("create default allow rules: %w", err)
}

if chain == nil {
log.Debugf("chain INPUT not found. Skipping add allow netbird rule")
return nil
}

rules, err := m.rConn.GetRules(chain.Table, chain)
if err != nil {
return fmt.Errorf("failed to get rules for the INPUT chain: %v", err)
}

if rule := m.detectAllowNetbirdRule(rules); rule != nil {
log.Debugf("allow netbird rule already exists: %v", rule)
return nil
}

m.applyAllowNetbirdRules(chain)

err = m.rConn.Flush()
if err != nil {
return fmt.Errorf("failed to flush allow input netbird rules: %v", err)
if err := m.rConn.Flush(); err != nil {
return fmt.Errorf("flush allow input netbird rules: %w", err)
}

return nil
Expand All @@ -251,10 +227,6 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
m.mutex.Lock()
defer m.mutex.Unlock()

if err := m.resetNetbirdInputRules(); err != nil {
return fmt.Errorf("reset netbird input rules: %v", err)
}

if err := m.router.Reset(); err != nil {
return fmt.Errorf("reset router: %v", err)
}
Expand All @@ -274,49 +246,15 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
return nil
}

func (m *Manager) resetNetbirdInputRules() error {
chains, err := m.rConn.ListChains()
if err != nil {
return fmt.Errorf("list chains: %w", err)
}

m.deleteNetbirdInputRules(chains)

return nil
}

func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) {
for _, c := range chains {
if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
rules, err := m.rConn.GetRules(c.Table, c)
if err != nil {
log.Errorf("get rules for chain %q: %v", c.Name, err)
continue
}

m.deleteMatchingRules(rules)
}
}
}

func (m *Manager) deleteMatchingRules(rules []*nftables.Rule) {
for _, r := range rules {
if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) {
if err := m.rConn.DelRule(r); err != nil {
log.Errorf("delete rule: %v", err)
}
}
}
}

func (m *Manager) cleanupNetbirdTables() error {
tables, err := m.rConn.ListTables()
if err != nil {
return fmt.Errorf("list tables: %w", err)
}

tableName := getTableName()
for _, t := range tables {
if t.Name == tableNameNetbird {
if t.Name == tableName {
m.rConn.DelTable(t)
}
}
Expand Down Expand Up @@ -399,55 +337,18 @@ func (m *Manager) createWorkTable() (*nftables.Table, error) {
return nil, fmt.Errorf("list of tables: %w", err)
}

tableName := getTableName()
for _, t := range tables {
if t.Name == tableNameNetbird {
if t.Name == tableName {
m.rConn.DelTable(t)
}
}

table := m.rConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4})
table := m.rConn.AddTable(&nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4})
err = m.rConn.Flush()
return table, err
}

func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
rule := &nftables.Rule{
Table: chain.Table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
UserData: []byte(allowNetbirdInputRuleID),
}
_ = m.rConn.InsertRule(rule)
}

func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule {
ifName := ifname(m.wgIface.Name())
for _, rule := range existedRules {
if rule.Table.Name == tableNameFilter && rule.Chain.Name == chainNameInput {
if len(rule.Exprs) < 4 {
if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME {
continue
}
if e, ok := rule.Exprs[1].(*expr.Cmp); !ok || e.Op != expr.CmpOpEq || !bytes.Equal(e.Data, ifName) {
continue
}
return rule
}
}
}
return nil
}

func insertReturnTrafficRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain) {
rule := &nftables.Rule{
Table: table,
Expand Down
Loading
Loading