Skip to content

Commit 736ba40

Browse files
authored
Merge pull request #288 from castai/fix-concurrent-cluster-registration
fix: concurrent cluster registration
2 parents 2798593 + 0a6f97c commit 736ba40

File tree

4 files changed

+282
-45
lines changed

4 files changed

+282
-45
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@
66
bin
77
.env
88
/.vscode
9+
.claude/settings.local.json

cmd/agent/run.go

Lines changed: 34 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ import (
3535
"castai-agent/pkg/castai"
3636
castailog "castai-agent/pkg/log"
3737
"castai-agent/pkg/services/providers"
38+
"castai-agent/pkg/services/providers/types"
3839
)
3940

4041
const (
@@ -224,7 +225,15 @@ func runAgentMode(parentCtx context.Context, castaiclient castai.Client, log *lo
224225
}
225226

226227
if clusterID == "" {
227-
reg, err := provider.RegisterCluster(ctx, castaiclient)
228+
registerFn := func(ctx context.Context) (*types.ClusterRegistration, error) {
229+
return provider.RegisterCluster(ctx, castaiclient)
230+
}
231+
var reg *types.ClusterRegistration
232+
if cfg.LeaderElection.Enabled {
233+
reg, err = replicas.RegisterClusterWithLease(ctx, log, cfg.LeaderElection, clientset, registerFn)
234+
} else {
235+
reg, err = registerFn(ctx)
236+
}
228237
if err != nil {
229238
return fmt.Errorf("registering cluster: %w", err)
230239
}
@@ -243,59 +252,39 @@ func runAgentMode(parentCtx context.Context, castaiclient castai.Client, log *lo
243252
return err
244253
}
245254

255+
var leaderStatusCh chan bool
246256
if cfg.LeaderElection.Enabled {
247257
// Buffered channel to avoid blocking leader election on slow consumers
248-
leaderStatusCh := make(chan bool, 10)
249-
250-
params := &controller.Params{
251-
Log: log,
252-
Clientset: clientset,
253-
MetricsClient: metricsClient,
254-
DynamicClient: dynamicClient,
255-
CastaiClient: castaiclient,
256-
Provider: provider,
257-
ClusterID: clusterID,
258-
Config: cfg,
259-
AgentVersion: agentVersion,
260-
HealthzProvider: ctrlHealthz,
261-
LeaderStatusCh: leaderStatusCh,
262-
}
258+
leaderStatusCh = make(chan bool, 10)
259+
}
263260

264-
go shutdown.RunThenTrigger(shutdownController.For("controller"), false, func() error {
265-
return controller.RunControllerWithRestart(ctx, params)
266-
})
261+
params := &controller.Params{
262+
Log: log,
263+
Clientset: clientset,
264+
MetricsClient: metricsClient,
265+
DynamicClient: dynamicClient,
266+
CastaiClient: castaiclient,
267+
Provider: provider,
268+
ClusterID: clusterID,
269+
Config: cfg,
270+
AgentVersion: agentVersion,
271+
HealthzProvider: ctrlHealthz,
272+
LeaderStatusCh: leaderStatusCh,
273+
}
274+
275+
go shutdown.RunThenTrigger(shutdownController.For("controller"), false, func() error {
276+
return controller.RunControllerWithRestart(ctx, params)
277+
})
267278

268-
// Run leader election with main context
269-
// If leader election returns an error, we shut down the entire app
279+
if cfg.LeaderElection.Enabled {
270280
go shutdown.RunThenTrigger(shutdownController.For("leader election"), false, func() error {
271281
return replicas.RunLeaderElection(ctx, log, cfg.LeaderElection, clientset, leaderWatchDog, leaderStatusCh)
272282
})
273-
274-
<-ctx.Done()
275-
log.Info("shutdown signal received, initiating graceful shutdown")
276-
} else {
277-
params := &controller.Params{
278-
Log: log,
279-
Clientset: clientset,
280-
MetricsClient: metricsClient,
281-
DynamicClient: dynamicClient,
282-
CastaiClient: castaiclient,
283-
Provider: provider,
284-
ClusterID: clusterID,
285-
Config: cfg,
286-
AgentVersion: agentVersion,
287-
HealthzProvider: ctrlHealthz,
288-
LeaderStatusCh: nil,
289-
}
290-
291-
go shutdown.RunThenTrigger(shutdownController.For("controller"), false, func() error {
292-
return controller.RunControllerWithRestart(ctx, params)
293-
})
294-
295-
<-ctx.Done()
296-
log.Info("shutdown signal received, initiating graceful shutdown")
297283
}
298284

285+
<-ctx.Done()
286+
log.Info("shutdown signal received, initiating graceful shutdown")
287+
299288
return nil
300289
}
301290

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
package replicas
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"time"
7+
8+
"github.com/google/uuid"
9+
"github.com/sirupsen/logrus"
10+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
11+
"k8s.io/client-go/kubernetes"
12+
"k8s.io/client-go/tools/leaderelection"
13+
"k8s.io/client-go/tools/leaderelection/resourcelock"
14+
15+
"castai-agent/internal/config"
16+
"castai-agent/pkg/services/providers/types"
17+
)
18+
19+
const (
20+
registrationLeaseDuration = 10 * time.Second
21+
registrationRenewDeadline = 6 * time.Second
22+
registrationRetryPeriod = 2 * time.Second
23+
)
24+
25+
// RegisterClusterWithLease serializes cluster registration across replicas using a short-lived
26+
// Kubernetes lease. Each pod acquires the lease, calls registerFn, then releases the lease so
27+
// the next pod can proceed.
28+
func RegisterClusterWithLease(
29+
ctx context.Context,
30+
log logrus.FieldLogger,
31+
cfg config.LeaderElectionConfig,
32+
client kubernetes.Interface,
33+
registerFn func(ctx context.Context) (*types.ClusterRegistration, error),
34+
) (*types.ClusterRegistration, error) {
35+
leaseName := cfg.LockName + "-registration"
36+
identity := uuid.New().String()
37+
38+
log = log.WithFields(logrus.Fields{
39+
"registration_identity": identity,
40+
"registration_lease": leaseName,
41+
})
42+
log.Info("acquiring registration lease")
43+
44+
var result *types.ClusterRegistration
45+
var registerErr error
46+
47+
// leaderCtx is cancelled by leaderCancel once registration completes,
48+
// causing the leader elector to release the lease and Run() to return.
49+
leaderCtx, leaderCancel := context.WithCancel(ctx)
50+
defer leaderCancel()
51+
52+
lock := &resourcelock.LeaseLock{
53+
LeaseMeta: metav1.ObjectMeta{
54+
Name: leaseName,
55+
Namespace: cfg.Namespace,
56+
},
57+
Client: client.CoordinationV1(),
58+
LockConfig: resourcelock.ResourceLockConfig{
59+
Identity: identity,
60+
},
61+
}
62+
63+
le, err := leaderelection.NewLeaderElector(leaderelection.LeaderElectionConfig{
64+
Lock: lock,
65+
ReleaseOnCancel: true,
66+
LeaseDuration: registrationLeaseDuration,
67+
RenewDeadline: registrationRenewDeadline,
68+
RetryPeriod: registrationRetryPeriod,
69+
Callbacks: leaderelection.LeaderCallbacks{
70+
OnStartedLeading: func(innerCtx context.Context) {
71+
log.Info("registration lease acquired, registering cluster")
72+
result, registerErr = registerFn(innerCtx)
73+
leaderCancel()
74+
},
75+
OnStoppedLeading: func() {
76+
log.Info("registration lease released")
77+
},
78+
OnNewLeader: func(identity string) {},
79+
},
80+
})
81+
if err != nil {
82+
return nil, fmt.Errorf("creating registration leader elector: %w", err)
83+
}
84+
85+
le.Run(leaderCtx)
86+
87+
if registerErr != nil {
88+
return nil, fmt.Errorf("registration failed: %w", registerErr)
89+
}
90+
if result == nil {
91+
// Run returned without calling OnStartedLeading — context was canceled before acquisition
92+
return nil, fmt.Errorf("context canceled before registration was initiated: %w", ctx.Err())
93+
}
94+
95+
log.Info("registration lease complete")
96+
return result, nil
97+
}
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
package replicas
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"sync/atomic"
7+
"testing"
8+
"time"
9+
10+
"github.com/sirupsen/logrus"
11+
"github.com/stretchr/testify/require"
12+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
13+
"k8s.io/client-go/kubernetes/fake"
14+
15+
"castai-agent/internal/config"
16+
"castai-agent/pkg/services/providers/types"
17+
)
18+
19+
func TestRegisterClusterWithLease(t *testing.T) {
20+
tests := map[string]struct {
21+
registerFn func(ctx context.Context) (*types.ClusterRegistration, error)
22+
setupCtx func() (context.Context, context.CancelFunc)
23+
verify func(t *testing.T, reg *types.ClusterRegistration, err error, clientset *fake.Clientset, cfg config.LeaderElectionConfig)
24+
}{
25+
"registers cluster successfully": {
26+
registerFn: func(ctx context.Context) (*types.ClusterRegistration, error) {
27+
return &types.ClusterRegistration{
28+
ClusterID: "cluster-123",
29+
OrganizationID: "org-456",
30+
}, nil
31+
},
32+
setupCtx: func() (context.Context, context.CancelFunc) {
33+
return context.WithTimeout(context.Background(), 10*time.Second)
34+
},
35+
verify: func(t *testing.T, reg *types.ClusterRegistration, err error, clientset *fake.Clientset, cfg config.LeaderElectionConfig) {
36+
require.NoError(t, err)
37+
require.NotNil(t, reg)
38+
require.Equal(t, "cluster-123", reg.ClusterID)
39+
require.Equal(t, "org-456", reg.OrganizationID)
40+
41+
// Verify lease was created
42+
leaseName := cfg.LockName + "-registration"
43+
lease, err := clientset.CoordinationV1().Leases(cfg.Namespace).Get(context.Background(), leaseName, metav1.GetOptions{})
44+
require.NoError(t, err)
45+
require.NotNil(t, lease)
46+
},
47+
},
48+
"propagates registration error": {
49+
registerFn: func(ctx context.Context) (*types.ClusterRegistration, error) {
50+
return nil, fmt.Errorf("api unavailable")
51+
},
52+
setupCtx: func() (context.Context, context.CancelFunc) {
53+
return context.WithTimeout(context.Background(), 10*time.Second)
54+
},
55+
verify: func(t *testing.T, reg *types.ClusterRegistration, err error, clientset *fake.Clientset, cfg config.LeaderElectionConfig) {
56+
require.Error(t, err)
57+
require.ErrorContains(t, err, "api unavailable")
58+
require.Nil(t, reg)
59+
},
60+
},
61+
"returns error on context cancellation": {
62+
registerFn: func(ctx context.Context) (*types.ClusterRegistration, error) {
63+
return &types.ClusterRegistration{ClusterID: "should-not-reach"}, nil
64+
},
65+
setupCtx: func() (context.Context, context.CancelFunc) {
66+
ctx, cancel := context.WithCancel(context.Background())
67+
cancel() // cancel immediately
68+
return ctx, cancel
69+
},
70+
verify: func(t *testing.T, reg *types.ClusterRegistration, err error, clientset *fake.Clientset, cfg config.LeaderElectionConfig) {
71+
require.Error(t, err)
72+
require.ErrorIs(t, err, context.Canceled)
73+
require.Nil(t, reg)
74+
},
75+
},
76+
}
77+
78+
for name, tt := range tests {
79+
t.Run(name, func(t *testing.T) {
80+
log := logrus.New()
81+
log.SetLevel(logrus.DebugLevel)
82+
83+
cfg := config.LeaderElectionConfig{
84+
LockName: "test-lock",
85+
Namespace: "default",
86+
}
87+
88+
clientset := fake.NewClientset()
89+
90+
ctx, cancel := tt.setupCtx()
91+
defer cancel()
92+
93+
reg, err := RegisterClusterWithLease(ctx, log, cfg, clientset, tt.registerFn)
94+
tt.verify(t, reg, err, clientset, cfg)
95+
})
96+
}
97+
}
98+
99+
func TestRegisterClusterWithLease_SerializesConcurrentRegistrations(t *testing.T) {
100+
log := logrus.New()
101+
log.SetLevel(logrus.DebugLevel)
102+
103+
cfg := config.LeaderElectionConfig{
104+
LockName: "test-lock",
105+
Namespace: "default",
106+
}
107+
108+
clientset := fake.NewClientset()
109+
110+
var maxConcurrent atomic.Int32
111+
var currentConcurrent atomic.Int32
112+
113+
registerFn := func(ctx context.Context) (*types.ClusterRegistration, error) {
114+
cur := currentConcurrent.Add(1)
115+
for {
116+
old := maxConcurrent.Load()
117+
if cur <= old || maxConcurrent.CompareAndSwap(old, cur) {
118+
break
119+
}
120+
}
121+
time.Sleep(50 * time.Millisecond) // simulate work
122+
currentConcurrent.Add(-1)
123+
return &types.ClusterRegistration{ClusterID: "cluster-123"}, nil
124+
}
125+
126+
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
127+
defer cancel()
128+
129+
const numGoroutines = 3
130+
errs := make(chan error, numGoroutines)
131+
results := make(chan *types.ClusterRegistration, numGoroutines)
132+
133+
for i := 0; i < numGoroutines; i++ {
134+
go func() {
135+
reg, err := RegisterClusterWithLease(ctx, log, cfg, clientset, registerFn)
136+
errs <- err
137+
results <- reg
138+
}()
139+
}
140+
141+
for i := 0; i < numGoroutines; i++ {
142+
err := <-errs
143+
require.NoError(t, err)
144+
reg := <-results
145+
require.NotNil(t, reg)
146+
require.Equal(t, "cluster-123", reg.ClusterID)
147+
}
148+
149+
require.Equal(t, int32(1), maxConcurrent.Load(), "registerFn should never run concurrently")
150+
}

0 commit comments

Comments
 (0)