Skip to content

Commit 0d3f052

Browse files
authored
fix(#1660): add validation for start command parameters (#1663)
1 parent 4251b14 commit 0d3f052

File tree

10 files changed

+801
-8
lines changed

10 files changed

+801
-8
lines changed

internal/cmd/params_validation.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
package cmd
2+
3+
import (
4+
"fmt"
5+
6+
"github.com/dagu-org/dagu/internal/core"
7+
)
8+
9+
func validateStartArgumentSeparator(ctx *Context, args []string) error {
10+
return core.ValidateStartArgs(ctx.Command.ArgsLenAtDash() != -1, args)
11+
}
12+
13+
func validateStartPositionalParamCount(ctx *Context, args []string, dag *core.DAG) error {
14+
input, err := buildStartValidationInput(ctx, args)
15+
if err != nil {
16+
return err
17+
}
18+
return core.ValidateStartParams(dag.DefaultParams, input)
19+
}
20+
21+
func buildStartValidationInput(ctx *Context, args []string) (core.StartParamInput, error) {
22+
if argsLenAtDash := ctx.Command.ArgsLenAtDash(); argsLenAtDash != -1 {
23+
if argsLenAtDash >= len(args) {
24+
return core.StartParamInput{}, nil
25+
}
26+
return core.StartParamInput{DashArgs: args[argsLenAtDash:]}, nil
27+
}
28+
29+
raw, err := ctx.Command.Flags().GetString("params")
30+
if err != nil {
31+
return core.StartParamInput{}, fmt.Errorf("failed to get parameters: %w", err)
32+
}
33+
return core.StartParamInput{RawParams: raw}, nil
34+
}
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
package cmd
2+
3+
import (
4+
"testing"
5+
6+
"github.com/dagu-org/dagu/internal/core"
7+
"github.com/spf13/cobra"
8+
"github.com/stretchr/testify/require"
9+
)
10+
11+
func TestValidateStartArgumentSeparator(t *testing.T) {
12+
t.Parallel()
13+
14+
tests := []struct {
15+
name string
16+
cliArgs []string
17+
wantErr string
18+
}{
19+
{
20+
name: "NoExtraArgs",
21+
cliArgs: []string{"dag.yaml"},
22+
},
23+
{
24+
name: "WithDashSeparator",
25+
cliArgs: []string{"dag.yaml", "--", "p1"},
26+
},
27+
{
28+
name: "ExtraArgsWithoutDash",
29+
cliArgs: []string{"dag.yaml", "p1"},
30+
wantErr: "use '--' before parameters",
31+
},
32+
}
33+
34+
for _, tt := range tests {
35+
t.Run(tt.name, func(t *testing.T) {
36+
t.Parallel()
37+
38+
ctx, args := testValidationContext(t, tt.cliArgs)
39+
err := validateStartArgumentSeparator(ctx, args)
40+
if tt.wantErr == "" {
41+
require.NoError(t, err)
42+
return
43+
}
44+
require.Error(t, err)
45+
require.Contains(t, err.Error(), tt.wantErr)
46+
})
47+
}
48+
}
49+
50+
func TestValidateStartPositionalParamCount(t *testing.T) {
51+
t.Parallel()
52+
53+
tests := []struct {
54+
name string
55+
cliArgs []string
56+
defaultArgs string
57+
wantErr string
58+
}{
59+
{
60+
name: "NoDeclaredParamsAllowsNamedPairsAfterDash",
61+
cliArgs: []string{"dag.yaml", "--", "key1=value1", "key2=value2"},
62+
defaultArgs: "",
63+
},
64+
{
65+
name: "NoDeclaredParamsAllowsPositionalAfterDash",
66+
cliArgs: []string{"dag.yaml", "--", "success"},
67+
defaultArgs: "",
68+
},
69+
{
70+
name: "AllowsFewerThanDeclaredPositionalAfterDash",
71+
cliArgs: []string{"dag.yaml", "--", "only-one"},
72+
defaultArgs: `p1 p2`,
73+
},
74+
{
75+
name: "RejectsTooManyPositionalAfterDash",
76+
cliArgs: []string{"dag.yaml", "--", "one", "two", "three"},
77+
defaultArgs: `p1 p2`,
78+
wantErr: "too many positional params: expected at most 2, got 3",
79+
},
80+
{
81+
name: "NamedOnlyDoesNotTriggerPositionalValidation",
82+
cliArgs: []string{"--params", "KEY1=value1 KEY2=value2", "dag.yaml"},
83+
defaultArgs: `p1 p2`,
84+
},
85+
{
86+
name: "JSONAfterDashSkipsPositionalValidation",
87+
cliArgs: []string{"dag.yaml", "--", `{"REGION":"us-east","VERSION":"2.0"}`},
88+
defaultArgs: `REGION=us-east-1 VERSION=1.0.0`,
89+
},
90+
{
91+
name: "JSONInParamsFlagSkipsPositionalValidation",
92+
cliArgs: []string{"--params", `{"KEY":"value"}`, "dag.yaml"},
93+
defaultArgs: `p1 p2`,
94+
},
95+
{
96+
name: "NamedDeclaredParamAllowsSinglePositional",
97+
cliArgs: []string{"dag.yaml", "--", "server1"},
98+
defaultArgs: `ITEM=default`,
99+
},
100+
}
101+
102+
for _, tt := range tests {
103+
t.Run(tt.name, func(t *testing.T) {
104+
t.Parallel()
105+
106+
ctx, args := testValidationContext(t, tt.cliArgs)
107+
dag := &core.DAG{DefaultParams: tt.defaultArgs}
108+
err := validateStartPositionalParamCount(ctx, args, dag)
109+
if tt.wantErr == "" {
110+
require.NoError(t, err)
111+
return
112+
}
113+
require.Error(t, err)
114+
require.Contains(t, err.Error(), tt.wantErr)
115+
})
116+
}
117+
}
118+
119+
func testValidationContext(t *testing.T, cliArgs []string) (*Context, []string) {
120+
t.Helper()
121+
122+
command := &cobra.Command{Use: "start"}
123+
command.Flags().String("params", "", "")
124+
125+
err := command.Flags().Parse(cliArgs)
126+
require.NoError(t, err)
127+
128+
return &Context{Command: command}, command.Flags().Args()
129+
}

internal/cmd/start.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,11 +154,19 @@ func runStart(ctx *Context, args []string) error {
154154
dag.Name = nameOverride
155155
}
156156
} else {
157+
if err := validateStartArgumentSeparator(ctx, args); err != nil {
158+
return err
159+
}
160+
157161
// Load parameters and DAG
158162
dag, params, err = loadDAGWithParams(ctx, args, isSubDAGRun)
159163
if err != nil {
160164
return err
161165
}
166+
167+
if err := validateStartPositionalParamCount(ctx, args, dag); err != nil {
168+
return err
169+
}
162170
}
163171

