diff --git a/changelog.d/19871.misc b/changelog.d/19871.misc new file mode 100644 index 00000000000..be10ee05403 --- /dev/null +++ b/changelog.d/19871.misc @@ -0,0 +1 @@ +Update `HomeserverTestCase.get_success(...)` and friends to drive async Rust (Tokio runtime/thread pool). diff --git a/tests/app/test_homeserver_shutdown.py b/tests/app/test_homeserver_shutdown.py index 0f5d1c73387..20d314cb682 100644 --- a/tests/app/test_homeserver_shutdown.py +++ b/tests/app/test_homeserver_shutdown.py @@ -76,6 +76,13 @@ async def shutdown() -> None: self.get_success(shutdown()) + # XXX: There can be a few already dispatched database queries (from normal + # background tasks in Synapse) and the threadless `ThreadPool` that we use in + # tests uses *untracked* clock calls to pass database results back so `shutdown` + # doesn't cancel those calls. This is a quirk of our test infrastructure + # (threadless `ThreadPool`) so this kind of "hack" is fine. + self.reactor.advance(0) + # Cleanup the internal reference in our test case del self.hs @@ -106,7 +113,7 @@ def test_clean_homeserver_shutdown_mid_background_updates(self) -> None: # Pump the background updates by a single iteration, just to ensure any extra # resources it uses have been started. store = weakref.proxy(self.hs.get_datastores().main) - self.get_success(store.db_pool.updates.do_next_background_update(False), by=0.1) + self.get_success(store.db_pool.updates.do_next_background_update(False)) hs_ref = weakref.ref(self.hs) @@ -127,6 +134,13 @@ async def shutdown() -> None: self.get_success(shutdown()) + # XXX: There can be a few already dispatched database queries (from normal + # background tasks in Synapse) and the threadless `ThreadPool` that we use in + # tests uses *untracked* clock calls to pass database results back so `shutdown` + # doesn't cancel those calls. This is a quirk of our test infrastructure + # (threadless `ThreadPool`) so this kind of "hack" is fine. + self.reactor.advance(0) + # Cleanup the internal reference in our test case del self.hs diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 794c0a3185f..0c7edbaa2da 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -357,7 +357,6 @@ def create_invite() -> EventBase: event.room_version, ), exc=LimitExceededError, - by=0.5, ) def _build_and_send_join_event( diff --git a/tests/handlers/test_oauth_delegation.py b/tests/handlers/test_oauth_delegation.py index c88f2c2d155..3c939b301c1 100644 --- a/tests/handlers/test_oauth_delegation.py +++ b/tests/handlers/test_oauth_delegation.py @@ -21,11 +21,10 @@ import json import threading -import time from http import HTTPStatus from http.server import BaseHTTPRequestHandler, HTTPServer from io import BytesIO -from typing import Any, ClassVar, Coroutine, Generator, TypeVar, Union +from typing import Any, ClassVar, TypeVar from unittest.mock import ANY, AsyncMock, Mock from urllib.parse import parse_qs @@ -37,7 +36,6 @@ ) from signedjson.sign import sign_json -from twisted.internet.defer import Deferred, ensureDeferred from twisted.internet.testing import MemoryReactor from synapse.api.auth.mas import MasDelegatedAuth @@ -809,31 +807,6 @@ class MasAuthDelegation(HomeserverTestCase): def device_scope(self) -> str: return self.device_scope_prefix + DEVICE - def till_deferred_has_result( - self, - awaitable: Union[ - "Coroutine[Deferred[Any], Any, T]", - "Generator[Deferred[Any], Any, T]", - "Deferred[T]", - ], - ) -> "Deferred[T]": - """Wait until a deferred has a result. - - This is useful because the Rust HTTP client will resolve the deferred - using reactor.callFromThread, which are only run when we call - reactor.advance. - """ - deferred = ensureDeferred(awaitable) - tries = 0 - while not deferred.called: - time.sleep(0.1) - self.reactor.advance(0) - tries += 1 - if tries > 100: - raise Exception("Timed out waiting for deferred to resolve") - - return deferred - def default_config(self) -> dict[str, Any]: config = super().default_config() config["public_baseurl"] = BASE_URL @@ -883,11 +856,7 @@ def test_simple_introspection(self) -> None: "expires_in": 60, } - requester = self.get_success( - self.till_deferred_has_result( - self._auth.get_user_by_access_token("some_token") - ) - ) + requester = self.get_success(self._auth.get_user_by_access_token("some_token")) self.assertEqual(requester.user.to_string(), USER_ID) self.assertEqual(requester.device_id, DEVICE) @@ -906,11 +875,7 @@ def test_unexpiring_token(self) -> None: "username": USERNAME, } - requester = self.get_success( - self.till_deferred_has_result( - self._auth.get_user_by_access_token("some_token") - ) - ) + requester = self.get_success(self._auth.get_user_by_access_token("some_token")) self.assertEqual(requester.user.to_string(), USER_ID) self.assertEqual(requester.device_id, DEVICE) @@ -931,9 +896,7 @@ def test_inexistent_device(self) -> None: } failure = self.get_failure( - self.till_deferred_has_result( - self._auth.get_user_by_access_token("some_token") - ), + self._auth.get_user_by_access_token("some_token"), InvalidClientTokenError, ) self.assertEqual(failure.value.code, 401) @@ -948,9 +911,7 @@ def test_inexistent_user(self) -> None: } failure = self.get_failure( - self.till_deferred_has_result( - self._auth.get_user_by_access_token("some_token") - ), + self._auth.get_user_by_access_token("some_token"), AuthError, ) # This is a 500, it should never happen really @@ -966,9 +927,7 @@ def test_missing_scope(self) -> None: } failure = self.get_failure( - self.till_deferred_has_result( - self._auth.get_user_by_access_token("some_token") - ), + self._auth.get_user_by_access_token("some_token"), InvalidClientTokenError, ) self.assertEqual(failure.value.code, 401) @@ -977,9 +936,7 @@ def test_invalid_response(self) -> None: self.server.introspection_response = {} failure = self.get_failure( - self.till_deferred_has_result( - self._auth.get_user_by_access_token("some_token") - ), + self._auth.get_user_by_access_token("some_token"), SynapseError, ) self.assertEqual(failure.value.code, 503) @@ -994,11 +951,7 @@ def test_device_id_in_body(self) -> None: "device_id": DEVICE, } - requester = self.get_success( - self.till_deferred_has_result( - self._auth.get_user_by_access_token("some_token") - ) - ) + requester = self.get_success(self._auth.get_user_by_access_token("some_token")) self.assertEqual(requester.device_id, DEVICE) @@ -1011,11 +964,7 @@ def test_admin_scope(self) -> None: "expires_in": 60, } - requester = self.get_success( - self.till_deferred_has_result( - self._auth.get_user_by_access_token("some_token") - ) - ) + requester = self.get_success(self._auth.get_user_by_access_token("some_token")) self.assertEqual(requester.user.to_string(), USER_ID) self.assertTrue(self.get_success(self._auth.is_server_admin(requester))) @@ -1040,17 +989,15 @@ def test_cached_expired_introspection(self) -> None: request.requestHeaders.getRawHeaders = mock_getRawHeaders() # The first CS-API request causes a successful introspection - self.get_success( - self.till_deferred_has_result(self._auth.get_user_by_req(request)) - ) + self.get_success(self._auth.get_user_by_req(request)) self.assertEqual(self.server.calls, 1) # Sleep for 60 seconds so the token expires. self.reactor.advance(60.0) # Now the CS-API request fails because the token expired - self.assertFailure( - self.till_deferred_has_result(self._auth.get_user_by_req(request)), + self.get_failure( + self._auth.get_user_by_req(request), InvalidClientTokenError, ) # Ensure another introspection request was not sent diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 44f1e6432d6..2aeb9a927a9 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -29,6 +29,7 @@ get_verify_key, ) +from twisted.internet.defer import ensureDeferred from twisted.internet.testing import MemoryReactor from synapse.api.constants import EventTypes, Membership, PresenceState @@ -58,6 +59,7 @@ from synapse.storage.keys import FetchKeyResult from synapse.types import JsonDict, UserID, get_domain_from_id from synapse.util.clock import Clock +from synapse.util.duration import Duration from tests import unittest from tests.replication._base import BaseMultiWorkerStreamTestCase @@ -948,12 +950,17 @@ def test_external_process_timeout(self) -> None: ) worker_presence_handler = worker_to_sync_against.get_presence_handler() - self.get_success( + sync_d = ensureDeferred( worker_presence_handler.user_syncing( self.user_id, self.device_id, True, PresenceState.ONLINE - ), - by=0.1, + ) ) + # `user_syncing` proxies the presence write to the main process over an HTTP + # replication request. The request body is streamed by a `Cooperator` that uses + # the clock to schedule each chunk at a tiny *non-zero* delay (`_EPSILON`), so + # we need to actually advance the clock for it to fire. + self.reactor.advance(Duration(microseconds=1).as_secs()) + self.get_success(sync_d) # Check that if we wait a while without telling the handler the user has # stopped syncing that their presence state doesn't get timed out. @@ -1264,30 +1271,40 @@ def test_set_presence_from_syncing_multi_device( worker_presence_handler = worker_to_sync_against.get_presence_handler() # 1. Sync with the first device. - self.get_success( + sync_d = ensureDeferred( worker_presence_handler.user_syncing( user_id, "dev-1", affect_presence=dev_1_state != PresenceState.OFFLINE, presence_state=dev_1_state, - ), - by=0.01, + ) ) + # `user_syncing` proxies the presence write to the main process over an HTTP + # replication request. The request body is streamed by a `Cooperator` that uses + # the clock to schedule each chunk at a tiny *non-zero* delay (`_EPSILON`), so + # we need to actually advance the clock for it to fire. + self.reactor.advance(Duration(microseconds=1).as_secs()) + self.get_success(sync_d) # 2. Wait half the idle timer. self.reactor.advance(IDLE_TIMER / 1000 / 2) self.reactor.pump([0.1]) # 3. Sync with the second device. - self.get_success( + sync_d = ensureDeferred( worker_presence_handler.user_syncing( user_id, "dev-2", affect_presence=dev_2_state != PresenceState.OFFLINE, presence_state=dev_2_state, - ), - by=0.01, + ) ) + # `user_syncing` proxies the presence write to the main process over an HTTP + # replication request. The request body is streamed by a `Cooperator` that uses + # the clock to schedule each chunk at a tiny *non-zero* delay (`_EPSILON`), so + # we need to actually advance the clock for it to fire. + self.reactor.advance(Duration(microseconds=1).as_secs()) + self.get_success(sync_d) # 4. Assert the expected presence state. state = self.get_success( @@ -1305,15 +1322,21 @@ def test_set_presence_from_syncing_multi_device( # # This is due to EXTERNAL_PROCESS_EXPIRY being equivalent to IDLE_TIMER. if test_with_workers: - with self.get_success( + sync_d = ensureDeferred( worker_presence_handler.user_syncing( f"@other-user:{self.hs.config.server.server_name}", "dev-3", affect_presence=True, presence_state=PresenceState.ONLINE, - ), - by=0.01, - ): + ) + ) + # `user_syncing` proxies the presence write to the main process over an HTTP + # replication request. The request body is streamed by a `Cooperator` that uses + # the clock to schedule each chunk at a tiny *non-zero* delay (`_EPSILON`), so + # we need to actually advance the clock for it to fire. + self.reactor.advance(Duration(microseconds=1).as_secs()) + + with self.get_success(sync_d): pass # 5. Advance such that the first device should be discarded (the idle timer), @@ -1501,26 +1524,36 @@ def test_set_presence_from_non_syncing_multi_device( worker_presence_handler = worker_to_sync_against.get_presence_handler() # 1. Sync with the first device. - sync_1 = self.get_success( + sync_d = ensureDeferred( worker_presence_handler.user_syncing( user_id, "dev-1", affect_presence=dev_1_state != PresenceState.OFFLINE, presence_state=dev_1_state, - ), - by=0.1, + ) ) + # `user_syncing` proxies the presence write to the main process over an HTTP + # replication request. The request body is streamed by a `Cooperator` that uses + # the clock to schedule each chunk at a tiny *non-zero* delay (`_EPSILON`), so + # we need to actually advance the clock for it to fire. + self.reactor.advance(Duration(microseconds=1).as_secs()) + sync_1 = self.get_success(sync_d) # 2. Sync with the second device. - sync_2 = self.get_success( + sync_d = ensureDeferred( worker_presence_handler.user_syncing( user_id, "dev-2", affect_presence=dev_2_state != PresenceState.OFFLINE, presence_state=dev_2_state, - ), - by=0.1, + ) ) + # `user_syncing` proxies the presence write to the main process over an HTTP + # replication request. The request body is streamed by a `Cooperator` that uses + # the clock to schedule each chunk at a tiny *non-zero* delay (`_EPSILON`), so + # we need to actually advance the clock for it to fire. + self.reactor.advance(Duration(microseconds=1).as_secs()) + sync_2 = self.get_success(sync_d) # 3. Assert the expected presence state. state = self.get_success( @@ -1622,12 +1655,17 @@ def test_set_presence_from_syncing_keeps_busy( # Perform a sync with a presence state other than busy. This should NOT change # our presence status; we only change from busy if we explicitly set it via # /presence/*. - self.get_success( + sync_d = ensureDeferred( worker_to_sync_against.get_presence_handler().user_syncing( self.user_id, self.device_id, True, PresenceState.ONLINE - ), - by=0.1, + ) ) + # `user_syncing` proxies the presence write to the main process over an HTTP + # replication request. The request body is streamed by a `Cooperator` that uses + # the clock to schedule each chunk at a tiny *non-zero* delay (`_EPSILON`), so + # we need to actually advance the clock for it to fire. + self.reactor.advance(Duration(microseconds=1).as_secs()) + self.get_success(sync_d) # Check against the main process that the user's presence did not change. state = self.get_success(self.presence_handler.get_state(self.user_id_obj)) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 5152e8fc536..561b45827fd 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -200,7 +200,7 @@ async def slow_update_membership(*args: Any, **kwargs: Any) -> tuple[str, int]: self.assertEqual(membership[state_tuple].content["displayname"], "Frank") # Let's be sure we are over the delay introduced by slow_update_membership - self.get_success(self.clock.sleep(Duration(milliseconds=20)), by=1) + self.reactor.advance(Duration(milliseconds=20).as_secs()) membership = self.get_success( self.storage_controllers.state.get_current_state( @@ -278,7 +278,7 @@ async def potentially_slow_update_membership( # Let's be sure we are over the delay introduced by slow_update_membership # and that the task was not executed as expected - self.get_success(self.clock.sleep(Duration(milliseconds=20)), by=1) + self.reactor.advance(Duration(milliseconds=20).as_secs()) membership = self.get_success( self.storage_controllers.state.get_current_state( @@ -299,8 +299,10 @@ async def potentially_slow_update_membership( ) ) + # Wait for the `TaskScheduler.SCHEDULE_INTERVAL` + self.reactor.advance(Duration(minutes=1).as_secs()) # Let's be sure we are over the delay introduced by slow_update_membership - self.get_success(self.clock.sleep(Duration(milliseconds=20)), by=1) + self.reactor.advance(Duration(milliseconds=20).as_secs()) # Updates should have been resumed from room 2 after the restart # so room 1 should not have been updated this time diff --git a/tests/handlers/test_room_member.py b/tests/handlers/test_room_member.py index d5b95e4ef6b..0a7475856a8 100644 --- a/tests/handlers/test_room_member.py +++ b/tests/handlers/test_room_member.py @@ -71,7 +71,6 @@ def test_local_user_local_joins_contribute_to_limit_and_are_limited(self) -> Non action=Membership.JOIN, ), LimitExceededError, - by=0.5, ) @override_config({"rc_joins_per_room": {"per_second": 0.1, "burst_count": 2}}) @@ -213,7 +212,6 @@ def test_remote_joins_contribute_to_rate_limit(self) -> None: remote_room_hosts=[self.OTHER_SERVER_NAME], ), LimitExceededError, - by=0.5, ) # TODO: test that remote joins to a room are rate limited. @@ -281,7 +279,6 @@ def test_local_users_joining_on_another_worker_contribute_to_rate_limit( action=Membership.JOIN, ), LimitExceededError, - by=0.5, ) # Try to join as Chris on the original worker. Should get denied because Alice @@ -294,7 +291,6 @@ def test_local_users_joining_on_another_worker_contribute_to_rate_limit( action=Membership.JOIN, ), LimitExceededError, - by=0.5, ) diff --git a/tests/handlers/test_send_email.py b/tests/handlers/test_send_email.py index eea88cd136b..80d34791b65 100644 --- a/tests/handlers/test_send_email.py +++ b/tests/handlers/test_send_email.py @@ -145,8 +145,12 @@ def test_send_email(self) -> None: ) ) + # This matches the two `callLater` delays in `FakeTransport.registerProducer` + self.reactor.advance(0) + self.reactor.advance(0.1) + # the message should now get delivered - self.get_success(d, by=0.1) + self.get_success(d) # check it arrived self.assertEqual(len(message_delivery.messages), 1) @@ -212,8 +216,12 @@ def test_send_email_force_tls(self) -> None: ) ) + # This matches the two `callLater` delays in `FakeTransport.registerProducer` + self.reactor.advance(0) + self.reactor.advance(0.1) + # the message should now get delivered - self.get_success(d, by=0.1) + self.get_success(d) # check it arrived self.assertEqual(len(message_delivery.messages), 1) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 623eef0ecb6..0bbe0845470 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -248,6 +248,14 @@ def test_started_typing_remote_send(self) -> None: ) ) + # Wait for the EDU to get pushed out over federation + # + # `started_typing` is fire-and-forget and handles the remote federation part as + # part of a background process which isn't waited on. + # + # We're specifically waiting for the database queries in the background process + self.reactor.advance(0) + self.mock_federation_client.put_json.assert_called_once_with( "farm", path="/_matrix/federation/v1/send/1000000", @@ -367,6 +375,14 @@ def test_stopped_typing(self) -> None: [call(StreamKeyType.TYPING, 1, rooms=[ROOM_ID])] ) + # Wait for the EDU to get pushed out over federation + # + # `stopped_typing` is fire-and-forget and handles the remote federation part as + # part of a background process which isn't waited on. + # + # We're specifically waiting for the database queries in the background process + self.reactor.advance(0) + self.mock_federation_client.put_json.assert_called_once_with( "farm", path="/_matrix/federation/v1/send/1000000", diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index f50fa1f4a02..dc6738ca286 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -555,7 +555,15 @@ def test_process_join_after_server_leaves_room(self) -> None: # Process the leave and join in one go. dir_handler.update_user_directory = True dir_handler.notify_new_event() - self.wait_for_background_updates() + + # Wait for the user directory to update + # + # `notify_new_event` is fire-and-forget and the actual changes happen as part of + # a background process loop which isn't waited on. + # + # We're specifically waiting for the database queries in the `notify_new_event` + # background process. + self.reactor.advance(0) # The user sharing tables should have been updated. public3 = self.get_success(self.user_dir_helper.get_users_in_public_rooms()) @@ -1124,7 +1132,6 @@ def test_local_user_leaving_room_remains_in_user_directory(self) -> None: # Alice leaves the other. She should still be in the directory. self.helper.leave(room2, alice, tok=alice_token) - self.wait_for_background_updates() users, in_public, in_private = self.get_success( self.user_dir_helper.get_tables() ) diff --git a/tests/media/test_media_storage.py b/tests/media/test_media_storage.py index f25b507aac5..855a623ec09 100644 --- a/tests/media/test_media_storage.py +++ b/tests/media/test_media_storage.py @@ -132,12 +132,7 @@ async def test_ensure_media() -> None: # This uses a real blocking threadpool so we have to wait for it to be # actually done :/ - x = defer.ensureDeferred(test_ensure_media()) - - # Hotloop until the threadpool does its job... - self.wait_on_thread(x) - - self.get_success(x) + self.get_success(test_ensure_media()) @attr.s(auto_attribs=True, slots=True, frozen=True) diff --git a/tests/replication/tcp/test_handler.py b/tests/replication/tcp/test_handler.py index a8eb7fc523c..4eda61cd237 100644 --- a/tests/replication/tcp/test_handler.py +++ b/tests/replication/tcp/test_handler.py @@ -206,6 +206,12 @@ def test_wait_for_stream_position_rdata(self) -> None: # Finish the context manager, triggering the data to be sent to master. self.get_success(ctx_worker1.__aexit__(None, None, None)) + # Wait for the stream position to be replicated to the master process + # + # Replication travels over `FakeTransport` and we're specifically flushing the + # write + self.reactor.advance(0) + # Master should get told about `next_token2`, so the deferred should # resolve. self.assertTrue(d.called) diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py index e6b9ea53832..07299c96983 100644 --- a/tests/replication/test_federation_ack.py +++ b/tests/replication/test_federation_ack.py @@ -81,6 +81,14 @@ def test_federation_ack_sent(self) -> None: ) ) + # Wait for the FEDERATION_ACK to be sent + # + # `on_rdata` handles this as part of a background process (see + # `FederationSenderHandler.update_token`) + # + # We're specifically waiting for the database queries in the background process + self.reactor.advance(0) + # now check that the FEDERATION_ACK was sent mock_connection.send_command.assert_called_once() cmd = mock_connection.send_command.call_args[0][0] diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index e3f79d76707..139906e97ca 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -59,8 +59,8 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = self.hs.get_datastores().main async def update(self, progress: JsonDict, count: int) -> int: - duration_ms = 10 - await self.clock.sleep(Duration(milliseconds=count * duration_ms)) + fake_work_duration = Duration(seconds=1) + await self.clock.sleep(fake_work_duration) progress = {"my_key": progress["my_key"] + 1} await self.store.db_pool.runInteraction( "update_progress", @@ -86,10 +86,15 @@ def test_do_background_update(self) -> None: self.update_handler.side_effect = self.update self.update_handler.reset_mock() - res = self.get_success( - self.updates.do_next_background_update(False), - by=0.02, - ) + background_update_d = ensureDeferred( + self.updates.do_next_background_update(False) + ) + # Wait for database queries to run in `do_next_background_update(...)` so the + # background update actually gets scheduled + self.reactor.advance(0) + # Wait for the actual background update `fake_work_duration` + self.reactor.advance(Duration(seconds=1).as_secs()) + res = self.get_success(background_update_d) self.assertFalse(res) # on the first call, we should get run with the default background update size @@ -143,10 +148,15 @@ def test_background_update_default_batch_set_by_config(self) -> None: self.update_handler.side_effect = self.update self.update_handler.reset_mock() - res = self.get_success( - self.updates.do_next_background_update(False), - by=0.01, - ) + background_update_d = ensureDeferred( + self.updates.do_next_background_update(False) + ) + # Wait for database queries to run in `do_next_background_update(...)` so the + # background update actually gets scheduled + self.reactor.advance(0) + # Wait for the actual background update `fake_work_duration` + self.reactor.advance(Duration(seconds=1).as_secs()) + res = self.get_success(background_update_d) self.assertFalse(res) # on the first call, we should get run with the default background update size specified in the config @@ -265,10 +275,15 @@ def test_background_update_duration_set_in_config(self) -> None: self.update_handler.side_effect = self.update self.update_handler.reset_mock() - res = self.get_success( - self.updates.do_next_background_update(False), - by=0.02, - ) + background_update_d = ensureDeferred( + self.updates.do_next_background_update(False) + ) + # Wait for database queries to run in `do_next_background_update(...)` so the + # background update actually gets scheduled + self.reactor.advance(0) + # Wait for the actual background update `fake_work_duration` + self.reactor.advance(Duration(seconds=1).as_secs()) + res = self.get_success(background_update_d) self.assertFalse(res) # the first update was run with the default batch size, this should be run with 500ms as the @@ -298,9 +313,6 @@ def test_background_update_min_batch_set_in_config(self) -> None: """ Test that the minimum batch size set in the config is used """ - # a very long-running individual update - duration_ms = 50 - self.get_success( self.store.db_pool.simple_insert( "background_updates", @@ -310,7 +322,8 @@ def test_background_update_min_batch_set_in_config(self) -> None: # Run the update with the long-running update item async def update_long(progress: JsonDict, count: int) -> int: - await self.clock.sleep(Duration(milliseconds=count * duration_ms)) + very_long_fake_work_duration = Duration(seconds=5) + await self.clock.sleep(very_long_fake_work_duration) progress = {"my_key": progress["my_key"] + 1} await self.store.db_pool.runInteraction( "update_progress", @@ -322,10 +335,15 @@ async def update_long(progress: JsonDict, count: int) -> int: self.update_handler.side_effect = update_long self.update_handler.reset_mock() - res = self.get_success( - self.updates.do_next_background_update(False), - by=1, - ) + background_update_d = ensureDeferred( + self.updates.do_next_background_update(False) + ) + # Wait for database queries to run in `do_next_background_update(...)` so the + # background update actually gets scheduled + self.reactor.advance(0) + # Wait for the actual background update `very_long_fake_work_duration` + self.reactor.advance(Duration(seconds=5).as_secs()) + res = self.get_success(background_update_d) self.assertFalse(res) # the first update was run with the default batch size, this should be run with minimum batch size diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index 175a5ffc788..d09437c080b 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -755,7 +755,7 @@ def test_background_update_single_large_room(self) -> None: ): iterations += 1 self.get_success( - self.store.db_pool.updates.do_next_background_update(False), by=0.1 + self.store.db_pool.updates.do_next_background_update(False) ) # Ensure that we did actually take multiple iterations to process the @@ -814,7 +814,7 @@ def test_background_update_multiple_large_room(self) -> None: ): iterations += 1 self.get_success( - self.store.db_pool.updates.do_next_background_update(False), by=0.1 + self.store.db_pool.updates.do_next_background_update(False) ) # Ensure that we did actually take multiple iterations to process the diff --git a/tests/synapse_rust/test_http_client.py b/tests/synapse_rust/test_http_client.py index 56fab3a0e1d..7a4488d3abd 100644 --- a/tests/synapse_rust/test_http_client.py +++ b/tests/synapse_rust/test_http_client.py @@ -12,12 +12,11 @@ import json import logging +import os import threading -import time from http.server import BaseHTTPRequestHandler, HTTPServer -from typing import Any, Coroutine, Generator, TypeVar, Union +from typing import Any, TypeVar -from twisted.internet.defer import Deferred, ensureDeferred from twisted.internet.testing import MemoryReactor from synapse.logging.context import ( @@ -118,31 +117,6 @@ def tearDown(self) -> None: for callbable, args, kwargs in triggers: callbable(*args, **kwargs) - def till_deferred_has_result( - self, - awaitable: Union[ - "Coroutine[Deferred[Any], Any, T]", - "Generator[Deferred[Any], Any, T]", - "Deferred[T]", - ], - ) -> "Deferred[T]": - """Wait until a deferred has a result. - - This is useful because the Rust HTTP client will resolve the deferred - using reactor.callFromThread, which are only run when we call - reactor.advance. - """ - deferred = ensureDeferred(awaitable) - tries = 0 - while not deferred.called: - time.sleep(0.1) - self.reactor.advance(0) - tries += 1 - if tries > 100: - raise Exception("Timed out waiting for deferred to resolve") - - return deferred - def _check_current_logcontext(self, expected_logcontext_string: str) -> None: context = current_context() assert isinstance(context, LoggingContext) or isinstance(context, _Sentinel), ( @@ -168,7 +142,7 @@ async def do_request() -> None: raw_response = json_decoder.decode(resp_body.decode("utf-8")) self.assertEqual(raw_response, {"ok": True}) - self.get_success(self.till_deferred_has_result(do_request())) + self.get_success(do_request()) self.assertEqual(self.server.calls, 1) def test_request_response_limit_exceeded(self) -> None: @@ -183,8 +157,8 @@ async def do_request() -> None: response_limit=1, ) - self.assertFailure( - self.till_deferred_has_result(do_request()), + self.get_failure( + do_request(), RuntimeError, ) self.assertEqual(self.server.calls, 1) @@ -227,8 +201,15 @@ async def do_request() -> None: # Now wait for the function under test to have run with PreserveLoggingContext(): while not callback_finished: - # await self.hs.get_clock().sleep(0) - time.sleep(0.1) + # Allow the async Rust to run + # + # Suspend execution of this thread to allow other the Tokio thread + # pool to do work. + os.sched_yield() + # Advance the Twisted reactor and run any scheduled callbacks + # + # In terms of other threads, they may have scheduled something on the + # reactor to run (like `reactor.callFromThread(...)`) self.reactor.advance(0) # check that the logcontext is left in a sane state. diff --git a/tests/unittest.py b/tests/unittest.py index 93131521d03..16b4b474fbb 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -25,6 +25,7 @@ import hmac import json import logging +import os import secrets import time from typing import ( @@ -48,6 +49,7 @@ import unpaddedbase64 from typing_extensions import Concatenate, ParamSpec +from twisted.internet import defer from twisted.internet.defer import Deferred, ensureDeferred from twisted.internet.testing import MemoryReactor, MemoryReactorClock from twisted.python.failure import Failure @@ -77,6 +79,7 @@ from synapse.storage.keys import FetchKeyResult from synapse.types import ISynapseReactor, JsonDict, Requester, UserID, create_requester from synapse.util.clock import Clock +from synapse.util.duration import Duration from synapse.util.httpresourcetree import create_resource_tree from tests.server import ( @@ -474,27 +477,13 @@ def tearDown(self) -> None: # Reset to not use frozen dicts. events.USE_FROZEN_DICTS = False - def wait_on_thread(self, deferred: Deferred, timeout: int = 10) -> None: - """ - Wait until a Deferred is done, where it's waiting on a real thread. - """ - start_time = time.time() - - while not deferred.called: - if start_time + timeout < time.time(): - raise ValueError("Timed out waiting for threadpool") - self.reactor.advance(0.01) - time.sleep(0.01) - def wait_for_background_updates(self) -> None: """Block until all background database updates have completed.""" store = self.hs.get_datastores().main while not self.get_success( store.db_pool.updates.has_completed_background_updates() ): - self.get_success( - store.db_pool.updates.do_next_background_update(False), by=0.1 - ) + self.get_success(store.db_pool.updates.do_next_background_update(False)) def make_homeserver( self, reactor: ThreadedMemoryReactorClock, clock: Clock @@ -736,21 +725,152 @@ def pump(self, by: float = 0.0) -> None: # whole chain to completion. self.reactor.pump([by] * 100) - def get_success(self, d: Awaitable[TV], by: float = 0.0) -> TV: + def _wait_for_deferred( + self, + d: "Deferred[Any]", + # 2-second default timeout as tests should be fast + timeout: Duration = Duration(seconds=2), + ) -> None: + """ + Wait for the deferred to finish or raise (with real-time timeout). + + Does not advance time in the Twisted reactor clock but will loop until the + real-time `timeout` waiting for a result. The loop 1) allows `clock.call_later` + scheduled callbacks to run if they are scheduled to run now and 2) will also + allow other threads to make progress. This could be things spawned on the + Twisted reactor threadpool or Tokio runtime (async Rust code). + + Args: + d: Twisted Deferred + timeout: Real-time time to wait for the deferred to have a result. + We use real-time as we may have to wait for work on other threads. + + Raises: + defer.TimeoutError: If the timeout expires before the deferred completes. + """ + start_time_seconds = time.time() + + # Wait until the deferred has a result + # + # Checking `d.called` by itself is not sufficient by itself as this is possible: + # + # If you have a first `Deferred` `D1`, you can add a callback which returns + # another `Deferred` `D2`, and `D2` must then complete before any further + # callbacks on `D1` will execute (and later callbacks on `D1` get the *result* + # of `D2` rather than `D2` itself). + # + # So, `D1` might have `called=True` (as in, it has started running its + # callbacks), but any new callbacks added to `D1` won't get run until `D2` + # completes. Fortunately, we can detect this by checking `d.paused`. + while not d.called or d.paused: + if start_time_seconds + timeout.as_secs() < time.time(): + raise defer.TimeoutError( + "Timed out waiting for work happening on a thread to finish" + ) + + # Suspend execution of this thread to allow other threads to do work. This + # could be things spawned on the Twisted reactor threadpool or Tokio thread + # pool (async Rust code). + # + # We could also use `time.sleep(0)` here but this is more precise + os.sched_yield() + + # Advance the Twisted reactor and run any scheduled callbacks + # + # In terms of other threads, they may have scheduled something on the + # reactor to run (like `reactor.callFromThread(...)`) + self.reactor.advance(0) + + def get_success( + self, + d: Awaitable[TV], + # 1 second default timeout as tests should be fast + timeout: Duration = Duration(seconds=1), + ) -> TV: + """ + Get the success result of an awaitable. + + Does not advance time in the Twisted reactor clock but will loop until the + real-time `timeout` waiting for a result. The loop 1) allows `clock.call_later` + scheduled callbacks to run if they are scheduled to run now and 2) will also + allow other threads to make progress. This could be things spawned on the + Twisted reactor threadpool or Tokio runtime (async Rust code). + + If you need to advance the Twisted reactor by an actual time increment, you can + use the following pattern: + ```python + # We use `ensureDeferred(...)` as a `Deferred` can run in the background on its own (unlike a Python coroutine) + task_d = ensureDeferred(my_async_task()) + # Please explain why/what scheduled call you're trying to trigger + self.reactor.advance(Duration(seconds=1).as_secs()) + result = self.get_success(sync_d) + ``` + + Args: + d: awaitable + timeout: Real-time time to wait for the awaitable to have a result. + We use real-time as we may have to wait for work on other threads. + + Raises: + defer.TimeoutError: If the timeout expires before the awaitable completes. + SynchronousTestCase.failureException: If the awaitable has a failure result or has no result + (although you would probably run into `defer.TimeoutError` in that case). + """ deferred: Deferred[TV] = ensureDeferred(d) # type: ignore[arg-type] - self.pump(by=by) + self._wait_for_deferred(deferred, timeout) + return self.successResultOf(deferred) def get_failure( - self, d: Awaitable[Any], exc: type[_ExcType], by: float = 0.0 + self, + d: Awaitable[Any], + exc: type[_ExcType], + # 1 second default timeout as tests should be fast + timeout: Duration = Duration(seconds=1), ) -> _TypedFailure[_ExcType]: """ - Run a Deferred and get a Failure from it. The failure must be of the type `exc`. + Get the failure result of an awaitable. The failure must be of the type `exc`. + + Does not advance time in the Twisted reactor clock but will loop until the + real-time `timeout` waiting for a result. The loop 1) allows `clock.call_later` + scheduled callbacks to run if they are scheduled to run now and 2) will also + allow other threads to make progress. This could be things spawned on the + Twisted reactor threadpool or Tokio runtime (async Rust code). + + If you need to advance the Twisted reactor by an actual time increment, you can + use the following pattern: + ```python + # We use `ensureDeferred(...)` as a `Deferred` can run in the background on its own (unlike a Python coroutine) + task_d = ensureDeferred(my_async_task()) + # Please explain why/what scheduled call you're trying to trigger + self.reactor.advance(Duration(seconds=1).as_secs()) + result = self.get_success(sync_d) + ``` + + Args: + d: awaitable + exc: Exception type to expect + timeout: Real-time time to wait for the awaitable to have a result. + We use real-time as we may have to wait for work on other threads. + + Raises: + defer.TimeoutError: If the timeout expires before the awaitable completes. + SynchronousTestCase.failureException: If the awaitable has a success result, + or has an unexpected failure result, or has no result (although you would + probably run into `defer.TimeoutError` in that case). """ deferred: Deferred[Any] = ensureDeferred(d) # type: ignore[arg-type] - self.pump(by) + self._wait_for_deferred(deferred, timeout) + return self.failureResultOf(deferred, exc) + # FIXME: Remove as this has the exact same semantics as `get_success()`. In + # https://github.com/matrix-org/synapse/pull/8402#discussion_r495992506 where it was + # introduced, it was claimed that "get_success fails the test if the deferred fails + # rather than raising, which I find a bit unintuitive." but `get_success()` actually + # does raise "@raise SynchronousTestCase.failureException : If the + # L{Deferred} has no result or has a failure + # result." at-least in today's world. def get_success_or_raise(self, d: Awaitable[TV], by: float = 0.0) -> TV: """Drive deferred to completion and return result or raise exception on failure.