Skip to content

Commit 8df13d5

Browse files
committed
Add a hook point to the Matching service once a task has been matched (or not)
This adds a hook point to the matching service once it has matched a task or decided to spool it. This can be leveraged to react to task sync or no-sync match events, while keeping matching service decoupled from the event handling logic.
1 parent 595713f commit 8df13d5

File tree

7 files changed

+232
-8
lines changed

7 files changed

+232
-8
lines changed

service/matching/handler.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"go.temporal.io/server/common/searchattribute"
2424
"go.temporal.io/server/common/testing/testhooks"
2525
"go.temporal.io/server/common/tqid"
26+
"go.temporal.io/server/service/matching/hooks"
2627
"go.temporal.io/server/service/matching/workers"
2728
"go.temporal.io/server/service/worker/workerdeployment"
2829
"go.uber.org/fx"
@@ -70,6 +71,7 @@ type (
7071
RateLimiter TaskDispatchRateLimiter `optional:"true"`
7172
WorkersRegistry workers.Registry
7273
Serializer serialization.Serializer
74+
TaskMatchHooks []hooks.TaskMatchHook `group:"TaskMatchHooks"`
7375
}
7476
)
7577

@@ -112,6 +114,7 @@ func NewHandler(
112114
params.SearchAttributeMapperProvider,
113115
params.RateLimiter,
114116
params.Serializer,
117+
params.TaskMatchHooks,
115118
),
116119
namespaceRegistry: params.NamespaceRegistry,
117120
workersRegistry: params.WorkersRegistry,
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
package hooks
2+
3+
import (
4+
"context"
5+
6+
deploymentpb "go.temporal.io/api/deployment/v1"
7+
enumspb "go.temporal.io/api/enums/v1"
8+
taskqueuepb "go.temporal.io/api/taskqueue/v1"
9+
"go.temporal.io/server/common/namespace"
10+
)
11+
12+
type (
13+
TaskMatchHookDetails struct {
14+
Namespace *namespace.Namespace
15+
TaskQueueName string
16+
TaskQueueType enumspb.TaskQueueType
17+
DeploymentVersion *deploymentpb.WorkerDeploymentVersion
18+
IsSyncMatch bool
19+
TaskQueueStatsRetriever func() map[int32]*taskqueuepb.TaskQueueStats
20+
}
21+
TaskMatchHook interface {
22+
ProcessTaskMatch(ctx context.Context, event *TaskMatchHookDetails)
23+
}
24+
)

service/matching/matching_engine.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ import (
6060
"go.temporal.io/server/common/util"
6161
"go.temporal.io/server/common/worker_versioning"
6262
"go.temporal.io/server/service/history/api"
63+
"go.temporal.io/server/service/matching/hooks"
6364
"go.temporal.io/server/service/worker/workerdeployment"
6465
"google.golang.org/protobuf/types/known/timestamppb"
6566
)
@@ -175,6 +176,8 @@ type (
175176
reachabilityCache reachabilityCache
176177
// Rate limiter to limit the task dispatch
177178
rateLimiter TaskDispatchRateLimiter
179+
180+
taskMatchHooks []hooks.TaskMatchHook
178181
}
179182
)
180183

