Skip to content

Commit 350ec16

Browse files
committed
Improve flag handling.
1 parent 4dc927c commit 350ec16

2 files changed

Lines changed: 125 additions & 35 deletions

File tree

cmd/runner/main.go

Lines changed: 67 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"fmt"
1010
"log"
1111
"os"
12+
"strconv"
1213
"time"
1314

1415
"github.com/alitto/pond"
@@ -20,30 +21,82 @@ import (
2021
"go.temporal.io/sdk/client"
2122
)
2223

23-
var nWorfklows = flag.Int("c", 10, "concurrent workflows")
24+
var nWorkflows = flag.Int("c", 10, "concurrent workflows")
2425
var sWorkflow = flag.String("t", "", "workflow type")
2526
var sSignalType = flag.String("s", "", "signal type")
2627
var bWait = flag.Bool("w", true, "wait for workflows to complete")
2728
var sNamespace = flag.String("n", "default", "namespace")
2829
var sTaskQueue = flag.String("tq", "benchmark", "task queue")
2930

31+
// Track which flags were explicitly set
32+
var flagsSet = make(map[string]bool)
33+
34+
// flagValue helps implement precedence: command line > environment variable > default
35+
func getStringValue(flagName, envName, flagValue, defaultValue string) string {
36+
if flagsSet[flagName] {
37+
return flagValue
38+
}
39+
if envValue := os.Getenv(envName); envValue != "" {
40+
return envValue
41+
}
42+
return defaultValue
43+
}
44+
45+
func getIntValue(flagName, envName string, flagValue, defaultValue int) int {
46+
if flagsSet[flagName] {
47+
return flagValue
48+
}
49+
if envValue := os.Getenv(envName); envValue != "" {
50+
if parsed, err := strconv.Atoi(envValue); err == nil {
51+
return parsed
52+
}
53+
}
54+
return defaultValue
55+
}
56+
57+
func getBoolValue(flagName, envName string, flagValue, defaultValue bool) bool {
58+
if flagsSet[flagName] {
59+
return flagValue
60+
}
61+
if envValue := os.Getenv(envName); envValue != "" {
62+
if parsed, err := strconv.ParseBool(envValue); err == nil {
63+
return parsed
64+
}
65+
}
66+
return defaultValue
67+
}
68+
3069
func main() {
3170
flag.Usage = func() {
3271
fmt.Fprintf(flag.CommandLine.Output(), "Usage: %s [flags] [workflow input] ...\n", os.Args[0])
3372
flag.PrintDefaults()
73+
fmt.Fprintf(flag.CommandLine.Output(), "\nEnvironment variables (used if flag not set):\n")
74+
fmt.Fprintf(flag.CommandLine.Output(), " TEMPORAL_CONCURRENT_WORKFLOWS\n")
75+
fmt.Fprintf(flag.CommandLine.Output(), " TEMPORAL_WORKFLOW_TYPE\n")
76+
fmt.Fprintf(flag.CommandLine.Output(), " TEMPORAL_SIGNAL_TYPE\n")
77+
fmt.Fprintf(flag.CommandLine.Output(), " TEMPORAL_WAIT\n")
78+
fmt.Fprintf(flag.CommandLine.Output(), " TEMPORAL_NAMESPACE\n")
79+
fmt.Fprintf(flag.CommandLine.Output(), " TEMPORAL_TASK_QUEUE\n")
3480
}
3581

3682
flag.Parse()
3783

84+
// Track which flags were explicitly set by the user
85+
flag.Visit(func(f *flag.Flag) {
86+
flagsSet[f.Name] = true
87+
})
88+
3889
if _, err := maxprocs.Set(); err != nil {
3990
log.Printf("WARNING: failed to set GOMAXPROCS: %v.\n", err)
4091
}
4192

42-
namespace := *sNamespace
43-
envNamespace := os.Getenv("TEMPORAL_NAMESPACE")
44-
if envNamespace != "" && envNamespace != "default" {
45-
namespace = envNamespace
46-
}
93+
// Apply precedence: command line > environment variable > default
94+
concurrentWorkflows := getIntValue("c", "TEMPORAL_CONCURRENT_WORKFLOWS", *nWorkflows, 10)
95+
workflowType := getStringValue("t", "TEMPORAL_WORKFLOW_TYPE", *sWorkflow, "")
96+
signalType := getStringValue("s", "TEMPORAL_SIGNAL_TYPE", *sSignalType, "")
97+
waitForCompletion := getBoolValue("w", "TEMPORAL_WAIT", *bWait, true)
98+
namespace := getStringValue("n", "TEMPORAL_NAMESPACE", *sNamespace, "default")
99+
taskQueue := getStringValue("tq", "TEMPORAL_TASK_QUEUE", *sTaskQueue, "benchmark")
47100

48101
log.Printf("Using namespace: %s", namespace)
49102

@@ -111,23 +164,23 @@ func main() {
111164
input = append(input, i)
112165
}
113166

114-
pool := pond.New(*nWorfklows, 0)
167+
pool := pond.New(concurrentWorkflows, 0)
115168

116169
var starter func() (client.WorkflowRun, error)
117170

118-
if *sSignalType != "" {
171+
if signalType != "" {
119172
starter = func() (client.WorkflowRun, error) {
120173
wID := uuid.New()
121174
return c.SignalWithStartWorkflow(
122175
context.Background(),
123176
wID,
124-
*sSignalType,
177+
signalType,
125178
nil,
126179
client.StartWorkflowOptions{
127180
ID: wID,
128-
TaskQueue: *sTaskQueue,
181+
TaskQueue: taskQueue,
129182
},
130-
*sWorkflow,
183+
workflowType,
131184
input...,
132185
)
133186
}
@@ -136,9 +189,9 @@ func main() {
136189
return c.ExecuteWorkflow(
137190
context.Background(),
138191
client.StartWorkflowOptions{
139-
TaskQueue: *sTaskQueue,
192+
TaskQueue: taskQueue,
140193
},
141-
*sWorkflow,
194+
workflowType,
142195
input...,
143196
)
144197
}
@@ -153,7 +206,7 @@ func main() {
153206
return
154207
}
155208

156-
if *bWait {
209+
if waitForCompletion {
157210
err = wf.Get(context.Background(), nil)
158211
if err != nil {
159212
log.Println("Workflow failed", err)

cmd/worker/main.go

Lines changed: 58 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ package main
33
import (
44
"crypto/tls"
55
"crypto/x509"
6+
"flag"
7+
"fmt"
68
"log"
79
"os"
810
"strconv"
@@ -19,23 +21,66 @@ import (
1921
"go.temporal.io/sdk/workflow"
2022
)
2123

24+
var sNamespace = flag.String("n", "default", "namespace")
25+
var sTaskQueue = flag.String("tq", "benchmark", "task queue")
26+
var nWorkflowPollers = flag.Int("wp", -1, "max concurrent workflow task pollers (-1 = use default, 0 = disable)")
27+
var nActivityPollers = flag.Int("ap", -1, "max concurrent activity task pollers (-1 = use default, 0 = disable)")
28+
29+
// Track which flags were explicitly set
30+
var flagsSet = make(map[string]bool)
31+
32+
func getStringValue(flagName, envName, flagValue, defaultValue string) string {
33+
if flagsSet[flagName] {
34+
return flagValue
35+
}
36+
if envValue := os.Getenv(envName); envValue != "" {
37+
return envValue
38+
}
39+
return defaultValue
40+
}
41+
42+
func getIntValue(flagName, envName string, flagValue, defaultValue int) int {
43+
if flagsSet[flagName] {
44+
return flagValue
45+
}
46+
if envValue := os.Getenv(envName); envValue != "" {
47+
if parsed, err := strconv.Atoi(envValue); err == nil {
48+
return parsed
49+
}
50+
}
51+
return defaultValue
52+
}
53+
2254
func main() {
55+
flag.Usage = func() {
56+
fmt.Fprintf(flag.CommandLine.Output(), "Usage: %s [flags]\n", os.Args[0])
57+
flag.PrintDefaults()
58+
fmt.Fprintf(flag.CommandLine.Output(), "\nEnvironment variables (used if flag not set):\n")
59+
fmt.Fprintf(flag.CommandLine.Output(), " TEMPORAL_NAMESPACE\n")
60+
fmt.Fprintf(flag.CommandLine.Output(), " TEMPORAL_TASK_QUEUE\n")
61+
fmt.Fprintf(flag.CommandLine.Output(), " TEMPORAL_WORKFLOW_TASK_POLLERS\n")
62+
fmt.Fprintf(flag.CommandLine.Output(), " TEMPORAL_ACTIVITY_TASK_POLLERS\n")
63+
}
64+
65+
flag.Parse()
66+
67+
// Track which flags were explicitly set by the user
68+
flag.Visit(func(f *flag.Flag) {
69+
flagsSet[f.Name] = true
70+
})
71+
2372
if _, err := maxprocs.Set(); err != nil {
2473
log.Printf("WARNING: failed to set GOMAXPROCS: %v.\n", err)
2574
}
2675

27-
namespace := os.Getenv("TEMPORAL_NAMESPACE")
28-
if namespace == "" {
29-
namespace = "default"
30-
}
76+
// Apply precedence: command line > environment variable > default
77+
namespace := getStringValue("n", "TEMPORAL_NAMESPACE", *sNamespace, "default")
78+
taskQueue := getStringValue("tq", "TEMPORAL_TASK_QUEUE", *sTaskQueue, "benchmark")
79+
workflowPollers := getIntValue("wp", "TEMPORAL_WORKFLOW_TASK_POLLERS", *nWorkflowPollers, -1)
80+
activityPollers := getIntValue("ap", "TEMPORAL_ACTIVITY_TASK_POLLERS", *nActivityPollers, -1)
3181

3282
log.Printf("Creating worker for namespace: %s", namespace)
3383

34-
taskQueue := os.Getenv("TEMPORAL_TASK_QUEUE")
35-
if taskQueue == "" {
36-
taskQueue = "benchmark"
37-
}
38-
3984
clientOptions := client.Options{
4085
HostPort: os.Getenv("TEMPORAL_GRPC_ENDPOINT"),
4186
Namespace: namespace,
@@ -89,20 +134,12 @@ func main() {
89134

90135
workerOptions := worker.Options{}
91136

92-
if os.Getenv("TEMPORAL_WORKFLOW_TASK_POLLERS") != "" {
93-
pollers, err := strconv.Atoi(os.Getenv("TEMPORAL_WORKFLOW_TASK_POLLERS"))
94-
if err != nil {
95-
log.Fatalf("TEMPORAL_WORKFLOW_TASK_POLLERS is invalid: %v", err)
96-
}
97-
workerOptions.MaxConcurrentWorkflowTaskPollers = pollers
137+
if workflowPollers >= 0 {
138+
workerOptions.MaxConcurrentWorkflowTaskPollers = workflowPollers
98139
}
99140

100-
if os.Getenv("TEMPORAL_ACTIVITY_TASK_POLLERS") != "" {
101-
pollers, err := strconv.Atoi(os.Getenv("TEMPORAL_ACTIVITY_TASK_POLLERS"))
102-
if err != nil {
103-
log.Fatalf("TEMPORAL_ACTIVITY_TASK_POLLERS is invalid: %v", err)
104-
}
105-
workerOptions.MaxConcurrentActivityTaskPollers = pollers
141+
if activityPollers >= 0 {
142+
workerOptions.MaxConcurrentActivityTaskPollers = activityPollers
106143
}
107144

108145
// TODO: Support more worker options

0 commit comments

Comments
 (0)