Skip to content

Commit a531de1

Browse files
committed
palo_alto_local_rulestack_rule - support port ranges and 'any' keyword in protocol_ports
- Accept port ranges in format 'TCP:1024-1206' and 'UDP:5000-5100' - Accept 'any' and 'application-default' as standalone values - Add unit tests for ProtocolWithPort validation function Fixes #25907
1 parent 4a747f7 commit a531de1

File tree

3 files changed

+131
-3
lines changed

3 files changed

+131
-3
lines changed

internal/services/paloalto/palo_alto_local_rulestack_rule_resource.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ func (r LocalRuleStackRule) Arguments() map[string]*pluginsdk.Schema {
157157
Optional: true,
158158
ValidateFunc: validation.Any(
159159
validate.ProtocolWithPort,
160-
validation.StringInSlice([]string{protocolApplicationDefault}, false),
160+
validation.StringInSlice([]string{protocolApplicationDefault, "any"}, false),
161161
),
162162
ExactlyOneOf: []string{"protocol", "protocol_ports"},
163163
},

internal/services/paloalto/validate/protocol.go

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,46 @@ func ProtocolWithPort(input interface{}, k string) (warnings []string, errors []
1616
return
1717
}
1818

19+
if v == "any" || v == "application-default" {
20+
return
21+
}
22+
1923
parts := strings.Split(v, ":")
2024
if len(parts) != 2 {
21-
errors = append(errors, fmt.Errorf("expected %s to be a two part string separated by a `:`, e.g. TCP:80", k))
25+
errors = append(errors, fmt.Errorf("expected %s to be a two part string separated by a `:`, e.g. TCP:80, or a supported keyword like `any`", k))
2226
return
2327
}
2428

2529
if parts[0] != "TCP" && parts[0] != "UDP" {
2630
errors = append(errors, fmt.Errorf("protocol portion of %s must be one of `TCP` or `UDP`, got %q", k, parts[0]))
2731
}
2832

33+
if strings.Contains(parts[1], "-") {
34+
rangeParts := strings.Split(parts[1], "-")
35+
if len(rangeParts) != 2 {
36+
errors = append(errors, fmt.Errorf("port range in %s must be in format START-END, e.g. TCP:1024-1206", k))
37+
return
38+
}
39+
startPort, err := strconv.Atoi(rangeParts[0])
40+
if err != nil || startPort < 1 || startPort > 65535 {
41+
errors = append(errors, fmt.Errorf("start port in %s must be an integer between 1 and 65535, got %q", k, rangeParts[0]))
42+
return
43+
}
44+
endPort, err := strconv.Atoi(rangeParts[1])
45+
if err != nil || endPort < 1 || endPort > 65535 {
46+
errors = append(errors, fmt.Errorf("end port in %s must be an integer between 1 and 65535, got %q", k, rangeParts[1]))
47+
return
48+
}
49+
if startPort > endPort {
50+
errors = append(errors, fmt.Errorf("start port must be less than or equal to end port in %s", k))
51+
return
52+
}
53+
return
54+
}
55+
2956
port, err := strconv.Atoi(parts[1])
3057
if err != nil || port == 0 || port > 65535 {
31-
errors = append(errors, fmt.Errorf("port in %s must me an integer value between 1 and 65535, got %q", k, parts[1]))
58+
errors = append(errors, fmt.Errorf("port in %s must be an integer value between 1 and 65535, got %q", k, parts[1]))
3259
}
3360

3461
return
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
// Copyright IBM Corp. 2014, 2025
2+
// SPDX-License-Identifier: MPL-2.0
3+
4+
package validate
5+
6+
import (
7+
"testing"
8+
)
9+
10+
func TestProtocolWithPort(t *testing.T) {
11+
tests := []struct {
12+
name string
13+
input string
14+
wantErrors int
15+
}{
16+
{
17+
name: "single port TCP",
18+
input: "TCP:80",
19+
wantErrors: 0,
20+
},
21+
{
22+
name: "single port UDP",
23+
input: "UDP:443",
24+
wantErrors: 0,
25+
},
26+
{
27+
name: "port range TCP",
28+
input: "TCP:1024-1206",
29+
wantErrors: 0,
30+
},
31+
{
32+
name: "port range UDP",
33+
input: "UDP:5000-5100",
34+
wantErrors: 0,
35+
},
36+
{
37+
name: "any keyword",
38+
input: "any",
39+
wantErrors: 0,
40+
},
41+
{
42+
name: "application-default keyword",
43+
input: "application-default",
44+
wantErrors: 0,
45+
},
46+
{
47+
name: "invalid single port zero",
48+
input: "TCP:0",
49+
wantErrors: 1,
50+
},
51+
{
52+
name: "invalid port too high",
53+
input: "TCP:70000",
54+
wantErrors: 1,
55+
},
56+
{
57+
name: "invalid start greater than end",
58+
input: "TCP:2000-1000",
59+
wantErrors: 1,
60+
},
61+
{
62+
name: "invalid start port zero in range",
63+
input: "TCP:0-100",
64+
wantErrors: 1,
65+
},
66+
{
67+
name: "invalid end port zero in range",
68+
input: "TCP:100-0",
69+
wantErrors: 1,
70+
},
71+
{
72+
name: "invalid protocol",
73+
input: "ICMP:80",
74+
wantErrors: 1,
75+
},
76+
{
77+
name: "invalid missing port",
78+
input: "TCP:",
79+
wantErrors: 1,
80+
},
81+
{
82+
name: "invalid missing colon",
83+
input: "TCP80",
84+
wantErrors: 1,
85+
},
86+
{
87+
name: "invalid malformed range",
88+
input: "TCP:100-200-300",
89+
wantErrors: 1,
90+
},
91+
}
92+
93+
for _, tt := range tests {
94+
t.Run(tt.name, func(t *testing.T) {
95+
_, errors := ProtocolWithPort(tt.input, "test_field")
96+
if len(errors) != tt.wantErrors {
97+
t.Errorf("ProtocolWithPort(%q) got %d errors, want %d", tt.input, len(errors), tt.wantErrors)
98+
}
99+
})
100+
}
101+
}

0 commit comments

Comments
 (0)