11package nftables
22
33import (
4- "bytes"
54 "context"
65 "fmt"
76 "net"
87 "net/netip"
8+ "os"
99 "sync"
1010
1111 "github.com/google/nftables"
@@ -19,13 +19,22 @@ import (
1919)
2020
2121const (
22- // tableNameNetbird is the name of the table that is used for filtering by the Netbird client
22+ // tableNameNetbird is the default name of the table that is used for filtering by the Netbird client
2323 tableNameNetbird = "netbird"
24+ // envTableName is the environment variable to override the table name
25+ envTableName = "NB_NFTABLES_TABLE"
2426
2527 tableNameFilter = "filter"
2628 chainNameInput = "INPUT"
2729)
2830
31+ func getTableName () string {
32+ if name := os .Getenv (envTableName ); name != "" {
33+ return name
34+ }
35+ return tableNameNetbird
36+ }
37+
2938// iFaceMapper defines subset methods of interface required for manager
3039type iFaceMapper interface {
3140 Name () string
@@ -50,7 +59,7 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
5059 wgIface : wgIface ,
5160 }
5261
53- workTable := & nftables.Table {Name : tableNameNetbird , Family : nftables .TableFamilyIPv4 }
62+ workTable := & nftables.Table {Name : getTableName () , Family : nftables .TableFamilyIPv4 }
5463
5564 var err error
5665 m .router , err = newRouter (workTable , wgIface , mtu )
@@ -198,44 +207,11 @@ func (m *Manager) AllowNetbird() error {
198207 m .mutex .Lock ()
199208 defer m .mutex .Unlock ()
200209
201- err := m .aclManager .createDefaultAllowRules ()
202- if err != nil {
203- return fmt .Errorf ("failed to create default allow rules: %v" , err )
204- }
205-
206- chains , err := m .rConn .ListChainsOfTableFamily (nftables .TableFamilyIPv4 )
207- if err != nil {
208- return fmt .Errorf ("list of chains: %w" , err )
209- }
210-
211- var chain * nftables.Chain
212- for _ , c := range chains {
213- if c .Table .Name == tableNameFilter && c .Name == chainNameInput {
214- chain = c
215- break
216- }
210+ if err := m .aclManager .createDefaultAllowRules (); err != nil {
211+ return fmt .Errorf ("create default allow rules: %w" , err )
217212 }
218-
219- if chain == nil {
220- log .Debugf ("chain INPUT not found. Skipping add allow netbird rule" )
221- return nil
222- }
223-
224- rules , err := m .rConn .GetRules (chain .Table , chain )
225- if err != nil {
226- return fmt .Errorf ("failed to get rules for the INPUT chain: %v" , err )
227- }
228-
229- if rule := m .detectAllowNetbirdRule (rules ); rule != nil {
230- log .Debugf ("allow netbird rule already exists: %v" , rule )
231- return nil
232- }
233-
234- m .applyAllowNetbirdRules (chain )
235-
236- err = m .rConn .Flush ()
237- if err != nil {
238- return fmt .Errorf ("failed to flush allow input netbird rules: %v" , err )
213+ if err := m .rConn .Flush (); err != nil {
214+ return fmt .Errorf ("flush allow input netbird rules: %w" , err )
239215 }
240216
241217 return nil
@@ -251,10 +227,6 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
251227 m .mutex .Lock ()
252228 defer m .mutex .Unlock ()
253229
254- if err := m .resetNetbirdInputRules (); err != nil {
255- return fmt .Errorf ("reset netbird input rules: %v" , err )
256- }
257-
258230 if err := m .router .Reset (); err != nil {
259231 return fmt .Errorf ("reset router: %v" , err )
260232 }
@@ -274,49 +246,15 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
274246 return nil
275247}
276248
277- func (m * Manager ) resetNetbirdInputRules () error {
278- chains , err := m .rConn .ListChains ()
279- if err != nil {
280- return fmt .Errorf ("list chains: %w" , err )
281- }
282-
283- m .deleteNetbirdInputRules (chains )
284-
285- return nil
286- }
287-
288- func (m * Manager ) deleteNetbirdInputRules (chains []* nftables.Chain ) {
289- for _ , c := range chains {
290- if c .Table .Name == tableNameFilter && c .Name == chainNameInput {
291- rules , err := m .rConn .GetRules (c .Table , c )
292- if err != nil {
293- log .Errorf ("get rules for chain %q: %v" , c .Name , err )
294- continue
295- }
296-
297- m .deleteMatchingRules (rules )
298- }
299- }
300- }
301-
302- func (m * Manager ) deleteMatchingRules (rules []* nftables.Rule ) {
303- for _ , r := range rules {
304- if bytes .Equal (r .UserData , []byte (allowNetbirdInputRuleID )) {
305- if err := m .rConn .DelRule (r ); err != nil {
306- log .Errorf ("delete rule: %v" , err )
307- }
308- }
309- }
310- }
311-
312249func (m * Manager ) cleanupNetbirdTables () error {
313250 tables , err := m .rConn .ListTables ()
314251 if err != nil {
315252 return fmt .Errorf ("list tables: %w" , err )
316253 }
317254
255+ tableName := getTableName ()
318256 for _ , t := range tables {
319- if t .Name == tableNameNetbird {
257+ if t .Name == tableName {
320258 m .rConn .DelTable (t )
321259 }
322260 }
@@ -399,55 +337,18 @@ func (m *Manager) createWorkTable() (*nftables.Table, error) {
399337 return nil , fmt .Errorf ("list of tables: %w" , err )
400338 }
401339
340+ tableName := getTableName ()
402341 for _ , t := range tables {
403- if t .Name == tableNameNetbird {
342+ if t .Name == tableName {
404343 m .rConn .DelTable (t )
405344 }
406345 }
407346
408- table := m .rConn .AddTable (& nftables.Table {Name : tableNameNetbird , Family : nftables .TableFamilyIPv4 })
347+ table := m .rConn .AddTable (& nftables.Table {Name : getTableName () , Family : nftables .TableFamilyIPv4 })
409348 err = m .rConn .Flush ()
410349 return table , err
411350}
412351
413- func (m * Manager ) applyAllowNetbirdRules (chain * nftables.Chain ) {
414- rule := & nftables.Rule {
415- Table : chain .Table ,
416- Chain : chain ,
417- Exprs : []expr.Any {
418- & expr.Meta {Key : expr .MetaKeyIIFNAME , Register : 1 },
419- & expr.Cmp {
420- Op : expr .CmpOpEq ,
421- Register : 1 ,
422- Data : ifname (m .wgIface .Name ()),
423- },
424- & expr.Verdict {
425- Kind : expr .VerdictAccept ,
426- },
427- },
428- UserData : []byte (allowNetbirdInputRuleID ),
429- }
430- _ = m .rConn .InsertRule (rule )
431- }
432-
433- func (m * Manager ) detectAllowNetbirdRule (existedRules []* nftables.Rule ) * nftables.Rule {
434- ifName := ifname (m .wgIface .Name ())
435- for _ , rule := range existedRules {
436- if rule .Table .Name == tableNameFilter && rule .Chain .Name == chainNameInput {
437- if len (rule .Exprs ) < 4 {
438- if e , ok := rule .Exprs [0 ].(* expr.Meta ); ! ok || e .Key != expr .MetaKeyIIFNAME {
439- continue
440- }
441- if e , ok := rule .Exprs [1 ].(* expr.Cmp ); ! ok || e .Op != expr .CmpOpEq || ! bytes .Equal (e .Data , ifName ) {
442- continue
443- }
444- return rule
445- }
446- }
447- }
448- return nil
449- }
450-
451352func insertReturnTrafficRule (conn * nftables.Conn , table * nftables.Table , chain * nftables.Chain ) {
452353 rule := & nftables.Rule {
453354 Table : table ,
0 commit comments