diff --git a/client/src/lib.rs b/client/src/lib.rs index 41cd2b79f..4d6088f87 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -32,7 +32,9 @@ pub use temporal_sdk_core_protos::temporal::api::{ }, }; pub use tonic; -pub use worker_registry::{Slot, SlotManager, SlotProvider, WorkerKey}; +pub use worker_registry::{ + ClientWorker, ClientWorkerSet, HeartbeatCallback, SharedNamespaceWorkerTrait, Slot, +}; pub use workflow_handle::{ GetWorkflowResultOpts, WorkflowExecutionInfo, WorkflowExecutionResult, WorkflowHandle, }; @@ -388,7 +390,7 @@ pub struct ConfiguredClient { headers: Arc>, /// Capabilities as read from the `get_system_info` RPC call made on client connection capabilities: Option, - workers: Arc, + workers: Arc, } impl ConfiguredClient { @@ -438,9 +440,14 @@ impl ConfiguredClient { } /// Returns a cloned reference to a registry with workers using this client instance - pub fn workers(&self) -> Arc { + pub fn workers(&self) -> Arc { self.workers.clone() } + + /// Returns the worker grouping key, this should be unique across each client + pub fn worker_grouping_key(&self) -> Uuid { + self.workers.worker_grouping_key() + } } #[derive(Debug)] @@ -584,7 +591,7 @@ impl ClientOptions { client: TemporalServiceClient::new(svc), options: Arc::new(self.clone()), capabilities: None, - workers: Arc::new(SlotManager::new()), + workers: Arc::new(ClientWorkerSet::new()), }; if !self.skip_get_system_info { match client @@ -901,6 +908,11 @@ impl Client { pub fn into_inner(self) -> ConfiguredClient { self.inner } + + /// Returns the client-wide key + pub fn worker_grouping_key(&self) -> Uuid { + self.inner.worker_grouping_key() + } } impl NamespacedClient for Client { diff --git a/client/src/raw.rs b/client/src/raw.rs index aa5500ee1..92e6a7956 100644 --- a/client/src/raw.rs +++ b/client/src/raw.rs @@ -7,7 +7,7 @@ use crate::{ TEMPORAL_NAMESPACE_HEADER_KEY, TemporalServiceClient, metrics::{namespace_kv, task_queue_kv}, raw::sealed::RawClientLike, - worker_registry::{Slot, SlotManager}, + worker_registry::{ClientWorkerSet, Slot}, }; use futures_util::{FutureExt, TryFutureExt, future::BoxFuture}; use std::sync::Arc; @@ -68,7 +68,7 @@ pub(super) mod sealed { fn health_client_mut(&mut self) -> &mut HealthClient; /// Return a registry with workers using this client instance - fn get_workers_info(&self) -> Option>; + fn get_workers_info(&self) -> Option>; async fn call( &mut self, @@ -134,7 +134,7 @@ where self.get_client_mut().health_client_mut() } - fn get_workers_info(&self) -> Option> { + fn get_workers_info(&self) -> Option> { self.get_client().get_workers_info() } @@ -213,7 +213,7 @@ where self.health_svc_mut() } - fn get_workers_info(&self) -> Option> { + fn get_workers_info(&self) -> Option> { None } } @@ -268,7 +268,7 @@ where self.client.health_client_mut() } - fn get_workers_info(&self) -> Option> { + fn get_workers_info(&self) -> Option> { Some(self.workers()) } } @@ -316,7 +316,7 @@ impl RawClientLike for Client { self.inner.health_client_mut() } - fn get_workers_info(&self) -> Option> { + fn get_workers_info(&self) -> Option> { self.inner.get_workers_info() } } diff --git a/client/src/worker_registry/mod.rs b/client/src/worker_registry/mod.rs index 90882a718..f10b128ce 100644 --- a/client/src/worker_registry/mod.rs +++ b/client/src/worker_registry/mod.rs @@ -2,27 +2,16 @@ //! This is needed to implement Eager Workflow Start, a latency optimization in which the client, //! after reserving a slot, directly forwards a WFT to a local worker. +use anyhow::bail; use parking_lot::RwLock; -use slotmap::SlotMap; -use std::collections::{HashMap, hash_map::Entry::Vacant}; - +use std::collections::{ + HashMap, + hash_map::Entry::{Occupied, Vacant}, +}; +use std::sync::Arc; +use temporal_sdk_core_protos::temporal::api::worker::v1::WorkerHeartbeat; use temporal_sdk_core_protos::temporal::api::workflowservice::v1::PollWorkflowTaskQueueResponse; - -slotmap::new_key_type! { - /// Registration key for a worker - pub struct WorkerKey; -} - -/// This trait is implemented by an object associated with a worker, which provides WFT processing slots. -#[cfg_attr(test, mockall::automock)] -pub trait SlotProvider: std::fmt::Debug { - /// The namespace for the WFTs that it can process. - fn namespace(&self) -> &str; - /// The task queue this provider listens to. - fn task_queue(&self) -> &str; - /// Try to reserve a slot on this worker. - fn try_reserve_wft_slot(&self) -> Option>; -} +use uuid::Uuid; /// This trait represents a slot reserved for processing a WFT by a worker. #[cfg_attr(test, mockall::automock)] @@ -49,21 +38,23 @@ impl SlotKey { } } -/// This is an inner class for [SlotManager] needed to hide the mutex. -#[derive(Default, Debug)] -struct SlotManagerImpl { - /// Maps keys, i.e., namespace#task_queue, to provider. - providers: HashMap>, - /// Maps ids to keys in `providers`. - index: SlotMap, +/// This is an inner class for [ClientWorkerSet] needed to hide the mutex. +struct ClientWorkerSetImpl { + /// Maps slot keys to slot provider worker. + slot_providers: HashMap, + /// Maps worker_instance_key to registered workers + all_workers: HashMap>, + /// Maps namespace to shared worker for worker heartbeating + shared_worker: HashMap>, } -impl SlotManagerImpl { +impl ClientWorkerSetImpl { /// Factory method. fn new() -> Self { Self { - index: Default::default(), - providers: Default::default(), + slot_providers: Default::default(), + all_workers: Default::default(), + shared_worker: Default::default(), } } @@ -73,55 +64,140 @@ impl SlotManagerImpl { task_queue: String, ) -> Option> { let key = SlotKey::new(namespace, task_queue); - if let Some(p) = self.providers.get(&key) - && let Some(slot) = p.try_reserve_wft_slot() + if let Some(p) = self.slot_providers.get(&key) + && let Some(worker) = self.all_workers.get(p) + && let Some(slot) = worker.try_reserve_wft_slot() { return Some(slot); } None } - fn register(&mut self, provider: Box) -> Option { - let key = SlotKey::new( - provider.namespace().to_string(), - provider.task_queue().to_string(), + fn register( + &mut self, + worker: Arc, + ) -> Result<(), anyhow::Error> { + let slot_key = SlotKey::new( + worker.namespace().to_string(), + worker.task_queue().to_string(), ); - if let Vacant(p) = self.providers.entry(key.clone()) { - p.insert(provider); - Some(self.index.insert(key)) - } else { - warn!("Ignoring registration for worker: {key:?}."); - None + if self.slot_providers.contains_key(&slot_key) { + bail!( + "Registration of multiple workers on the same namespace and task queue for the same client not allowed: {slot_key:?}, worker_instance_key: {:?}.", + worker.worker_instance_key() + ); } + + if worker.heartbeat_enabled() + && let Some(heartbeat_callback) = worker.heartbeat_callback() + { + let worker_instance_key = worker.worker_instance_key(); + let namespace = worker.namespace().to_string(); + + let shared_worker = match self.shared_worker.entry(namespace.clone()) { + Occupied(o) => o.into_mut(), + Vacant(v) => { + let shared_worker = worker.new_shared_namespace_worker()?; + v.insert(shared_worker) + } + }; + shared_worker.register_callback(worker_instance_key, heartbeat_callback); + } + + self.slot_providers + .insert(slot_key.clone(), worker.worker_instance_key()); + + self.all_workers + .insert(worker.worker_instance_key(), worker); + + Ok(()) } - fn unregister(&mut self, id: WorkerKey) -> Option> { - if let Some(key) = self.index.remove(id) { - self.providers.remove(&key) - } else { - None + fn unregister( + &mut self, + worker_instance_key: Uuid, + ) -> Result, anyhow::Error> { + let worker = self + .all_workers + .remove(&worker_instance_key) + .ok_or_else(|| { + anyhow::anyhow!("Worker with worker_instance_key {worker_instance_key} not found") + })?; + + let slot_key = SlotKey::new( + worker.namespace().to_string(), + worker.task_queue().to_string(), + ); + + self.slot_providers.remove(&slot_key); + + if let Some(w) = self.shared_worker.get_mut(worker.namespace()) { + let (callback, is_empty) = w.unregister_callback(worker.worker_instance_key()); + if let Some(cb) = callback { + if is_empty { + self.shared_worker.remove(worker.namespace()); + } + + // To maintain single ownership of the callback, we must re-register the callback + // back to the ClientWorker + worker.register_callback(cb); + } } + + Ok(worker) + } + + #[cfg(test)] + fn num_providers(&self) -> usize { + self.slot_providers.len() } #[cfg(test)] - fn num_providers(&self) -> (usize, usize) { - (self.index.len(), self.providers.len()) + fn num_heartbeat_workers(&self) -> usize { + self.shared_worker.values().map(|v| v.num_workers()).sum() } } +/// This trait represents a shared namespace worker that sends worker heartbeats and +/// receives worker commands. +pub trait SharedNamespaceWorkerTrait { + /// Namespace that the shared namespace worker is connected to. + fn namespace(&self) -> String; + + /// Registers a heartbeat callback. + fn register_callback(&self, worker_instance_key: Uuid, heartbeat_callback: HeartbeatCallback); + + /// Unregisters a heartbeat callback. Returns the callback removed, as well as a bool that + /// indicates if there are no remaining callbacks in the SharedNamespaceWorker, indicating + /// the shared worker itself can be shut down. + fn unregister_callback(&self, worker_instance_key: Uuid) -> (Option, bool); + + /// Returns the number of workers registered to this shared worker. + fn num_workers(&self) -> usize; +} + /// Enables local workers to make themselves visible to a shared client instance. -/// There can only be one worker registered per namespace+queue_name+client, others will get ignored. +/// +/// For slot managing, there can only be one worker registered per +/// namespace+queue_name+client, others will return an error. /// It also provides a convenient method to find compatible slots within the collection. -#[derive(Default, Debug)] -pub struct SlotManager { - manager: RwLock, +pub struct ClientWorkerSet { + worker_grouping_key: Uuid, + worker_manager: RwLock, } -impl SlotManager { +impl Default for ClientWorkerSet { + fn default() -> Self { + Self::new() + } +} + +impl ClientWorkerSet { /// Factory method. pub fn new() -> Self { Self { - manager: RwLock::new(SlotManagerImpl::new()), + worker_grouping_key: Uuid::new_v4(), + worker_manager: RwLock::new(ClientWorkerSetImpl::new()), } } @@ -131,29 +207,93 @@ impl SlotManager { namespace: String, task_queue: String, ) -> Option> { - self.manager + self.worker_manager .read() .try_reserve_wft_slot(namespace, task_queue) } - /// Register a local worker that can provide WFT processing slots. - pub fn register(&self, provider: Box) -> Option { - self.manager.write().register(provider) + /// Unregisters a local worker, typically when that worker starts shutdown. + pub fn unregister_worker( + &self, + worker_instance_key: Uuid, + ) -> Result, anyhow::Error> { + self.worker_manager.write().unregister(worker_instance_key) + } + + /// Register a local worker that can provide WFT processing slots and potentially worker heartbeating. + pub fn register_worker( + &self, + worker: Arc, + ) -> Result<(), anyhow::Error> { + self.worker_manager.write().register(worker) } - /// Unregister a provider, typically when its worker starts shutdown. - pub fn unregister(&self, id: WorkerKey) -> Option> { - self.manager.write().unregister(id) + /// Returns the worker grouping key, which is unique for each worker. + pub fn worker_grouping_key(&self) -> Uuid { + self.worker_grouping_key } #[cfg(test)] /// Returns (num_providers, num_buckets), where a bucket key is namespace+task_queue. /// There is only one provider per bucket so `num_providers` should be equal to `num_buckets`. - pub fn num_providers(&self) -> (usize, usize) { - self.manager.read().num_providers() + pub fn num_providers(&self) -> usize { + self.worker_manager.read().num_providers() + } + + #[cfg(test)] + /// Returns the total number of heartbeat workers registered across all namespaces. + pub fn num_heartbeat_workers(&self) -> usize { + self.worker_manager.read().num_heartbeat_workers() } } +impl std::fmt::Debug for ClientWorkerSet { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ClientWorkerSet") + .field("worker_grouping_key", &self.worker_grouping_key) + .finish() + } +} + +/// Contains a worker heartbeat callback, wrapped for mocking +pub type HeartbeatCallback = Box WorkerHeartbeat + Send + Sync>; + +/// Represents a complete worker that can handle both slot management +/// and worker heartbeat functionality. +#[cfg_attr(test, mockall::automock)] +pub trait ClientWorker: Send + Sync { + /// The namespace this worker operates in + fn namespace(&self) -> &str; + + /// The task queue this worker listens to + fn task_queue(&self) -> &str; + + /// Try to reserve a slot for workflow task processing. + /// + /// This method should return `Some(slot)` if a workflow task slot is available, + /// or `None` if all slots are currently in use. The returned slot will be used + /// to process exactly one workflow task. + fn try_reserve_wft_slot(&self) -> Option>; + + /// Unique identifier for this worker instance. + /// This must be stable across the worker's lifetime but unique per instance. + fn worker_instance_key(&self) -> Uuid; + + /// Indicates if worker heartbeating is enabled for this client worker. + fn heartbeat_enabled(&self) -> bool; + + /// Returns the heartbeat callback that can be used to get WorkerHeartbeat data. + fn heartbeat_callback(&self) -> Option; + + /// Creates a new worker that implements the [SharedNamespaceWorkerTrait] + fn new_shared_namespace_worker( + &self, + ) -> Result, anyhow::Error>; + + /// Registers a worker heartbeat callback, typically when a worker is unregistered from a client + fn register_callback(&self, callback: HeartbeatCallback); +} + #[cfg(test)] mod tests { use super::*; @@ -175,8 +315,9 @@ mod tests { task_queue: String, with_error: bool, no_slots: bool, - ) -> MockSlotProvider { - let mut mock_provider = MockSlotProvider::new(); + heartbeat_enabled: bool, + ) -> MockClientWorker { + let mut mock_provider = MockClientWorker::new(); mock_provider .expect_try_reserve_wft_slot() .returning(move || { @@ -189,78 +330,315 @@ mod tests { mock_provider.expect_namespace().return_const(namespace); mock_provider.expect_task_queue().return_const(task_queue); mock_provider + .expect_heartbeat_enabled() + .return_const(heartbeat_enabled); + mock_provider + .expect_worker_instance_key() + .return_const(Uuid::new_v4()); + mock_provider } #[test] - fn registry_respects_registration_order() { - let mock_provider1 = - new_mock_provider("foo".to_string(), "bar_q".to_string(), false, false); - let mock_provider2 = new_mock_provider("foo".to_string(), "bar_q".to_string(), false, true); - - let manager = SlotManager::new(); - let some_slots = manager.register(Box::new(mock_provider1)); - let no_slots = manager.register(Box::new(mock_provider2)); - assert!(no_slots.is_none()); - - let mut found = 0; - for _ in 0..10 { - if manager - .try_reserve_wft_slot("foo".to_string(), "bar_q".to_string()) - .is_some() - { - found += 1; + fn registry_keeps_one_provider_per_namespace() { + let manager = ClientWorkerSet::new(); + let mut worker_keys = vec![]; + let mut successful_registrations = 0; + + for i in 0..10 { + let namespace = format!("myId{}", i % 3); + let mock_provider = + new_mock_provider(namespace, "bar_q".to_string(), false, false, false); + let worker_instance_key = mock_provider.worker_instance_key(); + + let result = manager.register_worker(Arc::new(mock_provider)); + if result.is_ok() { + successful_registrations += 1; + worker_keys.push(worker_instance_key); + } else { + // Should get error for duplicate namespace+task_queue combinations + assert!(result.unwrap_err().to_string().contains( + "Registration of multiple workers on the same namespace and task queue" + )); } } - assert_eq!(found, 10); - assert_eq!((1, 1), manager.num_providers()); - - manager.unregister(some_slots.unwrap()); - assert_eq!((0, 0), manager.num_providers()); - - let mock_provider1 = - new_mock_provider("foo".to_string(), "bar_q".to_string(), false, false); - let mock_provider2 = new_mock_provider("foo".to_string(), "bar_q".to_string(), false, true); - - let no_slots = manager.register(Box::new(mock_provider2)); - let some_slots = manager.register(Box::new(mock_provider1)); - assert!(some_slots.is_none()); - - let mut not_found = 0; - for _ in 0..10 { - if manager - .try_reserve_wft_slot("foo".to_string(), "bar_q".to_string()) - .is_none() - { - not_found += 1; + + assert_eq!(successful_registrations, 3); + assert_eq!(3, manager.num_providers()); + + let count = worker_keys.iter().fold(0, |count, key| { + manager.unregister_worker(*key).unwrap(); + // expect error since worker is already unregistered + let result = manager.unregister_worker(*key); + assert!(result.is_err()); + count + 1 + }); + assert_eq!(3, count); + assert_eq!(0, manager.num_providers()); + } + + struct MockSharedNamespaceWorker { + namespace: String, + callbacks: Arc>>, + } + + impl std::fmt::Debug for MockSharedNamespaceWorker { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("MockSharedNamespaceWorker") + .field("namespace", &self.namespace) + .field("callbacks_count", &self.callbacks.read().len()) + .finish() + } + } + + impl MockSharedNamespaceWorker { + fn new(namespace: String) -> Self { + Self { + namespace, + callbacks: Arc::new(RwLock::new(HashMap::new())), } } - assert_eq!(not_found, 10); - assert_eq!((1, 1), manager.num_providers()); - manager.unregister(no_slots.unwrap()); - assert_eq!((0, 0), manager.num_providers()); } - #[test] - fn registry_keeps_one_provider_per_namespace() { - let manager = SlotManager::new(); - let mut worker_keys = vec![]; - for i in 0..10 { - let namespace = format!("myId{}", i % 3); - let mock_provider = new_mock_provider(namespace, "bar_q".to_string(), false, false); - worker_keys.push(manager.register(Box::new(mock_provider))); + impl SharedNamespaceWorkerTrait for MockSharedNamespaceWorker { + fn namespace(&self) -> String { + self.namespace.clone() } - assert_eq!((3, 3), manager.num_providers()); - - let count = worker_keys - .iter() - .filter(|key| key.is_some()) - .fold(0, |count, key| { - manager.unregister(key.unwrap()); - // Should be idempotent - manager.unregister(key.unwrap()); - count + 1 - }); - assert_eq!(3, count); - assert_eq!((0, 0), manager.num_providers()); + + fn register_callback( + &self, + worker_instance_key: Uuid, + heartbeat_callback: HeartbeatCallback, + ) { + self.callbacks + .write() + .insert(worker_instance_key, heartbeat_callback); + } + + fn unregister_callback( + &self, + worker_instance_key: Uuid, + ) -> (Option, bool) { + let mut callbacks = self.callbacks.write(); + let callback = callbacks.remove(&worker_instance_key); + let is_empty = callbacks.is_empty(); + (callback, is_empty) + } + + fn num_workers(&self) -> usize { + self.callbacks.read().len() + } + } + + fn new_mock_provider_with_heartbeat( + namespace: String, + task_queue: String, + heartbeat_enabled: bool, + worker_instance_key: Uuid, + ) -> MockClientWorker { + let mut mock_provider = MockClientWorker::new(); + mock_provider + .expect_try_reserve_wft_slot() + .returning(|| Some(new_mock_slot(false))); + mock_provider + .expect_namespace() + .return_const(namespace.clone()); + mock_provider.expect_task_queue().return_const(task_queue); + mock_provider + .expect_heartbeat_enabled() + .return_const(heartbeat_enabled); + mock_provider + .expect_worker_instance_key() + .return_const(worker_instance_key); + + if heartbeat_enabled { + mock_provider + .expect_heartbeat_callback() + .returning(|| Some(Box::new(WorkerHeartbeat::default))); + + let namespace_clone = namespace.clone(); + mock_provider + .expect_new_shared_namespace_worker() + .returning(move || { + Ok(Box::new(MockSharedNamespaceWorker::new( + namespace_clone.clone(), + ))) + }); + + mock_provider.expect_register_callback().returning(|_| {}); + } + + mock_provider + } + + #[test] + fn duplicate_namespace_task_queue_registration_fails() { + let manager = ClientWorkerSet::new(); + + let worker1 = new_mock_provider_with_heartbeat( + "test_namespace".to_string(), + "test_queue".to_string(), + true, + Uuid::new_v4(), + ); + + // Same namespace+task_queue but different worker instance + let worker2 = new_mock_provider_with_heartbeat( + "test_namespace".to_string(), + "test_queue".to_string(), + true, + Uuid::new_v4(), + ); + + manager.register_worker(Arc::new(worker1)).unwrap(); + + // second worker register should fail due to duplicate namespace+task_queue + let result = manager.register_worker(Arc::new(worker2)); + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("Registration of multiple workers on the same namespace and task queue") + ); + + assert_eq!(1, manager.num_providers()); + assert_eq!(manager.num_heartbeat_workers(), 1); + + let impl_ref = manager.worker_manager.read(); + assert_eq!(impl_ref.shared_worker.len(), 1); + assert!(impl_ref.shared_worker.contains_key("test_namespace")); + } + + #[test] + fn multiple_workers_same_namespace_share_heartbeat_manager() { + let manager = ClientWorkerSet::new(); + + let worker1 = new_mock_provider_with_heartbeat( + "shared_namespace".to_string(), + "queue1".to_string(), + true, + Uuid::new_v4(), + ); + + // Same namespace but different task queue + let worker2 = new_mock_provider_with_heartbeat( + "shared_namespace".to_string(), + "queue2".to_string(), + true, + Uuid::new_v4(), + ); + + manager.register_worker(Arc::new(worker1)).unwrap(); + manager.register_worker(Arc::new(worker2)).unwrap(); + + assert_eq!(2, manager.num_providers()); + assert_eq!(manager.num_heartbeat_workers(), 2); + + let impl_ref = manager.worker_manager.read(); + assert_eq!(impl_ref.shared_worker.len(), 1); + assert!(impl_ref.shared_worker.contains_key("shared_namespace")); + + let shared_worker = impl_ref.shared_worker.get("shared_namespace").unwrap(); + assert_eq!(shared_worker.namespace(), "shared_namespace"); + } + + #[test] + fn different_namespaces_get_separate_heartbeat_managers() { + let manager = ClientWorkerSet::new(); + let worker1 = new_mock_provider_with_heartbeat( + "namespace1".to_string(), + "queue1".to_string(), + true, + Uuid::new_v4(), + ); + let worker2 = new_mock_provider_with_heartbeat( + "namespace2".to_string(), + "queue1".to_string(), + true, + Uuid::new_v4(), + ); + + manager.register_worker(Arc::new(worker1)).unwrap(); + manager.register_worker(Arc::new(worker2)).unwrap(); + + assert_eq!(2, manager.num_providers()); + assert_eq!(manager.num_heartbeat_workers(), 2); + + let impl_ref = manager.worker_manager.read(); + assert_eq!(impl_ref.num_heartbeat_workers(), 2); + assert!(impl_ref.shared_worker.contains_key("namespace1")); + assert!(impl_ref.shared_worker.contains_key("namespace2")); + } + + #[test] + fn unregister_heartbeat_workers_cleans_up_shared_worker_when_last_removed() { + let manager = ClientWorkerSet::new(); + + // Create two workers with same namespace but different task queues + let worker1 = new_mock_provider_with_heartbeat( + "test_namespace".to_string(), + "queue1".to_string(), + true, + Uuid::new_v4(), + ); + let worker2 = new_mock_provider_with_heartbeat( + "test_namespace".to_string(), + "queue2".to_string(), + true, + Uuid::new_v4(), + ); + let worker_instance_key1 = worker1.worker_instance_key(); + let worker_instance_key2 = worker2.worker_instance_key(); + + manager.register_worker(Arc::new(worker1)).unwrap(); + manager.register_worker(Arc::new(worker2)).unwrap(); + + // Verify initial state: 2 slot providers, 2 heartbeat workers, 1 shared worker + assert_eq!(2, manager.num_providers()); + assert_eq!(manager.num_heartbeat_workers(), 2); + + let impl_ref = manager.worker_manager.read(); + assert_eq!(impl_ref.shared_worker.len(), 1); + assert!(impl_ref.shared_worker.contains_key("test_namespace")); + assert_eq!( + impl_ref + .shared_worker + .get("test_namespace") + .unwrap() + .num_workers(), + 2 + ); + drop(impl_ref); + + // Unregister first worker + manager.unregister_worker(worker_instance_key1).unwrap(); + + // After unregistering first worker: 1 slot provider, 1 heartbeat worker, shared worker still exists + assert_eq!(1, manager.num_providers()); + assert_eq!(manager.num_heartbeat_workers(), 1); + + let impl_ref = manager.worker_manager.read(); + assert_eq!(impl_ref.num_heartbeat_workers(), 1); // SharedNamespaceWorker still exists + assert!(impl_ref.shared_worker.contains_key("test_namespace")); + assert_eq!( + impl_ref + .shared_worker + .get("test_namespace") + .unwrap() + .num_workers(), + 1 + ); + drop(impl_ref); + + // Unregister second worker + manager.unregister_worker(worker_instance_key2).unwrap(); + + // After unregistering last worker: 0 slot providers, 0 heartbeat workers, shared worker is removed + assert_eq!(0, manager.num_providers()); + assert_eq!(manager.num_heartbeat_workers(), 0); + + let impl_ref = manager.worker_manager.read(); + assert_eq!(impl_ref.shared_worker.len(), 0); // SharedNamespaceWorker is cleaned up + assert!(!impl_ref.shared_worker.contains_key("test_namespace")); } } diff --git a/client/src/workflow_handle/mod.rs b/client/src/workflow_handle/mod.rs index 6d284711b..af6d7ab16 100644 --- a/client/src/workflow_handle/mod.rs +++ b/client/src/workflow_handle/mod.rs @@ -200,8 +200,7 @@ where o => Err(anyhow!( "Server returned an event that didn't match the CloseEvent filter. \ This is either a server bug or a new event the SDK does not understand. \ - Event details: {:?}", - o + Event details: {o:?}" )), }; } diff --git a/core-api/src/worker.rs b/core-api/src/worker.rs index b7304c5c9..d92efeec0 100644 --- a/core-api/src/worker.rs +++ b/core-api/src/worker.rs @@ -161,12 +161,6 @@ pub struct WorkerConfig { /// A versioning strategy for this worker. pub versioning_strategy: WorkerVersioningStrategy, - - /// The interval within which the worker will send a heartbeat. - /// The timer is reset on each existing RPC call that also happens to send this data, like - /// `PollWorkflowTaskQueueRequest`. - #[builder(default)] - pub heartbeat_interval: Option, } impl WorkerConfig { diff --git a/core-c-bridge/include/temporal-sdk-core-c-bridge.h b/core-c-bridge/include/temporal-sdk-core-c-bridge.h index 803879f7d..b6ee3e275 100644 --- a/core-c-bridge/include/temporal-sdk-core-c-bridge.h +++ b/core-c-bridge/include/temporal-sdk-core-c-bridge.h @@ -397,6 +397,7 @@ typedef struct TemporalCoreTelemetryOptions { typedef struct TemporalCoreRuntimeOptions { const struct TemporalCoreTelemetryOptions *telemetry; + uint64_t worker_heartbeat_duration_millis; } TemporalCoreRuntimeOptions; typedef struct TemporalCoreTestServerOptions { @@ -868,8 +869,8 @@ void temporal_core_worker_validate(struct TemporalCoreWorker *worker, void *user_data, TemporalCoreWorkerCallback callback); -void temporal_core_worker_replace_client(struct TemporalCoreWorker *worker, - struct TemporalCoreClient *new_client); +const struct TemporalCoreByteArray *temporal_core_worker_replace_client(struct TemporalCoreWorker *worker, + struct TemporalCoreClient *new_client); void temporal_core_worker_poll_workflow_activation(struct TemporalCoreWorker *worker, void *user_data, diff --git a/core-c-bridge/src/client.rs b/core-c-bridge/src/client.rs index fde3548ce..ccdd660fd 100644 --- a/core-c-bridge/src/client.rs +++ b/core-c-bridge/src/client.rs @@ -685,7 +685,7 @@ async fn call_workflow_service( "UpdateWorkerBuildIdCompatibility" => { rpc_call!(client, call, update_worker_build_id_compatibility) } - rpc => Err(anyhow::anyhow!("Unknown RPC call {}", rpc)), + rpc => Err(anyhow::anyhow!("Unknown RPC call {rpc}")), } } @@ -715,7 +715,7 @@ async fn call_operator_service( "UpdateNexusEndpoint" => { rpc_call_on_trait!(client, call, OperatorService, update_nexus_endpoint) } - rpc => Err(anyhow::anyhow!("Unknown RPC call {}", rpc)), + rpc => Err(anyhow::anyhow!("Unknown RPC call {rpc}")), } } @@ -785,7 +785,7 @@ async fn call_cloud_service(client: &CoreClient, call: &RpcCallOptions) -> anyho "GetConnectivityRule" => rpc_call!(client, call, get_connectivity_rule), "GetConnectivityRules" => rpc_call!(client, call, get_connectivity_rules), "DeleteConnectivityRule" => rpc_call!(client, call, delete_connectivity_rule), - rpc => Err(anyhow::anyhow!("Unknown RPC call {}", rpc)), + rpc => Err(anyhow::anyhow!("Unknown RPC call {rpc}")), } } @@ -799,7 +799,7 @@ async fn call_test_service(client: &CoreClient, call: &RpcCallOptions) -> anyhow "Sleep" => rpc_call!(client, call, sleep), "UnlockTimeSkippingWithSleep" => rpc_call!(client, call, unlock_time_skipping_with_sleep), "UnlockTimeSkipping" => rpc_call!(client, call, unlock_time_skipping), - rpc => Err(anyhow::anyhow!("Unknown RPC call {}", rpc)), + rpc => Err(anyhow::anyhow!("Unknown RPC call {rpc}")), } } @@ -814,7 +814,7 @@ async fn call_health_service( "Watch" => Err(anyhow::anyhow!( "Health service Watch method is not implemented in C bridge" )), - rpc => Err(anyhow::anyhow!("Unknown RPC call {}", rpc)), + rpc => Err(anyhow::anyhow!("Unknown RPC call {rpc}")), } } diff --git a/core-c-bridge/src/runtime.rs b/core-c-bridge/src/runtime.rs index 94fe46929..5a330e268 100644 --- a/core-c-bridge/src/runtime.rs +++ b/core-c-bridge/src/runtime.rs @@ -16,7 +16,8 @@ use std::{ time::{Duration, UNIX_EPOCH}, }; use temporal_sdk_core::{ - CoreRuntime, TokioRuntimeBuilder, + CoreRuntime, RuntimeOptions as CoreRuntimeOptions, + RuntimeOptionsBuilder as CoreRuntimeOptionsBuilder, TokioRuntimeBuilder, telemetry::{build_otlp_metric_exporter, start_prometheus_metric_exporter}, }; use temporal_sdk_core_api::telemetry::{ @@ -30,6 +31,7 @@ use url::Url; #[repr(C)] pub struct RuntimeOptions { pub telemetry: *const TelemetryOptions, + pub worker_heartbeat_duration_millis: u64, } #[repr(C)] @@ -142,7 +144,7 @@ pub extern "C" fn temporal_core_runtime_new(options: *const RuntimeOptions) -> R let mut runtime = Runtime { core: Arc::new( CoreRuntime::new( - CoreTelemetryOptions::default(), + CoreRuntimeOptions::default(), TokioRuntimeBuilder::default(), ) .unwrap(), @@ -238,8 +240,21 @@ impl Runtime { CoreTelemetryOptions::default() }; + let heartbeat_interval = if options.worker_heartbeat_duration_millis == 0 { + None + } else { + Some(Duration::from_millis( + options.worker_heartbeat_duration_millis, + )) + }; + + let core_runtime_options = CoreRuntimeOptionsBuilder::default() + .telemetry_options(telemetry_options) + .heartbeat_interval(heartbeat_interval) + .build()?; + // Build core runtime - let mut core = CoreRuntime::new(telemetry_options, TokioRuntimeBuilder::default())?; + let mut core = CoreRuntime::new(core_runtime_options, TokioRuntimeBuilder::default())?; // We late-bind the metrics after core runtime is created since it needs // the Tokio handle diff --git a/core-c-bridge/src/tests/context.rs b/core-c-bridge/src/tests/context.rs index 5f889d442..0fb7aaa04 100644 --- a/core-c-bridge/src/tests/context.rs +++ b/core-c-bridge/src/tests/context.rs @@ -153,6 +153,7 @@ impl Context { let RuntimeOrFail { runtime, fail } = temporal_core_runtime_new(&RuntimeOptions { telemetry: std::ptr::null(), + worker_heartbeat_duration_millis: 0, }); if let Some(fail) = byte_array_to_string(runtime, fail) { @@ -162,11 +163,7 @@ impl Context { temporal_core_runtime_free(runtime); "" }; - Err(anyhow!( - "Runtime creation failed: {}{}", - runtime_is_null, - fail - )) + Err(anyhow!("Runtime creation failed: {runtime_is_null}{fail}")) } else if runtime.is_null() { Err(anyhow!("Runtime creation failed: runtime is null")) } else { @@ -522,8 +519,7 @@ extern "C" fn ephemeral_server_start_callback( if let Some(fail) = fail { ContextOperationState::CallbackError(anyhow!( - "Ephemeral server start failed: {}", - fail + "Ephemeral server start failed: {fail}" )) } else if server.is_null() { ContextOperationState::CallbackError(anyhow!( @@ -568,8 +564,7 @@ extern "C" fn ephemeral_server_shutdown_callback( let _ = context.complete_operation_catch_unwind(|guard| { if let Some(fail) = byte_array_to_string(guard.runtime, std::mem::take(&mut fail)) { ContextOperationState::CallbackError(anyhow!( - "Ephemeral server shutdown failed: {}", - fail + "Ephemeral server shutdown failed: {fail}" )) } else { ContextOperationState::CallbackOk(None) @@ -591,7 +586,7 @@ extern "C" fn client_connect_callback( if let Some(context) = user_data.context.upgrade() { let _ = context.complete_operation_catch_unwind(|guard| { if let Some(fail) = byte_array_to_string(guard.runtime, std::mem::take(&mut fail)) { - ContextOperationState::CallbackError(anyhow!("Client connect failed: {}", fail)) + ContextOperationState::CallbackError(anyhow!("Client connect failed: {fail}")) } else { guard.client = std::mem::take(&mut client); ContextOperationState::CallbackOk(None) diff --git a/core-c-bridge/src/worker.rs b/core-c-bridge/src/worker.rs index 8eaa13f26..c752fdae0 100644 --- a/core-c-bridge/src/worker.rs +++ b/core-c-bridge/src/worker.rs @@ -522,11 +522,20 @@ pub extern "C" fn temporal_core_worker_validate( pub extern "C" fn temporal_core_worker_replace_client( worker: *mut Worker, new_client: *mut Client, -) { +) -> *const ByteArray { let worker = unsafe { &*worker }; let core_worker = worker.worker.as_ref().expect("missing worker").clone(); let client = unsafe { &*new_client }; - core_worker.replace_client(client.core.get_client().clone()); + + match core_worker.replace_client(client.core.get_client().clone()) { + Ok(()) => std::ptr::null(), + Err(err) => worker + .runtime + .clone() + .alloc_utf8(&format!("Replace client failed: {err}")) + .into_raw() + .cast_const(), + } } /// If success or fail are present, they must be freed. They will both be null diff --git a/core/src/core_tests/activity_tasks.rs b/core/src/core_tests/activity_tasks.rs index 6a3acdaaf..b508a10e0 100644 --- a/core/src/core_tests/activity_tasks.rs +++ b/core/src/core_tests/activity_tasks.rs @@ -1,7 +1,7 @@ use crate::{ ActivityHeartbeat, Worker, advance_fut, job_assert, prost_dur, test_help::{ - MockPollCfg, MockWorkerInputs, MocksHolder, QueueResponse, TEST_Q, WorkerExt, + MockPollCfg, MockWorkerInputs, MocksHolder, QueueResponse, WorkerExt, WorkflowCachingPolicy, build_fake_worker, build_mock_pollers, fanout_tasks, gen_assert_and_reply, mock_manual_poller, mock_poller, mock_worker, poll_and_reply, single_hist_mock_sg, test_worker_cfg, @@ -734,7 +734,7 @@ async fn no_eager_activities_requested_when_worker_options_disable_it( ScheduleActivity { seq: 1, activity_id: "act_id".to_string(), - task_queue: TEST_Q.to_string(), + task_queue: core.get_config().task_queue.clone(), cancellation_type: ActivityCancellationType::TryCancel as i32, ..Default::default() } @@ -821,6 +821,7 @@ async fn activity_tasks_from_completion_are_delivered() { let mut mock = build_mock_pollers(mh); mock.worker_cfg(|wc| wc.max_cached_workflows = 2); let core = mock_worker(mock); + let task_queue = core.get_config().task_queue.clone(); // Test start let wf_task = core.poll_workflow_activation().await.unwrap(); @@ -829,7 +830,7 @@ async fn activity_tasks_from_completion_are_delivered() { ScheduleActivity { seq, activity_id: format!("act_id_{seq}_same_queue"), - task_queue: TEST_Q.to_string(), + task_queue: task_queue.clone(), cancellation_type: ActivityCancellationType::TryCancel as i32, ..Default::default() } @@ -840,7 +841,7 @@ async fn activity_tasks_from_completion_are_delivered() { ScheduleActivity { seq: 4, activity_id: "act_id_same_queue_not_eager".to_string(), - task_queue: TEST_Q.to_string(), + task_queue: task_queue.clone(), cancellation_type: ActivityCancellationType::TryCancel as i32, ..Default::default() } diff --git a/core/src/core_tests/workers.rs b/core/src/core_tests/workers.rs index b29fbe4b5..f5288a442 100644 --- a/core/src/core_tests/workers.rs +++ b/core/src/core_tests/workers.rs @@ -315,7 +315,7 @@ async fn worker_shutdown_api(#[case] use_cache: bool, #[case] api_success: bool) mock.expect_is_mock().returning(|| true); mock.expect_sdk_name_and_version() .returning(|| ("test-core".to_string(), "0.0.0".to_string())); - mock.expect_get_identity() + mock.expect_identity() .returning(|| "test-identity".to_string()); if use_cache { if api_success { diff --git a/core/src/core_tests/workflow_tasks.rs b/core/src/core_tests/workflow_tasks.rs index e26b6c887..e5550938b 100644 --- a/core/src/core_tests/workflow_tasks.rs +++ b/core/src/core_tests/workflow_tasks.rs @@ -2996,7 +2996,9 @@ async fn both_normal_and_sticky_pollers_poll_concurrently() { Arc::new(mock_client), None, None, - ); + false, + ) + .unwrap(); for _ in 1..50 { let activation = worker.poll_workflow_activation().await.unwrap(); diff --git a/core/src/ephemeral_server/mod.rs b/core/src/ephemeral_server/mod.rs index f8f0ee57b..e64aa6482 100644 --- a/core/src/ephemeral_server/mod.rs +++ b/core/src/ephemeral_server/mod.rs @@ -241,8 +241,7 @@ impl EphemeralServer { } } Err(anyhow!( - "Failed connecting to test server after 5 seconds, last error: {:?}", - last_error + "Failed connecting to test server after 5 seconds, last error: {last_error:?}" )) } @@ -368,7 +367,7 @@ impl EphemeralExe { let arch = match std::env::consts::ARCH { "x86_64" => "amd64", "arm" | "aarch64" => "arm64", - other => return Err(anyhow!("Unsupported arch: {}", other)), + other => return Err(anyhow!("Unsupported arch: {other}")), }; let mut get_info_params = vec![("arch", arch), ("platform", platform)]; if let Some(format) = preferred_format { diff --git a/core/src/lib.rs b/core/src/lib.rs index db8f80da4..3995b7562 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -61,7 +61,8 @@ use crate::{ }; use anyhow::bail; use futures_util::Stream; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; +use std::time::Duration; use temporal_client::{ConfiguredClient, NamespacedClient, TemporalServiceClientWithMetrics}; use temporal_sdk_core_api::{ Worker as WorkerTrait, @@ -89,39 +90,41 @@ pub fn init_worker( where CT: Into, { - let client = init_worker_client(&worker_config, *client.into().into_inner()); - if client.namespace() != worker_config.namespace { + let client_inner = *client.into().into_inner(); + let client = init_worker_client( + worker_config.namespace.clone(), + worker_config.client_identity_override.clone(), + client_inner, + ); + let namespace = worker_config.namespace.clone(); + if client.namespace() != namespace { bail!("Passed in client is not bound to the same namespace as the worker"); } if client.namespace() == "" { bail!("Client namespace cannot be empty"); } let client_ident = client.get_identity().to_owned(); - let sticky_q = sticky_q_name_for_worker(&client_ident, &worker_config); + let sticky_q = sticky_q_name_for_worker(&client_ident, worker_config.max_cached_workflows); if client_ident.is_empty() { bail!("Client identity cannot be empty. Either lang or user should be setting this value"); } - let heartbeat_fn = worker_config - .heartbeat_interval - .map(|_| Arc::new(OnceLock::new())); - let client_bag = Arc::new(WorkerClientBag::new( client, - worker_config.namespace.clone(), - client_ident, + namespace.clone(), + client_ident.clone(), worker_config.versioning_strategy.clone(), - heartbeat_fn.clone(), )); - Ok(Worker::new( - worker_config, + Worker::new( + worker_config.clone(), sticky_q, - client_bag, + client_bag.clone(), Some(&runtime.telemetry), - heartbeat_fn, - )) + runtime.heartbeat_interval, + false, + ) } /// Create a worker for replaying one or more existing histories. It will auto-shutdown as soon as @@ -142,11 +145,12 @@ where } pub(crate) fn init_worker_client( - config: &WorkerConfig, + namespace: String, + client_identity_override: Option, client: ConfiguredClient, ) -> RetryClient { - let mut client = Client::new(client, config.namespace.clone()); - if let Some(ref id_override) = config.client_identity_override { + let mut client = Client::new(client, namespace); + if let Some(ref id_override) = client_identity_override { client.options_mut().identity.clone_from(id_override); } RetryClient::new(client, RetryConfig::default()) @@ -156,9 +160,9 @@ pub(crate) fn init_worker_client( /// workflows. pub(crate) fn sticky_q_name_for_worker( process_identity: &str, - config: &WorkerConfig, + max_cached_workflows: usize, ) -> Option { - if config.max_cached_workflows > 0 { + if max_cached_workflows > 0 { Some(format!( "{}-{}", &process_identity, @@ -220,6 +224,21 @@ pub struct CoreRuntime { telemetry: TelemetryInstance, runtime: Option, runtime_handle: tokio::runtime::Handle, + heartbeat_interval: Option, +} + +/// Holds telemetry options, as well as worker heartbeat_interval. Construct with [RuntimeOptionsBuilder] +#[derive(derive_builder::Builder)] +#[non_exhaustive] +#[derive(Default)] +pub struct RuntimeOptions { + /// Telemetry configuration options. + #[builder(default)] + telemetry_options: TelemetryOptions, + /// Optional worker heartbeat interval - This configures the heartbeat setting of all + /// workers created using this runtime. + #[builder(default = "Some(Duration::from_secs(60))")] + heartbeat_interval: Option, } /// Wraps a [tokio::runtime::Builder] to allow layering multiple on_thread_start functions @@ -254,13 +273,13 @@ impl CoreRuntime { /// If a tokio runtime has already been initialized. To re-use an existing runtime, call /// [CoreRuntime::new_assume_tokio]. pub fn new( - telemetry_options: TelemetryOptions, + runtime_options: RuntimeOptions, mut tokio_builder: TokioRuntimeBuilder, ) -> Result where F: Fn() + Send + Sync + 'static, { - let telemetry = telemetry_init(telemetry_options)?; + let telemetry = telemetry_init(runtime_options.telemetry_options)?; let subscriber = telemetry.trace_subscriber(); let runtime = tokio_builder .inner @@ -275,7 +294,8 @@ impl CoreRuntime { }) .build()?; let _rg = runtime.enter(); - let mut me = Self::new_assume_tokio_initialized_telem(telemetry); + let mut me = + Self::new_assume_tokio_initialized_telem(telemetry, runtime_options.heartbeat_interval); me.runtime = Some(runtime); Ok(me) } @@ -285,9 +305,12 @@ impl CoreRuntime { /// /// # Panics /// If there is no currently active Tokio runtime - pub fn new_assume_tokio(telemetry_options: TelemetryOptions) -> Result { - let telemetry = telemetry_init(telemetry_options)?; - Ok(Self::new_assume_tokio_initialized_telem(telemetry)) + pub fn new_assume_tokio(runtime_options: RuntimeOptions) -> Result { + let telemetry = telemetry_init(runtime_options.telemetry_options)?; + Ok(Self::new_assume_tokio_initialized_telem( + telemetry, + runtime_options.heartbeat_interval, + )) } /// Construct a runtime from an already-initialized telemetry instance, assuming a tokio runtime @@ -295,7 +318,10 @@ impl CoreRuntime { /// /// # Panics /// If there is no currently active Tokio runtime - pub fn new_assume_tokio_initialized_telem(telemetry: TelemetryInstance) -> Self { + pub fn new_assume_tokio_initialized_telem( + telemetry: TelemetryInstance, + heartbeat_interval: Option, + ) -> Self { let runtime_handle = tokio::runtime::Handle::current(); if let Some(sub) = telemetry.trace_subscriber() { set_trace_subscriber_for_current_thread(sub); @@ -304,6 +330,7 @@ impl CoreRuntime { telemetry, runtime: None, runtime_handle, + heartbeat_interval, } } diff --git a/core/src/pollers/poll_buffer.rs b/core/src/pollers/poll_buffer.rs index 72f2e8a41..7bc4311fb 100644 --- a/core/src/pollers/poll_buffer.rs +++ b/core/src/pollers/poll_buffer.rs @@ -203,6 +203,7 @@ impl LongPollBuffer { permit_dealer: MeteredPermitDealer, shutdown: CancellationToken, num_pollers_handler: Option, + send_heartbeat: bool, ) -> Self { let no_retry = if matches!(poller_behavior, PollerBehavior::Autoscaling { .. }) { Some(NoRetryOnMatching { @@ -216,11 +217,14 @@ impl LongPollBuffer { let task_queue = task_queue.clone(); async move { client - .poll_nexus_task(PollOptions { - task_queue, - no_retry, - timeout_override, - }) + .poll_nexus_task( + PollOptions { + task_queue, + no_retry, + timeout_override, + }, + send_heartbeat, + ) .await } }; diff --git a/core/src/protosext/mod.rs b/core/src/protosext/mod.rs index fe0d60686..1a1119033 100644 --- a/core/src/protosext/mod.rs +++ b/core/src/protosext/mod.rs @@ -135,7 +135,7 @@ impl TryFrom for ValidPollWFTQResponse { _cant_construct_me: (), }) } - _ => Err(anyhow!("Unable to interpret poll response: {:?}", value)), + _ => Err(anyhow!("Unable to interpret poll response: {value:?}",)), } } } diff --git a/core/src/protosext/protocol_messages.rs b/core/src/protosext/protocol_messages.rs index 3962edec5..af2aab47d 100644 --- a/core/src/protosext/protocol_messages.rs +++ b/core/src/protosext/protocol_messages.rs @@ -116,7 +116,7 @@ impl TryFrom> for IncomingProtocolMessageBody { v.unpack_as(update::v1::Request::default())?.try_into()?, ) } - o => bail!("Could not understand protocol message type {}", o), + o => bail!("Could not understand protocol message type {o}"), }) } } diff --git a/core/src/replay/mod.rs b/core/src/replay/mod.rs index 650070b20..1e4990000 100644 --- a/core/src/replay/mod.rs +++ b/core/src/replay/mod.rs @@ -114,7 +114,7 @@ where hist_allow_tx.send("Failed".to_string()).unwrap(); async move { Ok(RespondWorkflowTaskFailedResponse::default()) }.boxed() }); - let mut worker = Worker::new(self.config, None, Arc::new(client), None, None); + let mut worker = Worker::new(self.config, None, Arc::new(client), None, None, false)?; worker.set_post_activate_hook(post_activate); shutdown_tok(worker.shutdown_token()); Ok(worker) diff --git a/core/src/telemetry/metrics.rs b/core/src/telemetry/metrics.rs index 28bf18845..fe6aec8e1 100644 --- a/core/src/telemetry/metrics.rs +++ b/core/src/telemetry/metrics.rs @@ -1,4 +1,6 @@ -use crate::{abstractions::dbg_panic, telemetry::TelemetryInstance}; +#[cfg(test)] +use crate::TelemetryInstance; +use crate::abstractions::dbg_panic; use std::{ fmt::{Debug, Display}, @@ -11,7 +13,7 @@ use temporal_sdk_core_api::telemetry::metrics::{ GaugeF64, GaugeF64Base, Histogram, HistogramBase, HistogramDuration, HistogramDurationBase, HistogramF64, HistogramF64Base, LazyBufferInstrument, MetricAttributable, MetricAttributes, MetricCallBufferer, MetricEvent, MetricKeyValue, MetricKind, MetricParameters, MetricUpdateVal, - NewAttributes, NoOpCoreMeter, + NewAttributes, NoOpCoreMeter, TemporalMeter, }; use temporal_sdk_core_protos::temporal::api::{ enums::v1::WorkflowTaskFailedCause, failure::v1::Failure, @@ -76,8 +78,17 @@ impl MetricsContext { } } + #[cfg(test)] pub(crate) fn top_level(namespace: String, tq: String, telemetry: &TelemetryInstance) -> Self { - if let Some(mut meter) = telemetry.get_temporal_metric_meter() { + MetricsContext::top_level_with_meter(namespace, tq, telemetry.get_temporal_metric_meter()) + } + + pub(crate) fn top_level_with_meter( + namespace: String, + tq: String, + temporal_meter: Option, + ) -> Self { + if let Some(mut meter) = temporal_meter { meter .default_attribs .attributes diff --git a/core/src/telemetry/prometheus_meter.rs b/core/src/telemetry/prometheus_meter.rs index 0b53e8c12..37810a534 100644 --- a/core/src/telemetry/prometheus_meter.rs +++ b/core/src/telemetry/prometheus_meter.rs @@ -315,8 +315,7 @@ where Ok(labels) } else { let e = anyhow!( - "Must use Prometheus attributes with a Prometheus metric implementation. Got: {:?}", - attributes + "Must use Prometheus attributes with a Prometheus metric implementation. Got: {attributes:?}" ); dbg_panic!("{:?}", e); Err(e) diff --git a/core/src/test_help/integ_helpers.rs b/core/src/test_help/integ_helpers.rs index 2c3417dc6..2f3cded78 100644 --- a/core/src/test_help/integ_helpers.rs +++ b/core/src/test_help/integ_helpers.rs @@ -62,13 +62,11 @@ use temporal_sdk_core_protos::{ }; use tokio::sync::{Notify, mpsc::unbounded_channel}; use tokio_stream::wrappers::UnboundedReceiverStream; +use uuid::Uuid; /// Default namespace for testing pub const NAMESPACE: &str = "default"; -/// Default task queue for testing -pub const TEST_Q: &str = "q"; - /// Initiate shutdown, drain the pollers (handling evictions), and wait for shutdown to complete. pub async fn drain_pollers_and_shutdown(worker: &dyn WorkerTrait) { worker.initiate_shutdown(); @@ -102,7 +100,7 @@ pub async fn drain_pollers_and_shutdown(worker: &dyn WorkerTrait) { pub fn test_worker_cfg() -> WorkerConfigBuilder { let mut wcb = WorkerConfigBuilder::default(); wcb.namespace(NAMESPACE) - .task_queue(TEST_Q) + .task_queue(Uuid::new_v4().to_string()) .versioning_strategy(WorkerVersioningStrategy::None { build_id: "test_bin_id".to_string(), }) @@ -185,7 +183,7 @@ pub fn build_fake_worker( } pub fn mock_worker(mocks: MocksHolder) -> Worker { - let sticky_q = sticky_q_name_for_worker("unit-test", &mocks.inputs.config); + let sticky_q = sticky_q_name_for_worker("unit-test", mocks.inputs.config.max_cached_workflows); let act_poller = if mocks.inputs.config.no_remote_activities { None } else { @@ -205,7 +203,9 @@ pub fn mock_worker(mocks: MocksHolder) -> Worker { }, None, None, + false, ) + .unwrap() } pub struct FakeWfResponses { @@ -275,7 +275,7 @@ impl MocksHolder { } } - /// Uses the provided list of tasks to create a mock poller for the `TEST_Q` + /// Uses the provided list of tasks to create a mock poller with a randomly generated task queue pub fn from_client_with_activities( client: impl WorkerClient + 'static, act_tasks: ACT, diff --git a/core/src/worker/client.rs b/core/src/worker/client.rs index 519994fc3..5d773330e 100644 --- a/core/src/worker/client.rs +++ b/core/src/worker/client.rs @@ -1,17 +1,12 @@ //! Worker-specific client needs pub(crate) mod mocks; -use crate::{ - abstractions::dbg_panic, protosext::legacy_query_failure, worker::heartbeat::HeartbeatFn, -}; +use crate::protosext::legacy_query_failure; use parking_lot::RwLock; -use std::{ - sync::{Arc, OnceLock}, - time::Duration, -}; +use std::{sync::Arc, time::Duration}; use temporal_client::{ - Client, IsWorkerTaskLongPoll, Namespace, NamespacedClient, NoRetryOnMatching, RetryClient, - SlotManager, WorkflowService, + Client, ClientWorkerSet, IsWorkerTaskLongPoll, Namespace, NamespacedClient, NoRetryOnMatching, + RetryClient, WorkflowService, }; use temporal_sdk_core_api::worker::WorkerVersioningStrategy; use temporal_sdk_core_protos::{ @@ -38,6 +33,7 @@ use temporal_sdk_core_protos::{ }, }; use tonic::IntoRequest; +use uuid::Uuid; type Result = std::result::Result; @@ -52,7 +48,6 @@ pub(crate) struct WorkerClientBag { namespace: String, identity: String, worker_versioning_strategy: WorkerVersioningStrategy, - heartbeat_data: Option>>, } impl WorkerClientBag { @@ -61,14 +56,12 @@ impl WorkerClientBag { namespace: String, identity: String, worker_versioning_strategy: WorkerVersioningStrategy, - heartbeat_data: Option>>, ) -> Self { Self { replaceable_client: RwLock::new(client), namespace, identity, worker_versioning_strategy, - heartbeat_data, } } @@ -129,19 +122,6 @@ impl WorkerClientBag { None } } - - fn capture_heartbeat(&self) -> Option { - if let Some(heartbeat_data) = self.heartbeat_data.as_ref() { - if let Some(hb) = heartbeat_data.get() { - hb() - } else { - dbg_panic!("Heartbeat function never set"); - None - } - } else { - None - } - } } /// This trait contains everything workers need to interact with Temporal, and hence provides a @@ -165,6 +145,7 @@ pub trait WorkerClient: Sync + Send { async fn poll_nexus_task( &self, poll_options: PollOptions, + send_heartbeat: bool, ) -> Result; /// Complete a workflow task async fn complete_workflow_task( @@ -234,7 +215,8 @@ pub trait WorkerClient: Sync + Send { /// Record a worker heartbeat async fn record_worker_heartbeat( &self, - heartbeat: WorkerHeartbeat, + namespace: String, + worker_heartbeat: Vec, ) -> Result; /// Replace the underlying client @@ -242,13 +224,15 @@ pub trait WorkerClient: Sync + Send { /// Return server capabilities fn capabilities(&self) -> Option; /// Return workers using this client - fn workers(&self) -> Arc; + fn workers(&self) -> Arc; /// Indicates if this is a mock client fn is_mock(&self) -> bool; /// Return name and version of the SDK fn sdk_name_and_version(&self) -> (String, String); /// Get worker identity - fn get_identity(&self) -> String; + fn identity(&self) -> String; + /// Get worker grouping key + fn worker_grouping_key(&self) -> Uuid; } /// Configuration options shared by workflow, activity, and Nexus polling calls @@ -360,6 +344,7 @@ impl WorkerClient for WorkerClientBag { async fn poll_nexus_task( &self, poll_options: PollOptions, + _send_heartbeat: bool, ) -> Result { #[allow(deprecated)] // want to list all fields explicitly let mut request = PollNexusTaskQueueRequest { @@ -372,7 +357,7 @@ impl WorkerClient for WorkerClientBag { identity: self.identity.clone(), worker_version_capabilities: self.worker_version_capabilities(), deployment_options: self.deployment_options(), - worker_heartbeat: self.capture_heartbeat().into_iter().collect(), + worker_heartbeat: Vec::new(), } .into_request(); request.extensions_mut().insert(IsWorkerTaskLongPoll); @@ -661,7 +646,7 @@ impl WorkerClient for WorkerClientBag { identity: self.identity.clone(), sticky_task_queue, reason: "graceful shutdown".to_string(), - worker_heartbeat: self.capture_heartbeat(), + worker_heartbeat: None, }; Ok( @@ -671,32 +656,34 @@ impl WorkerClient for WorkerClientBag { ) } - fn replace_client(&self, new_client: RetryClient) { - let mut replaceable_client = self.replaceable_client.write(); - *replaceable_client = new_client; - } - async fn record_worker_heartbeat( &self, - heartbeat: WorkerHeartbeat, + namespace: String, + worker_heartbeat: Vec, ) -> Result { + let request = RecordWorkerHeartbeatRequest { + namespace, + identity: self.identity.clone(), + worker_heartbeat, + }; Ok(self .cloned_client() - .record_worker_heartbeat(RecordWorkerHeartbeatRequest { - namespace: self.namespace.clone(), - identity: self.identity.clone(), - worker_heartbeat: vec![heartbeat], - }) + .record_worker_heartbeat(request) .await? .into_inner()) } + fn replace_client(&self, new_client: RetryClient) { + let mut replaceable_client = self.replaceable_client.write(); + *replaceable_client = new_client; + } + fn capabilities(&self) -> Option { let client = self.replaceable_client.read(); client.get_client().inner().capabilities().cloned() } - fn workers(&self) -> Arc { + fn workers(&self) -> Arc { let client = self.replaceable_client.read(); client.get_client().inner().workers() } @@ -711,9 +698,16 @@ impl WorkerClient for WorkerClientBag { (opts.client_name.clone(), opts.client_version.clone()) } - fn get_identity(&self) -> String { + fn identity(&self) -> String { self.identity.clone() } + + fn worker_grouping_key(&self) -> Uuid { + self.replaceable_client + .read() + .get_client() + .worker_grouping_key() + } } impl NamespacedClient for WorkerClientBag { diff --git a/core/src/worker/client/mocks.rs b/core/src/worker/client/mocks.rs index f6407f2a4..93984c364 100644 --- a/core/src/worker/client/mocks.rs +++ b/core/src/worker/client/mocks.rs @@ -1,10 +1,10 @@ use super::*; use futures_util::Future; use std::sync::{Arc, LazyLock}; -use temporal_client::SlotManager; +use temporal_client::ClientWorkerSet; -pub(crate) static DEFAULT_WORKERS_REGISTRY: LazyLock> = - LazyLock::new(|| Arc::new(SlotManager::new())); +pub(crate) static DEFAULT_WORKERS_REGISTRY: LazyLock> = + LazyLock::new(|| Arc::new(ClientWorkerSet::new())); pub(crate) static DEFAULT_TEST_CAPABILITIES: &Capabilities = &Capabilities { signal_and_query_header: true, @@ -33,8 +33,9 @@ pub fn mock_worker_client() -> MockWorkerClient { .returning(|_| Ok(ShutdownWorkerResponse {})); r.expect_sdk_name_and_version() .returning(|| ("test-core".to_string(), "0.0.0".to_string())); - r.expect_get_identity() + r.expect_identity() .returning(|| "test-identity".to_string()); + r.expect_worker_grouping_key().returning(Uuid::new_v4); r } @@ -48,7 +49,7 @@ pub(crate) fn mock_manual_worker_client() -> MockManualWorkerClient { r.expect_is_mock().returning(|| true); r.expect_sdk_name_and_version() .returning(|| ("test-core".to_string(), "0.0.0".to_string())); - r.expect_get_identity() + r.expect_identity() .returning(|| "test-identity".to_string()); r } @@ -68,7 +69,7 @@ mockall::mock! { -> impl Future> + Send + 'b where 'a: 'b, Self: 'b; - fn poll_nexus_task<'a, 'b>(&self, poll_options: PollOptions) + fn poll_nexus_task<'a, 'b>(&self, poll_options: PollOptions, send_heartbeat: bool) -> impl Future> + Send + 'b where 'a: 'b, Self: 'b; @@ -139,7 +140,7 @@ mockall::mock! { fn respond_legacy_query<'a, 'b>( &self, task_token: TaskToken, - query_result: LegacyQueryResult, + query_result: LegacyQueryResult, ) -> impl Future> + Send + 'b where 'a: 'b, Self: 'b; @@ -150,13 +151,18 @@ mockall::mock! { fn shutdown_worker<'a, 'b>(&self, sticky_task_queue: String) -> impl Future> + Send + 'b where 'a: 'b, Self: 'b; - fn record_worker_heartbeat<'a, 'b>(&self, heartbeat: WorkerHeartbeat) -> impl Future> + Send + 'b where 'a: 'b, Self: 'b; + fn record_worker_heartbeat<'a, 'b>( + &self, + namespace: String, + heartbeat: Vec + ) -> impl Future> + Send + 'b where 'a: 'b, Self: 'b; fn replace_client(&self, new_client: RetryClient); fn capabilities(&self) -> Option; - fn workers(&self) -> Arc; + fn workers(&self) -> Arc; fn is_mock(&self) -> bool; fn sdk_name_and_version(&self) -> (String, String); - fn get_identity(&self) -> String; + fn identity(&self) -> String; + fn worker_grouping_key(&self) -> Uuid; } } diff --git a/core/src/worker/heartbeat.rs b/core/src/worker/heartbeat.rs index f7c8d5694..88774647f 100644 --- a/core/src/worker/heartbeat.rs +++ b/core/src/worker/heartbeat.rs @@ -1,55 +1,113 @@ -use crate::{WorkerClient, abstractions::dbg_panic}; -use gethostname::gethostname; +use crate::WorkerClient; +use crate::worker::{TaskPollers, WorkerTelemetry}; use parking_lot::Mutex; use prost_types::Duration as PbDuration; +use std::collections::HashMap; use std::{ - sync::{Arc, OnceLock}, + sync::Arc, time::{Duration, SystemTime}, }; -use temporal_sdk_core_api::worker::WorkerConfig; -use temporal_sdk_core_protos::temporal::api::worker::v1::{WorkerHeartbeat, WorkerHostInfo}; -use tokio::{sync::Notify, task::JoinHandle, time::MissedTickBehavior}; +use temporal_client::SharedNamespaceWorkerTrait; +use temporal_sdk_core_api::worker::{ + PollerBehavior, WorkerConfigBuilder, WorkerVersioningStrategy, +}; +use temporal_sdk_core_protos::temporal::api::worker::v1::WorkerHeartbeat; +use tokio::sync::Notify; +use tokio_util::sync::CancellationToken; use uuid::Uuid; -pub(crate) type HeartbeatFn = Box Option + Send + Sync>; +/// Callback used to collect heartbeat data from each worker at the time of heartbeat +pub(crate) type HeartbeatFn = Box WorkerHeartbeat + Send + Sync>; -pub(crate) struct WorkerHeartbeatManager { - heartbeat_handle: JoinHandle<()>, +/// SharedNamespaceWorker is responsible for polling nexus-delivered worker commands and sending +/// worker heartbeats to the server. This invokes callbacks on all workers in the same process that +/// share the same namespace. +pub(crate) struct SharedNamespaceWorker { + heartbeat_map: Arc>>, + namespace: String, + cancel: CancellationToken, } -impl WorkerHeartbeatManager { +impl SharedNamespaceWorker { pub(crate) fn new( - config: WorkerConfig, - identity: String, - heartbeat_fn: Arc>, client: Arc, - ) -> Self { - let sdk_name_and_ver = client.sdk_name_and_version(); - let reset_notify = Arc::new(Notify::new()); - let data = Arc::new(Mutex::new(WorkerHeartbeatData::new( + namespace: String, + heartbeat_interval: Duration, + telemetry: Option, + ) -> Result { + let config = WorkerConfigBuilder::default() + .namespace(namespace.clone()) + .task_queue(format!( + "temporal-sys/worker-commands/{namespace}/{}", + client.worker_grouping_key(), + )) + .no_remote_activities(true) + .max_outstanding_nexus_tasks(5_usize) + .versioning_strategy(WorkerVersioningStrategy::None { + build_id: "1.0".to_owned(), + }) + .nexus_task_poller_behavior(PollerBehavior::SimpleMaximum(1_usize)) + .build() + .expect("all required fields should be implemented"); + let worker = crate::worker::Worker::new_with_pollers_inner( config, - identity, - sdk_name_and_ver, - reset_notify.clone(), - ))); - let data_clone = data.clone(); - - let heartbeat_handle = tokio::spawn(async move { - let mut ticker = tokio::time::interval(data_clone.lock().heartbeat_interval); - ticker.set_missed_tick_behavior(MissedTickBehavior::Delay); + None, + client.clone(), + TaskPollers::Real, + telemetry, + None, + true, + )?; + + let last_heartbeat_time_map = Mutex::new(HashMap::new()); + + let reset_notify = Arc::new(Notify::new()); + let cancel = CancellationToken::new(); + let cancel_clone = cancel.clone(); + + let client_clone = client; + let namespace_clone = namespace.clone(); + + let heartbeat_map = Arc::new(Mutex::new(HashMap::::new())); + let heartbeat_map_clone = heartbeat_map.clone(); + + tokio::spawn(async move { + let mut ticker = tokio::time::interval(heartbeat_interval); loop { tokio::select! { _ = ticker.tick() => { - let heartbeat = if let Some(heartbeat) = data_clone.lock().capture_heartbeat_if_needed() { - heartbeat - } else { - continue - }; - if let Err(e) = client.clone().record_worker_heartbeat(heartbeat).await { - if matches!( - e.code(), - tonic::Code::Unimplemented - ) { + let mut hb_to_send = Vec::new(); + for (instance_key, heartbeat_callback) in heartbeat_map_clone.lock().iter() { + let mut heartbeat = heartbeat_callback(); + let mut last_heartbeat_time_map = last_heartbeat_time_map.lock(); + let now = SystemTime::now(); + let elapsed_since_last_heartbeat = last_heartbeat_time_map.get(instance_key).cloned().map( + |hb_time| { + let dur = now.duration_since(hb_time).unwrap_or(Duration::ZERO); + PbDuration { + seconds: dur.as_secs() as i64, + nanos: dur.subsec_nanos() as i32, + } + } + ); + + heartbeat.elapsed_since_last_heartbeat = elapsed_since_last_heartbeat; + heartbeat.heartbeat_time = Some(now.into()); + + // All of these heartbeat details rely on a client. To avoid circular + // dependencies, this must be populated from within SharedNamespaceWorker + // to get info from the current client + heartbeat.worker_identity = client_clone.identity(); + let sdk_name_and_ver = client_clone.sdk_name_and_version(); + heartbeat.sdk_name = sdk_name_and_ver.0; + heartbeat.sdk_version = sdk_name_and_ver.1; + + hb_to_send.push(heartbeat); + + last_heartbeat_time_map.insert(*instance_key, now); + } + if let Err(e) = client_clone.record_worker_heartbeat(namespace_clone.clone(), hb_to_send).await { + if matches!(e.code(), tonic::Code::Unimplemented) { return; } warn!(error=?e, "Network error while sending worker heartbeat"); @@ -58,131 +116,83 @@ impl WorkerHeartbeatManager { _ = reset_notify.notified() => { ticker.reset(); } + _ = cancel_clone.cancelled() => { + worker.shutdown().await; + return; + } } } }); - let data_clone = data.clone(); - if heartbeat_fn - .set(Box::new(move || { - data_clone.lock().capture_heartbeat_if_needed() - })) - .is_err() - { - dbg_panic!( - "Failed to set heartbeat_fn, heartbeat_fn should only be set once, when a singular WorkerHeartbeatInfo is created" - ); - } - - Self { heartbeat_handle } - } - - pub(crate) fn shutdown(&self) { - self.heartbeat_handle.abort() + Ok(Self { + heartbeat_map, + namespace, + cancel, + }) } } -#[derive(Debug, Clone)] -struct WorkerHeartbeatData { - worker_instance_key: String, - worker_identity: String, - host_info: WorkerHostInfo, - // Time of the last heartbeat. This is used to both for heartbeat_time and last_heartbeat_time - heartbeat_time: Option, - task_queue: String, - /// SDK name - sdk_name: String, - /// SDK version - sdk_version: String, - /// Worker start time - start_time: SystemTime, - heartbeat_interval: Duration, - reset_notify: Arc, -} +impl SharedNamespaceWorkerTrait for SharedNamespaceWorker { + fn namespace(&self) -> String { + self.namespace.clone() + } -impl WorkerHeartbeatData { - fn new( - worker_config: WorkerConfig, - worker_identity: String, - sdk_name_and_ver: (String, String), - reset_notify: Arc, - ) -> Self { - Self { - worker_identity, - host_info: WorkerHostInfo { - host_name: gethostname().to_string_lossy().to_string(), - process_id: std::process::id().to_string(), - ..Default::default() - }, - sdk_name: sdk_name_and_ver.0, - sdk_version: sdk_name_and_ver.1, - task_queue: worker_config.task_queue.clone(), - start_time: SystemTime::now(), - heartbeat_time: None, - worker_instance_key: Uuid::new_v4().to_string(), - heartbeat_interval: worker_config - .heartbeat_interval - .expect("WorkerHeartbeatData is only called when heartbeat_interval is Some"), - reset_notify, + fn register_callback( + &self, + worker_instance_key: Uuid, + heartbeat_callback: Box WorkerHeartbeat + Send + Sync>, + ) { + self.heartbeat_map + .lock() + .insert(worker_instance_key, heartbeat_callback); + } + fn unregister_callback( + &self, + worker_instance_key: Uuid, + ) -> (Option WorkerHeartbeat + Send + Sync>>, bool) { + let mut heartbeat_map = self.heartbeat_map.lock(); + let heartbeat_callback = heartbeat_map.remove(&worker_instance_key); + if heartbeat_map.is_empty() { + self.cancel.cancel(); } + (heartbeat_callback, heartbeat_map.is_empty()) } - fn capture_heartbeat_if_needed(&mut self) -> Option { - let now = SystemTime::now(); - let elapsed_since_last_heartbeat = if let Some(heartbeat_time) = self.heartbeat_time { - let dur = now.duration_since(heartbeat_time).unwrap_or(Duration::ZERO); - - // Only send poll data if it's nearly been a full interval since this data has been sent - // In this case, "nearly" is 90% of the interval - if dur.as_secs_f64() < 0.9 * self.heartbeat_interval.as_secs_f64() { - return None; - } - Some(PbDuration { - seconds: dur.as_secs() as i64, - nanos: dur.subsec_nanos() as i32, - }) - } else { - None - }; - - self.heartbeat_time = Some(now); - - self.reset_notify.notify_one(); - - Some(WorkerHeartbeat { - worker_instance_key: self.worker_instance_key.clone(), - worker_identity: self.worker_identity.clone(), - host_info: Some(self.host_info.clone()), - task_queue: self.task_queue.clone(), - sdk_name: self.sdk_name.clone(), - sdk_version: self.sdk_version.clone(), - status: 0, - start_time: Some(self.start_time.into()), - heartbeat_time: Some(SystemTime::now().into()), - elapsed_since_last_heartbeat, - ..Default::default() - }) + fn num_workers(&self) -> usize { + self.heartbeat_map.lock().len() } } #[cfg(test)] mod tests { - use super::*; use crate::{ test_help::{WorkerExt, test_worker_cfg}, worker, worker::client::mocks::mock_worker_client, }; - use std::{sync::Arc, time::Duration}; + use std::{ + sync::{ + Arc, + atomic::{AtomicUsize, Ordering}, + }, + time::Duration, + }; use temporal_sdk_core_api::worker::PollerBehavior; use temporal_sdk_core_protos::temporal::api::workflowservice::v1::RecordWorkerHeartbeatResponse; #[tokio::test] - async fn worker_heartbeat() { + async fn worker_heartbeat_basic() { let mut mock = mock_worker_client(); - mock.expect_record_worker_heartbeat() - .times(2) - .returning(move |heartbeat| { + let heartbeat_count = Arc::new(AtomicUsize::new(0)); + let heartbeat_count_clone = heartbeat_count.clone(); + mock.expect_poll_workflow_task() + .returning(move |_namespace, _task_queue| Ok(Default::default())); + mock.expect_poll_nexus_task() + .returning(move |_poll_options, _send_heartbeat| Ok(Default::default())); + mock.expect_record_worker_heartbeat().times(3).returning( + move |_namespace, worker_heartbeat| { + assert_eq!(1, worker_heartbeat.len()); + let heartbeat = worker_heartbeat[0].clone(); let host_info = heartbeat.host_info.clone().unwrap(); assert_eq!("test-identity", heartbeat.worker_identity); assert!(!heartbeat.worker_instance_key.is_empty()); @@ -193,38 +203,35 @@ mod tests { assert_eq!(host_info.process_id, std::process::id().to_string()); assert_eq!(heartbeat.sdk_name, "test-core"); assert_eq!(heartbeat.sdk_version, "0.0.0"); - assert_eq!(heartbeat.task_queue, "q"); assert!(heartbeat.heartbeat_time.is_some()); assert!(heartbeat.start_time.is_some()); + heartbeat_count_clone.fetch_add(1, Ordering::Relaxed); + Ok(RecordWorkerHeartbeatResponse {}) - }); + }, + ); let config = test_worker_cfg() .activity_task_poller_behavior(PollerBehavior::SimpleMaximum(1_usize)) .max_outstanding_activities(1_usize) - .heartbeat_interval(Duration::from_millis(200)) .build() .unwrap(); - let heartbeat_fn = Arc::new(OnceLock::new()); let client = Arc::new(mock); - let worker = worker::Worker::new(config, None, client, None, Some(heartbeat_fn.clone())); - heartbeat_fn.get().unwrap()(); - - // heartbeat timer fires once - advance_time(Duration::from_millis(300)).await; - // it hasn't been >90% of the interval since the last heartbeat, so no data should be returned here - assert_eq!(None, heartbeat_fn.get().unwrap()()); - // heartbeat timer fires once - advance_time(Duration::from_millis(300)).await; - + let worker = worker::Worker::new( + config, + None, + client.clone(), + None, + Some(Duration::from_millis(100)), + false, + ) + .unwrap(); + + tokio::time::sleep(Duration::from_millis(250)).await; worker.drain_activity_poller_and_shutdown().await; - } - async fn advance_time(dur: Duration) { - tokio::time::pause(); - tokio::time::advance(dur).await; - tokio::time::resume(); + assert_eq!(3, heartbeat_count.load(Ordering::Relaxed)); } } diff --git a/core/src/worker/mod.rs b/core/src/worker/mod.rs index 5e57a26e8..5428c067a 100644 --- a/core/src/worker/mod.rs +++ b/core/src/worker/mod.rs @@ -20,11 +20,12 @@ pub(crate) use activities::{ pub(crate) use wft_poller::WFTPollerShared; pub use workflow::LEGACY_QUERY_ID; +use crate::worker::heartbeat::{HeartbeatFn, SharedNamespaceWorker}; use crate::{ ActivityHeartbeat, CompleteActivityError, PollError, WorkerTrait, abstractions::{MeteredPermitDealer, PermitDealerContextData, dbg_panic}, errors::CompleteWfError, - pollers::{ActivityTaskOptions, BoxedActPoller, BoxedNexusPoller, LongPollBuffer}, + pollers::{BoxedActPoller, BoxedNexusPoller}, protosext::validate_activity_completion, telemetry::{ TelemetryInstance, @@ -36,32 +37,41 @@ use crate::{ worker::{ activities::{LACompleteAction, LocalActivityManager, NextPendingLAAction}, client::WorkerClient, - heartbeat::{HeartbeatFn, WorkerHeartbeatManager}, nexus::NexusManager, workflow::{ - LAReqSink, LocalResolution, WorkflowBasics, Workflows, wft_poller, - wft_poller::make_wft_poller, + LAReqSink, LocalResolution, WorkflowBasics, Workflows, wft_poller::make_wft_poller, }, }, }; +use crate::{ + pollers::{ActivityTaskOptions, LongPollBuffer}, + worker::workflow::wft_poller, +}; use activities::WorkerActivityTasks; +use anyhow::bail; use futures_util::{StreamExt, stream}; -use parking_lot::Mutex; +use gethostname::gethostname; +use parking_lot::{Mutex, RwLock}; use slot_provider::SlotProvider; use std::{ convert::TryInto, future, sync::{ - Arc, OnceLock, + Arc, atomic::{AtomicBool, Ordering}, }, time::Duration, }; -use temporal_client::{ConfiguredClient, TemporalServiceClientWithMetrics, WorkerKey}; +use temporal_client::{ClientWorker, HeartbeatCallback, Slot as SlotTrait}; +use temporal_client::{ + ConfiguredClient, SharedNamespaceWorkerTrait, TemporalServiceClientWithMetrics, +}; +use temporal_sdk_core_api::telemetry::metrics::TemporalMeter; use temporal_sdk_core_api::{ errors::{CompleteNexusError, WorkerValidationError}, worker::PollerBehavior, }; +use temporal_sdk_core_protos::temporal::api::worker::v1::{WorkerHeartbeat, WorkerHostInfo}; use temporal_sdk_core_protos::{ TaskToken, coresdk::{ @@ -80,7 +90,8 @@ use temporal_sdk_core_protos::{ use tokio::sync::{mpsc::unbounded_channel, watch}; use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_util::sync::CancellationToken; - +use tracing::Subscriber; +use uuid::Uuid; #[cfg(any(feature = "test-utilities", test))] use { crate::{ @@ -97,8 +108,8 @@ use { pub struct Worker { config: WorkerConfig, client: Arc, - /// Registration key to enable eager workflow start for this worker - worker_key: Mutex>, + /// Worker instance key, unique identifier for this worker + worker_instance_key: Uuid, /// Manages all workflows and WFT processing workflows: Workflows, /// Manages activity tasks for this worker/task queue @@ -118,8 +129,8 @@ pub struct Worker { local_activities_complete: Arc, /// Used to track all permits have been released all_permits_tracker: tokio::sync::Mutex, - /// Used to shutdown the worker heartbeat task - worker_heartbeat: Option, + /// Used to track worker client + client_worker_registrator: Arc, } struct AllPermitsTracker { @@ -136,6 +147,13 @@ impl AllPermitsTracker { } } +#[derive(Clone)] +pub(crate) struct WorkerTelemetry { + metric_meter: Option, + temporal_metric_meter: Option, + trace_subscriber: Option>, +} + #[async_trait::async_trait] impl WorkerTrait for Worker { async fn validate(&self) -> Result<(), WorkerValidationError> { @@ -231,10 +249,20 @@ impl WorkerTrait for Worker { ); } self.shutdown_token.cancel(); - // First, disable Eager Workflow Start - if let Some(key) = *self.worker_key.lock() { - self.client.workers().unregister(key); + // First, unregister worker from the client + if let Err(e) = self + .client + .workers() + .unregister_worker(self.worker_instance_key) + { + error!( + task_queue=%self.config.task_queue, + namespace=%self.config.namespace, + error=%e, + "Failed to unregister worker on shutdown", + ); } + // Second, we want to stop polling of both activity and workflow tasks if let Some(atm) = self.at_task_mgr.as_ref() { atm.initiate_shutdown(); @@ -272,8 +300,9 @@ impl Worker { sticky_queue_name: Option, client: Arc, telem_instance: Option<&TelemetryInstance>, - heartbeat_fn: Option>>, - ) -> Self { + worker_heartbeat_interval: Option, + shared_namespace_worker: bool, + ) -> Result { info!(task_queue=%config.task_queue, namespace=%config.namespace, "Initializing worker"); Self::new_with_pollers( @@ -282,25 +311,43 @@ impl Worker { client, TaskPollers::Real, telem_instance, - heartbeat_fn, + worker_heartbeat_interval, + shared_namespace_worker, ) } - /// Replace client and return a new client. For eager workflow purposes, this new client will - /// now apply to future eager start requests and the older client will not. - pub fn replace_client(&self, new_client: ConfiguredClient) { + /// Replace client and return a new client. + /// + /// For eager workflow purposes, this new client will now apply to future eager start requests + /// and the older client will not. Note, if this registration fails, the worker heartbeat will + /// also not be registered. + /// + /// For worker heartbeat, this will remove an existing shared worker if it is the last worker of + /// the old client and create a new nexus worker if it's the first client of the namespace on + /// the new client. + pub fn replace_client( + &self, + new_client: ConfiguredClient, + ) -> Result<(), anyhow::Error> { // Unregister worker from current client, register in new client at the end - let mut worker_key = self.worker_key.lock(); - let slot_provider = (*worker_key).and_then(|k| self.client.workers().unregister(k)); - self.client - .replace_client(super::init_worker_client(&self.config, new_client)); - *worker_key = - slot_provider.and_then(|slot_provider| self.client.workers().register(slot_provider)); + let client_worker = self + .client + .workers() + .unregister_worker(self.worker_instance_key)?; + let new_worker_client = super::init_worker_client( + self.config.namespace.clone(), + self.config.client_identity_override.clone(), + new_client, + ); + + self.client.replace_client(new_worker_client); + *self.client_worker_registrator.client.write() = self.client.clone(); + self.client.workers().register_worker(client_worker) } #[cfg(test)] pub(crate) fn new_test(config: WorkerConfig, client: impl WorkerClient + 'static) -> Self { - Self::new(config, None, Arc::new(client), None, None) + Self::new(config, None, Arc::new(client), None, None, false).unwrap() } pub(crate) fn new_with_pollers( @@ -309,16 +356,48 @@ impl Worker { client: Arc, task_pollers: TaskPollers, telem_instance: Option<&TelemetryInstance>, - heartbeat_fn: Option>>, - ) -> Self { - let (metrics, meter) = if let Some(ti) = telem_instance { + worker_heartbeat_interval: Option, + shared_namespace_worker: bool, + ) -> Result { + let worker_telemetry = telem_instance.map(|telem| WorkerTelemetry { + metric_meter: telem.get_metric_meter(), + temporal_metric_meter: telem.get_temporal_metric_meter(), + trace_subscriber: telem.trace_subscriber(), + }); + + Worker::new_with_pollers_inner( + config, + sticky_queue_name, + client, + task_pollers, + worker_telemetry, + worker_heartbeat_interval, + shared_namespace_worker, + ) + } + + pub(crate) fn new_with_pollers_inner( + config: WorkerConfig, + sticky_queue_name: Option, + client: Arc, + task_pollers: TaskPollers, + worker_telemetry: Option, + worker_heartbeat_interval: Option, + shared_namespace_worker: bool, + ) -> Result { + let (metrics, meter) = if let Some(wt) = worker_telemetry.as_ref() { ( - MetricsContext::top_level(config.namespace.clone(), config.task_queue.clone(), ti), - ti.get_metric_meter(), + MetricsContext::top_level_with_meter( + config.namespace.clone(), + config.task_queue.clone(), + wt.temporal_metric_meter.clone(), + ), + wt.metric_meter.clone(), ) } else { (MetricsContext::no_op(), None) }; + let tuner = config .tuner .as_ref() @@ -329,7 +408,7 @@ impl Worker { let shutdown_token = CancellationToken::new(); let slot_context_data = Arc::new(PermitDealerContextData { task_queue: config.task_queue.clone(), - worker_identity: client.get_identity(), + worker_identity: client.identity(), worker_deployment_version: config.computed_deployment_version(), }); let wft_slots = MeteredPermitDealer::new( @@ -408,6 +487,7 @@ impl Worker { nexus_slots.clone(), shutdown_token.child_token(), Some(move |np| np_metrics.record_num_pollers(np)), + shared_namespace_worker, )) as BoxedNexusPoller; #[cfg(any(feature = "test-utilities", test))] @@ -486,20 +566,33 @@ impl Worker { wft_slots.clone(), external_wft_tx, ); - let worker_key = Mutex::new(client.workers().register(Box::new(provider))); - let sdk_name_and_ver = client.sdk_name_and_version(); + let worker_instance_key = Uuid::new_v4(); - let worker_heartbeat = heartbeat_fn.map(|heartbeat_fn| { + let sdk_name_and_ver = client.sdk_name_and_version(); + let worker_heartbeat = worker_heartbeat_interval.map(|hb_interval| { WorkerHeartbeatManager::new( config.clone(), - client.get_identity(), - heartbeat_fn, - client.clone(), + worker_instance_key, + hb_interval, + worker_telemetry.clone(), ) }); - Self { - worker_key, + let client_worker_registrator = Arc::new(ClientWorkerRegistrator { + worker_instance_key, + slot_provider: provider, + heartbeat_manager: worker_heartbeat, + client: RwLock::new(client.clone()), + }); + + if !shared_namespace_worker { + client + .workers() + .register_worker(client_worker_registrator.clone())?; + } + + Ok(Self { + worker_instance_key, client: client.clone(), workflows: Workflows::new( WorkflowBasics { @@ -538,7 +631,9 @@ impl Worker { _ => Some(mgr.get_handle_for_workflows()), } }), - telem_instance, + worker_telemetry + .as_ref() + .and_then(|telem| telem.trace_subscriber.clone()), ), at_task_mgr, local_act_mgr, @@ -554,8 +649,8 @@ impl Worker { la_permits, }), nexus_mgr, - worker_heartbeat, - } + client_worker_registrator, + }) } /// Will shutdown the worker. Does not resolve until all outstanding workflow tasks have been @@ -599,9 +694,6 @@ impl Worker { dbg_panic!("Waiting for all slot permits to release took too long!"); } } - if let Some(heartbeat) = self.worker_heartbeat.as_ref() { - heartbeat.shutdown(); - } } /// Finish shutting down by consuming the background pollers and freeing all resources @@ -858,6 +950,124 @@ impl Worker { } } +struct ClientWorkerRegistrator { + worker_instance_key: Uuid, + slot_provider: SlotProvider, + heartbeat_manager: Option, + client: RwLock>, +} + +impl ClientWorker for ClientWorkerRegistrator { + fn namespace(&self) -> &str { + self.slot_provider.namespace() + } + fn task_queue(&self) -> &str { + self.slot_provider.task_queue() + } + + fn try_reserve_wft_slot(&self) -> Option> { + self.slot_provider.try_reserve_wft_slot() + } + + fn worker_instance_key(&self) -> Uuid { + self.worker_instance_key + } + + fn heartbeat_enabled(&self) -> bool { + self.heartbeat_manager.is_some() + } + + fn heartbeat_callback(&self) -> Option { + if let Some(hb_mgr) = self.heartbeat_manager.as_ref() { + let mut heartbeat_manager = hb_mgr.heartbeat_callback.lock(); + heartbeat_manager.take() + } else { + None + } + } + fn new_shared_namespace_worker( + &self, + ) -> Result, anyhow::Error> { + if let Some(ref hb_mgr) = self.heartbeat_manager { + Ok(Box::new(SharedNamespaceWorker::new( + self.client.read().clone(), + self.namespace().to_string(), + hb_mgr.heartbeat_interval, + hb_mgr.telemetry.clone(), + )?)) + } else { + bail!("Shared namespace worker creation never be called without a heartbeat manager"); + } + } + + fn register_callback(&self, callback: HeartbeatCallback) { + if let Some(hb_mgr) = self.heartbeat_manager.as_ref() { + hb_mgr.heartbeat_callback.lock().replace(callback); + } + } +} + +struct WorkerHeartbeatManager { + /// Heartbeat interval, defaults to 60s + heartbeat_interval: Duration, + /// Telemetry instance, needed to initialize [SharedNamespaceWorker] when replacing client + telemetry: Option, + /// Heartbeat callback + heartbeat_callback: Mutex WorkerHeartbeat + Send + Sync>>>, +} + +impl WorkerHeartbeatManager { + fn new( + config: WorkerConfig, + worker_instance_key: Uuid, + heartbeat_interval: Duration, + telemetry_instance: Option, + ) -> Self { + let worker_instance_key_clone = worker_instance_key.to_string(); + let task_queue = config.task_queue.clone(); + + // TODO: requires the metrics changes to get the rest of these fields + let worker_heartbeat_callback: HeartbeatFn = Box::new(move || { + WorkerHeartbeat { + worker_instance_key: worker_instance_key_clone.clone(), + host_info: Some(WorkerHostInfo { + host_name: gethostname().to_string_lossy().to_string(), + process_id: std::process::id().to_string(), + ..Default::default() + }), + task_queue: task_queue.clone(), + deployment_version: None, + + status: 0, + start_time: Some(std::time::SystemTime::now().into()), + workflow_task_slots_info: None, + activity_task_slots_info: None, + nexus_task_slots_info: None, + local_activity_slots_info: None, + workflow_poller_info: None, + workflow_sticky_poller_info: None, + activity_poller_info: None, + nexus_poller_info: None, + total_sticky_cache_hit: 0, + total_sticky_cache_miss: 0, + current_sticky_cache_size: 0, + plugins: vec![], + + // sdk_name, sdk_version, and worker_identity must be set by + // SharedNamespaceWorker because they rely on the client, and + // need to be pulled from the current client used by SharedNamespaceWorker + ..Default::default() + } + }); + + WorkerHeartbeatManager { + heartbeat_interval, + telemetry: telemetry_instance, + heartbeat_callback: Mutex::new(Some(worker_heartbeat_callback)), + } + } +} + pub(crate) struct PostActivateHookData<'a> { pub(crate) run_id: &'a str, pub(crate) replaying: bool, diff --git a/core/src/worker/slot_provider.rs b/core/src/worker/slot_provider.rs index 1b5fcba95..40e9fb00e 100644 --- a/core/src/worker/slot_provider.rs +++ b/core/src/worker/slot_provider.rs @@ -7,7 +7,7 @@ use crate::{ protosext::ValidPollWFTQResponse, worker::workflow::wft_poller::validate_wft, }; -use temporal_client::{Slot as SlotTrait, SlotProvider as SlotProviderTrait}; +use temporal_client::Slot as SlotTrait; use temporal_sdk_core_api::worker::WorkflowSlotKind; use temporal_sdk_core_protos::temporal::api::workflowservice::v1::PollWorkflowTaskQueueResponse; use tokio::sync::mpsc::UnboundedSender; @@ -74,16 +74,13 @@ impl SlotProvider { external_wft_tx, } } -} - -impl SlotProviderTrait for SlotProvider { - fn namespace(&self) -> &str { + pub(super) fn namespace(&self) -> &str { &self.namespace } - fn task_queue(&self) -> &str { + pub(super) fn task_queue(&self) -> &str { &self.task_queue } - fn try_reserve_wft_slot(&self) -> Option> { + pub(super) fn try_reserve_wft_slot(&self) -> Option> { match self.wft_semaphore.try_acquire_owned().ok() { Some(permit) => Some(Box::new(Slot::new(permit, self.external_wft_tx.clone()))), None => None, diff --git a/core/src/worker/workflow/mod.rs b/core/src/worker/workflow/mod.rs index 33b9c6dac..24aa59f11 100644 --- a/core/src/worker/workflow/mod.rs +++ b/core/src/worker/workflow/mod.rs @@ -23,7 +23,7 @@ use crate::{ internal_flags::InternalFlags, pollers::TrackedPermittedTqResp, protosext::{ValidPollWFTQResponse, protocol_messages::IncomingProtocolMessage}, - telemetry::{TelemetryInstance, VecDisplayer, set_trace_subscriber_for_current_thread}, + telemetry::{VecDisplayer, set_trace_subscriber_for_current_thread}, worker::{ LocalActRequest, LocalActivityExecutionResult, LocalActivityResolution, PostActivateHookData, @@ -94,7 +94,7 @@ use tokio::{ }; use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_util::sync::CancellationToken; -use tracing::Span; +use tracing::{Span, Subscriber}; /// Id used by server for "legacy" queries. IE: Queries that come in the `query` rather than /// `queries` field of a WFT, and are responded to on the separate `respond_query_task_completed` @@ -166,7 +166,7 @@ impl Workflows { local_act_mgr: Arc, heartbeat_timeout_rx: UnboundedReceiver, activity_tasks_handle: Option, - telem_instance: Option<&TelemetryInstance>, + tracing_sub: Option>, ) -> Self { let (local_tx, local_rx) = unbounded_channel(); let (fetch_tx, fetch_rx) = unbounded_channel(); @@ -187,7 +187,6 @@ impl Workflows { let (start_polling_tx, start_polling_rx) = oneshot::channel(); // We must spawn a task to constantly poll the activation stream, because otherwise // activation completions would not cause anything to happen until the next poll. - let tracing_sub = telem_instance.and_then(|ti| ti.trace_subscriber()); let processing_task = thread::Builder::new() .name("workflow-processing".to_string()) .spawn(move || { diff --git a/sdk-core-protos/src/history_builder.rs b/sdk-core-protos/src/history_builder.rs index 3dcc306c5..9aef52737 100644 --- a/sdk-core-protos/src/history_builder.rs +++ b/sdk-core-protos/src/history_builder.rs @@ -621,7 +621,7 @@ fn default_attribs(et: EventType) -> Result { EventType::WorkflowExecutionStarted => default_wes_attribs().into(), EventType::WorkflowTaskScheduled => WorkflowTaskScheduledEventAttributes::default().into(), EventType::TimerStarted => TimerStartedEventAttributes::default().into(), - _ => bail!("Don't know how to construct default attrs for {:?}", et), + _ => bail!("Don't know how to construct default attrs for {et:?}"), }) } diff --git a/sdk/src/interceptors.rs b/sdk/src/interceptors.rs index 84a7b56fd..5d014fb17 100644 --- a/sdk/src/interceptors.rs +++ b/sdk/src/interceptors.rs @@ -88,10 +88,7 @@ impl WorkerInterceptor for FailOnNondeterminismInterceptor { activation.eviction_reason(), Some(EvictionReason::Nondeterminism) ) { - bail!( - "Workflow is being evicted because of nondeterminism! {}", - activation - ); + bail!("Workflow is being evicted because of nondeterminism! {activation}"); } Ok(()) } diff --git a/sdk/src/lib.rs b/sdk/src/lib.rs index 90115715e..e3970b4e2 100644 --- a/sdk/src/lib.rs +++ b/sdk/src/lib.rs @@ -10,7 +10,7 @@ //! ```no_run //! use std::{str::FromStr, sync::Arc}; //! use temporal_sdk::{sdk_client_options, ActContext, Worker}; -//! use temporal_sdk_core::{init_worker, Url, CoreRuntime}; +//! use temporal_sdk_core::{init_worker, Url, CoreRuntime, RuntimeOptionsBuilder}; //! use temporal_sdk_core_api::{ //! worker::{WorkerConfigBuilder, WorkerVersioningStrategy}, //! telemetry::TelemetryOptionsBuilder @@ -20,10 +20,11 @@ //! async fn main() -> Result<(), Box> { //! let server_options = sdk_client_options(Url::from_str("http://localhost:7233")?).build()?; //! -//! let client = server_options.connect("default", None).await?; -//! //! let telemetry_options = TelemetryOptionsBuilder::default().build()?; -//! let runtime = CoreRuntime::new_assume_tokio(telemetry_options)?; +//! let runtime_options = RuntimeOptionsBuilder::default().telemetry_options(telemetry_options).build().unwrap(); +//! let runtime = CoreRuntime::new_assume_tokio(runtime_options)?; +//! +//! let client = server_options.connect("default", None).await?; //! //! let worker_config = WorkerConfigBuilder::default() //! .namespace("default") @@ -497,11 +498,7 @@ impl WorkflowHalf { // In all other cases, we want to error as the runtime could be in an inconsistent state // at this point. - bail!( - "Got activation {:?} for unknown workflow {}", - activation, - run_id - ); + bail!("Got activation {activation:?} for unknown workflow {run_id}"); }; Ok(res) diff --git a/sdk/src/workflow_future.rs b/sdk/src/workflow_future.rs index 9e1b6036e..648d58f82 100644 --- a/sdk/src/workflow_future.rs +++ b/sdk/src/workflow_future.rs @@ -135,7 +135,7 @@ impl WorkflowFuture { }; let unblocker = self.command_status.remove(&cmd_id); let _ = unblocker - .ok_or_else(|| anyhow!("Command {:?} not found to unblock!", cmd_id))? + .ok_or_else(|| anyhow!("Command {cmd_id:?} not found to unblock!"))? .unblocker .send(event); Ok(()) diff --git a/tests/common/mod.rs b/tests/common/mod.rs index b17d07d31..35b2825c3 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -39,8 +39,8 @@ use temporal_sdk::{ }, }; use temporal_sdk_core::{ - ClientOptions, ClientOptionsBuilder, CoreRuntime, WorkerConfigBuilder, init_replay_worker, - init_worker, + ClientOptions, ClientOptionsBuilder, CoreRuntime, RuntimeOptions, RuntimeOptionsBuilder, + WorkerConfigBuilder, init_replay_worker, init_worker, replay::{HistoryForReplay, ReplayWorkerInput}, telemetry::{build_otlp_metric_exporter, start_prometheus_metric_exporter}, }; @@ -164,8 +164,12 @@ pub(crate) fn init_integ_telem() -> Option<&'static CoreRuntime> { } Some(INTEG_TESTS_RT.get_or_init(|| { let telemetry_options = get_integ_telem_options(); + let runtime_options = RuntimeOptionsBuilder::default() + .telemetry_options(telemetry_options) + .build() + .expect("Runtime options build cleanly"); let rt = - CoreRuntime::new_assume_tokio(telemetry_options).expect("Core runtime inits cleanly"); + CoreRuntime::new_assume_tokio(runtime_options).expect("Core runtime inits cleanly"); if let Some(sub) = rt.telemetry().trace_subscriber() { let _ = tracing::subscriber::set_global_default(sub); } @@ -314,8 +318,7 @@ impl CoreWfStarter { pub(crate) async fn worker(&mut self) -> TestWorker { let w = self.get_worker().await; - let tq = w.get_config().task_queue.clone(); - let mut w = TestWorker::new(w, tq); + let mut w = TestWorker::new(w); w.client = Some(self.get_client().await); w @@ -477,8 +480,11 @@ pub(crate) struct TestWorker { } impl TestWorker { /// Create a new test worker - pub(crate) fn new(core_worker: Arc, task_queue: impl Into) -> Self { - let inner = Worker::new_from_core(core_worker.clone(), task_queue); + pub(crate) fn new(core_worker: Arc) -> Self { + let inner = Worker::new_from_core( + core_worker.clone(), + core_worker.get_config().task_queue.clone(), + ); Self { inner, core_worker, @@ -807,6 +813,13 @@ pub(crate) fn get_integ_telem_options() -> TelemetryOptions { .unwrap() } +pub(crate) fn get_integ_runtime_options(telemopts: TelemetryOptions) -> RuntimeOptions { + RuntimeOptionsBuilder::default() + .telemetry_options(telemopts) + .build() + .unwrap() +} + #[async_trait::async_trait(?Send)] pub(crate) trait WorkflowHandleExt { async fn fetch_history_and_replay( @@ -932,10 +945,7 @@ pub(crate) fn mock_sdk_cfg( let mut mock = build_mock_pollers(poll_cfg); mock.worker_cfg(mutator); let core = mock_worker(mock); - TestWorker::new( - Arc::new(core), - temporal_sdk_core::test_help::TEST_Q.to_string(), - ) + TestWorker::new(Arc::new(core)) } #[derive(Default)] diff --git a/tests/global_metric_tests.rs b/tests/global_metric_tests.rs index 14e799195..822bf238c 100644 --- a/tests/global_metric_tests.rs +++ b/tests/global_metric_tests.rs @@ -2,6 +2,7 @@ #[allow(dead_code)] mod common; +use crate::common::get_integ_runtime_options; use common::CoreWfStarter; use parking_lot::Mutex; use std::{sync::Arc, time::Duration}; @@ -71,18 +72,16 @@ async fn otel_errors_logged_as_errors() { .unwrap(), ) .unwrap(); + let telemopts = TelemetryOptionsBuilder::default() + .metrics(Arc::new(exporter) as Arc) + // Importantly, _not_ using subscriber override, is using console. + .logging(Logger::Console { + filter: construct_filter_string(Level::INFO, Level::WARN), + }) + .build() + .unwrap(); - let rt = CoreRuntime::new_assume_tokio( - TelemetryOptionsBuilder::default() - .metrics(Arc::new(exporter) as Arc) - // Importantly, _not_ using subscriber override, is using console. - .logging(Logger::Console { - filter: construct_filter_string(Level::INFO, Level::WARN), - }) - .build() - .unwrap(), - ) - .unwrap(); + let rt = CoreRuntime::new_assume_tokio(get_integ_runtime_options(telemopts)).unwrap(); let mut starter = CoreWfStarter::new_with_runtime("otel_errors_logged_as_errors", rt); let _worker = starter.get_worker().await; diff --git a/tests/heavy_tests.rs b/tests/heavy_tests.rs index f5dc9018d..64bb298fb 100644 --- a/tests/heavy_tests.rs +++ b/tests/heavy_tests.rs @@ -2,6 +2,7 @@ #[allow(dead_code)] mod common; +use crate::common::get_integ_runtime_options; use common::{ CoreWfStarter, init_integ_telem, prom_metrics, rand_6_chars, workflows::la_problem_workflow, }; @@ -194,7 +195,7 @@ async fn workflow_load() { // cause us to encounter the tracing span drop bug telemopts.logging = None; init_integ_telem(); - let rt = CoreRuntime::new_assume_tokio(telemopts).unwrap(); + let rt = CoreRuntime::new_assume_tokio(get_integ_runtime_options(telemopts)).unwrap(); let mut starter = CoreWfStarter::new_with_runtime("workflow_load", rt); starter .worker_config diff --git a/tests/integ_tests/metrics_tests.rs b/tests/integ_tests/metrics_tests.rs index dd4068ac4..dc7caa812 100644 --- a/tests/integ_tests/metrics_tests.rs +++ b/tests/integ_tests/metrics_tests.rs @@ -1,3 +1,4 @@ +use crate::common::get_integ_runtime_options; use crate::{ common::{ ANY_PORT, CoreWfStarter, NAMESPACE, OTEL_URL_ENV_VAR, PROMETHEUS_QUERY_API, @@ -97,7 +98,7 @@ async fn prometheus_metrics_exported( }); } let (telemopts, addr, _aborter) = prom_metrics(Some(opts_builder.build().unwrap())); - let rt = CoreRuntime::new_assume_tokio(telemopts).unwrap(); + let rt = CoreRuntime::new_assume_tokio(get_integ_runtime_options(telemopts)).unwrap(); let opts = get_integ_server_options(); let mut raw_client = opts .connect_no_namespace(rt.telemetry().get_temporal_metric_meter()) @@ -148,7 +149,7 @@ async fn prometheus_metrics_exported( async fn one_slot_worker_reports_available_slot() { let (telemopts, addr, _aborter) = prom_metrics(None); let tq = "one_slot_worker_tq"; - let rt = CoreRuntime::new_assume_tokio(telemopts).unwrap(); + let rt = CoreRuntime::new_assume_tokio(get_integ_runtime_options(telemopts)).unwrap(); let worker_cfg = WorkerConfigBuilder::default() .namespace(NAMESPACE) @@ -401,7 +402,7 @@ async fn query_of_closed_workflow_doesnt_tick_terminal_metric( completion: workflow_command::Variant, ) { let (telemopts, addr, _aborter) = prom_metrics(None); - let rt = CoreRuntime::new_assume_tokio(telemopts).unwrap(); + let rt = CoreRuntime::new_assume_tokio(get_integ_runtime_options(telemopts)).unwrap(); let mut starter = CoreWfStarter::new_with_runtime("query_of_closed_workflow_doesnt_tick_terminal_metric", rt); // Disable cache to ensure replay happens completely @@ -523,8 +524,11 @@ async fn query_of_closed_workflow_doesnt_tick_terminal_metric( #[test] fn runtime_new() { - let mut rt = - CoreRuntime::new(get_integ_telem_options(), TokioRuntimeBuilder::default()).unwrap(); + let mut rt = CoreRuntime::new( + get_integ_runtime_options(get_integ_telem_options()), + TokioRuntimeBuilder::default(), + ) + .unwrap(); let handle = rt.tokio_handle(); let _rt = handle.enter(); let (telemopts, addr, _aborter) = prom_metrics(None); @@ -570,7 +574,7 @@ async fn latency_metrics( .build() .unwrap(), )); - let rt = CoreRuntime::new_assume_tokio(telemopts).unwrap(); + let rt = CoreRuntime::new_assume_tokio(get_integ_runtime_options(telemopts)).unwrap(); let mut starter = CoreWfStarter::new_with_runtime("latency_metrics", rt); let worker = starter.get_worker().await; starter.start_wf().await; @@ -624,7 +628,7 @@ async fn latency_metrics( #[tokio::test] async fn request_fail_codes() { let (telemopts, addr, _aborter) = prom_metrics(None); - let rt = CoreRuntime::new_assume_tokio(telemopts).unwrap(); + let rt = CoreRuntime::new_assume_tokio(get_integ_runtime_options(telemopts)).unwrap(); let opts = get_integ_server_options(); let mut client = opts .connect(NAMESPACE, rt.telemetry().get_temporal_metric_meter()) @@ -667,8 +671,8 @@ async fn request_fail_codes_otel() { let mut telemopts = TelemetryOptionsBuilder::default(); let exporter = Arc::new(exporter); telemopts.metrics(exporter as Arc); - - let rt = CoreRuntime::new_assume_tokio(telemopts.build().unwrap()).unwrap(); + let rt = CoreRuntime::new_assume_tokio(get_integ_runtime_options(telemopts.build().unwrap())) + .unwrap(); let opts = get_integ_server_options(); let mut client = opts .connect(NAMESPACE, rt.telemetry().get_temporal_metric_meter()) @@ -718,7 +722,7 @@ async fn docker_metrics_with_prometheus( .metric_prefix(test_uid.clone()) .build() .unwrap(); - let rt = CoreRuntime::new_assume_tokio(telemopts).unwrap(); + let rt = CoreRuntime::new_assume_tokio(get_integ_runtime_options(telemopts)).unwrap(); let test_name = "docker_metrics_with_prometheus"; let mut starter = CoreWfStarter::new_with_runtime(test_name, rt); let worker = starter.get_worker().await; @@ -758,8 +762,15 @@ async fn docker_metrics_with_prometheus( assert!(!data.is_empty(), "No metrics found for query: {test_uid}"); assert_eq!(data[0]["metric"]["exported_job"], "temporal-core-sdk"); assert_eq!(data[0]["metric"]["job"], "otel-collector"); + // Worker heartbeating nexus worker assert!( data[0]["metric"]["task_queue"] + .as_str() + .unwrap() + .starts_with("temporal-sys/worker-commands/default/") + ); + assert!( + data[1]["metric"]["task_queue"] .as_str() .unwrap() .starts_with(test_name) @@ -772,7 +783,7 @@ async fn docker_metrics_with_prometheus( #[tokio::test] async fn activity_metrics() { let (telemopts, addr, _aborter) = prom_metrics(None); - let rt = CoreRuntime::new_assume_tokio(telemopts).unwrap(); + let rt = CoreRuntime::new_assume_tokio(get_integ_runtime_options(telemopts)).unwrap(); let wf_name = "activity_metrics"; let mut starter = CoreWfStarter::new_with_runtime(wf_name, rt); starter @@ -906,7 +917,7 @@ async fn activity_metrics() { #[tokio::test] async fn nexus_metrics() { let (telemopts, addr, _aborter) = prom_metrics(None); - let rt = CoreRuntime::new_assume_tokio(telemopts).unwrap(); + let rt = CoreRuntime::new_assume_tokio(get_integ_runtime_options(telemopts)).unwrap(); let wf_name = "nexus_metrics"; let mut starter = CoreWfStarter::new_with_runtime(wf_name, rt); starter.worker_config.no_remote_activities(true); @@ -1083,7 +1094,7 @@ async fn nexus_metrics() { #[tokio::test] async fn evict_on_complete_does_not_count_as_forced_eviction() { let (telemopts, addr, _aborter) = prom_metrics(None); - let rt = CoreRuntime::new_assume_tokio(telemopts).unwrap(); + let rt = CoreRuntime::new_assume_tokio(get_integ_runtime_options(telemopts)).unwrap(); let wf_name = "evict_on_complete_does_not_count_as_forced_eviction"; let mut starter = CoreWfStarter::new_with_runtime(wf_name, rt); starter.worker_config.no_remote_activities(true); @@ -1166,7 +1177,7 @@ where #[tokio::test] async fn metrics_available_from_custom_slot_supplier() { let (telemopts, addr, _aborter) = prom_metrics(None); - let rt = CoreRuntime::new_assume_tokio(telemopts).unwrap(); + let rt = CoreRuntime::new_assume_tokio(get_integ_runtime_options(telemopts)).unwrap(); let mut starter = CoreWfStarter::new_with_runtime("metrics_available_from_custom_slot_supplier", rt); starter.worker_config.no_remote_activities(true); diff --git a/tests/integ_tests/polling_tests.rs b/tests/integ_tests/polling_tests.rs index a69006e39..aaa69ec16 100644 --- a/tests/integ_tests/polling_tests.rs +++ b/tests/integ_tests/polling_tests.rs @@ -198,7 +198,9 @@ async fn switching_worker_client_changes_poll() { // Swap client, poll for next task, confirm it's second wf, and respond w/ empty info!("Replacing client and polling again"); - worker.replace_client(client2.get_client().inner().clone()); + worker + .replace_client(client2.get_client().inner().clone()) + .unwrap(); let act2 = worker.poll_workflow_activation().await.unwrap(); assert_eq!(wf2.run_id, act2.run_id); worker.complete_execution(&act2.run_id).await; diff --git a/tests/integ_tests/worker_tests.rs b/tests/integ_tests/worker_tests.rs index 728ba3222..920e2606d 100644 --- a/tests/integ_tests/worker_tests.rs +++ b/tests/integ_tests/worker_tests.rs @@ -1,3 +1,4 @@ +use crate::common::get_integ_runtime_options; use crate::{ common::{CoreWfStarter, get_integ_server_options, get_integ_telem_options, mock_sdk_cfg}, shared_tests, @@ -17,8 +18,8 @@ use temporal_sdk::{ActivityOptions, WfContext, interceptors::WorkerInterceptor}; use temporal_sdk_core::{ CoreRuntime, ResourceBasedTuner, ResourceSlotOptions, init_worker, test_help::{ - FakeWfResponses, MockPollCfg, ResponseType, TEST_Q, build_mock_pollers, - drain_pollers_and_shutdown, hist_to_poll_resp, mock_worker, mock_worker_client, + FakeWfResponses, MockPollCfg, ResponseType, build_mock_pollers, drain_pollers_and_shutdown, + hist_to_poll_resp, mock_worker, mock_worker_client, }, }; use temporal_sdk_core_api::{ @@ -61,7 +62,9 @@ use uuid::Uuid; #[tokio::test] async fn worker_validation_fails_on_nonexistent_namespace() { let opts = get_integ_server_options(); - let runtime = CoreRuntime::new_assume_tokio(get_integ_telem_options()).unwrap(); + let runtime = + CoreRuntime::new_assume_tokio(get_integ_runtime_options(get_integ_telem_options())) + .unwrap(); let retrying_client = opts .connect_no_namespace(runtime.telemetry().get_temporal_metric_meter()) .await @@ -318,7 +321,7 @@ async fn activity_tasks_from_completion_reserve_slots() { cfg.max_outstanding_activities = Some(2); }); let core = Arc::new(mock_worker(mock)); - let mut worker = crate::common::TestWorker::new(core.clone(), TEST_Q.to_string()); + let mut worker = crate::common::TestWorker::new(core.clone()); // First poll for activities twice, occupying both slots let at1 = core.poll_activity_task().await.unwrap(); diff --git a/tests/integ_tests/workflow_tests.rs b/tests/integ_tests/workflow_tests.rs index 6d10fbb71..263c942b0 100644 --- a/tests/integ_tests/workflow_tests.rs +++ b/tests/integ_tests/workflow_tests.rs @@ -18,6 +18,7 @@ mod stickyness; mod timers; mod upsert_search_attrs; +use crate::common::get_integ_runtime_options; use crate::{ common::{ CoreWfStarter, history_from_proto_binary, init_core_and_create_wf, @@ -67,7 +68,6 @@ use temporal_sdk_core_protos::{ test_utils::schedule_activity_cmd, }; use tokio::{join, sync::Notify, time::sleep}; - // TODO: We should get expected histories for these tests and confirm that the history at the end // matches. @@ -764,7 +764,7 @@ async fn nondeterminism_errors_fail_workflow_when_configured_to( #[values(true, false)] whole_worker: bool, ) { let (telemopts, addr, _aborter) = prom_metrics(None); - let rt = CoreRuntime::new_assume_tokio(telemopts).unwrap(); + let rt = CoreRuntime::new_assume_tokio(get_integ_runtime_options(telemopts)).unwrap(); let wf_name = "nondeterminism_errors_fail_workflow_when_configured_to"; let mut starter = CoreWfStarter::new_with_runtime(wf_name, rt); starter.worker_config.no_remote_activities(true); diff --git a/tests/main.rs b/tests/main.rs index bf05e2a1f..8a71f03ca 100644 --- a/tests/main.rs +++ b/tests/main.rs @@ -27,7 +27,8 @@ mod integ_tests { mod workflow_tests; use crate::common::{ - CoreWfStarter, get_integ_server_options, get_integ_telem_options, rand_6_chars, + CoreWfStarter, get_integ_runtime_options, get_integ_server_options, + get_integ_telem_options, rand_6_chars, }; use std::time::Duration; use temporal_client::{NamespacedClient, WorkflowService}; @@ -44,7 +45,9 @@ mod integ_tests { #[ignore] // Really a compile time check more than anything async fn lang_bridge_example() { let opts = get_integ_server_options(); - let runtime = CoreRuntime::new_assume_tokio(get_integ_telem_options()).unwrap(); + let runtime = + CoreRuntime::new_assume_tokio(get_integ_runtime_options(get_integ_telem_options())) + .unwrap(); let mut retrying_client = opts .connect_no_namespace(runtime.telemetry().get_temporal_metric_meter()) .await diff --git a/tests/manual_tests.rs b/tests/manual_tests.rs index 8f5ef4c5b..9588be3f3 100644 --- a/tests/manual_tests.rs +++ b/tests/manual_tests.rs @@ -5,6 +5,7 @@ #[allow(dead_code)] mod common; +use crate::common::get_integ_runtime_options; use common::{CoreWfStarter, prom_metrics, rand_6_chars}; use futures_util::{ StreamExt, @@ -41,7 +42,7 @@ async fn poller_load_spiky() { } else { prom_metrics(None) }; - let rt = CoreRuntime::new_assume_tokio(telemopts).unwrap(); + let rt = CoreRuntime::new_assume_tokio(get_integ_runtime_options(telemopts)).unwrap(); let mut starter = CoreWfStarter::new_with_runtime("poller_load", rt); starter .worker_config @@ -200,7 +201,7 @@ async fn poller_load_sustained() { } else { prom_metrics(None) }; - let rt = CoreRuntime::new_assume_tokio(telemopts).unwrap(); + let rt = CoreRuntime::new_assume_tokio(get_integ_runtime_options(telemopts)).unwrap(); let mut starter = CoreWfStarter::new_with_runtime("poller_load", rt); starter .worker_config @@ -291,7 +292,7 @@ async fn poller_load_spike_then_sustained() { } else { prom_metrics(None) }; - let rt = CoreRuntime::new_assume_tokio(telemopts).unwrap(); + let rt = CoreRuntime::new_assume_tokio(get_integ_runtime_options(telemopts)).unwrap(); let mut starter = CoreWfStarter::new_with_runtime("poller_load", rt); starter .worker_config diff --git a/tests/workflow_replay_bench.rs b/tests/workflow_replay_bench.rs index 4200ebd3a..d80796b0a 100644 --- a/tests/workflow_replay_bench.rs +++ b/tests/workflow_replay_bench.rs @@ -5,7 +5,9 @@ #[allow(dead_code)] mod common; -use crate::common::{DONT_AUTO_INIT_INTEG_TELEM, prom_metrics, replay_sdk_worker}; +use crate::common::{ + DONT_AUTO_INIT_INTEG_TELEM, get_integ_runtime_options, prom_metrics, replay_sdk_worker, +}; use criterion::{BatchSize, Criterion, criterion_group, criterion_main}; use futures_util::StreamExt; use std::{ @@ -80,7 +82,7 @@ pub fn bench_metrics(c: &mut Criterion) { let _tokio = tokio_runtime.enter(); let (mut telemopts, _addr, _aborter) = prom_metrics(None); telemopts.logging = None; - let rt = CoreRuntime::new_assume_tokio(telemopts).unwrap(); + let rt = CoreRuntime::new_assume_tokio(get_integ_runtime_options(telemopts)).unwrap(); let meter = rt.telemetry().get_metric_meter().unwrap(); c.bench_function("Record with new attributes on each call", move |b| {