diff --git a/internal/controller/clientpool/clientpool.go b/internal/controller/clientpool/clientpool.go index b5b9976b..7f290cf9 100644 --- a/internal/controller/clientpool/clientpool.go +++ b/internal/controller/clientpool/clientpool.go @@ -79,6 +79,17 @@ func New(l log.Logger, c runtimeclient.Client) *ClientPool { } } +// EvictClient removes the client for the given key from the pool and closes it. +// Safe to call when the key is not present. +func (cp *ClientPool) EvictClient(key ClientPoolKey) { + cp.mux.Lock() + defer cp.mux.Unlock() + if info, ok := cp.clients[key]; ok { + info.client.Close() + delete(cp.clients, key) + } +} + func (cp *ClientPool) GetSDKClient(key ClientPoolKey) (sdkclient.Client, bool) { cp.mux.RLock() defer cp.mux.RUnlock() diff --git a/internal/controller/clientpool/clientpool_test.go b/internal/controller/clientpool/clientpool_test.go index b00cf014..53809760 100644 --- a/internal/controller/clientpool/clientpool_test.go +++ b/internal/controller/clientpool/clientpool_test.go @@ -126,6 +126,7 @@ type mockSDKClient struct { // Used to verify that CheckHealth is not called when API Key auth is used. checkHealthCalled bool + closed bool } func (m *mockSDKClient) CheckHealth(_ context.Context, _ *sdkclient.CheckHealthRequest) (*sdkclient.CheckHealthResponse, error) { @@ -133,7 +134,7 @@ func (m *mockSDKClient) CheckHealth(_ context.Context, _ *sdkclient.CheckHealthR return &sdkclient.CheckHealthResponse{}, nil } -func (m *mockSDKClient) Close() {} +func (m *mockSDKClient) Close() { m.closed = true } // ─── Tests: fetchClientUsingMTLSSecret ──────────────────────────────────────── @@ -355,6 +356,36 @@ func TestDialAndUpsert_NoCredsCallsCheckHealth(t *testing.T) { assert.True(t, mock.checkHealthCalled, "CheckHealth must be called for no-credentials auth") } +// ─── Tests: EvictClient ─────────────────────────────────────────────────────── + +func TestEvictClient_RemovesAndClosesClient(t *testing.T) { + cp := newTestPool() + key := ClientPoolKey{ + HostPort: "localhost:7233", + Namespace: "default", + SecretName: "my-secret", + AuthMode: AuthModeAPIKey, + } + mock := &mockSDKClient{} + cp.SetClientForTesting(key, mock) + + _, ok := cp.GetSDKClient(key) + require.True(t, ok, "client should be present before eviction") + + cp.EvictClient(key) + + assert.True(t, mock.closed, "Close should be called on eviction") + _, ok = cp.GetSDKClient(key) + assert.False(t, ok, "client should be absent after eviction") +} + +func TestEvictClient_NoopWhenKeyAbsent(t *testing.T) { + cp := newTestPool() + key := ClientPoolKey{HostPort: "localhost:7233", Namespace: "default", AuthMode: AuthModeNoCredentials} + // Should not panic when key is not in the pool + cp.EvictClient(key) +} + // ─── Helpers ────────────────────────────────────────────────────────────────── func decodePEMCert(certPEM []byte) (*x509.Certificate, error) { diff --git a/internal/controller/worker_controller.go b/internal/controller/worker_controller.go index a544d602..70fcb5ba 100644 --- a/internal/controller/worker_controller.go +++ b/internal/controller/worker_controller.go @@ -15,6 +15,8 @@ import ( "github.com/temporalio/temporal-worker-controller/internal/k8s" "github.com/temporalio/temporal-worker-controller/internal/temporal" "go.temporal.io/api/serviceerror" + "google.golang.org/grpc/codes" + grpcstatus "google.golang.org/grpc/status" appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" apierrors "k8s.io/apimachinery/pkg/api/errors" @@ -182,12 +184,13 @@ func (r *TemporalWorkerDeploymentReconciler) Reconcile(ctx context.Context, req } // Get or update temporal client for connection - temporalClient, ok := r.TemporalClientPool.GetSDKClient(clientpool.ClientPoolKey{ + clientPoolKey := clientpool.ClientPoolKey{ HostPort: temporalConnection.Spec.HostPort, Namespace: workerDeploy.Spec.WorkerOptions.TemporalNamespace, SecretName: secretName, AuthMode: authMode, - }) + } + temporalClient, ok := r.TemporalClientPool.GetSDKClient(clientPoolKey) if !ok { clientOpts, key, clientAuth, err := r.TemporalClientPool.ParseClientSecret(ctx, secretName, authMode, clientpool.NewClientOptions{ K8sNamespace: workerDeploy.Namespace, @@ -243,6 +246,9 @@ func (r *TemporalWorkerDeploymentReconciler) Reconcile(ctx context.Context, req getControllerIdentity(), ) if err != nil { + if isAccessDeniedErr(err) { + r.TemporalClientPool.EvictClient(clientPoolKey) + } var rateLimitErr *serviceerror.ResourceExhausted if errors.As(err, &rateLimitErr) { r.recordWarningAndSetBlocked(ctx, &workerDeploy, @@ -286,6 +292,9 @@ func (r *TemporalWorkerDeploymentReconciler) Reconcile(ctx context.Context, req // Execute the plan, handling any errors if err := r.executePlan(ctx, l, &workerDeploy, temporalClient, plan); err != nil { + if isAccessDeniedErr(err) { + r.TemporalClientPool.EvictClient(clientPoolKey) + } r.recordWarningAndSetBlocked(ctx, &workerDeploy, ReasonPlanExecutionFailed, fmt.Sprintf("Unable to execute reconciliation plan: %v", err), @@ -525,3 +534,11 @@ func (r *TemporalWorkerDeploymentReconciler) findTWDsUsingConnection(ctx context return requests } + +func isAccessDeniedErr(err error) bool { + var permDenied *serviceerror.PermissionDenied + if errors.As(err, &permDenied) { + return true + } + return grpcstatus.Code(err) == codes.Unauthenticated +}