Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions internal/controller/clientpool/clientpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,17 @@ func New(l log.Logger, c runtimeclient.Client) *ClientPool {
}
}

// EvictClient removes the client for the given key from the pool and closes it.
// Safe to call when the key is not present.
func (cp *ClientPool) EvictClient(key ClientPoolKey) {
cp.mux.Lock()
defer cp.mux.Unlock()
if info, ok := cp.clients[key]; ok {
info.client.Close()
delete(cp.clients, key)
}
}

func (cp *ClientPool) GetSDKClient(key ClientPoolKey) (sdkclient.Client, bool) {
cp.mux.RLock()
defer cp.mux.RUnlock()
Expand Down
33 changes: 32 additions & 1 deletion internal/controller/clientpool/clientpool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,15 @@ type mockSDKClient struct {

// Used to verify that CheckHealth is not called when API Key auth is used.
checkHealthCalled bool
closed bool
}

func (m *mockSDKClient) CheckHealth(_ context.Context, _ *sdkclient.CheckHealthRequest) (*sdkclient.CheckHealthResponse, error) {
m.checkHealthCalled = true
return &sdkclient.CheckHealthResponse{}, nil
}

func (m *mockSDKClient) Close() {}
func (m *mockSDKClient) Close() { m.closed = true }

// ─── Tests: fetchClientUsingMTLSSecret ────────────────────────────────────────

Expand Down Expand Up @@ -355,6 +356,36 @@ func TestDialAndUpsert_NoCredsCallsCheckHealth(t *testing.T) {
assert.True(t, mock.checkHealthCalled, "CheckHealth must be called for no-credentials auth")
}

// ─── Tests: EvictClient ───────────────────────────────────────────────────────

func TestEvictClient_RemovesAndClosesClient(t *testing.T) {
cp := newTestPool()
key := ClientPoolKey{
HostPort: "localhost:7233",
Namespace: "default",
SecretName: "my-secret",
AuthMode: AuthModeAPIKey,
}
mock := &mockSDKClient{}
cp.SetClientForTesting(key, mock)

_, ok := cp.GetSDKClient(key)
require.True(t, ok, "client should be present before eviction")

cp.EvictClient(key)

assert.True(t, mock.closed, "Close should be called on eviction")
_, ok = cp.GetSDKClient(key)
assert.False(t, ok, "client should be absent after eviction")
}

func TestEvictClient_NoopWhenKeyAbsent(t *testing.T) {
cp := newTestPool()
key := ClientPoolKey{HostPort: "localhost:7233", Namespace: "default", AuthMode: AuthModeNoCredentials}
// Should not panic when key is not in the pool
cp.EvictClient(key)
}

// ─── Helpers ──────────────────────────────────────────────────────────────────

func decodePEMCert(certPEM []byte) (*x509.Certificate, error) {
Expand Down
21 changes: 19 additions & 2 deletions internal/controller/worker_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import (
"github.com/temporalio/temporal-worker-controller/internal/k8s"
"github.com/temporalio/temporal-worker-controller/internal/temporal"
"go.temporal.io/api/serviceerror"
"google.golang.org/grpc/codes"
grpcstatus "google.golang.org/grpc/status"
appsv1 "k8s.io/api/apps/v1"
corev1 "k8s.io/api/core/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors"
Expand Down Expand Up @@ -182,12 +184,13 @@ func (r *TemporalWorkerDeploymentReconciler) Reconcile(ctx context.Context, req
}

// Get or update temporal client for connection
temporalClient, ok := r.TemporalClientPool.GetSDKClient(clientpool.ClientPoolKey{
clientPoolKey := clientpool.ClientPoolKey{
HostPort: temporalConnection.Spec.HostPort,
Namespace: workerDeploy.Spec.WorkerOptions.TemporalNamespace,
SecretName: secretName,
AuthMode: authMode,
})
}
temporalClient, ok := r.TemporalClientPool.GetSDKClient(clientPoolKey)
if !ok {
clientOpts, key, clientAuth, err := r.TemporalClientPool.ParseClientSecret(ctx, secretName, authMode, clientpool.NewClientOptions{
K8sNamespace: workerDeploy.Namespace,
Expand Down Expand Up @@ -243,6 +246,9 @@ func (r *TemporalWorkerDeploymentReconciler) Reconcile(ctx context.Context, req
getControllerIdentity(),
)
if err != nil {
if isAccessDeniedErr(err) {
r.TemporalClientPool.EvictClient(clientPoolKey)
}
var rateLimitErr *serviceerror.ResourceExhausted
if errors.As(err, &rateLimitErr) {
r.recordWarningAndSetBlocked(ctx, &workerDeploy,
Expand Down Expand Up @@ -286,6 +292,9 @@ func (r *TemporalWorkerDeploymentReconciler) Reconcile(ctx context.Context, req

// Execute the plan, handling any errors
if err := r.executePlan(ctx, l, &workerDeploy, temporalClient, plan); err != nil {
if isAccessDeniedErr(err) {
r.TemporalClientPool.EvictClient(clientPoolKey)
}
r.recordWarningAndSetBlocked(ctx, &workerDeploy,
ReasonPlanExecutionFailed,
fmt.Sprintf("Unable to execute reconciliation plan: %v", err),
Expand Down Expand Up @@ -525,3 +534,11 @@ func (r *TemporalWorkerDeploymentReconciler) findTWDsUsingConnection(ctx context

return requests
}

func isAccessDeniedErr(err error) bool {
var permDenied *serviceerror.PermissionDenied
if errors.As(err, &permDenied) {
return true
}
return grpcstatus.Code(err) == codes.Unauthenticated
}
Loading