Remove MockClock() (#18992)

Spawning from adding some logcontext debug logs in
https://github.com/element-hq/synapse/pull/18966 and since we're not
logging at the `set_current_context(...)` level (see reasoning there),
this removes some usage of `set_current_context(...)`.

Specifically, `MockClock.call_later(...)` doesn't handle logcontexts
correctly. It uses the calling logcontext as the callback context
(wrong, as the logcontext could finish before the callback finishes) and
it didn't reset back to the sentinel context before handing back to the
reactor. It was like this since it was [introduced 10+ years
ago](38da9884e7).
Instead of fixing the implementation which would just be a copy of our
normal `Clock`, we can just remove `MockClock`
This commit is contained in:
Eric Eastwood 2025-09-30 11:27:29 -05:00 committed by GitHub
parent ad8dcc2119
commit 5adb08f3c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 139 additions and 289 deletions

1
changelog.d/18992.misc Normal file
View File

@ -0,0 +1 @@
Remove `MockClock()` in tests.

View File

@ -33,15 +33,17 @@ from synapse.metrics.background_process_metrics import (
run_as_background_process, run_as_background_process,
) )
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util.constants import ONE_HOUR_SECONDS, ONE_MINUTE_SECONDS from synapse.util.constants import (
MILLISECONDS_PER_SECOND,
ONE_HOUR_SECONDS,
ONE_MINUTE_SECONDS,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
logger = logging.getLogger("synapse.app.homeserver") logger = logging.getLogger("synapse.app.homeserver")
MILLISECONDS_PER_SECOND = 1000
INITIAL_DELAY_BEFORE_FIRST_PHONE_HOME_SECONDS = 5 * ONE_MINUTE_SECONDS INITIAL_DELAY_BEFORE_FIRST_PHONE_HOME_SECONDS = 5 * ONE_MINUTE_SECONDS
""" """
We wait 5 minutes to send the first set of stats as the server can be quite busy the We wait 5 minutes to send the first set of stats as the server can be quite busy the

View File

@ -18,3 +18,5 @@
# readability and catching bugs. # readability and catching bugs.
ONE_MINUTE_SECONDS = 60 ONE_MINUTE_SECONDS = 60
ONE_HOUR_SECONDS = 60 * ONE_MINUTE_SECONDS ONE_HOUR_SECONDS = 60 * ONE_MINUTE_SECONDS
MILLISECONDS_PER_SECOND = 1000

View File

@ -18,7 +18,7 @@
# [This file includes modifications made by New Vector Limited] # [This file includes modifications made by New Vector Limited]
# #
# #
from typing import List, Optional, Sequence, Tuple, cast from typing import List, Optional, Sequence, Tuple
from unittest.mock import AsyncMock, Mock from unittest.mock import AsyncMock, Mock
from typing_extensions import TypeAlias from typing_extensions import TypeAlias
@ -44,13 +44,12 @@ from synapse.types import DeviceListUpdates, JsonDict
from synapse.util.clock import Clock from synapse.util.clock import Clock
from tests import unittest from tests import unittest
from tests.server import get_clock
from ..utils import MockClock
class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.clock = MockClock() self.reactor, self.clock = get_clock()
self.store = Mock() self.store = Mock()
self.as_api = Mock() self.as_api = Mock()
@ -170,14 +169,14 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.clock = MockClock() self.reactor, self.clock = get_clock()
self.as_api = Mock() self.as_api = Mock()
self.store = Mock() self.store = Mock()
self.service = Mock() self.service = Mock()
self.callback = AsyncMock() self.callback = AsyncMock()
self.recoverer = _Recoverer( self.recoverer = _Recoverer(
server_name="test_server", server_name="test_server",
clock=cast(Clock, self.clock), clock=self.clock,
as_api=self.as_api, as_api=self.as_api,
store=self.store, store=self.store,
service=self.service, service=self.service,
@ -202,7 +201,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
txn.send = AsyncMock(return_value=True) txn.send = AsyncMock(return_value=True)
txn.complete = AsyncMock(return_value=None) txn.complete = AsyncMock(return_value=None)
# wait for exp backoff # wait for exp backoff
self.clock.advance_time(2) self.reactor.advance(2)
self.assertEqual(1, txn.send.call_count) self.assertEqual(1, txn.send.call_count)
self.assertEqual(1, txn.complete.call_count) self.assertEqual(1, txn.complete.call_count)
# 2 because it needs to get None to know there are no more txns # 2 because it needs to get None to know there are no more txns
@ -229,21 +228,21 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
self.assertEqual(0, self.store.get_oldest_unsent_txn.call_count) self.assertEqual(0, self.store.get_oldest_unsent_txn.call_count)
txn.send = AsyncMock(return_value=False) txn.send = AsyncMock(return_value=False)
txn.complete = AsyncMock(return_value=None) txn.complete = AsyncMock(return_value=None)
self.clock.advance_time(2) self.reactor.advance(2)
self.assertEqual(1, txn.send.call_count) self.assertEqual(1, txn.send.call_count)
self.assertEqual(0, txn.complete.call_count) self.assertEqual(0, txn.complete.call_count)
self.assertEqual(0, self.callback.call_count) self.assertEqual(0, self.callback.call_count)
self.clock.advance_time(4) self.reactor.advance(4)
self.assertEqual(2, txn.send.call_count) self.assertEqual(2, txn.send.call_count)
self.assertEqual(0, txn.complete.call_count) self.assertEqual(0, txn.complete.call_count)
self.assertEqual(0, self.callback.call_count) self.assertEqual(0, self.callback.call_count)
self.clock.advance_time(8) self.reactor.advance(8)
self.assertEqual(3, txn.send.call_count) self.assertEqual(3, txn.send.call_count)
self.assertEqual(0, txn.complete.call_count) self.assertEqual(0, txn.complete.call_count)
self.assertEqual(0, self.callback.call_count) self.assertEqual(0, self.callback.call_count)
txn.send = AsyncMock(return_value=True) # successfully send the txn txn.send = AsyncMock(return_value=True) # successfully send the txn
pop_txn = True # returns the txn the first time, then no more. pop_txn = True # returns the txn the first time, then no more.
self.clock.advance_time(16) self.reactor.advance(16)
self.assertEqual(1, txn.send.call_count) # new mock reset call count self.assertEqual(1, txn.send.call_count) # new mock reset call count
self.assertEqual(1, txn.complete.call_count) self.assertEqual(1, txn.complete.call_count)
self.callback.assert_called_once_with(self.recoverer) self.callback.assert_called_once_with(self.recoverer)
@ -268,7 +267,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
self.assertEqual(0, self.store.get_oldest_unsent_txn.call_count) self.assertEqual(0, self.store.get_oldest_unsent_txn.call_count)
txn.send = AsyncMock(return_value=False) txn.send = AsyncMock(return_value=False)
txn.complete = AsyncMock(return_value=None) txn.complete = AsyncMock(return_value=None)
self.clock.advance_time(2) self.reactor.advance(2)
self.assertEqual(1, txn.send.call_count) self.assertEqual(1, txn.send.call_count)
self.assertEqual(0, txn.complete.call_count) self.assertEqual(0, txn.complete.call_count)
self.assertEqual(0, self.callback.call_count) self.assertEqual(0, self.callback.call_count)

View File

@ -231,7 +231,10 @@ class MSC3861OAuthDelegation(TestCase):
reactor, clock = get_clock() reactor, clock = get_clock()
with self.assertRaises(ConfigError): with self.assertRaises(ConfigError):
setup_test_homeserver( setup_test_homeserver(
self.addCleanup, reactor=reactor, clock=clock, config=config cleanup_func=self.addCleanup,
config=config,
reactor=reactor,
clock=clock,
) )
def test_jwt_auth_cannot_be_enabled(self) -> None: def test_jwt_auth_cannot_be_enabled(self) -> None:
@ -395,7 +398,10 @@ class MasAuthDelegation(TestCase):
reactor, clock = get_clock() reactor, clock = get_clock()
with self.assertRaises(ConfigError): with self.assertRaises(ConfigError):
setup_test_homeserver( setup_test_homeserver(
self.addCleanup, reactor=reactor, clock=clock, config=config cleanup_func=self.addCleanup,
config=config,
reactor=reactor,
clock=clock,
) )
@skip_unless(HAS_AUTHLIB, "requires authlib") @skip_unless(HAS_AUTHLIB, "requires authlib")

View File

@ -49,9 +49,9 @@ from synapse.util.clock import Clock
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from tests import unittest from tests import unittest
from tests.server import get_clock
from tests.test_utils import event_injection from tests.test_utils import event_injection
from tests.unittest import override_config from tests.unittest import override_config
from tests.utils import MockClock
class AppServiceHandlerTestCase(unittest.TestCase): class AppServiceHandlerTestCase(unittest.TestCase):
@ -61,6 +61,8 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self.mock_store = Mock() self.mock_store = Mock()
self.mock_as_api = AsyncMock() self.mock_as_api = AsyncMock()
self.mock_scheduler = Mock() self.mock_scheduler = Mock()
self.reactor, self.clock = get_clock()
hs = Mock() hs = Mock()
hs.get_datastores.return_value = Mock(main=self.mock_store) hs.get_datastores.return_value = Mock(main=self.mock_store)
self.mock_store.get_appservice_last_pos = AsyncMock(return_value=None) self.mock_store.get_appservice_last_pos = AsyncMock(return_value=None)
@ -68,7 +70,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self.mock_store.set_appservice_stream_type_pos = AsyncMock(return_value=None) self.mock_store.set_appservice_stream_type_pos = AsyncMock(return_value=None)
hs.get_application_service_api.return_value = self.mock_as_api hs.get_application_service_api.return_value = self.mock_as_api
hs.get_application_service_scheduler.return_value = self.mock_scheduler hs.get_application_service_scheduler.return_value = self.mock_scheduler
hs.get_clock.return_value = MockClock() hs.get_clock.return_value = self.clock
self.handler = ApplicationServicesHandler(hs) self.handler = ApplicationServicesHandler(hs)
self.event_source = hs.get_event_sources() self.event_source = hs.get_event_sources()

View File

@ -21,7 +21,6 @@
# #
import copy import copy
from unittest import mock
from twisted.internet.testing import MemoryReactor from twisted.internet.testing import MemoryReactor
@ -50,7 +49,7 @@ room_keys = {
class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
return self.setup_test_homeserver(replication_layer=mock.Mock()) return self.setup_test_homeserver()
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = hs.get_e2e_room_keys_handler() self.handler = hs.get_e2e_room_keys_handler()

View File

@ -30,7 +30,7 @@ from synapse.http.federation.srv_resolver import Server, SrvResolver
from synapse.logging.context import LoggingContext, current_context from synapse.logging.context import LoggingContext, current_context
from tests import unittest from tests import unittest
from tests.utils import MockClock from tests.server import get_clock
class SrvResolverTestCase(unittest.TestCase): class SrvResolverTestCase(unittest.TestCase):
@ -105,7 +105,7 @@ class SrvResolverTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_from_cache(self) -> Generator["Deferred[object]", object, None]: def test_from_cache(self) -> Generator["Deferred[object]", object, None]:
clock = MockClock() reactor, clock = get_clock()
dns_client_mock = Mock(spec_set=["lookupService"]) dns_client_mock = Mock(spec_set=["lookupService"])
dns_client_mock.lookupService = Mock(spec_set=[]) dns_client_mock.lookupService = Mock(spec_set=[])

View File

@ -63,10 +63,6 @@ def check_logcontext(context: LoggingContextOrSentinel) -> None:
class FederationClientTests(HomeserverTestCase): class FederationClientTests(HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver(reactor=reactor, clock=clock)
return hs
def prepare( def prepare(
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None: ) -> None:

View File

@ -37,7 +37,6 @@ from synapse.util.stringutils import (
from tests import unittest from tests import unittest
from tests.unittest import override_config from tests.unittest import override_config
from tests.utils import MockClock
class MediaRetentionTestCase(unittest.HomeserverTestCase): class MediaRetentionTestCase(unittest.HomeserverTestCase):
@ -51,12 +50,6 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
admin.register_servlets_for_client_rest_resource, admin.register_servlets_for_client_rest_resource,
] ]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
# We need to be able to test advancing time in the homeserver, so we
# replace the test homeserver's default clock with a MockClock, which
# supports advancing time.
return self.setup_test_homeserver(clock=MockClock())
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.remote_server_name = "remote.homeserver" self.remote_server_name = "remote.homeserver"
self.store = hs.get_datastores().main self.store = hs.get_datastores().main

View File

@ -29,16 +29,19 @@ from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_co
from synapse.rest.client.transactions import CLEANUP_PERIOD_MS, HttpTransactionCache from synapse.rest.client.transactions import CLEANUP_PERIOD_MS, HttpTransactionCache
from synapse.types import ISynapseReactor, JsonDict from synapse.types import ISynapseReactor, JsonDict
from synapse.util.clock import Clock from synapse.util.clock import Clock
from synapse.util.constants import (
MILLISECONDS_PER_SECOND,
)
from tests import unittest from tests import unittest
from tests.utils import MockClock from tests.server import get_clock
reactor = cast(ISynapseReactor, _reactor) reactor = cast(ISynapseReactor, _reactor)
class HttpTransactionCacheTestCase(unittest.TestCase): class HttpTransactionCacheTestCase(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.clock = MockClock() self.reactor, self.clock = get_clock()
self.hs = Mock() self.hs = Mock()
self.hs.get_clock = Mock(return_value=self.clock) self.hs.get_clock = Mock(return_value=self.clock)
self.hs.get_auth = Mock() self.hs.get_auth = Mock()
@ -180,8 +183,9 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
yield self.cache.fetch_or_execute_request( yield self.cache.fetch_or_execute_request(
self.mock_request, self.mock_requester, cb, "an arg" self.mock_request, self.mock_requester, cb, "an arg"
) )
# should NOT have cleaned up yet # Advance time just under the cleanup period.
self.clock.advance_time_msec(CLEANUP_PERIOD_MS / 2) # Should NOT have cleaned up yet
self.reactor.advance((CLEANUP_PERIOD_MS - 1) / MILLISECONDS_PER_SECOND)
yield self.cache.fetch_or_execute_request( yield self.cache.fetch_or_execute_request(
self.mock_request, self.mock_requester, cb, "an arg" self.mock_request, self.mock_requester, cb, "an arg"
@ -189,7 +193,8 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
# still using cache # still using cache
cb.assert_called_once_with("an arg") cb.assert_called_once_with("an arg")
self.clock.advance_time_msec(CLEANUP_PERIOD_MS) # Advance time just after the cleanup period.
self.reactor.advance(2 / MILLISECONDS_PER_SECOND)
yield self.cache.fetch_or_execute_request( yield self.cache.fetch_or_execute_request(
self.mock_request, self.mock_requester, cb, "an arg" self.mock_request, self.mock_requester, cb, "an arg"

View File

@ -170,7 +170,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# make a second homeserver, configured to use the first one as a key notary # make a second homeserver, configured to use the first one as a key notary
self.http_client2 = Mock() self.http_client2 = Mock()
config = default_config(name="keyclient") config = default_config(server_name="keyclient")
config["trusted_key_servers"] = [ config["trusted_key_servers"] = [
{ {
"server_name": self.hs.hostname, "server_name": self.hs.hostname,

View File

@ -114,7 +114,6 @@ from tests.utils import (
POSTGRES_USER, POSTGRES_USER,
SQLITE_PERSIST_DB, SQLITE_PERSIST_DB,
USE_POSTGRES_FOR_TESTS, USE_POSTGRES_FOR_TESTS,
MockClock,
default_config, default_config,
) )
@ -786,9 +785,9 @@ class ThreadPool:
def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]: def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]:
clock = ThreadedMemoryReactorClock() reactor = ThreadedMemoryReactorClock()
hs_clock = Clock(clock, server_name="test_server") hs_clock = Clock(reactor, server_name="test_server")
return clock, hs_clock return reactor, hs_clock
@implementer(ITCPTransport) @implementer(ITCPTransport)
@ -1020,12 +1019,14 @@ class TestHomeServer(HomeServer):
def setup_test_homeserver( def setup_test_homeserver(
*,
cleanup_func: Callable[[Callable[[], None]], None], cleanup_func: Callable[[Callable[[], None]], None],
name: str = "test", server_name: str = "test",
config: Optional[HomeServerConfig] = None, config: Optional[HomeServerConfig] = None,
reactor: Optional[ISynapseReactor] = None, reactor: Optional[ISynapseReactor] = None,
homeserver_to_use: Type[HomeServer] = TestHomeServer, homeserver_to_use: Type[HomeServer] = TestHomeServer,
**kwargs: Any, db_txn_limit: Optional[int] = None,
**extra_homeserver_attributes: Any,
) -> HomeServer: ) -> HomeServer:
""" """
Setup a homeserver suitable for running tests against. Keyword arguments Setup a homeserver suitable for running tests against. Keyword arguments
@ -1034,29 +1035,41 @@ def setup_test_homeserver(
If no datastore is supplied, one is created and given to the homeserver. If no datastore is supplied, one is created and given to the homeserver.
Args: Args:
cleanup_func : The function used to register a cleanup routine for cleanup_func: The function used to register a cleanup routine for after the
after the test. test.
server_name: Homeserver name
config: Homeserver config
reactor: Twisted reactor
homeserver_to_use: Homeserver class to instantiate.
db_txn_limit: Gives the maximum number of database transactions to run per
connection before reconnecting. 0 means no limit. If unset, defaults to None
here which will default upstream to `0`.
**extra_homeserver_attributes: Additional keyword arguments to install as
`@cache_in_self` attributes on the homeserver. For example, `clock` will be
installed as `hs._clock`.
Calling this method directly is deprecated: you should instead derive from Calling this method directly is deprecated: you should instead derive from
HomeserverTestCase. HomeserverTestCase.
""" """
if reactor is None: if reactor is None:
from twisted.internet import reactor as _reactor reactor = ThreadedMemoryReactorClock()
reactor = cast(ISynapseReactor, _reactor)
if config is None: if config is None:
config = default_config(name, parse=True) config = default_config(server_name, parse=True)
server_name = config.server.server_name
if not isinstance(server_name, str):
raise ConfigError("Must be a string", ("server_name",))
if "clock" not in extra_homeserver_attributes:
extra_homeserver_attributes["clock"] = Clock(reactor, server_name=server_name)
config.caches.resize_all_caches() config.caches.resize_all_caches()
if "clock" not in kwargs:
kwargs["clock"] = MockClock()
if USE_POSTGRES_FOR_TESTS: if USE_POSTGRES_FOR_TESTS:
test_db = "synapse_test_%s" % uuid.uuid4().hex test_db = "synapse_test_%s" % uuid.uuid4().hex
database_config = { database_config: JsonDict = {
"name": "psycopg2", "name": "psycopg2",
"args": { "args": {
"dbname": test_db, "dbname": test_db,
@ -1088,10 +1101,6 @@ def setup_test_homeserver(
"args": {"database": test_db_location, "cp_min": 1, "cp_max": 1}, "args": {"database": test_db_location, "cp_min": 1, "cp_max": 1},
} }
server_name = config.server.server_name
if not isinstance(server_name, str):
raise ConfigError("Must be a string", ("server_name",))
# Check if we have set up a DB that we can use as a template. # Check if we have set up a DB that we can use as a template.
global PREPPED_SQLITE_DB_CONN global PREPPED_SQLITE_DB_CONN
if PREPPED_SQLITE_DB_CONN is None: if PREPPED_SQLITE_DB_CONN is None:
@ -1111,8 +1120,8 @@ def setup_test_homeserver(
database_config["_TEST_PREPPED_CONN"] = PREPPED_SQLITE_DB_CONN database_config["_TEST_PREPPED_CONN"] = PREPPED_SQLITE_DB_CONN
if "db_txn_limit" in kwargs: if db_txn_limit is not None:
database_config["txn_limit"] = kwargs["db_txn_limit"] database_config["txn_limit"] = db_txn_limit
database = DatabaseConnectionConfig("master", database_config) database = DatabaseConnectionConfig("master", database_config)
config.database.databases = [database] config.database.databases = [database]
@ -1139,7 +1148,7 @@ def setup_test_homeserver(
db_conn.close() db_conn.close()
hs = homeserver_to_use( hs = homeserver_to_use(
name, server_name,
config=config, config=config,
version_string="Synapse/tests", version_string="Synapse/tests",
reactor=reactor, reactor=reactor,
@ -1149,7 +1158,7 @@ def setup_test_homeserver(
cleanup_func(hs.cleanup) cleanup_func(hs.cleanup)
# Install @cache_in_self attributes # Install @cache_in_self attributes
for key, val in kwargs.items(): for key, val in extra_homeserver_attributes.items():
setattr(hs, "_" + key, val) setattr(hs, "_" + key, val)
# Mock TLS # Mock TLS

View File

@ -86,7 +86,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
conn_pool.runWithConnection = runWithConnection conn_pool.runWithConnection = runWithConnection
config = default_config(name="test", parse=True) config = default_config(server_name="test", parse=True)
hs = TestHomeServer("test", config=config) hs = TestHomeServer("test", config=config)
if USE_POSTGRES_FOR_TESTS: if USE_POSTGRES_FOR_TESTS:

View File

@ -55,9 +55,9 @@ class JsonResourceTests(unittest.TestCase):
reactor, clock = get_clock() reactor, clock = get_clock()
self.reactor = reactor self.reactor = reactor
self.homeserver = setup_test_homeserver( self.homeserver = setup_test_homeserver(
self.addCleanup, cleanup_func=self.addCleanup,
clock=clock,
reactor=self.reactor, reactor=self.reactor,
clock=clock,
) )
def test_handler_for_request(self) -> None: def test_handler_for_request(self) -> None:
@ -217,9 +217,9 @@ class OptionsResourceTests(unittest.TestCase):
reactor, clock = get_clock() reactor, clock = get_clock()
self.reactor = reactor self.reactor = reactor
self.homeserver = setup_test_homeserver( self.homeserver = setup_test_homeserver(
self.addCleanup, cleanup_func=self.addCleanup,
clock=clock,
reactor=self.reactor, reactor=self.reactor,
clock=clock,
) )
class DummyResource(Resource): class DummyResource(Resource):

View File

@ -29,7 +29,6 @@ from typing import (
Optional, Optional,
Set, Set,
Tuple, Tuple,
cast,
) )
from unittest.mock import AsyncMock, Mock from unittest.mock import AsyncMock, Mock
@ -43,12 +42,11 @@ from synapse.events.snapshot import EventContext
from synapse.state import StateHandler, StateResolutionHandler, _make_state_cache_entry from synapse.state import StateHandler, StateResolutionHandler, _make_state_cache_entry
from synapse.types import MutableStateMap, StateMap from synapse.types import MutableStateMap, StateMap
from synapse.types.state import StateFilter from synapse.types.state import StateFilter
from synapse.util.clock import Clock
from synapse.util.macaroons import MacaroonGenerator from synapse.util.macaroons import MacaroonGenerator
from tests import unittest from tests import unittest
from tests.server import get_clock
from .utils import MockClock, default_config from tests.utils import default_config
_next_event_id = 1000 _next_event_id = 1000
@ -248,7 +246,7 @@ class StateTestCase(unittest.TestCase):
"hostname", "hostname",
] ]
) )
clock = cast(Clock, MockClock()) reactor, clock = get_clock()
hs.config = default_config("tesths", True) hs.config = default_config("tesths", True)
hs.get_datastores.return_value = Mock( hs.get_datastores.return_value = Mock(
main=self.dummy_store, main=self.dummy_store,

View File

@ -1,79 +0,0 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
#
from tests import unittest
from tests.utils import MockClock
class MockClockTestCase(unittest.TestCase):
def setUp(self) -> None:
self.clock = MockClock()
def test_advance_time(self) -> None:
start_time = self.clock.time()
self.clock.advance_time(20)
self.assertEqual(20, self.clock.time() - start_time)
def test_later(self) -> None:
invoked = [0, 0]
def _cb0() -> None:
invoked[0] = 1
self.clock.call_later(10, _cb0)
def _cb1() -> None:
invoked[1] = 1
self.clock.call_later(20, _cb1)
self.assertFalse(invoked[0])
self.clock.advance_time(15)
self.assertTrue(invoked[0])
self.assertFalse(invoked[1])
self.clock.advance_time(5)
self.assertTrue(invoked[1])
def test_cancel_later(self) -> None:
invoked = [0, 0]
def _cb0() -> None:
invoked[0] = 1
t0 = self.clock.call_later(10, _cb0)
def _cb1() -> None:
invoked[1] = 1
self.clock.call_later(20, _cb1)
self.clock.cancel_call_later(t0)
self.clock.advance_time(30)
self.assertFalse(invoked[0])
self.assertTrue(invoked[1])

View File

@ -80,7 +80,7 @@ from synapse.logging.context import (
from synapse.rest import RegisterServletsFunc from synapse.rest import RegisterServletsFunc
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.keys import FetchKeyResult from synapse.storage.keys import FetchKeyResult
from synapse.types import JsonDict, Requester, UserID, create_requester from synapse.types import ISynapseReactor, JsonDict, Requester, UserID, create_requester
from synapse.util.clock import Clock from synapse.util.clock import Clock
from synapse.util.httpresourcetree import create_resource_tree from synapse.util.httpresourcetree import create_resource_tree
@ -99,6 +99,8 @@ from tests.utils import checked_cast, default_config, setupdb
setupdb() setupdb()
setup_logging() setup_logging()
logger = logging.getLogger(__name__)
TV = TypeVar("TV") TV = TypeVar("TV")
_ExcType = TypeVar("_ExcType", bound=BaseException, covariant=True) _ExcType = TypeVar("_ExcType", bound=BaseException, covariant=True)
@ -135,7 +137,7 @@ def around(target: TV) -> Callable[[Callable[Concatenate[S, P], R]], None]:
return _around return _around
_TConfig = TypeVar("_TConfig", Config, RootConfig) _TConfig = TypeVar("_TConfig", Config, HomeServerConfig)
def deepcopy_config(config: _TConfig) -> _TConfig: def deepcopy_config(config: _TConfig) -> _TConfig:
@ -161,13 +163,13 @@ def deepcopy_config(config: _TConfig) -> _TConfig:
@functools.lru_cache(maxsize=8) @functools.lru_cache(maxsize=8)
def _parse_config_dict(config: str) -> RootConfig: def _parse_config_dict(config: str) -> HomeServerConfig:
config_obj = HomeServerConfig() config_obj = HomeServerConfig()
config_obj.parse_config_dict(json.loads(config), "", "") config_obj.parse_config_dict(json.loads(config), "", "")
return config_obj return config_obj
def make_homeserver_config_obj(config: Dict[str, Any]) -> RootConfig: def make_homeserver_config_obj(config: Dict[str, Any]) -> HomeServerConfig:
"""Creates a :class:`HomeServerConfig` instance with the given configuration dict. """Creates a :class:`HomeServerConfig` instance with the given configuration dict.
This is equivalent to:: This is equivalent to::
@ -392,8 +394,8 @@ class HomeserverTestCase(TestCase):
hijacking the authentication system to return a fixed user, and then hijacking the authentication system to return a fixed user, and then
calling the prepare function. calling the prepare function.
""" """
# We need to share the reactor between the homeserver and all of our test utils.
self.reactor, self.clock = get_clock() self.reactor, self.clock = get_clock()
self._hs_args = {"clock": self.clock, "reactor": self.reactor}
self.hs = self.make_homeserver(self.reactor, self.clock) self.hs = self.make_homeserver(self.reactor, self.clock)
self.hs.get_datastores().main.tests_allow_no_chain_cover_index = False self.hs.get_datastores().main.tests_allow_no_chain_cover_index = False
@ -511,7 +513,7 @@ class HomeserverTestCase(TestCase):
Function to be overridden in subclasses. Function to be overridden in subclasses.
""" """
hs = self.setup_test_homeserver() hs = self.setup_test_homeserver(reactor=reactor, clock=clock)
return hs return hs
def create_test_resource(self) -> Resource: def create_test_resource(self) -> Resource:
@ -634,7 +636,12 @@ class HomeserverTestCase(TestCase):
) )
def setup_test_homeserver( def setup_test_homeserver(
self, server_name: Optional[str] = None, **kwargs: Any self,
server_name: Optional[str] = None,
config: Optional[JsonDict] = None,
reactor: Optional[ISynapseReactor] = None,
clock: Optional[Clock] = None,
**extra_homeserver_attributes: Any,
) -> HomeServer: ) -> HomeServer:
""" """
Set up the test homeserver, meant to be called by the overridable Set up the test homeserver, meant to be called by the overridable
@ -647,12 +654,15 @@ class HomeserverTestCase(TestCase):
Returns: Returns:
synapse.server.HomeServer synapse.server.HomeServer
""" """
kwargs = dict(kwargs) if config is None:
kwargs.update(self._hs_args)
if "config" not in kwargs:
config = self.default_config() config = self.default_config()
else:
config = kwargs["config"] # The sane default is to use the same reactor and clock as our other test utils
if reactor is None:
reactor = self.reactor
if clock is None:
clock = self.clock
# The server name can be specified using either the `name` argument or a config # The server name can be specified using either the `name` argument or a config
# override. The `name` argument takes precedence over any config overrides. # override. The `name` argument takes precedence over any config overrides.
@ -661,19 +671,24 @@ class HomeserverTestCase(TestCase):
# Parse the config from a config dict into a HomeServerConfig # Parse the config from a config dict into a HomeServerConfig
config_obj = make_homeserver_config_obj(config) config_obj = make_homeserver_config_obj(config)
kwargs["config"] = config_obj
# The server name in the config is now `name`, if provided, or the `server_name` # The server name in the config is now `name`, if provided, or the `server_name`
# from a config override, or the default of "test". Whichever it is, we # from a config override, or the default of "test". Whichever it is, we
# construct a homeserver with a matching name. # construct a homeserver with a matching name.
server_name = config_obj.server.server_name server_name = config_obj.server.server_name
kwargs["name"] = server_name
async def run_bg_updates() -> None: async def run_bg_updates() -> None:
with LoggingContext(name="run_bg_updates", server_name=server_name): with LoggingContext(name="run_bg_updates", server_name=server_name):
self.get_success(stor.db_pool.updates.run_background_updates(False)) self.get_success(stor.db_pool.updates.run_background_updates(False))
hs = setup_test_homeserver(self.addCleanup, **kwargs) hs = setup_test_homeserver(
cleanup_func=self.addCleanup,
server_name=server_name,
config=config_obj,
reactor=reactor,
clock=clock,
**extra_homeserver_attributes,
)
stor = hs.get_datastores().main stor = hs.get_datastores().main
# Run the database background updates, when running against "master". # Run the database background updates, when running against "master".

View File

@ -19,23 +19,22 @@
# #
# #
from typing import List, cast from typing import List
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.clock import Clock
from tests.utils import MockClock from tests.server import get_clock
from .. import unittest from .. import unittest
class ExpiringCacheTestCase(unittest.HomeserverTestCase): class ExpiringCacheTestCase(unittest.HomeserverTestCase):
def test_get_set(self) -> None: def test_get_set(self) -> None:
clock = MockClock() reactor, clock = get_clock()
cache: ExpiringCache[str, str] = ExpiringCache( cache: ExpiringCache[str, str] = ExpiringCache(
cache_name="test", cache_name="test",
server_name="testserver", server_name="testserver",
clock=cast(Clock, clock), clock=clock,
max_len=1, max_len=1,
) )
@ -44,11 +43,11 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(cache["key"], "value") self.assertEqual(cache["key"], "value")
def test_eviction(self) -> None: def test_eviction(self) -> None:
clock = MockClock() reactor, clock = get_clock()
cache: ExpiringCache[str, str] = ExpiringCache( cache: ExpiringCache[str, str] = ExpiringCache(
cache_name="test", cache_name="test",
server_name="testserver", server_name="testserver",
clock=cast(Clock, clock), clock=clock,
max_len=2, max_len=2,
) )
@ -63,11 +62,11 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(cache.get("key3"), "value3") self.assertEqual(cache.get("key3"), "value3")
def test_iterable_eviction(self) -> None: def test_iterable_eviction(self) -> None:
clock = MockClock() reactor, clock = get_clock()
cache: ExpiringCache[str, List[int]] = ExpiringCache( cache: ExpiringCache[str, List[int]] = ExpiringCache(
cache_name="test", cache_name="test",
server_name="testserver", server_name="testserver",
clock=cast(Clock, clock), clock=clock,
max_len=5, max_len=5,
iterable=True, iterable=True,
) )
@ -87,25 +86,25 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(cache.get("key4"), [6, 7]) self.assertEqual(cache.get("key4"), [6, 7])
def test_time_eviction(self) -> None: def test_time_eviction(self) -> None:
clock = MockClock() reactor, clock = get_clock()
cache: ExpiringCache[str, int] = ExpiringCache( cache: ExpiringCache[str, int] = ExpiringCache(
cache_name="test", cache_name="test",
server_name="testserver", server_name="testserver",
clock=cast(Clock, clock), clock=clock,
expiry_ms=1000, expiry_ms=1000,
) )
cache["key"] = 1 cache["key"] = 1
clock.advance_time(0.5) reactor.advance(0.5)
cache["key2"] = 2 cache["key2"] = 2
self.assertEqual(cache.get("key"), 1) self.assertEqual(cache.get("key"), 1)
self.assertEqual(cache.get("key2"), 2) self.assertEqual(cache.get("key2"), 2)
clock.advance_time(0.9) reactor.advance(0.9)
self.assertEqual(cache.get("key"), None) self.assertEqual(cache.get("key"), None)
self.assertEqual(cache.get("key2"), 2) self.assertEqual(cache.get("key2"), 2)
clock.advance_time(1) reactor.advance(1)
self.assertEqual(cache.get("key"), None) self.assertEqual(cache.get("key"), None)
self.assertEqual(cache.get("key2"), None) self.assertEqual(cache.get("key2"), None)

View File

@ -24,27 +24,19 @@ import os
import signal import signal
from types import FrameType, TracebackType from types import FrameType, TracebackType
from typing import ( from typing import (
Any,
Callable,
Dict, Dict,
List,
Literal, Literal,
Optional, Optional,
Tuple,
Type, Type,
TypeVar, TypeVar,
Union, Union,
overload, overload,
) )
import attr
from typing_extensions import ParamSpec
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.config.server import DEFAULT_ROOM_VERSION
from synapse.logging.context import current_context, set_current_context
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.database import LoggingDatabaseConnection from synapse.storage.database import LoggingDatabaseConnection
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
@ -140,21 +132,27 @@ def setupdb() -> None:
@overload @overload
def default_config(name: str, parse: Literal[False] = ...) -> Dict[str, object]: ... def default_config(
server_name: str, parse: Literal[False] = ...
) -> Dict[str, object]: ...
@overload @overload
def default_config(name: str, parse: Literal[True]) -> HomeServerConfig: ... def default_config(server_name: str, parse: Literal[True]) -> HomeServerConfig: ...
def default_config( def default_config(
name: str, parse: bool = False server_name: str, parse: bool = False
) -> Union[Dict[str, object], HomeServerConfig]: ) -> Union[Dict[str, object], HomeServerConfig]:
""" """
Create a reasonable test config. Create a reasonable test config.
Args:
server_name: homeserver name
parse: TODO
""" """
config_dict = { config_dict = {
"server_name": name, "server_name": server_name,
# Setting this to an empty list turns off federation sending. # Setting this to an empty list turns off federation sending.
"federation_sender_instances": [], "federation_sender_instances": [],
"media_store_path": "media", "media_store_path": "media",
@ -247,101 +245,6 @@ def mock_getRawHeaders(headers=None): # type: ignore[no-untyped-def]
return getRawHeaders return getRawHeaders
P = ParamSpec("P")
@attr.s(slots=True, auto_attribs=True)
class Timer:
absolute_time: float
callback: Callable[[], None]
expired: bool
# TODO: Make this generic over a ParamSpec?
@attr.s(slots=True, auto_attribs=True)
class Looper:
func: Callable[..., Any]
interval: float # seconds
last: float
args: Tuple[object, ...]
kwargs: Dict[str, object]
class MockClock:
now = 1000.0
def __init__(self) -> None:
# Timers in no particular order
self.timers: List[Timer] = []
self.loopers: List[Looper] = []
def time(self) -> float:
return self.now
def time_msec(self) -> int:
return int(self.time() * 1000)
def call_later(
self,
delay: float,
callback: Callable[P, object],
*args: P.args,
**kwargs: P.kwargs,
) -> Timer:
ctx = current_context()
def wrapped_callback() -> None:
set_current_context(ctx)
callback(*args, **kwargs)
t = Timer(self.now + delay, wrapped_callback, False)
self.timers.append(t)
return t
def looping_call(
self,
function: Callable[P, object],
interval: float,
*args: P.args,
**kwargs: P.kwargs,
) -> None:
self.loopers.append(Looper(function, interval / 1000.0, self.now, args, kwargs))
def cancel_call_later(self, timer: Timer, ignore_errs: bool = False) -> None:
if timer.expired:
if not ignore_errs:
raise Exception("Cannot cancel an expired timer")
timer.expired = True
self.timers = [t for t in self.timers if t != timer]
# For unit testing
def advance_time(self, secs: float) -> None:
self.now += secs
timers = self.timers
self.timers = []
for t in timers:
if t.expired:
raise Exception("Timer already expired")
if self.now >= t.absolute_time:
t.expired = True
t.callback()
else:
self.timers.append(t)
for looped in self.loopers:
if looped.last + looped.interval < self.now:
looped.func(*looped.args, **looped.kwargs)
looped.last = self.now
def advance_time_msec(self, ms: float) -> None:
self.advance_time(ms / 1000.0)
async def create_room(hs: HomeServer, room_id: str, creator_id: str) -> None: async def create_room(hs: HomeServer, room_id: str, creator_id: str) -> None:
"""Creates and persist a creation event for the given room""" """Creates and persist a creation event for the given room"""