7
7
"log/slog"
8
8
"slices"
9
9
"sync"
10
+ "sync/atomic"
10
11
"time"
11
12
12
13
"github.com/kolide/launcher/ee/agent/flags/keys"
@@ -27,9 +28,9 @@ type Runner struct {
27
28
knapsack types.Knapsack
28
29
serviceClient service.KolideService // shared service client for communication between osquery instance and Kolide SaaS
29
30
opts []OsqueryInstanceOption // global options applying to all osquery instances
30
- shutdown chan struct {}
31
- rerunRequired bool
32
- interrupted bool
31
+ shutdown chan struct {} // buffered shutdown channel for to enable shutting down to restart or exit
32
+ rerunRequired atomic. Bool
33
+ interrupted atomic. Bool
33
34
}
34
35
35
36
func New (k types.Knapsack , serviceClient service.KolideService , opts ... OsqueryInstanceOption ) * Runner {
@@ -39,9 +40,9 @@ func New(k types.Knapsack, serviceClient service.KolideService, opts ...OsqueryI
39
40
slogger : k .Slogger ().With ("component" , "osquery_runner" ),
40
41
knapsack : k ,
41
42
serviceClient : serviceClient ,
42
- shutdown : make ( chan struct {}),
43
- rerunRequired : false ,
44
- opts : opts ,
43
+ // the buffer length is arbitrarily set at 100, this number just needs to be higher than the total possible instances
44
+ shutdown : make ( chan struct {}, 100 ) ,
45
+ opts : opts ,
45
46
}
46
47
47
48
k .RegisterChangeObserver (runner ,
@@ -60,8 +61,8 @@ func (r *Runner) Run() error {
60
61
61
62
// if we're in a state that required re-running all registered instances,
62
63
// reset the field and do that
63
- if r .rerunRequired {
64
- r .rerunRequired = false
64
+ if r .rerunRequired . Load () {
65
+ r .rerunRequired . Store ( false )
65
66
continue
66
67
}
67
68
@@ -214,6 +215,13 @@ func (r *Runner) Query(query string) ([]map[string]string, error) {
214
215
}
215
216
216
217
func (r * Runner ) Interrupt (_ error ) {
218
+ if r .interrupted .Load () {
219
+ // Already shut down, nothing else to do
220
+ return
221
+ }
222
+
223
+ r .interrupted .Store (true )
224
+
217
225
if err := r .Shutdown (); err != nil {
218
226
r .slogger .Log (context .TODO (), slog .LevelWarn ,
219
227
"could not shut down runner on interrupt" ,
@@ -225,13 +233,12 @@ func (r *Runner) Interrupt(_ error) {
225
233
// Shutdown instructs the runner to permanently stop the running instance (no
226
234
// restart will be attempted).
227
235
func (r * Runner ) Shutdown () error {
228
- if r .interrupted {
229
- // Already shut down, nothing else to do
230
- return nil
236
+ // ensure one shutdown is sent for each instance to read
237
+ r .instanceLock .Lock ()
238
+ for range r .instances {
239
+ r .shutdown <- struct {}{}
231
240
}
232
-
233
- r .interrupted = true
234
- close (r .shutdown )
241
+ r .instanceLock .Unlock ()
235
242
236
243
if err := r .triggerShutdownForInstances (); err != nil {
237
244
return fmt .Errorf ("triggering shutdown for instances during runner shutdown: %w" , err )
@@ -385,7 +392,7 @@ func (r *Runner) UpdateRegistrationIDs(newRegistrationIDs []string) error {
385
392
r .registrationIds = newRegistrationIDs
386
393
// mark rerun as required so that we can safely shutdown all workers and have the changes
387
394
// picked back up from within the main Run function
388
- r .rerunRequired = true
395
+ r .rerunRequired . Store ( true )
389
396
390
397
if err := r .Shutdown (); err != nil {
391
398
r .slogger .Log (context .TODO (), slog .LevelWarn ,
@@ -396,9 +403,5 @@ func (r *Runner) UpdateRegistrationIDs(newRegistrationIDs []string) error {
396
403
return err
397
404
}
398
405
399
- // reset the shutdown channel and interrupted state
400
- r .shutdown = make (chan struct {})
401
- r .interrupted = false
402
-
403
406
return nil
404
407
}
0 commit comments