-
Notifications
You must be signed in to change notification settings - Fork 23
Expand file tree
/
Copy pathrun_scenario.go
More file actions
218 lines (196 loc) · 7.37 KB
/
run_scenario.go
File metadata and controls
218 lines (196 loc) · 7.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
package cli
import (
"context"
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"os"
"strings"
"time"
"github.com/antithesishq/antithesis-sdk-go/assert"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"github.com/temporalio/omes/cmd/clioptions"
"github.com/temporalio/omes/loadgen"
"go.temporal.io/sdk/client"
"go.uber.org/zap"
)
func runScenarioCmd() *cobra.Command {
var r scenarioRunner
cmd := &cobra.Command{
Use: "run-scenario",
Short: "Run scenario",
PreRun: func(cmd *cobra.Command, args []string) {
r.preRun()
},
Run: func(cmd *cobra.Command, args []string) {
ctx, cancel := withCancelOnInterrupt(cmd.Context())
defer cancel()
if err := r.run(ctx); err != nil {
r.logger.Fatal(err)
}
},
}
r.addCLIFlags(cmd.Flags())
cmd.MarkFlagRequired("scenario")
cmd.MarkFlagRequired("run-id")
return cmd
}
type scenarioRunner struct {
scenarioRunConfig
scenario clioptions.ScenarioID
logger *zap.SugaredLogger
connectTimeout time.Duration
clientOptions clioptions.ClientOptions
metricsOptions clioptions.MetricsOptions
loggingOptions clioptions.LoggingOptions
}
type scenarioRunConfig struct {
iterations int
duration time.Duration
maxConcurrent int
maxIterationsPerSecond float64
maxIterationAttempts int
scenarioOptions []string
timeout time.Duration
doNotRegisterSearchAttributes bool
verificationTimeout time.Duration
}
func (r *scenarioRunner) addCLIFlags(fs *pflag.FlagSet) {
r.scenario.AddCLIFlags(fs)
r.scenarioRunConfig.addCLIFlags(fs)
fs.DurationVar(&r.connectTimeout, "connect-timeout", 0, "Duration to try to connect to server before failing")
fs.AddFlagSet(r.clientOptions.FlagSet())
fs.AddFlagSet(r.metricsOptions.FlagSet(""))
fs.AddFlagSet(r.loggingOptions.FlagSet())
}
func (r *scenarioRunConfig) addCLIFlags(fs *pflag.FlagSet) {
fs.IntVar(&r.iterations, "iterations", 0, "Override default iterations for the scenario (cannot be provided with duration)")
fs.DurationVar(&r.duration, "duration", 0, "Override duration for the scenario (cannot be provided with iteration)."+
" This is the amount of time for which we will start new iterations of the scenario.")
fs.Float64Var(&r.maxIterationsPerSecond, "max-iterations-per-second", 0, "Override iterations per second rate limit for the scenario."+
" This is the maximum rate at which we will start new iterations of the scenario.")
fs.IntVar(&r.maxIterationAttempts, "max-iteration-attempts", 1, "Maximum attempts per iteration")
fs.DurationVar(&r.timeout, "timeout", 0, "If set, the scenario will stop after this amount of"+
" time has elapsed. Any still-running iterations will be cancelled, and omes will exit nonzero.")
fs.IntVar(&r.maxConcurrent, "max-concurrent", 0, "Override max-concurrent for the scenario")
fs.StringSliceVar(&r.scenarioOptions, "option", nil, "Additional options for the scenario, in key=value format")
fs.BoolVar(&r.doNotRegisterSearchAttributes, "do-not-register-search-attributes", false,
"Do not register the default search attributes used by scenarios. "+
"If the search attributes are not registed by the scenario they must be registered through some other method")
fs.DurationVar(&r.verificationTimeout, "verification-timeout", 2*time.Minute,
"Maximum duration to wait for post-scenario verification (default 2m).")
}
func (r *scenarioRunner) preRun() {
r.logger = r.loggingOptions.MustCreateLogger()
}
func (r *scenarioRunner) run(ctx context.Context) error {
scenario := loadgen.GetScenario(r.scenario.Scenario)
if scenario == nil {
return fmt.Errorf("scenario not found")
} else if r.scenario.RunID == "" {
return fmt.Errorf("run ID not found")
} else if r.iterations > 0 && r.duration > 0 {
return fmt.Errorf("cannot provide both iterations and duration")
} else if r.verificationTimeout <= 0 {
return fmt.Errorf("verification-timeout must be greater than 0")
}
// Parse options
scenarioOptions := make(map[string]string, len(r.scenarioOptions))
for _, v := range r.scenarioOptions {
pieces := strings.SplitN(v, "=", 2)
if len(pieces) != 2 {
return fmt.Errorf("option does not have '='")
}
key, value := pieces[0], pieces[1]
// If the value starts with '@', read the file and use its contents as the value.
if strings.HasPrefix(value, "@") {
filePath := strings.TrimPrefix(value, "@")
data, err := os.ReadFile(filePath)
if err != nil {
return fmt.Errorf("failed to read file %s: %w", filePath, err)
}
value = string(data)
}
scenarioOptions[key] = value
}
metrics := r.metricsOptions.MustCreateMetrics(r.logger)
defer metrics.Shutdown(ctx)
start := time.Now()
var client client.Client
var err error
for {
client, err = r.clientOptions.Dial(metrics, r.logger)
if err == nil {
break
}
// Only fail if past wait period
if time.Since(start) > r.connectTimeout {
return fmt.Errorf("failed dialing: %w", err)
}
// Wait 300ms and try again
time.Sleep(300 * time.Millisecond)
r.logger.Error("Failed to dial, retrying ...", zap.Error(err))
}
defer client.Close()
repoDir, err := getRepoDir()
if err != nil {
return fmt.Errorf("failed to get root directory: %w", err)
}
// Generate a random execution ID to ensure no two executions with the same RunID collide
executionID, err := generateExecutionID()
if err != nil {
return fmt.Errorf("failed to generate execution ID: %w", err)
}
scenarioInfo := loadgen.ScenarioInfo{
ScenarioName: r.scenario.Scenario,
RunID: r.scenario.RunID,
ExecutionID: executionID,
Logger: r.logger,
MetricsHandler: metrics.NewHandler(),
Client: client,
Configuration: loadgen.RunConfiguration{
Iterations: r.iterations,
Duration: r.duration,
MaxConcurrent: r.maxConcurrent,
MaxIterationsPerSecond: r.maxIterationsPerSecond,
MaxIterationAttempts: r.maxIterationAttempts,
Timeout: r.timeout,
DoNotRegisterSearchAttributes: r.doNotRegisterSearchAttributes,
},
ScenarioOptions: scenarioOptions,
Namespace: r.clientOptions.Namespace,
RootPath: repoDir,
}
executor := scenario.ExecutorFn()
// 1. Run the scenario
scenarioErr := executor.Run(ctx, scenarioInfo)
// Collect all errors
var allErrors []error
if scenarioErr != nil {
allErrors = append(allErrors, fmt.Errorf("scenario execution failed: %w", scenarioErr))
assert.Unreachable("scenario execution failed", map[string]any{"error": scenarioErr})
}
verifyCtx, verifyCancel := context.WithTimeout(ctx, r.verificationTimeout)
defer verifyCancel()
// 2. Run verifications
if scenario.VerifyFn != nil {
verifyErrs := scenario.VerifyFn(verifyCtx, scenarioInfo, executor)
for _, err := range verifyErrs {
allErrors = append(allErrors, fmt.Errorf("post-scenario verification failed: %w", err))
assert.Unreachable("post-scenario verification failed", map[string]any{"error": err})
}
}
// Aggregate all errors
return errors.Join(allErrors...)
}
// generateExecutionID generates a random execution ID to uniquely identify this particular
// execution of a scenario. This ensures no two executions with the same RunID collide.
func generateExecutionID() (string, error) {
bytes := make([]byte, 8) // 8 bytes = 16 hex characters
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}