Skip to content

Commit aa0dc07

Browse files
committed
restarts seem to be working this way
1 parent 5748b5d commit aa0dc07

File tree

3 files changed

+218
-2
lines changed

3 files changed

+218
-2
lines changed

ee/agent/types/registration.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@ type RegistrationTracker interface {
1111
RegistrationIDs() []string
1212
SetRegistrationIDs(registrationIDs []string) error
1313
}
14+

pkg/osquery/runtime/runner.go

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"errors"
66
"fmt"
77
"log/slog"
8+
"slices"
89
"sync"
910
"time"
1011

@@ -27,6 +28,7 @@ type Runner struct {
2728
serviceClient service.KolideService // shared service client for communication between osquery instance and Kolide SaaS
2829
opts []OsqueryInstanceOption // global options applying to all osquery instances
2930
shutdown chan struct{}
31+
rerunRequired bool
3032
interrupted bool
3133
}
3234

@@ -38,6 +40,7 @@ func New(k types.Knapsack, serviceClient service.KolideService, opts ...OsqueryI
3840
knapsack: k,
3941
serviceClient: serviceClient,
4042
shutdown: make(chan struct{}),
43+
rerunRequired: false,
4144
opts: opts,
4245
}
4346

@@ -49,6 +52,31 @@ func New(k types.Knapsack, serviceClient service.KolideService, opts ...OsqueryI
4952
}
5053

