Skip to content

Commit 2539ddd

Browse files
fix: handle nil chain type in FromChainToRulesArray
1 parent de96cea commit 2539ddd

File tree

4 files changed

+41
-9
lines changed

4 files changed

+41
-9
lines changed

pkg/firewall/chain.go

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
package firewall
1616

1717
import (
18+
"fmt"
19+
1820
"github.com/google/nftables"
1921
"k8s.io/klog/v2"
2022

@@ -214,14 +216,22 @@ func isChainModified(nftChain *nftables.Chain, chain *firewallapi.Chain) bool {
214216
}
215217

216218
// FromChainToRulesArray converts a chain to an array of rules.
217-
func FromChainToRulesArray(chain *firewallapi.Chain) (rules []firewallutils.Rule) {
219+
func FromChainToRulesArray(chain *firewallapi.Chain) (rules []firewallutils.Rule, err error) {
220+
if chain == nil {
221+
return nil, fmt.Errorf("chain is nil")
222+
}
223+
224+
if chain.Type == nil {
225+
return nil, fmt.Errorf("chain type is required")
226+
}
227+
218228
switch *chain.Type {
219229
case firewallapi.ChainTypeFilter:
220230
rules = make([]firewallutils.Rule, len(chain.Rules.FilterRules))
221231
for i := range chain.Rules.FilterRules {
222232
rules[i] = &firewallutils.FilterRuleWrapper{FilterRule: &chain.Rules.FilterRules[i]}
223233
}
224-
return rules
234+
return rules, nil
225235
case firewallapi.ChainTypeNAT:
226236
rules = make([]firewallutils.Rule, len(chain.Rules.NatRules))
227237
for i := range chain.Rules.NatRules {
@@ -237,7 +247,7 @@ func FromChainToRulesArray(chain *firewallapi.Chain) (rules []firewallutils.Rule
237247
rules = []firewallutils.Rule{}
238248
}
239249
// It is not necessary, but linter complains
240-
return rules
250+
return rules, nil
241251
}
242252

243253
// cleanChain removes all the rules that are not present in the firewall configuration or that have been modified.
@@ -246,7 +256,11 @@ func cleanChain(nftconn *nftables.Conn, chain *firewallapi.Chain, nftChain *nfta
246256
if err != nil {
247257
return err
248258
}
249-
rules := FromChainToRulesArray(chain)
259+
rules, err := FromChainToRulesArray(chain)
260+
if err != nil {
261+
return err
262+
}
263+
250264
for i := range nftRules {
251265
// If the rule is outdated, delete it.
252266
outdated, ruleName := isRuleOutdated(nftRules[i], rules)

pkg/firewall/rule.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,11 @@ import (
2323
)
2424

2525
func addRules(nftconn *nftables.Conn, chain *firewallapi.Chain, nftchain *nftables.Chain) error {
26-
apirules := FromChainToRulesArray(chain)
26+
apirules, err := FromChainToRulesArray(chain)
27+
if err != nil {
28+
return err
29+
}
30+
2731
nftrules, err := nftconn.GetRules(nftchain.Table, nftchain)
2832
if err != nil {
2933
return err

pkg/webhooks/firewallconfiguration/firewallconfiguration.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,11 @@ func (w *webhookMutate) Handle(_ context.Context, req admission.Request) admissi
9090
return admission.Errored(http.StatusBadRequest, err)
9191
}
9292

93-
generateRuleNames(firewallConfiguration.Spec.Table.Chains)
93+
err = generateRuleNames(firewallConfiguration.Spec.Table.Chains)
94+
if err != nil {
95+
klog.Errorf("Failed generating rule names: %v", err)
96+
return admission.Errored(http.StatusBadRequest, err)
97+
}
9498

9599
return w.CreatePatchResponse(&req, firewallConfiguration)
96100
}

pkg/webhooks/firewallconfiguration/rule.go

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@ import (
2525
)
2626

2727
func checkRulesInChain(chain *firewallapi.Chain) error {
28-
rules := firewall.FromChainToRulesArray(chain)
28+
rules, err := firewall.FromChainToRulesArray(chain)
29+
if err != nil {
30+
return forgeChainError(chain, err)
31+
}
32+
2933
if err := checkVoidRuleName(rules); err != nil {
3034
return forgeChainError(chain, err)
3135
}
@@ -60,13 +64,19 @@ func checkUniqueRuleNames(rules []firewallutils.Rule) error {
6064
return nil
6165
}
6266

63-
func generateRuleNames(chains []firewallapi.Chain) {
67+
func generateRuleNames(chains []firewallapi.Chain) error {
6468
for i := range chains {
65-
rules := firewall.FromChainToRulesArray(&chains[i])
69+
rules, err := firewall.FromChainToRulesArray(&chains[i])
70+
if err != nil {
71+
return err
72+
}
73+
6674
for j := range rules {
6775
if rules[j].GetName() == nil || *rules[j].GetName() == "" {
6876
rules[j].SetName(uuid.NewString())
6977
}
7078
}
7179
}
80+
81+
return nil
7282
}

0 commit comments

Comments
 (0)