Skip to content

Commit 01e4063

Browse files
committed
shift to using buffered shutdown channel, rework
1 parent aa0dc07 commit 01e4063

File tree

3 files changed

+30
-25
lines changed

3 files changed

+30
-25
lines changed

ee/agent/types/registration.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,3 @@ type RegistrationTracker interface {
1111
RegistrationIDs() []string
1212
SetRegistrationIDs(registrationIDs []string) error
1313
}
14-

pkg/osquery/runtime/runner.go

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"log/slog"
88
"slices"
99
"sync"
10+
"sync/atomic"
1011
"time"
1112

1213
"github.com/kolide/launcher/ee/agent/flags/keys"
@@ -27,9 +28,9 @@ type Runner struct {
2728
knapsack types.Knapsack
2829
serviceClient service.KolideService // shared service client for communication between osquery instance and Kolide SaaS
2930
opts []OsqueryInstanceOption // global options applying to all osquery instances
30-
shutdown chan struct{}
31-
rerunRequired bool
32-
interrupted bool
31+
shutdown chan struct{} // buffered shutdown channel for to enable shutting down to restart or exit
32+
rerunRequired atomic.Bool
33+
interrupted atomic.Bool
3334
}
3435

3536
func New(k types.Knapsack, serviceClient service.KolideService, opts ...OsqueryInstanceOption) *Runner {
@@ -39,9 +40,9 @@ func New(k types.Knapsack, serviceClient service.KolideService, opts ...OsqueryI
3940
slogger: k.Slogger().With("component", "osquery_runner"),
4041
knapsack: k,
4142
serviceClient: serviceClient,
42-
shutdown: make(chan struct{}),
43-
rerunRequired: false,
44-
opts: opts,
43+
// the buffer length is arbitrarily set at 100, this number just needs to be higher than the total possible instances
44+
shutdown: make(chan struct{}, 100),
45+
opts: opts,
4546
}
4647

4748
k.RegisterChangeObserver(runner,
@@ -60,8 +61,8 @@ func (r *Runner) Run() error {
6061

6162
// if we're in a state that required re-running all registered instances,
6263
// reset the field and do that
63-
if r.rerunRequired {
64-
r.rerunRequired = false
64+
if r.rerunRequired.Load() {
65+
r.rerunRequired.Store(false)
6566
continue
6667
}
6768

@@ -214,6 +215,13 @@ func (r *Runner) Query(query string) ([]map[string]string, error) {
214215
}
215216

216217
func (r *Runner) Interrupt(_ error) {
218+
if r.interrupted.Load() {
219+
// Already shut down, nothing else to do
220+
return
221+
}
222+
223+
r.interrupted.Store(true)
224+
217225
if err := r.Shutdown(); err != nil {
218226
r.slogger.Log(context.TODO(), slog.LevelWarn,
219227
"could not shut down runner on interrupt",
@@ -225,13 +233,12 @@ func (r *Runner) Interrupt(_ error) {
225233
// Shutdown instructs the runner to permanently stop the running instance (no
226234
// restart will be attempted).
227235
func (r *Runner) Shutdown() error {
228-
if r.interrupted {
229-
// Already shut down, nothing else to do
230-
return nil
236+
// ensure one shutdown is sent for each instance to read
237+
r.instanceLock.Lock()
238+
for range r.instances {
239+
r.shutdown <- struct{}{}
231240
}
232-
233-
r.interrupted = true
234-
close(r.shutdown)
241+
r.instanceLock.Unlock()
235242

236243
if err := r.triggerShutdownForInstances(); err != nil {
237244
return fmt.Errorf("triggering shutdown for instances during runner shutdown: %w", err)
@@ -385,7 +392,7 @@ func (r *Runner) UpdateRegistrationIDs(newRegistrationIDs []string) error {
385392
r.registrationIds = newRegistrationIDs
386393
// mark rerun as required so that we can safely shutdown all workers and have the changes
387394
// picked back up from within the main Run function
388-
r.rerunRequired = true
395+
r.rerunRequired.Store(true)
389396

390397
if err := r.Shutdown(); err != nil {
391398
r.slogger.Log(context.TODO(), slog.LevelWarn,
@@ -396,9 +403,5 @@ func (r *Runner) UpdateRegistrationIDs(newRegistrationIDs []string) error {
396403
return err
397404
}
398405

399-
// reset the shutdown channel and interrupted state
400-
r.shutdown = make(chan struct{})
401-
r.interrupted = false
402-
403406
return nil
404407
}

pkg/osquery/runtime/runtime_test.go

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -638,11 +638,12 @@ func TestMultipleInstancesWithUpdatedRegistrationIDs(t *testing.T) {
638638

639639
// Add in an extra instance
640640
extraRegistrationId := ulid.New()
641-
runner.UpdateRegistrationIDs([]string{types.DefaultRegistrationID, extraRegistrationId})
641+
updateErr := runner.UpdateRegistrationIDs([]string{types.DefaultRegistrationID, extraRegistrationId})
642+
require.NoError(t, updateErr)
642643
waitHealthy(t, runner, logBytes)
643644
updatedInstanceStatuses := runner.InstanceStatuses()
644645
// verify that rerunRequired has been reset for any future changes
645-
require.False(t, runner.rerunRequired)
646+
require.False(t, runner.rerunRequired.Load())
646647
// now verify both instances are reported
647648
require.Equal(t, 2, len(runner.instances))
648649
require.Contains(t, updatedInstanceStatuses, types.DefaultRegistrationID)
@@ -655,7 +656,8 @@ func TestMultipleInstancesWithUpdatedRegistrationIDs(t *testing.T) {
655656

656657
// update registration IDs one more time, this time removing the additional registration
657658
originalDefaultInstanceStartTime := runner.instances[extraRegistrationId].stats.StartTime
658-
runner.UpdateRegistrationIDs([]string{types.DefaultRegistrationID})
659+
updateErr = runner.UpdateRegistrationIDs([]string{types.DefaultRegistrationID})
660+
require.NoError(t, updateErr)
659661
waitHealthy(t, runner, logBytes)
660662

661663
// now verify only the default instance remains
@@ -666,7 +668,7 @@ func TestMultipleInstancesWithUpdatedRegistrationIDs(t *testing.T) {
666668
require.NotEmpty(t, runner.instances[types.DefaultRegistrationID].stats.StartTime, "start time should be added to default instance stats on start up")
667669
require.NotEmpty(t, runner.instances[types.DefaultRegistrationID].stats.ConnectTime, "connect time should be added to default instance stats on start up")
668670
// verify that rerunRequired has been reset for any future changes
669-
require.False(t, runner.rerunRequired)
671+
require.False(t, runner.rerunRequired.Load())
670672
// verify the default instance was restarted
671673
require.NotEqual(t, originalDefaultInstanceStartTime, runner.instances[types.DefaultRegistrationID].stats.StartTime)
672674

@@ -726,7 +728,8 @@ func TestUpdatingRegistrationIDsOnlyRestartsForChanges(t *testing.T) {
726728
extraInstanceStartTime := runner.instances[extraRegistrationId].stats.StartTime
727729

728730
// rerun with identical registrationIDs in swapped order and verify that the instances are not restarted
729-
runner.UpdateRegistrationIDs([]string{extraRegistrationId, types.DefaultRegistrationID})
731+
updateErr := runner.UpdateRegistrationIDs([]string{extraRegistrationId, types.DefaultRegistrationID})
732+
require.NoError(t, updateErr)
730733
waitHealthy(t, runner, logBytes)
731734

732735
require.Equal(t, 2, len(runner.instances))

0 commit comments

Comments
 (0)