@@ -13,7 +13,15 @@ use crate::{
1313 } ,
1414} ;
1515use futures_util:: { stream, stream:: StreamExt } ;
16+ use std:: sync:: {
17+ Arc ,
18+ atomic:: { AtomicBool , Ordering } ,
19+ } ;
1620use std:: { cell:: RefCell , collections:: HashMap , time:: Duration } ;
21+ use temporalio_common:: protos:: temporal:: api:: {
22+ namespace:: v1:: { NamespaceInfo , namespace_info:: Capabilities } ,
23+ workflowservice:: v1:: DescribeNamespaceResponse ,
24+ } ;
1725use temporalio_common:: {
1826 protos:: {
1927 canned_histories,
@@ -51,6 +59,7 @@ use temporalio_common::{
5159 } ,
5260 worker:: WorkerTaskTypes ,
5361} ;
62+ use tokio:: sync:: Notify ;
5463use tokio:: sync:: { Barrier , watch} ;
5564use uuid:: Uuid ;
5665
@@ -1209,3 +1218,106 @@ async fn nexus_start_operation_failure_converts_to_legacy_for_old_server(
12091218 worker. shutdown ( ) . await ;
12101219 worker. finalize_shutdown ( ) . await ;
12111220}
1221+
1222+ /// Verifies that `initiate_shutdown` sends the `ShutdownWorker` RPC so that the server can
1223+ /// complete in-flight polls. Without this, graceful poll shutdown deadlocks: the SDK waits for
1224+ /// polls to drain, but the server was never told to flush them.
1225+ #[ tokio:: test]
1226+ async fn graceful_shutdown_sends_shutdown_worker_rpc_during_initiate ( ) {
1227+ let shutdown_rpc_called = Arc :: new ( AtomicBool :: new ( false ) ) ;
1228+ let shutdown_rpc_called_clone = shutdown_rpc_called. clone ( ) ;
1229+ // When the shutdown_worker RPC fires, it signals polls to complete (simulating server
1230+ // behavior where ShutdownWorker causes the server to return empty poll responses).
1231+ let poll_releaser = Arc :: new ( Notify :: new ( ) ) ;
1232+ let poll_releaser_for_rpc = poll_releaser. clone ( ) ;
1233+
1234+ let mut mock_client = MockWorkerClient :: new ( ) ;
1235+ mock_client
1236+ . expect_capabilities ( )
1237+ . returning ( || Some ( * DEFAULT_TEST_CAPABILITIES ) ) ;
1238+ mock_client
1239+ . expect_workers ( )
1240+ . returning ( || DEFAULT_WORKERS_REGISTRY . clone ( ) ) ;
1241+ mock_client. expect_is_mock ( ) . returning ( || true ) ;
1242+ mock_client
1243+ . expect_sdk_name_and_version ( )
1244+ . returning ( || ( "test-core" . to_string ( ) , "0.0.0" . to_string ( ) ) ) ;
1245+ mock_client
1246+ . expect_identity ( )
1247+ . returning ( || "test-identity" . to_string ( ) ) ;
1248+ mock_client
1249+ . expect_worker_grouping_key ( )
1250+ . returning ( Uuid :: new_v4) ;
1251+ mock_client
1252+ . expect_worker_instance_key ( )
1253+ . returning ( Uuid :: new_v4) ;
1254+ mock_client
1255+ . expect_set_heartbeat_client_fields ( )
1256+ . returning ( |hb| {
1257+ hb. sdk_name = "test-core" . to_string ( ) ;
1258+ hb. sdk_version = "0.0.0" . to_string ( ) ;
1259+ hb. worker_identity = "test-identity" . to_string ( ) ;
1260+ hb. heartbeat_time = Some ( std:: time:: SystemTime :: now ( ) . into ( ) ) ;
1261+ } ) ;
1262+ // Return the worker_poll_complete_on_shutdown capability so validate() enables graceful mode
1263+ mock_client. expect_describe_namespace ( ) . returning ( move || {
1264+ Ok ( DescribeNamespaceResponse {
1265+ namespace_info : Some ( NamespaceInfo {
1266+ capabilities : Some ( Capabilities {
1267+ worker_poll_complete_on_shutdown : true ,
1268+ ..Capabilities :: default ( )
1269+ } ) ,
1270+ ..NamespaceInfo :: default ( )
1271+ } ) ,
1272+ ..DescribeNamespaceResponse :: default ( )
1273+ } )
1274+ } ) ;
1275+ // When shutdown_worker RPC is called, mark it and release polls
1276+ mock_client
1277+ . expect_shutdown_worker ( )
1278+ . returning ( move |_, _, _, _| {
1279+ shutdown_rpc_called_clone. store ( true , Ordering :: SeqCst ) ;
1280+ poll_releaser_for_rpc. notify_waiters ( ) ;
1281+ Ok ( ShutdownWorkerResponse { } )
1282+ } ) ;
1283+ mock_client
1284+ . expect_complete_workflow_task ( )
1285+ . returning ( |_| Ok ( RespondWorkflowTaskCompletedResponse :: default ( ) ) ) ;
1286+
1287+ // Polls block until shutdown_worker RPC releases them (simulating server holding polls
1288+ // open until it receives the ShutdownWorker signal)
1289+ let poll_releaser_for_stream = poll_releaser. clone ( ) ;
1290+ let stream = stream:: unfold ( poll_releaser_for_stream, |releaser| async move {
1291+ releaser. notified ( ) . await ;
1292+ Some ( (
1293+ Ok ( PollWorkflowTaskQueueResponse :: default ( ) . try_into ( ) . unwrap ( ) ) ,
1294+ releaser,
1295+ ) )
1296+ } ) ;
1297+
1298+ let mw = MockWorkerInputs :: new ( stream. boxed ( ) ) ;
1299+ let worker = mock_worker ( MocksHolder :: from_mock_worker ( mock_client, mw) ) ;
1300+
1301+ // validate() reads describe_namespace and sets capabilities.graceful_poll_shutdown = true
1302+ worker. validate ( ) . await . unwrap ( ) ;
1303+
1304+ let poll_fut = worker. poll_workflow_activation ( ) ;
1305+ let shutdown_fut = async {
1306+ // initiate_shutdown must send the ShutdownWorker RPC, which releases the polls
1307+ worker. initiate_shutdown ( ) ;
1308+ } ;
1309+
1310+ let ( poll_result, _) = tokio:: time:: timeout ( Duration :: from_secs ( 5 ) , async {
1311+ tokio:: join!( poll_fut, shutdown_fut)
1312+ } )
1313+ . await
1314+ . expect ( "Shutdown should complete within 5s -- if it hangs, the ShutdownWorker RPC was not sent during initiate_shutdown" ) ;
1315+
1316+ assert_matches ! ( poll_result. unwrap_err( ) , PollError :: ShutDown ) ;
1317+ assert ! (
1318+ shutdown_rpc_called. load( Ordering :: SeqCst ) ,
1319+ "ShutdownWorker RPC must be called during initiate_shutdown"
1320+ ) ;
1321+
1322+ worker. finalize_shutdown ( ) . await ;
1323+ }
0 commit comments