@@ -5,10 +5,179 @@ package traffic
55import (
66 "bufio"
77 "fmt"
8+ "log/slog"
9+ "net"
810 "os"
11+ "strconv"
912 "strings"
13+
14+ "github.com/google/nftables"
15+ "github.com/google/nftables/expr"
1016)
1117
18+ const statsTableName = "system_control_stats"
19+
20+ // SetupNodeCounters creates nftables counting rules for each node IP.
21+ // Uses a separate table to avoid interference with NAT rules.
22+ func (e * Engine ) SetupNodeCounters (nodes []NodeInfo ) error {
23+ conn , err := nftables .New ()
24+ if err != nil {
25+ return fmt .Errorf ("nftables connect: %w" , err )
26+ }
27+
28+ // Delete existing stats table (ignore error if not exists)
29+ conn .DelTable (& nftables.Table {Name : statsTableName , Family : nftables .TableFamilyIPv4 })
30+ _ = conn .Flush ()
31+
32+ if len (nodes ) == 0 {
33+ return nil
34+ }
35+
36+ conn , err = nftables .New ()
37+ if err != nil {
38+ return fmt .Errorf ("nftables connect: %w" , err )
39+ }
40+
41+ table := conn .AddTable (& nftables.Table {
42+ Name : statsTableName ,
43+ Family : nftables .TableFamilyIPv4 ,
44+ })
45+
46+ // Forward chain to count traffic passing through this host to/from nodes
47+ chain := conn .AddChain (& nftables.Chain {
48+ Name : "traffic_count" ,
49+ Table : table ,
50+ Type : nftables .ChainTypeFilter ,
51+ Hooknum : nftables .ChainHookForward ,
52+ Priority : nftables .ChainPriorityFilter ,
53+ })
54+
55+ for _ , node := range nodes {
56+ ip := net .ParseIP (node .IP ).To4 ()
57+ if ip == nil {
58+ slog .Error ("traffic: invalid node IP, skipping" , "nodeId" , node .ID , "ip" , node .IP )
59+ continue
60+ }
61+
62+ nodeIDStr := strconv .FormatInt (node .ID , 10 )
63+
64+ // Rule: match dst IP = node.IP → count incoming traffic to node
65+ conn .AddRule (& nftables.Rule {
66+ Table : table ,
67+ Chain : chain ,
68+ Exprs : []expr.Any {
69+ & expr.Payload {DestRegister : 1 , Base : expr .PayloadBaseNetworkHeader , Offset : 16 , Len : 4 },
70+ & expr.Cmp {Op : expr .CmpOpEq , Register : 1 , Data : ip },
71+ & expr.Counter {},
72+ },
73+ UserData : []byte ("node:" + nodeIDStr + ":in" ),
74+ })
75+
76+ // Rule: match src IP = node.IP → count outgoing traffic from node
77+ conn .AddRule (& nftables.Rule {
78+ Table : table ,
79+ Chain : chain ,
80+ Exprs : []expr.Any {
81+ & expr.Payload {DestRegister : 1 , Base : expr .PayloadBaseNetworkHeader , Offset : 12 , Len : 4 },
82+ & expr.Cmp {Op : expr .CmpOpEq , Register : 1 , Data : ip },
83+ & expr.Counter {},
84+ },
85+ UserData : []byte ("node:" + nodeIDStr + ":out" ),
86+ })
87+ }
88+
89+ if err := conn .Flush (); err != nil {
90+ return fmt .Errorf ("nftables flush stats: %w" , err )
91+ }
92+
93+ slog .Info ("traffic: node counters set up" , "nodes" , len (nodes ))
94+ return nil
95+ }
96+
97+ // CollectNodeCounters reads nftables counter values for all node rules.
98+ func (e * Engine ) CollectNodeCounters () ([]NodeTrafficSnapshot , error ) {
99+ conn , err := nftables .New ()
100+ if err != nil {
101+ return nil , fmt .Errorf ("nftables connect: %w" , err )
102+ }
103+
104+ table := & nftables.Table {Name : statsTableName , Family : nftables .TableFamilyIPv4 }
105+ chain := & nftables.Chain {Name : "traffic_count" , Table : table }
106+
107+ rules , err := conn .GetRules (table , chain )
108+ if err != nil {
109+ return nil , fmt .Errorf ("get rules: %w" , err )
110+ }
111+
112+ // Aggregate counters by node ID
113+ type counterPair struct {
114+ bytesIn uint64
115+ bytesOut uint64
116+ }
117+ counters := make (map [int64 ]* counterPair )
118+
119+ for _ , rule := range rules {
120+ if len (rule .UserData ) == 0 {
121+ continue
122+ }
123+ tag := string (rule .UserData )
124+ // Format: "node:{id}:{in|out}"
125+ parts := strings .SplitN (tag , ":" , 3 )
126+ if len (parts ) != 3 || parts [0 ] != "node" {
127+ continue
128+ }
129+
130+ nodeID , err := strconv .ParseInt (parts [1 ], 10 , 64 )
131+ if err != nil {
132+ continue
133+ }
134+ direction := parts [2 ]
135+
136+ // Find counter expression in rule
137+ var bytes uint64
138+ for _ , e := range rule .Exprs {
139+ if c , ok := e .(* expr.Counter ); ok {
140+ bytes = c .Bytes
141+ break
142+ }
143+ }
144+
145+ pair , ok := counters [nodeID ]
146+ if ! ok {
147+ pair = & counterPair {}
148+ counters [nodeID ] = pair
149+ }
150+
151+ switch direction {
152+ case "in" :
153+ pair .bytesIn = bytes
154+ case "out" :
155+ pair .bytesOut = bytes
156+ }
157+ }
158+
159+ results := make ([]NodeTrafficSnapshot , 0 , len (counters ))
160+ for nodeID , pair := range counters {
161+ results = append (results , NodeTrafficSnapshot {
162+ NodeID : nodeID ,
163+ BytesIn : pair .bytesIn ,
164+ BytesOut : pair .bytesOut ,
165+ })
166+ }
167+
168+ return results , nil
169+ }
170+
171+ // CleanupNodeCounters removes the stats table.
172+ func (e * Engine ) CleanupNodeCounters () error {
173+ conn , err := nftables .New ()
174+ if err != nil {
175+ return fmt .Errorf ("nftables connect: %w" , err )
176+ }
177+ conn .DelTable (& nftables.Table {Name : statsTableName , Family : nftables .TableFamilyIPv4 })
178+ return conn .Flush ()
179+ }
180+
12181// CollectInterfaces reads /proc/net/dev and returns per-interface byte counters.
13182func (e * Engine ) CollectInterfaces (names []string ) ([]InterfaceSnapshot , error ) {
14183 f , err := os .Open ("/proc/net/dev" )
0 commit comments