@@ -254,6 +257,7 @@ func NewEngine(
254257
saMapperProvider searchattribute.MapperProvider,
255258
rateLimiter TaskDispatchRateLimiter,
256259
historySerializer serialization.Serializer,
260+
taskMatchHooks []hooks.TaskMatchHook,
257261
) Engine {
258262
scopedMetricsHandler := metricsHandler.WithTags(metrics.OperationTag(metrics.MatchingEngineScope))
259263
e := &matchingEngineImpl{
@@ -296,6 +300,7 @@ func NewEngine(
296300
namespaceReplicationQueue: namespaceReplicationQueue,
297301
userDataUpdateBatchers: collection.NewSyncMap[namespace.ID, *stream_batcher.Batcher[*userDataUpdate, error]](),
298302
rateLimiter: rateLimiter,
303+
taskMatchHooks: taskMatchHooks,
299304
}
300305
e.reachabilityCache = newReachabilityCache(
301306
metrics.NoopMetricsHandler,
@@ -497,6 +502,7 @@ func (e *matchingEngineImpl) getTaskQueuePartitionManager(
497502
throttledLogger,
498503
metricsHandler,
499504
userDataManager,
505+
e.taskMatchHooks,
500506
)
501507
if err != nil {
502508
return nil, false, err

service/matching/matching_engine_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ func newMatchingEngine(
272272
func (s *matchingEngineSuite) newPartitionManager(prtn tqid.Partition, config *Config) taskQueuePartitionManager {
273273
tqConfig := newTaskQueueConfig(prtn.TaskQueue(), config, matchingTestNamespace)
274274
logger, _, metricsHandler := s.matchingEngine.loggerAndMetricsForPartition(s.ns, prtn, tqConfig)
275-
pm, err := newTaskQueuePartitionManager(s.matchingEngine, s.ns, prtn, tqConfig, logger, logger, metricsHandler, &mockUserDataManager{})
275+
pm, err := newTaskQueuePartitionManager(s.matchingEngine, s.ns, prtn, tqConfig, logger, logger, metricsHandler, &mockUserDataManager{}, nil)
276276
s.Require().NoError(err)
277277
return pm
278278
}

service/matching/physical_task_queue_manager_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ func (s *PhysicalTaskQueueManagerTestSuite) SetupTest() {
8282
onFatalErr := func(unloadCause) { s.T().Fatal("user data manager called onFatalErr") }
8383
udMgr := newUserDataManager(engine.taskManager, engine.matchingRawClient, onFatalErr, nil, nil, prtn, tqConfig, engine.logger, engine.namespaceRegistry)
8484

85-
prtnMgr, err := newTaskQueuePartitionManager(engine, ns, prtn, tqConfig, engine.logger, nil, metrics.NoopMetricsHandler, udMgr)
85+
prtnMgr, err := newTaskQueuePartitionManager(engine, ns, prtn, tqConfig, engine.logger, nil, metrics.NoopMetricsHandler, udMgr, nil)
8686
s.NoError(err)
8787
engine.partitions[prtn.Key()] = prtnMgr
8888

@@ -533,7 +533,7 @@ func TestDrainCompletionNoReloadDraining(t *testing.T) {
533533
onFatalErr := func(unloadCause) { t.Fatal("user data manager called onFatalErr") }
534534
udMgr := newUserDataManager(engine.taskManager, engine.matchingRawClient, onFatalErr, nil, nil, prtn, tqConfig, engine.logger, engine.namespaceRegistry)
535535

536-
prtnMgr, err := newTaskQueuePartitionManager(engine, ns, prtn, tqConfig, engine.logger, nil, metrics.NoopMetricsHandler, udMgr)
536+
prtnMgr, err := newTaskQueuePartitionManager(engine, ns, prtn, tqConfig, engine.logger, nil, metrics.NoopMetricsHandler, udMgr, nil)
537537
require.NoError(t, err)
538538
engine.partitions[prtn.Key()] = prtnMgr
539539

@@ -578,7 +578,7 @@ func TestDrainCompletionNoReloadDraining(t *testing.T) {
578578
prevPriStats, prevFairStats := priQueueData.persistenceStats(), fairQueueData.persistenceStats()
579579

580580
// create a new manager (reload)
581-
prtnMgr2, err := newTaskQueuePartitionManager(engine, ns, prtn, tqConfig, engine.logger, nil, metrics.NoopMetricsHandler, udMgr)
581+
prtnMgr2, err := newTaskQueuePartitionManager(engine, ns, prtn, tqConfig, engine.logger, nil, metrics.NoopMetricsHandler, udMgr, nil)
582582
require.NoError(t, err)
583583
engine.partitions[prtn.Key()] = prtnMgr2
584584

service/matching/task_queue_partition_manager.go

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ import (
3636
"go.temporal.io/server/common/tqid"
3737
"go.temporal.io/server/common/util"
3838
"go.temporal.io/server/common/worker_versioning"
39+
"go.temporal.io/server/service/matching/hooks"
3940
"google.golang.org/protobuf/types/known/durationpb"
4041
"google.golang.org/protobuf/types/known/timestamppb"
4142
)
@@ -74,6 +75,8 @@ type (
7475
// TODO(stephanos): move cache out of partition manager
7576
cache cache.Cache // non-nil for root-partition
7677

78+
taskMatchHooks []hooks.TaskMatchHook
79+
7780
goroGroup goro.Group
7881

7982
autoEnableRateLimiter quotas.RateLimiter
@@ -113,6 +116,7 @@ func newTaskQueuePartitionManager(
113116
throttledLogger log.Logger,
114117
metricsHandler metrics.Handler,
115118
userDataManager userDataManager,
119+
taskMatchHooks []hooks.TaskMatchHook,
116120
) (*taskQueuePartitionManagerImpl, error) {
117121
rateLimitManager := newRateLimitManager(
118122
userDataManager,
@@ -132,6 +136,7 @@ func newTaskQueuePartitionManager(
132136
rateLimitManager: rateLimitManager,
133137
defaultQueueFuture: future.NewFuture[physicalTaskQueueManager](),
134138
autoEnableRateLimiter: quotas.NewRateLimiter(1.0/60, 1),
139+
taskMatchHooks: taskMatchHooks,
135140
}
136141
pm.initCtx, pm.initCancel = context.WithCancel(context.Background())
137142

@@ -357,6 +362,20 @@ reredirectTask:
357362
if isActive {
358363
syncMatched, err = syncMatchQueue.TrySyncMatch(ctx, syncMatchTask)
359364
if syncMatched && !pm.shouldBacklogSyncMatchTaskOnError(err) {
365+
for _, l := range pm.taskMatchHooks {
366+
var deploymentVersion *deploymentpb.WorkerDeploymentVersion
367+
if targetVersion != nil {
368+
deploymentVersion = &deploymentpb.WorkerDeploymentVersion{DeploymentName: targetVersion.DeploymentName, BuildId: targetVersion.BuildId}
369+
}
370+
l.ProcessTaskMatch(ctx, &hooks.TaskMatchHookDetails{
371+
Namespace: pm.ns,
372+
TaskQueueName: pm.partition.TaskQueue().Name(),
373+
TaskQueueType: pm.partition.TaskType(),
374+
DeploymentVersion: deploymentVersion,
375+
IsSyncMatch: syncMatched,
376+
TaskQueueStatsRetriever: func() map[int32]*taskqueuepb.TaskQueueStats { return dbq.GetStatsByPriority(true) },
377+
})
378+
}
360379

361380
// Build ID is not returned for sync match. The returned build ID is used by History to update
362381
// mutable state (and visibility) when the first workflow task is spooled.
@@ -382,6 +401,21 @@ reredirectTask:
382401
assignedBuildId = spoolQueue.QueueKey().Version().BuildId()
383402
}
384403

404+
for _, l := range pm.taskMatchHooks {
405+
var deploymentVersion *deploymentpb.WorkerDeploymentVersion
406+
if targetVersion != nil {
407+
deploymentVersion = &deploymentpb.WorkerDeploymentVersion{DeploymentName: targetVersion.DeploymentName, BuildId: targetVersion.BuildId}
408+
}
409+
l.ProcessTaskMatch(ctx, &hooks.TaskMatchHookDetails{
410+
Namespace: pm.ns,
411+
TaskQueueName: pm.partition.TaskQueue().Name(),
412+
TaskQueueType: pm.partition.TaskType(),
413+
DeploymentVersion: deploymentVersion,
414+
IsSyncMatch: false,
415+
TaskQueueStatsRetriever: func() map[int32]*taskqueuepb.TaskQueueStats { return dbq.GetStatsByPriority(true) },
416+
})
417+
}
418+
385419
return assignedBuildId, false, spoolQueue.SpoolTask(params.taskInfo)
386420
}
387421

@@ -960,7 +994,7 @@ func (pm *taskQueuePartitionManagerImpl) Describe(
960994
unversionedCurrentShareByPriority, unversionedRampingShareByPriority =
961995
splitStatsByPriorityByRampPercentage(unversionedStatsByPriority, rampPercentage)
962996
} else if currentExists {
963-
// If there exist no ramping version, weattribute the entire unversioned backlog to the current version.
997+
// If there exist no ramping version, we attribute the entire unversioned backlog to the current version.
964998
unversionedCurrentShareByPriority = cloneStatsByPriority(unversionedStatsByPriority)
965999
}
9661000
}

0 commit comments

Comments
 (0)