164172
root, err := determineRootDAGRun(isSubDAGRun, rootRef, dag, dagRunID)

internal/cmd/start_test.go

Lines changed: 95 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ steps:
6868
}
6969

7070
func TestCmdStart_BackwardCompatibility(t *testing.T) {
71-
t.Run("ShouldAcceptParametersAfter", func(t *testing.T) {
71+
t.Run("ShouldRejectParametersAfterWithoutSeparator", func(t *testing.T) {
7272
th := test.SetupCommand(t)
7373
dagContent := `
7474
params: KEY1=default1 KEY2=default2
@@ -78,12 +78,11 @@ steps:
7878
`
7979
dagFile := th.CreateDAGFile(t, "test-params.yaml", dagContent)
8080

81-
cli := cmd.Start()
82-
cli.SetArgs([]string{dagFile, "--", "KEY1=value1", "KEY2=value2"})
83-
84-
// Execute will fail due to missing context setup, but we're testing
85-
// that the command accepts the arguments
86-
_ = cli.Execute()
81+
err := th.RunCommandWithError(t, cmd.Start(), test.CmdTest{
82+
Args: []string{"start", dagFile, "KEY1=value1", "KEY2=value2"},
83+
})
84+
require.Error(t, err)
85+
require.Contains(t, err.Error(), "use '--' before parameters")
8786
})
8887

8988
t.Run("ShouldAcceptParamsFlag", func(t *testing.T) {
@@ -105,6 +104,95 @@ steps:
105104
})
106105
}
107106

107+
func TestCmdStart_PositionalParamValidation(t *testing.T) {
108+
th := test.SetupCommand(t)
109+
110+
dagFile := th.CreateDAGFile(t, "test-positional-params.yaml", `
111+
params: "p1 p2"
112+
steps:
113+
- name: step1
114+
command: echo $1 $2
115+
`)
116+
dagNoParamsFile := th.CreateDAGFile(t, "test-no-params.yaml", `
117+
steps:
118+
- name: step1
119+
command: echo $1
120+
`)
121+
122+
t.Run("AllowsTooFewAfterDash", func(t *testing.T) {
123+
err := th.RunCommandWithError(t, cmd.Start(), test.CmdTest{
124+
Args: []string{"start", dagFile, "--", "only-one"},
125+
})
126+
require.NoError(t, err)
127+
})
128+
129+
t.Run("RejectsTooManyAfterDash", func(t *testing.T) {
130+
err := th.RunCommandWithError(t, cmd.Start(), test.CmdTest{
131+
Args: []string{"start", dagFile, "--", "one", "two", "three"},
132+
})
133+
require.Error(t, err)
134+
require.Contains(t, err.Error(), "too many positional params: expected at most 2, got 3")
135+
})
136+
137+
t.Run("AllowsTooFewWithParamsFlag", func(t *testing.T) {
138+
err := th.RunCommandWithError(t, cmd.Start(), test.CmdTest{
139+
Args: []string{"start", "--params", "only-one", dagFile},
140+
})
141+
require.NoError(t, err)
142+
})
143+
144+
t.Run("AllowsNamedOnlyWithPositionalDefaults", func(t *testing.T) {
145+
err := th.RunCommandWithError(t, cmd.Start(), test.CmdTest{
146+
Args: []string{"start", "--params", "KEY1=value1 KEY2=value2", dagFile},
147+
})
148+
require.NoError(t, err)
149+
})
150+
151+
t.Run("AllowsJSONParamsWithoutPositionalValidation", func(t *testing.T) {
152+
err := th.RunCommandWithError(t, cmd.Start(), test.CmdTest{
153+
Args: []string{"start", "--params", `{"KEY":"value"}`, dagFile},
154+
})
155+
require.NoError(t, err)
156+
})
157+
158+
t.Run("AllowsJSONAfterDashWithoutPositionalValidation", func(t *testing.T) {
159+
err := th.RunCommandWithError(t, cmd.Start(), test.CmdTest{
160+
Args: []string{"start", dagFile, "--", `{"KEY":"value"}`},
161+
})
162+
require.NoError(t, err)
163+
})
164+
165+
t.Run("AllowsNamedPairsWhenNoParamsDeclared", func(t *testing.T) {
166+
err := th.RunCommandWithError(t, cmd.Start(), test.CmdTest{
167+
Args: []string{"start", dagNoParamsFile, "--", "key1=value1", "key2=value2"},
168+
})
169+
require.NoError(t, err)
170+
})
171+
172+
t.Run("AllowsPositionalWhenNoParamsDeclared", func(t *testing.T) {
173+
err := th.RunCommandWithError(t, cmd.Start(), test.CmdTest{
174+
Args: []string{"start", dagNoParamsFile, "--", "success"},
175+
})
176+
require.NoError(t, err)
177+
})
178+
}
179+
180+
func TestCmdStart_NamedParamsIgnorePositionalCount(t *testing.T) {
181+
th := test.SetupCommand(t)
182+
183+
dagFile := th.CreateDAGFile(t, "test-named-params.yaml", `
184+
params: KEY1=default1 KEY2=default2
185+
steps:
186+
- name: step1
187+
command: echo $KEY1 $KEY2
188+
`)
189+
190+
err := th.RunCommandWithError(t, cmd.Start(), test.CmdTest{
191+
Args: []string{"start", "--params", "KEY1=value1 KEY2=value2", dagFile},
192+
})
193+
require.NoError(t, err)
194+
}
195+
108196
func TestCmdStart_FromRunID(t *testing.T) {
109197
t.Run("ReschedulesWithStoredParameters", func(t *testing.T) {
110198
th := test.SetupCommand(t)

internal/core/spec/params.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ func parseMapParams(ctx BuildContext, input []any) ([]paramPair, error) {
252252

253253
// paramRegex is a regex to match the parameters in the command.
254254
var paramRegex = regexp.MustCompile(
255-
`(?:([^\s=]+)=)?("(?:\\"|[^"])*"|` + "`(" + `?:\\"|[^"]*)` + "`" + `|[^"\s]+)`,
255+
`(?:([^\s=]+)=)?("(?:\\"|[^"])*"|` + "`[^`]*`" + `|[^"\s]+)`,
256256
)
257257

258258
// backtickRegex matches backtick-enclosed commands for substitution.

0 commit comments

Comments
 (0)