5154
func (r *Runner) Run() error {
55+
for {
56+
// if our instances ever exit unexpectedly, return immediately
57+
if err := r.runRegisteredInstances(); err != nil {
58+
return err
59+
}
60+
61+
// if we're in a state that required re-running all registered instances,
62+
// reset the field and do that
63+
if r.rerunRequired {
64+
r.rerunRequired = false
65+
continue
66+
}
67+
68+
// otherwise, exit cleanly
69+
return nil
70+
}
71+
}
72+
73+
func (r *Runner) runRegisteredInstances() error {
74+
// clear the internal instances to add back in fresh as we runInstance,
75+
// this prevents old instances from sticking around if a registrationID is ever removed
76+
r.instanceLock.Lock()
77+
r.instances = make(map[string]*OsqueryInstance)
78+
r.instanceLock.Unlock()
79+
5280
// Create a group to track the workers running each instance
5381
wg, ctx := errgroup.WithContext(context.Background())
5482

@@ -334,7 +362,43 @@ func (r *Runner) InstanceStatuses() map[string]types.InstanceStatus {
334362
return instanceStatuses
335363
}
336364

337-
func (r *Runner) UpdateRegistrationIDs(registrationIDs []string) error {
338-
// TODO: detect any difference in reg IDs and shut down/spin up instances accordingly
365+
func (r *Runner) UpdateRegistrationIDs(newRegistrationIDs []string) error {
366+
slices.Sort(newRegistrationIDs)
367+
existingRegistrationIDs := r.registrationIds
368+
slices.Sort(existingRegistrationIDs)
369+
370+
if slices.Equal(newRegistrationIDs, existingRegistrationIDs) {
371+
r.slogger.Log(context.TODO(), slog.LevelDebug,
372+
"skipping runner restarts for updated registration IDs, no changes detected",
373+
)
374+
375+
return nil
376+
}
377+
378+
r.slogger.Log(context.TODO(), slog.LevelDebug,
379+
"detected changes to registrationIDs, will restart runner instances",
380+
"previous_registration_ids", existingRegistrationIDs,
381+
"new_registration_ids", newRegistrationIDs,
382+
)
383+
384+
// we know there are changes, safe to update the internal registrationIDs now
385+
r.registrationIds = newRegistrationIDs
386+
// mark rerun as required so that we can safely shutdown all workers and have the changes
387+
// picked back up from within the main Run function
388+
r.rerunRequired = true
389+
390+
if err := r.Shutdown(); err != nil {
391+
r.slogger.Log(context.TODO(), slog.LevelWarn,
392+
"could not shut down runner instances for restart after registration changes",
393+
"err", err,
394+
)
395+
396+
return err
397+
}
398+
399+
// reset the shutdown channel and interrupted state
400+
r.shutdown = make(chan struct{})
401+
r.interrupted = false
402+
339403
return nil
340404
}

pkg/osquery/runtime/runtime_test.go

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -593,6 +593,157 @@ func TestExtensionIsCleanedUp(t *testing.T) {
593593
<-timer1.C
594594
}
595595

596+
func TestMultipleInstancesWithUpdatedRegistrationIDs(t *testing.T) {
597+
t.Parallel()
598+
rootDirectory := testRootDirectory(t)
599+
600+
logBytes, slogger := setUpTestSlogger()
601+
602+
k := typesMocks.NewKnapsack(t)
603+
k.On("RegistrationIDs").Return([]string{types.DefaultRegistrationID})
604+
k.On("OsqueryHealthcheckStartupDelay").Return(0 * time.Second).Maybe()
605+
k.On("WatchdogEnabled").Return(false)
606+
k.On("RegisterChangeObserver", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything)
607+
k.On("Slogger").Return(slogger)
608+
k.On("LatestOsquerydPath", mock.Anything).Return(testOsqueryBinaryDirectory)
609+
k.On("RootDirectory").Return(rootDirectory).Maybe()
610+
k.On("OsqueryFlags").Return([]string{})
611+
k.On("OsqueryVerbose").Return(true)
612+
k.On("LoggingInterval").Return(5 * time.Minute).Maybe()
613+
k.On("LogMaxBytesPerBatch").Return(0).Maybe()
614+
k.On("Transport").Return("jsonrpc").Maybe()
615+
k.On("ReadEnrollSecret").Return("", nil).Maybe()
616+
setUpMockStores(t, k)
617+
serviceClient := mockServiceClient()
618+
619+
runner := New(k, serviceClient)
620+
621+
// Start the instance
622+
go runner.Run()
623+
waitHealthy(t, runner, logBytes)
624+
625+
// Confirm the default instance was started
626+
require.Contains(t, runner.instances, types.DefaultRegistrationID)
627+
require.NotNil(t, runner.instances[types.DefaultRegistrationID].stats)
628+
require.NotEmpty(t, runner.instances[types.DefaultRegistrationID].stats.StartTime, "start time should be added to default instance stats on start up")
629+
require.NotEmpty(t, runner.instances[types.DefaultRegistrationID].stats.ConnectTime, "connect time should be added to default instance stats on start up")
630+
631+
// confirm only the default instance has started
632+
require.Equal(t, 1, len(runner.instances))
633+
634+
// Confirm instance statuses are reported correctly
635+
instanceStatuses := runner.InstanceStatuses()
636+
require.Contains(t, instanceStatuses, types.DefaultRegistrationID)
637+
require.Equal(t, instanceStatuses[types.DefaultRegistrationID], types.InstanceStatusHealthy)
638+
639+
// Add in an extra instance
640+
extraRegistrationId := ulid.New()
641+
runner.UpdateRegistrationIDs([]string{types.DefaultRegistrationID, extraRegistrationId})
642+
waitHealthy(t, runner, logBytes)
643+
updatedInstanceStatuses := runner.InstanceStatuses()
644+
// verify that rerunRequired has been reset for any future changes
645+
require.False(t, runner.rerunRequired)
646+
// now verify both instances are reported
647+
require.Equal(t, 2, len(runner.instances))
648+
require.Contains(t, updatedInstanceStatuses, types.DefaultRegistrationID)
649+
require.Contains(t, updatedInstanceStatuses, extraRegistrationId)
650+
// Confirm the additional instance was started and is healthy
651+
require.NotNil(t, runner.instances[extraRegistrationId].stats)
652+
require.NotEmpty(t, runner.instances[extraRegistrationId].stats.StartTime, "start time should be added to secondary instance stats on start up")
653+
require.NotEmpty(t, runner.instances[extraRegistrationId].stats.ConnectTime, "connect time should be added to secondary instance stats on start up")
654+
require.Equal(t, updatedInstanceStatuses[extraRegistrationId], types.InstanceStatusHealthy)
655+
656+
// update registration IDs one more time, this time removing the additional registration
657+
originalDefaultInstanceStartTime := runner.instances[extraRegistrationId].stats.StartTime
658+
runner.UpdateRegistrationIDs([]string{types.DefaultRegistrationID})
659+
waitHealthy(t, runner, logBytes)
660+
661+
// now verify only the default instance remains
662+
require.Equal(t, 1, len(runner.instances))
663+
// Confirm the default instance was started and is healthy
664+
require.Contains(t, runner.instances, types.DefaultRegistrationID)
665+
require.NotNil(t, runner.instances[types.DefaultRegistrationID].stats)
666+
require.NotEmpty(t, runner.instances[types.DefaultRegistrationID].stats.StartTime, "start time should be added to default instance stats on start up")
667+
require.NotEmpty(t, runner.instances[types.DefaultRegistrationID].stats.ConnectTime, "connect time should be added to default instance stats on start up")
668+
// verify that rerunRequired has been reset for any future changes
669+
require.False(t, runner.rerunRequired)
670+
// verify the default instance was restarted
671+
require.NotEqual(t, originalDefaultInstanceStartTime, runner.instances[types.DefaultRegistrationID].stats.StartTime)
672+
673+
waitShutdown(t, runner, logBytes)
674+
675+
// Confirm both instances exited
676+
require.Contains(t, runner.instances, types.DefaultRegistrationID)
677+
require.NotNil(t, runner.instances[types.DefaultRegistrationID].stats)
678+
require.NotEmpty(t, runner.instances[types.DefaultRegistrationID].stats.ExitTime, "exit time should be added to default instance stats on shutdown")
679+
}
680+
681+
func TestUpdatingRegistrationIDsOnlyRestartsForChanges(t *testing.T) {
682+
t.Parallel()
683+
rootDirectory := testRootDirectory(t)
684+
685+
logBytes, slogger := setUpTestSlogger()
686+
extraRegistrationId := ulid.New()
687+
688+
k := typesMocks.NewKnapsack(t)
689+
k.On("RegistrationIDs").Return([]string{types.DefaultRegistrationID, extraRegistrationId})
690+
k.On("OsqueryHealthcheckStartupDelay").Return(0 * time.Second).Maybe()
691+
k.On("WatchdogEnabled").Return(false)
692+
k.On("RegisterChangeObserver", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything)
693+
k.On("Slogger").Return(slogger)
694+
k.On("LatestOsquerydPath", mock.Anything).Return(testOsqueryBinaryDirectory)
695+
k.On("RootDirectory").Return(rootDirectory).Maybe()
696+
k.On("OsqueryFlags").Return([]string{})
697+
k.On("OsqueryVerbose").Return(true)
698+
k.On("LoggingInterval").Return(5 * time.Minute).Maybe()
699+
k.On("LogMaxBytesPerBatch").Return(0).Maybe()
700+
k.On("Transport").Return("jsonrpc").Maybe()
701+
k.On("ReadEnrollSecret").Return("", nil).Maybe()
702+
setUpMockStores(t, k)
703+
serviceClient := mockServiceClient()
704+
705+
runner := New(k, serviceClient)
706+
707+
// Start the instance
708+
go runner.Run()
709+
waitHealthy(t, runner, logBytes)
710+
711+
require.Equal(t, 2, len(runner.instances))
712+
// Confirm the default instance was started
713+
require.Contains(t, runner.instances, types.DefaultRegistrationID)
714+
require.NotNil(t, runner.instances[types.DefaultRegistrationID].stats)
715+
require.NotEmpty(t, runner.instances[types.DefaultRegistrationID].stats.StartTime, "start time should be added to default instance stats on start up")
716+
require.NotEmpty(t, runner.instances[types.DefaultRegistrationID].stats.ConnectTime, "connect time should be added to default instance stats on start up")
717+
// note the original start time
718+
defaultInstanceStartTime := runner.instances[types.DefaultRegistrationID].stats.StartTime
719+
720+
// Confirm the extra instance was started
721+
require.Contains(t, runner.instances, extraRegistrationId)
722+
require.NotNil(t, runner.instances[extraRegistrationId].stats)
723+
require.NotEmpty(t, runner.instances[extraRegistrationId].stats.StartTime, "start time should be added to extra instance stats on start up")
724+
require.NotEmpty(t, runner.instances[extraRegistrationId].stats.ConnectTime, "connect time should be added to extra instance stats on start up")
725+
// note the original start time
726+
extraInstanceStartTime := runner.instances[extraRegistrationId].stats.StartTime
727+
728+
// rerun with identical registrationIDs in swapped order and verify that the instances are not restarted
729+
runner.UpdateRegistrationIDs([]string{extraRegistrationId, types.DefaultRegistrationID})
730+
waitHealthy(t, runner, logBytes)
731+
732+
require.Equal(t, 2, len(runner.instances))
733+
require.Equal(t, extraInstanceStartTime, runner.instances[extraRegistrationId].stats.StartTime)
734+
require.Equal(t, defaultInstanceStartTime, runner.instances[types.DefaultRegistrationID].stats.StartTime)
735+
736+
waitShutdown(t, runner, logBytes)
737+
738+
// Confirm both instances exited
739+
require.Contains(t, runner.instances, types.DefaultRegistrationID)
740+
require.NotNil(t, runner.instances[types.DefaultRegistrationID].stats)
741+
require.NotEmpty(t, runner.instances[types.DefaultRegistrationID].stats.ExitTime, "exit time should be added to default instance stats on shutdown")
742+
require.Contains(t, runner.instances, extraRegistrationId)
743+
require.NotNil(t, runner.instances[extraRegistrationId].stats)
744+
require.NotEmpty(t, runner.instances[extraRegistrationId].stats.ExitTime, "exit time should be added to secondary instance stats on shutdown")
745+
}
746+
596747
// sets up an osquery instance with a running extension to be used in tests.
597748
func setupOsqueryInstanceForTests(t *testing.T) (runner *Runner, logBytes *threadsafebuffer.ThreadSafeBuffer, teardown func()) {
598749
rootDirectory := testRootDirectory(t)

0 commit comments

Comments
 (0)