Compare commits

...

22 Commits

Author SHA1 Message Date
Eric Eastwood
e3770c7209
Merge 90780763c874f9ce47dca68e03a8f38a236fc515 into 24bcdb3f3c5aa2d8d2000dc34d54a8002a914616 2025-07-01 11:36:10 -05:00
Eric Eastwood
90780763c8 Fix ApplicationService sender usage in test2 2025-06-30 18:06:14 -05:00
Eric Eastwood
f7e6f0967f Fix CacheMetric splatting label objects as arguments 2025-06-30 18:05:07 -05:00
Eric Eastwood
b3ecd5cd88 Fix ApplicationService sender usage in test 2025-06-30 18:03:03 -05:00
Eric Eastwood
e943bb12fe Fix more ApplicationService usage/mocks 2025-06-30 17:58:56 -05:00
Eric Eastwood
64ed156532 Fill in server_name attribute for @cached 2025-06-30 17:32:20 -05:00
Eric Eastwood
1917a0bc93 Fill in server_name attribute for ApplicationService (for @cached) 2025-06-30 17:19:35 -05:00
Eric Eastwood
ee91f6b00d Better explain usage 2025-06-30 16:46:23 -05:00
Eric Eastwood
9eae037c1b Fix mypy complaining about unknown types
```
synapse/replication/tcp/streams/_base.py:568: error: Cannot determine type of "_device_list_id_gen"  [has-type]
synapse/storage/databases/main/event_push_actions.py:256: error: Cannot determine type of "server_name" in base class "ReceiptsWorkerStore"  [misc]
synapse/storage/databases/main/event_push_actions.py:256: error: Cannot determine type of "server_name" in base class "EventsWorkerStore"  [misc]
synapse/storage/databases/main/event_push_actions.py:256: error: Cannot determine type of "_instance_name" in base class "ReceiptsWorkerStore"  [misc]
synapse/storage/databases/main/metrics.py:64: error: Cannot determine type of "server_name" in base class "ReceiptsWorkerStore"  [misc]
synapse/storage/databases/main/metrics.py:64: error: Cannot determine type of "server_name" in base class "EventsWorkerStore"  [misc]
synapse/storage/databases/main/metrics.py:64: error: Cannot determine type of "_instance_name" in base class "ReceiptsWorkerStore"  [misc]
synapse/storage/databases/main/push_rule.py:118: error: Cannot determine type of "_instance_name" in base class "ReceiptsWorkerStore"  [misc]
synapse/storage/databases/main/push_rule.py:118: error: Cannot determine type of "server_name" in base class "ReceiptsWorkerStore"  [misc]
synapse/storage/databases/main/push_rule.py:118: error: Cannot determine type of "server_name" in base class "EventsWorkerStore"  [misc]
synapse/storage/databases/main/account_data.py:60: error: Cannot determine type of "_instance_name" in base class "ReceiptsWorkerStore"  [misc]
synapse/storage/databases/main/account_data.py:60: error: Cannot determine type of "server_name" in base class "ReceiptsWorkerStore"  [misc]
synapse/storage/databases/main/account_data.py:60: error: Cannot determine type of "server_name" in base class "EventsWorkerStore"  [misc]
synapse/storage/databases/main/__init__.py:114: error: Cannot determine type of "server_name" in base class "PresenceStore"  [misc]
synapse/storage/databases/main/__init__.py:114: error: Cannot determine type of "server_name" in base class "ReceiptsWorkerStore"  [misc]
synapse/storage/databases/main/__init__.py:114: error: Cannot determine type of "server_name" in base class "ClientIpWorkerStore"  [misc]
synapse/storage/databases/main/__init__.py:114: error: Cannot determine type of "server_name" in base class "DeviceInboxWorkerStore"  [misc]
synapse/storage/databases/main/__init__.py:114: error: Cannot determine type of "server_name" in base class "EventsWorkerStore"  [misc]
synapse/storage/databases/main/__init__.py:114: error: Cannot determine type of "_instance_name" in base class "ReceiptsWorkerStore"  [misc]
synapse/storage/databases/main/__init__.py:114: error: Cannot determine type of "_instance_name" in base class "DeviceInboxWorkerStore"  [misc]
synapse/app/generic_worker.py:117: error: Cannot determine type of "_instance_name" in base class "DeviceInboxWorkerStore"  [misc]
synapse/app/generic_worker.py:117: error: Cannot determine type of "_instance_name" in base class "ReceiptsWorkerStore"  [misc]
Found 22 errors in 7 files (checked 937 source files)
```
2025-06-30 16:42:16 -05:00
Eric Eastwood
9895b3b726 Fill in missing LruCache usage 2025-06-30 15:53:00 -05:00
Eric Eastwood
d10d862ae6 Add changelog 2025-06-30 14:55:10 -05:00
Eric Eastwood
4fcfda0256 Merge branch 'develop' into madlittlemods/per-hs-metrics-cache 2025-06-30 14:54:39 -05:00
Eric Eastwood
a17206f564 Fix arguments in DeferredCache usage 2025-06-27 17:46:01 -05:00
Eric Eastwood
0453666448 Fix missing server_name on ExpiringCache usage 2025-06-27 17:42:55 -05:00
Eric Eastwood
19c917cace Attempt @cached solution v1 2025-06-27 17:41:59 -05:00
Eric Eastwood
1e57b57e29 Fix LruCache positional argument lint 2025-06-27 17:16:33 -05:00
Eric Eastwood
8dbca87f44 Fill in LruCache except for @cached 2025-06-27 17:12:29 -05:00
Eric Eastwood
74610aabd2 Fill in TTLCache 2025-06-27 16:49:57 -05:00
Eric Eastwood
61fc9ba52a Fill in StreamChangeCache 2025-06-27 16:17:58 -05:00
Eric Eastwood
8e71fcdb82 Fill in ResponseCache 2025-06-27 16:06:44 -05:00
Eric Eastwood
749b7a493c Fill in ExpiringCache 2025-06-27 15:58:49 -05:00
Eric Eastwood
ba3bbbb13a Add INSTANCE_LABEL_NAME to register_cache(...) 2025-06-27 15:49:06 -05:00
89 changed files with 705 additions and 244 deletions

1
changelog.d/18604.misc Normal file
View File

@ -0,0 +1 @@
Refactor cache metrics to be homeserver-scoped.

View File

@ -172,7 +172,7 @@ class BaseAuth:
""" """
# It's ok if the app service is trying to use the sender from their registration # It's ok if the app service is trying to use the sender from their registration
if app_service.sender == user_id: if app_service.sender.to_string() == user_id:
pass pass
# Check to make sure the app service is allowed to control the user # Check to make sure the app service is allowed to control the user
elif not app_service.is_interested_in_user(user_id): elif not app_service.is_interested_in_user(user_id):

View File

@ -176,6 +176,7 @@ class MSC3861DelegatedAuth(BaseAuth):
assert self._config.client_id, "No client_id provided" assert self._config.client_id, "No client_id provided"
assert auth_method is not None, "Invalid client_auth_method provided" assert auth_method is not None, "Invalid client_auth_method provided"
self.server_name = hs.hostname
self._clock = hs.get_clock() self._clock = hs.get_clock()
self._http_client = hs.get_proxied_http_client() self._http_client = hs.get_proxied_http_client()
self._hostname = hs.hostname self._hostname = hs.hostname
@ -206,8 +207,9 @@ class MSC3861DelegatedAuth(BaseAuth):
# In this case, the device still exists and it's not the end of the world for # In this case, the device still exists and it's not the end of the world for
# the old access token to continue working for a short time. # the old access token to continue working for a short time.
self._introspection_cache: ResponseCache[str] = ResponseCache( self._introspection_cache: ResponseCache[str] = ResponseCache(
self._clock, clock=self._clock,
"token_introspection", name="token_introspection",
server_name=self.server_name,
timeout_ms=120_000, timeout_ms=120_000,
# don't log because the keys are access tokens # don't log because the keys are access tokens
enable_logging=False, enable_logging=False,

View File

@ -78,7 +78,7 @@ class ApplicationService:
self, self,
token: str, token: str,
id: str, id: str,
sender: str, sender: UserID,
url: Optional[str] = None, url: Optional[str] = None,
namespaces: Optional[JsonDict] = None, namespaces: Optional[JsonDict] = None,
hs_token: Optional[str] = None, hs_token: Optional[str] = None,
@ -96,6 +96,8 @@ class ApplicationService:
self.hs_token = hs_token self.hs_token = hs_token
# The full Matrix ID for this application service's sender. # The full Matrix ID for this application service's sender.
self.sender = sender self.sender = sender
# The application service user should be part of the server's domain.
self.server_name = sender.domain
self.namespaces = self._check_namespaces(namespaces) self.namespaces = self._check_namespaces(namespaces)
self.id = id self.id = id
self.ip_range_whitelist = ip_range_whitelist self.ip_range_whitelist = ip_range_whitelist
@ -223,7 +225,7 @@ class ApplicationService:
""" """
return ( return (
# User is the appservice's configured sender_localpart user # User is the appservice's configured sender_localpart user
user_id == self.sender user_id == self.sender.to_string()
# User is in the appservice's user namespace # User is in the appservice's user namespace
or self.is_user_in_namespace(user_id) or self.is_user_in_namespace(user_id)
) )
@ -347,7 +349,7 @@ class ApplicationService:
def is_exclusive_user(self, user_id: str) -> bool: def is_exclusive_user(self, user_id: str) -> bool:
return ( return (
self._is_exclusive(ApplicationService.NS_USERS, user_id) self._is_exclusive(ApplicationService.NS_USERS, user_id)
or user_id == self.sender or user_id == self.sender.to_string()
) )
def is_interested_in_protocol(self, protocol: str) -> bool: def is_interested_in_protocol(self, protocol: str) -> bool:

View File

@ -126,11 +126,15 @@ class ApplicationServiceApi(SimpleHttpClient):
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.server_name = hs.hostname
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.config = hs.config.appservice self.config = hs.config.appservice
self.protocol_meta_cache: ResponseCache[Tuple[str, str]] = ResponseCache( self.protocol_meta_cache: ResponseCache[Tuple[str, str]] = ResponseCache(
hs.get_clock(), "as_protocol_meta", timeout_ms=HOUR_IN_MS clock=hs.get_clock(),
name="as_protocol_meta",
server_name=self.server_name,
timeout_ms=HOUR_IN_MS,
) )
def _get_headers(self, service: "ApplicationService") -> Dict[bytes, List[bytes]]: def _get_headers(self, service: "ApplicationService") -> Dict[bytes, List[bytes]]:

View File

@ -319,7 +319,7 @@ class _ServiceQueuer:
users: Set[str] = set() users: Set[str] = set()
# The sender is always included # The sender is always included
users.add(service.sender) users.add(service.sender.to_string())
# All AS users that would receive the PDUs or EDUs sent to these rooms # All AS users that would receive the PDUs or EDUs sent to these rooms
# are classed as 'interesting'. # are classed as 'interesting'.

View File

@ -122,8 +122,7 @@ def _load_appservice(
localpart = as_info["sender_localpart"] localpart = as_info["sender_localpart"]
if urlparse.quote(localpart) != localpart: if urlparse.quote(localpart) != localpart:
raise ValueError("sender_localpart needs characters which are not URL encoded.") raise ValueError("sender_localpart needs characters which are not URL encoded.")
user = UserID(localpart, hostname) user_id = UserID(localpart, hostname)
user_id = user.to_string()
# Rate limiting for users of this AS is on by default (excludes sender) # Rate limiting for users of this AS is on by default (excludes sender)
rate_limited = as_info.get("rate_limited") rate_limited = as_info.get("rate_limited")

View File

@ -137,13 +137,14 @@ class FederationClient(FederationBase):
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.transport_layer = hs.get_federation_transport_client() self.transport_layer = hs.get_federation_transport_client()
self.hostname = hs.hostname self.server_name = hs.hostname
self.signing_key = hs.signing_key self.signing_key = hs.signing_key
# Cache mapping `event_id` to a tuple of the event itself and the `pull_origin` # Cache mapping `event_id` to a tuple of the event itself and the `pull_origin`
# (which server we pulled the event from) # (which server we pulled the event from)
self._get_pdu_cache: ExpiringCache[str, Tuple[EventBase, str]] = ExpiringCache( self._get_pdu_cache: ExpiringCache[str, Tuple[EventBase, str]] = ExpiringCache(
cache_name="get_pdu_cache", cache_name="get_pdu_cache",
server_name=self.server_name,
clock=self._clock, clock=self._clock,
max_len=1000, max_len=1000,
expiry_ms=120 * 1000, expiry_ms=120 * 1000,
@ -162,6 +163,7 @@ class FederationClient(FederationBase):
Tuple[JsonDict, Sequence[JsonDict], Sequence[JsonDict], Sequence[str]], Tuple[JsonDict, Sequence[JsonDict], Sequence[JsonDict], Sequence[str]],
] = ExpiringCache( ] = ExpiringCache(
cache_name="get_room_hierarchy_cache", cache_name="get_room_hierarchy_cache",
server_name=self.server_name,
clock=self._clock, clock=self._clock,
max_len=1000, max_len=1000,
expiry_ms=5 * 60 * 1000, expiry_ms=5 * 60 * 1000,
@ -1068,7 +1070,7 @@ class FederationClient(FederationBase):
# there's some we never care about # there's some we never care about
ev = builder.create_local_event_from_event_dict( ev = builder.create_local_event_from_event_dict(
self._clock, self._clock,
self.hostname, self.server_name,
self.signing_key, self.signing_key,
room_version=room_version, room_version=room_version,
event_dict=pdu_dict, event_dict=pdu_dict,

View File

@ -160,7 +160,10 @@ class FederationServer(FederationBase):
# We cache results for transaction with the same ID # We cache results for transaction with the same ID
self._transaction_resp_cache: ResponseCache[Tuple[str, str]] = ResponseCache( self._transaction_resp_cache: ResponseCache[Tuple[str, str]] = ResponseCache(
hs.get_clock(), "fed_txn_handler", timeout_ms=30000 clock=hs.get_clock(),
name="fed_txn_handler",
server_name=self.server_name,
timeout_ms=30000,
) )
self.transaction_actions = TransactionActions(self.store) self.transaction_actions = TransactionActions(self.store)
@ -170,10 +173,18 @@ class FederationServer(FederationBase):
# We cache responses to state queries, as they take a while and often # We cache responses to state queries, as they take a while and often
# come in waves. # come in waves.
self._state_resp_cache: ResponseCache[Tuple[str, Optional[str]]] = ( self._state_resp_cache: ResponseCache[Tuple[str, Optional[str]]] = (
ResponseCache(hs.get_clock(), "state_resp", timeout_ms=30000) ResponseCache(
clock=hs.get_clock(),
name="state_resp",
server_name=self.server_name,
timeout_ms=30000,
)
) )
self._state_ids_resp_cache: ResponseCache[Tuple[str, str]] = ResponseCache( self._state_ids_resp_cache: ResponseCache[Tuple[str, str]] = ResponseCache(
hs.get_clock(), "state_ids_resp", timeout_ms=30000 clock=hs.get_clock(),
name="state_ids_resp",
server_name=self.server_name,
timeout_ms=30000,
) )
self._federation_metrics_domains = ( self._federation_metrics_domains = (

View File

@ -839,7 +839,7 @@ class ApplicationServicesHandler:
# user not found; could be the AS though, so check. # user not found; could be the AS though, so check.
services = self.store.get_app_services() services = self.store.get_app_services()
service_list = [s for s in services if s.sender == user_id] service_list = [s for s in services if s.sender.to_string() == user_id]
return len(service_list) == 0 return len(service_list) == 0
async def _check_user_exists(self, user_id: str) -> bool: async def _check_user_exists(self, user_id: str) -> bool:

View File

@ -1213,6 +1213,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
"Handles incoming device list updates from federation and updates the DB" "Handles incoming device list updates from federation and updates the DB"
def __init__(self, hs: "HomeServer", device_handler: DeviceHandler): def __init__(self, hs: "HomeServer", device_handler: DeviceHandler):
self.server_name = hs.hostname
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.federation = hs.get_federation_client() self.federation = hs.get_federation_client()
self.clock = hs.get_clock() self.clock = hs.get_clock()
@ -1232,6 +1233,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
# resyncs. # resyncs.
self._seen_updates: ExpiringCache[str, Set[str]] = ExpiringCache( self._seen_updates: ExpiringCache[str, Set[str]] = ExpiringCache(
cache_name="device_update_edu", cache_name="device_update_edu",
server_name=self.server_name,
clock=self.clock, clock=self.clock,
max_len=10000, max_len=10000,
expiry_ms=30 * 60 * 1000, expiry_ms=30 * 60 * 1000,

View File

@ -406,7 +406,7 @@ class DirectoryHandler:
] ]
for service in interested_services: for service in interested_services:
if user_id == service.sender: if user_id == service.sender.to_string():
# this user IS the app service so they can do whatever they like # this user IS the app service so they can do whatever they like
return True return True
elif service.is_exclusive_alias(alias.to_string()): elif service.is_exclusive_alias(alias.to_string()):

View File

@ -60,6 +60,7 @@ logger = logging.getLogger(__name__)
class InitialSyncHandler: class InitialSyncHandler:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.server_name = hs.hostname
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.state_handler = hs.get_state_handler() self.state_handler = hs.get_state_handler()
@ -77,7 +78,11 @@ class InitialSyncHandler:
bool, bool,
bool, bool,
] ]
] = ResponseCache(hs.get_clock(), "initial_sync_cache") ] = ResponseCache(
clock=hs.get_clock(),
name="initial_sync_cache",
server_name=self.server_name,
)
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
self._storage_controllers = hs.get_storage_controllers() self._storage_controllers = hs.get_storage_controllers()
self._state_storage_controller = self._storage_controllers.state self._state_storage_controller = self._storage_controllers.state

View File

@ -558,8 +558,9 @@ class EventCreationHandler:
self._external_cache_joined_hosts_updates: Optional[ExpiringCache] = None self._external_cache_joined_hosts_updates: Optional[ExpiringCache] = None
if self._external_cache.is_enabled(): if self._external_cache.is_enabled():
self._external_cache_joined_hosts_updates = ExpiringCache( self._external_cache_joined_hosts_updates = ExpiringCache(
"_external_cache_joined_hosts_updates", cache_name="_external_cache_joined_hosts_updates",
self.clock, server_name=self.server_name,
clock=self.clock,
expiry_ms=30 * 60 * 1000, expiry_ms=30 * 60 * 1000,
) )

View File

@ -55,6 +55,7 @@ class ProfileHandler:
""" """
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.server_name = hs.hostname
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.hs = hs self.hs = hs

View File

@ -118,6 +118,7 @@ class EventContext:
class RoomCreationHandler: class RoomCreationHandler:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.server_name = hs.hostname
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers() self._storage_controllers = hs.get_storage_controllers()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@ -174,7 +175,10 @@ class RoomCreationHandler:
# succession, only process the first attempt and return its result to # succession, only process the first attempt and return its result to
# subsequent requests # subsequent requests
self._upgrade_response_cache: ResponseCache[Tuple[str, str]] = ResponseCache( self._upgrade_response_cache: ResponseCache[Tuple[str, str]] = ResponseCache(
hs.get_clock(), "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS clock=hs.get_clock(),
name="room_upgrade",
server_name=self.server_name,
timeout_ms=FIVE_MINUTES_IN_MS,
) )
self._server_notices_mxid = hs.config.servernotices.server_notices_mxid self._server_notices_mxid = hs.config.servernotices.server_notices_mxid

View File

@ -61,16 +61,26 @@ MAX_PUBLIC_ROOMS_IN_RESPONSE = 100
class RoomListHandler: class RoomListHandler:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.server_name = hs.hostname
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers() self._storage_controllers = hs.get_storage_controllers()
self.hs = hs self.hs = hs
self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search
self.response_cache: ResponseCache[ self.response_cache: ResponseCache[
Tuple[Optional[int], Optional[str], Optional[ThirdPartyInstanceID]] Tuple[Optional[int], Optional[str], Optional[ThirdPartyInstanceID]]
] = ResponseCache(hs.get_clock(), "room_list") ] = ResponseCache(
clock=hs.get_clock(),
name="room_list",
server_name=self.server_name,
)
self.remote_response_cache: ResponseCache[ self.remote_response_cache: ResponseCache[
Tuple[str, Optional[int], Optional[str], bool, Optional[str]] Tuple[str, Optional[int], Optional[str], bool, Optional[str]]
] = ResponseCache(hs.get_clock(), "remote_room_list", timeout_ms=30 * 1000) ] = ResponseCache(
clock=hs.get_clock(),
name="remote_room_list",
server_name=self.server_name,
timeout_ms=30 * 1000,
)
async def get_local_public_room_list( async def get_local_public_room_list(
self, self,

View File

@ -96,6 +96,7 @@ class RoomSummaryHandler:
_PAGINATION_SESSION_VALIDITY_PERIOD_MS = 5 * 60 * 1000 _PAGINATION_SESSION_VALIDITY_PERIOD_MS = 5 * 60 * 1000
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.server_name = hs.hostname
self._event_auth_handler = hs.get_event_auth_handler() self._event_auth_handler = hs.get_event_auth_handler()
self._store = hs.get_datastores().main self._store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers() self._storage_controllers = hs.get_storage_controllers()
@ -121,8 +122,9 @@ class RoomSummaryHandler:
Optional[Tuple[str, ...]], Optional[Tuple[str, ...]],
] ]
] = ResponseCache( ] = ResponseCache(
hs.get_clock(), clock=hs.get_clock(),
"get_room_hierarchy", name="get_room_hierarchy",
server_name=self.server_name,
) )
self._msc3266_enabled = hs.config.experimental.msc3266_enabled self._msc3266_enabled = hs.config.experimental.msc3266_enabled

View File

@ -329,6 +329,7 @@ class E2eeSyncResult:
class SyncHandler: class SyncHandler:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.server_name = hs.hostname
self.hs_config = hs.config self.hs_config = hs.config
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
@ -352,8 +353,9 @@ class SyncHandler:
# cached result any more, and we could flush the entry from the cache to save # cached result any more, and we could flush the entry from the cache to save
# memory. # memory.
self.response_cache: ResponseCache[SyncRequestKey] = ResponseCache( self.response_cache: ResponseCache[SyncRequestKey] = ResponseCache(
hs.get_clock(), clock=hs.get_clock(),
"sync", name="sync",
server_name=self.server_name,
timeout_ms=hs.config.caches.sync_response_cache_duration, timeout_ms=hs.config.caches.sync_response_cache_duration,
) )
@ -361,8 +363,9 @@ class SyncHandler:
self.lazy_loaded_members_cache: ExpiringCache[ self.lazy_loaded_members_cache: ExpiringCache[
Tuple[str, Optional[str]], LruCache[str, str] Tuple[str, Optional[str]], LruCache[str, str]
] = ExpiringCache( ] = ExpiringCache(
"lazy_loaded_members_cache", cache_name="lazy_loaded_members_cache",
self.clock, server_name=self.server_name,
clock=self.clock,
max_len=0, max_len=0,
expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE, expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE,
) )
@ -1129,7 +1132,7 @@ class SyncHandler:
) )
if cache is None: if cache is None:
logger.debug("creating LruCache for %r", cache_key) logger.debug("creating LruCache for %r", cache_key)
cache = LruCache(LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE) cache = LruCache(max_size=LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE)
self.lazy_loaded_members_cache[cache_key] = cache self.lazy_loaded_members_cache[cache_key] = cache
else: else:
logger.debug("found LruCache for %r", cache_key) logger.debug("found LruCache for %r", cache_key)

View File

@ -263,6 +263,7 @@ class TypingWriterHandler(FollowerTypingHandler):
assert hs.get_instance_name() in hs.config.worker.writers.typing assert hs.get_instance_name() in hs.config.worker.writers.typing
self.server_name = hs.hostname
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.event_auth_handler = hs.get_event_auth_handler() self.event_auth_handler = hs.get_event_auth_handler()
@ -280,7 +281,9 @@ class TypingWriterHandler(FollowerTypingHandler):
# caches which room_ids changed at which serials # caches which room_ids changed at which serials
self._typing_stream_change_cache = StreamChangeCache( self._typing_stream_change_cache = StreamChangeCache(
"TypingStreamChangeCache", self._latest_room_serial name="TypingStreamChangeCache",
server_name=self.server_name,
current_stream_pos=self._latest_room_serial,
) )
def _handle_timeout_for_member(self, now: int, member: RoomMember) -> None: def _handle_timeout_for_member(self, now: int, member: RoomMember) -> None:

View File

@ -97,9 +97,23 @@ class MatrixFederationAgent:
user_agent: bytes, user_agent: bytes,
ip_allowlist: Optional[IPSet], ip_allowlist: Optional[IPSet],
ip_blocklist: IPSet, ip_blocklist: IPSet,
server_name: str,
_srv_resolver: Optional[SrvResolver] = None, _srv_resolver: Optional[SrvResolver] = None,
_well_known_resolver: Optional[WellKnownResolver] = None, _well_known_resolver: Optional[WellKnownResolver] = None,
): ):
"""
Args:
reactor
tls_client_options_factory
user_agent
ip_allowlist
ip_blocklist
server_name: The homeserver name running this resolver
(used to label metrics) (`hs.hostname`).
_srv_resolver
_well_known_resolver
"""
# proxy_reactor is not blocklisting reactor # proxy_reactor is not blocklisting reactor
proxy_reactor = reactor proxy_reactor = reactor
@ -139,6 +153,7 @@ class MatrixFederationAgent:
ip_blocklist=ip_blocklist, ip_blocklist=ip_blocklist,
), ),
user_agent=self.user_agent, user_agent=self.user_agent,
server_name=server_name,
) )
self._well_known_resolver = _well_known_resolver self._well_known_resolver = _well_known_resolver

View File

@ -77,10 +77,6 @@ WELL_KNOWN_RETRY_ATTEMPTS = 3
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_well_known_cache: TTLCache[bytes, Optional[bytes]] = TTLCache("well-known")
_had_valid_well_known_cache: TTLCache[bytes, bool] = TTLCache("had-valid-well-known")
@attr.s(slots=True, frozen=True, auto_attribs=True) @attr.s(slots=True, frozen=True, auto_attribs=True)
class WellKnownLookupResult: class WellKnownLookupResult:
delegated_server: Optional[bytes] delegated_server: Optional[bytes]
@ -94,17 +90,33 @@ class WellKnownResolver:
reactor: IReactorTime, reactor: IReactorTime,
agent: IAgent, agent: IAgent,
user_agent: bytes, user_agent: bytes,
server_name: str,
well_known_cache: Optional[TTLCache[bytes, Optional[bytes]]] = None, well_known_cache: Optional[TTLCache[bytes, Optional[bytes]]] = None,
had_well_known_cache: Optional[TTLCache[bytes, bool]] = None, had_well_known_cache: Optional[TTLCache[bytes, bool]] = None,
): ):
"""
Args:
reactor
agent
user_agent
server_name: The homeserver name running this resolver
(used to label metrics) (`hs.hostname`).
well_known_cache
had_well_known_cache
"""
self._reactor = reactor self._reactor = reactor
self._clock = Clock(reactor) self._clock = Clock(reactor)
if well_known_cache is None: if well_known_cache is None:
well_known_cache = _well_known_cache well_known_cache = TTLCache(
cache_name="well-known", server_name=server_name
)
if had_well_known_cache is None: if had_well_known_cache is None:
had_well_known_cache = _had_valid_well_known_cache had_well_known_cache = TTLCache(
cache_name="had-valid-well-known", server_name=server_name
)
self._well_known_cache = well_known_cache self._well_known_cache = well_known_cache
self._had_valid_well_known_cache = had_well_known_cache self._had_valid_well_known_cache = had_well_known_cache

View File

@ -422,6 +422,7 @@ class MatrixFederationHttpClient:
user_agent.encode("ascii"), user_agent.encode("ascii"),
hs.config.server.federation_ip_range_allowlist, hs.config.server.federation_ip_range_allowlist,
hs.config.server.federation_ip_range_blocklist, hs.config.server.federation_ip_range_blocklist,
server_name=self.server_name,
) )
else: else:
proxy_authorization_secret = hs.config.worker.worker_replication_secret proxy_authorization_secret = hs.config.worker.worker_replication_secret

View File

@ -200,6 +200,7 @@ class UrlPreviewer:
# JSON-encoded OG metadata # JSON-encoded OG metadata
self._cache: ExpiringCache[str, ObservableDeferred] = ExpiringCache( self._cache: ExpiringCache[str, ObservableDeferred] = ExpiringCache(
cache_name="url_previews", cache_name="url_previews",
server_name=self.server_name,
clock=self.clock, clock=self.clock,
# don't spider URLs more often than once an hour # don't spider URLs more often than once an hour
expiry_ms=ONE_HOUR, expiry_ms=ONE_HOUR,

View File

@ -66,6 +66,17 @@ all_gauges: Dict[str, Collector] = {}
HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat") HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat")
INSTANCE_LABEL_NAME = "instance"
"""
The standard Prometheus label name used to identify which server instance the metrics
came from.
In the case of a Synapse homeserver, this should be set to the homeserver name
(`hs.hostname`).
Normally, this would be set automatically by the Prometheus server scraping the data but
since we support multiple instances of Synapse running in the same process and all
metrics are in a single global `REGISTRY`, we need to manually label any metrics.
"""
class _RegistryProxy: class _RegistryProxy:
@staticmethod @staticmethod

View File

@ -128,6 +128,7 @@ class BulkPushRuleEvaluator:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self.server_name = hs.hostname
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.clock = hs.get_clock() self.clock = hs.get_clock()
self._event_auth_handler = hs.get_event_auth_handler() self._event_auth_handler = hs.get_event_auth_handler()
@ -136,10 +137,11 @@ class BulkPushRuleEvaluator:
self._related_event_match_enabled = self.hs.config.experimental.msc3664_enabled self._related_event_match_enabled = self.hs.config.experimental.msc3664_enabled
self.room_push_rule_cache_metrics = register_cache( self.room_push_rule_cache_metrics = register_cache(
"cache", cache_type="cache",
"room_push_rule_cache", cache_name="room_push_rule_cache",
cache=[], # Meaningless size, as this isn't a cache that stores values, cache=[], # Meaningless size, as this isn't a cache that stores values,
resizable=False, resizable=False,
server_name=self.server_name,
) )
async def _get_rules_for_event( async def _get_rules_for_event(

View File

@ -121,9 +121,14 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
WAIT_FOR_STREAMS: ClassVar[bool] = True WAIT_FOR_STREAMS: ClassVar[bool] = True
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.server_name = hs.hostname
if self.CACHE: if self.CACHE:
self.response_cache: ResponseCache[str] = ResponseCache( self.response_cache: ResponseCache[str] = ResponseCache(
hs.get_clock(), "repl." + self.NAME, timeout_ms=30 * 60 * 1000 clock=hs.get_clock(),
name="repl." + self.NAME,
server_name=self.server_name,
timeout_ms=30 * 60 * 1000,
) )
# We reserve `instance_name` as a parameter to sending requests, so we # We reserve `instance_name` as a parameter to sending requests, so we

View File

@ -313,7 +313,7 @@ class LoginRestServlet(RestServlet):
should_issue_refresh_token=should_issue_refresh_token, should_issue_refresh_token=should_issue_refresh_token,
# The user represented by an appservice's configured sender_localpart # The user represented by an appservice's configured sender_localpart
# is not actually created in Synapse. # is not actually created in Synapse.
should_check_deactivated=qualified_user_id != appservice.sender, should_check_deactivated=qualified_user_id != appservice.sender.to_string(),
request_info=request_info, request_info=request_info,
) )

View File

@ -111,6 +111,7 @@ class SyncRestServlet(RestServlet):
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.server_name = hs.hostname
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.sync_handler = hs.get_sync_handler() self.sync_handler = hs.get_sync_handler()
@ -125,6 +126,7 @@ class SyncRestServlet(RestServlet):
self._json_filter_cache: LruCache[str, bool] = LruCache( self._json_filter_cache: LruCache[str, bool] = LruCache(
max_size=1000, max_size=1000,
cache_name="sync_valid_filter", cache_name="sync_valid_filter",
server_name=self.server_name,
) )
# Ratelimiter for presence updates, keyed by requester. # Ratelimiter for presence updates, keyed by requester.

View File

@ -35,6 +35,7 @@ SERVER_NOTICE_ROOM_TAG = "m.server_notice"
class ServerNoticesManager: class ServerNoticesManager:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.server_name = hs.hostname
self._store = hs.get_datastores().main self._store = hs.get_datastores().main
self._config = hs.config self._config = hs.config
self._account_data_handler = hs.get_account_data_handler() self._account_data_handler = hs.get_account_data_handler()
@ -44,7 +45,6 @@ class ServerNoticesManager:
self._message_handler = hs.get_message_handler() self._message_handler = hs.get_message_handler()
self._storage_controllers = hs.get_storage_controllers() self._storage_controllers = hs.get_storage_controllers()
self._is_mine_id = hs.is_mine_id self._is_mine_id = hs.is_mine_id
self._server_name = hs.hostname
self._notifier = hs.get_notifier() self._notifier = hs.get_notifier()
self.server_notices_mxid = self._config.servernotices.server_notices_mxid self.server_notices_mxid = self._config.servernotices.server_notices_mxid
@ -77,7 +77,7 @@ class ServerNoticesManager:
assert self.server_notices_mxid is not None assert self.server_notices_mxid is not None
requester = create_requester( requester = create_requester(
self.server_notices_mxid, authenticated_entity=self._server_name self.server_notices_mxid, authenticated_entity=self.server_name
) )
logger.info("Sending server notice to %s", user_id) logger.info("Sending server notice to %s", user_id)
@ -151,7 +151,7 @@ class ServerNoticesManager:
assert self._is_mine_id(user_id), "Cannot send server notices to remote users" assert self._is_mine_id(user_id), "Cannot send server notices to remote users"
requester = create_requester( requester = create_requester(
self.server_notices_mxid, authenticated_entity=self._server_name self.server_notices_mxid, authenticated_entity=self.server_name
) )
room_id = await self.maybe_get_notice_room_for_user(user_id) room_id = await self.maybe_get_notice_room_for_user(user_id)
@ -256,7 +256,7 @@ class ServerNoticesManager:
""" """
assert self.server_notices_mxid is not None assert self.server_notices_mxid is not None
requester = create_requester( requester = create_requester(
self.server_notices_mxid, authenticated_entity=self._server_name self.server_notices_mxid, authenticated_entity=self.server_name
) )
# Check whether the user has already joined or been invited to this room. If # Check whether the user has already joined or been invited to this room. If
@ -279,7 +279,7 @@ class ServerNoticesManager:
if self._config.servernotices.server_notices_auto_join: if self._config.servernotices.server_notices_auto_join:
user_requester = create_requester( user_requester = create_requester(
user_id, authenticated_entity=self._server_name user_id, authenticated_entity=self.server_name
) )
await self._room_member_handler.update_membership( await self._room_member_handler.update_membership(
requester=user_requester, requester=user_requester,

View File

@ -631,6 +631,7 @@ class StateResolutionHandler:
""" """
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.server_name = hs.hostname
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.resolve_linearizer = Linearizer(name="state_resolve_lock") self.resolve_linearizer = Linearizer(name="state_resolve_lock")
@ -639,6 +640,7 @@ class StateResolutionHandler:
self._state_cache: ExpiringCache[FrozenSet[int], _StateCacheEntry] = ( self._state_cache: ExpiringCache[FrozenSet[int], _StateCacheEntry] = (
ExpiringCache( ExpiringCache(
cache_name="state_cache", cache_name="state_cache",
server_name=self.server_name,
clock=self.clock, clock=self.clock,
max_len=100000, max_len=100000,
expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000, expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,

View File

@ -55,6 +55,7 @@ class SQLBaseStore(metaclass=ABCMeta):
hs: "HomeServer", hs: "HomeServer",
): ):
self.hs = hs self.hs = hs
self.server_name = hs.hostname
self._clock = hs.get_clock() self._clock = hs.get_clock()
self.database_engine = database.engine self.database_engine = database.engine
self.db_pool = database self.db_pool = database

View File

@ -68,6 +68,7 @@ class StateStorageController:
""" """
def __init__(self, hs: "HomeServer", stores: "Databases"): def __init__(self, hs: "HomeServer", stores: "Databases"):
self.server_name = hs.hostname
self._is_mine_id = hs.is_mine_id self._is_mine_id = hs.is_mine_id
self._clock = hs.get_clock() self._clock = hs.get_clock()
self.stores = stores self.stores = stores

View File

@ -89,7 +89,9 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
account_max = self.get_max_account_data_stream_id() account_max = self.get_max_account_data_stream_id()
self._account_data_stream_cache = StreamChangeCache( self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache", account_max name="AccountDataAndTagsChangeCache",
server_name=self.server_name,
current_stream_pos=account_max,
) )
self.db_pool.updates.register_background_index_update( self.db_pool.updates.register_background_index_update(

View File

@ -126,7 +126,7 @@ class ApplicationServiceWorkerStore(RoomMemberWorkerStore):
The application service or None. The application service or None.
""" """
for service in self.services_cache: for service in self.services_cache:
if service.sender == user_id: if service.sender.to_string() == user_id:
return service return service
return None return None

View File

@ -421,6 +421,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
hs: "HomeServer", hs: "HomeServer",
): ):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.server_name = hs.hostname
if hs.config.redis.redis_enabled: if hs.config.redis.redis_enabled:
# If we're using Redis, we can shift this update process off to # If we're using Redis, we can shift this update process off to
@ -434,7 +435,9 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
# (user_id, access_token, ip,) -> last_seen # (user_id, access_token, ip,) -> last_seen
self.client_ip_last_seen = LruCache[Tuple[str, str, str], int]( self.client_ip_last_seen = LruCache[Tuple[str, str, str], int](
cache_name="client_ip_last_seen", max_size=50000 cache_name="client_ip_last_seen",
server_name=self.server_name,
max_size=50000,
) )
if hs.config.worker.run_background_tasks and self.user_ips_max_age: if hs.config.worker.run_background_tasks and self.user_ips_max_age:

View File

@ -94,6 +94,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
Tuple[str, Optional[str]], int Tuple[str, Optional[str]], int
] = ExpiringCache( ] = ExpiringCache(
cache_name="last_device_delete_cache", cache_name="last_device_delete_cache",
server_name=self.server_name,
clock=self._clock, clock=self._clock,
max_len=10000, max_len=10000,
expiry_ms=30 * 60 * 1000, expiry_ms=30 * 60 * 1000,
@ -127,8 +128,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
limit=1000, limit=1000,
) )
self._device_inbox_stream_cache = StreamChangeCache( self._device_inbox_stream_cache = StreamChangeCache(
"DeviceInboxStreamChangeCache", name="DeviceInboxStreamChangeCache",
min_device_inbox_id, server_name=self.server_name,
current_stream_pos=min_device_inbox_id,
prefilled_cache=device_inbox_prefill, prefilled_cache=device_inbox_prefill,
) )
@ -143,8 +145,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
limit=1000, limit=1000,
) )
self._device_federation_outbox_stream_cache = StreamChangeCache( self._device_federation_outbox_stream_cache = StreamChangeCache(
"DeviceFederationOutboxStreamChangeCache", name="DeviceFederationOutboxStreamChangeCache",
min_device_outbox_id, server_name=self.server_name,
current_stream_pos=min_device_outbox_id,
prefilled_cache=device_outbox_prefill, prefilled_cache=device_outbox_prefill,
) )

View File

@ -128,8 +128,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
limit=10000, limit=10000,
) )
self._device_list_stream_cache = StreamChangeCache( self._device_list_stream_cache = StreamChangeCache(
"DeviceListStreamChangeCache", name="DeviceListStreamChangeCache",
min_device_list_id, server_name=self.server_name,
current_stream_pos=min_device_list_id,
prefilled_cache=device_list_prefill, prefilled_cache=device_list_prefill,
) )
@ -142,8 +143,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
limit=10000, limit=10000,
) )
self._device_list_room_stream_cache = StreamChangeCache( self._device_list_room_stream_cache = StreamChangeCache(
"DeviceListRoomStreamChangeCache", name="DeviceListRoomStreamChangeCache",
min_device_list_room_id, server_name=self.server_name,
current_stream_pos=min_device_list_room_id,
prefilled_cache=device_list_room_prefill, prefilled_cache=device_list_room_prefill,
) )
@ -159,8 +161,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
limit=1000, limit=1000,
) )
self._user_signature_stream_cache = StreamChangeCache( self._user_signature_stream_cache = StreamChangeCache(
"UserSignatureStreamChangeCache", name="UserSignatureStreamChangeCache",
user_signature_stream_list_id, server_name=self.server_name,
current_stream_pos=user_signature_stream_list_id,
prefilled_cache=user_signature_stream_prefill, prefilled_cache=user_signature_stream_prefill,
) )
@ -178,8 +181,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
limit=10000, limit=10000,
) )
self._device_list_federation_stream_cache = StreamChangeCache( self._device_list_federation_stream_cache = StreamChangeCache(
"DeviceListFederationStreamChangeCache", name="DeviceListFederationStreamChangeCache",
device_list_federation_list_id, server_name=self.server_name,
current_stream_pos=device_list_federation_list_id,
prefilled_cache=device_list_federation_prefill, prefilled_cache=device_list_federation_prefill,
) )
@ -1769,11 +1773,16 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
hs: "HomeServer", hs: "HomeServer",
): ):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.server_name = hs.hostname
# Map of (user_id, device_id) -> bool. If there is an entry that implies # Map of (user_id, device_id) -> bool. If there is an entry that implies
# the device exists. # the device exists.
self.device_id_exists_cache: LruCache[Tuple[str, str], Literal[True]] = ( self.device_id_exists_cache: LruCache[Tuple[str, str], Literal[True]] = (
LruCache(cache_name="device_id_exists", max_size=10000) LruCache(
cache_name="device_id_exists",
server_name=self.server_name,
max_size=10000,
)
) )
async def store_device( async def store_device(

View File

@ -145,7 +145,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# Cache of event ID to list of auth event IDs and their depths. # Cache of event ID to list of auth event IDs and their depths.
self._event_auth_cache: LruCache[str, List[Tuple[str, int]]] = LruCache( self._event_auth_cache: LruCache[str, List[Tuple[str, int]]] = LruCache(
500000, "_event_auth_cache", size_callback=len max_size=500000,
server_name=self.server_name,
cache_name="_event_auth_cache",
size_callback=len,
) )
# Flag used by unit tests to disable fallback when there is no chain cover # Flag used by unit tests to disable fallback when there is no chain cover

View File

@ -269,8 +269,9 @@ class EventsWorkerStore(SQLBaseStore):
limit=1000, limit=1000,
) )
self._curr_state_delta_stream_cache: StreamChangeCache = StreamChangeCache( self._curr_state_delta_stream_cache: StreamChangeCache = StreamChangeCache(
"_curr_state_delta_stream_cache", name="_curr_state_delta_stream_cache",
min_curr_state_delta_id, server_name=self.server_name,
current_stream_pos=min_curr_state_delta_id,
prefilled_cache=curr_state_delta_prefill, prefilled_cache=curr_state_delta_prefill,
) )
@ -283,6 +284,7 @@ class EventsWorkerStore(SQLBaseStore):
self._get_event_cache: AsyncLruCache[Tuple[str], EventCacheEntry] = ( self._get_event_cache: AsyncLruCache[Tuple[str], EventCacheEntry] = (
AsyncLruCache( AsyncLruCache(
server_name=self.server_name,
cache_name="*getEvent*", cache_name="*getEvent*",
max_size=hs.config.caches.event_cache_size, max_size=hs.config.caches.event_cache_size,
# `extra_index_cb` Returns a tuple as that is the key type # `extra_index_cb` Returns a tuple as that is the key type

View File

@ -108,8 +108,9 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
max_value=self._presence_id_gen.get_current_token(), max_value=self._presence_id_gen.get_current_token(),
) )
self.presence_stream_cache = StreamChangeCache( self.presence_stream_cache = StreamChangeCache(
"PresenceStreamChangeCache", name="PresenceStreamChangeCache",
min_presence_val, server_name=self.server_name,
current_stream_pos=min_presence_val,
prefilled_cache=presence_cache_prefill, prefilled_cache=presence_cache_prefill,
) )

View File

@ -163,8 +163,9 @@ class PushRulesWorkerStore(
) )
self.push_rules_stream_cache = StreamChangeCache( self.push_rules_stream_cache = StreamChangeCache(
"PushRulesStreamChangeCache", name="PushRulesStreamChangeCache",
push_rules_id, server_name=self.server_name,
current_stream_pos=push_rules_id,
prefilled_cache=push_rules_prefill, prefilled_cache=push_rules_prefill,
) )

View File

@ -158,8 +158,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
limit=10000, limit=10000,
) )
self._receipts_stream_cache = StreamChangeCache( self._receipts_stream_cache = StreamChangeCache(
"ReceiptsRoomChangeCache", name="ReceiptsRoomChangeCache",
min_receipts_stream_id, server_name=self.server_name,
current_stream_pos=min_receipts_stream_id,
prefilled_cache=receipts_stream_prefill, prefilled_cache=receipts_stream_prefill,
) )

View File

@ -617,12 +617,15 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
max_value=events_max, max_value=events_max,
) )
self._events_stream_cache = StreamChangeCache( self._events_stream_cache = StreamChangeCache(
"EventsRoomStreamChangeCache", name="EventsRoomStreamChangeCache",
min_event_val, server_name=self.server_name,
current_stream_pos=min_event_val,
prefilled_cache=event_cache_prefill, prefilled_cache=event_cache_prefill,
) )
self._membership_stream_cache = StreamChangeCache( self._membership_stream_cache = StreamChangeCache(
"MembershipStreamChangeCache", events_max name="MembershipStreamChangeCache",
server_name=self.server_name,
current_stream_pos=events_max,
) )
self._stream_order_on_start = self.get_room_max_stream_ordering() self._stream_order_on_start = self.get_room_max_stream_ordering()

View File

@ -92,6 +92,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
): ):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self._state_deletion_store = state_deletion_store self._state_deletion_store = state_deletion_store
self.server_name = hs.hostname
# Originally the state store used a single DictionaryCache to cache the # Originally the state store used a single DictionaryCache to cache the
# event IDs for the state types in a given state group to avoid hammering # event IDs for the state types in a given state group to avoid hammering
@ -123,14 +124,16 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
# vast majority of state in Matrix (today) is member events. # vast majority of state in Matrix (today) is member events.
self._state_group_cache: DictionaryCache[int, StateKey, str] = DictionaryCache( self._state_group_cache: DictionaryCache[int, StateKey, str] = DictionaryCache(
"*stateGroupCache*", name="*stateGroupCache*",
server_name=self.server_name,
# TODO: this hasn't been tuned yet # TODO: this hasn't been tuned yet
50000, max_entries=50000,
) )
self._state_group_members_cache: DictionaryCache[int, StateKey, str] = ( self._state_group_members_cache: DictionaryCache[int, StateKey, str] = (
DictionaryCache( DictionaryCache(
"*stateGroupMembersCache*", name="*stateGroupMembersCache*",
500000, server_name=self.server_name,
max_entries=500000,
) )
) )

View File

@ -31,6 +31,7 @@ from prometheus_client import REGISTRY
from prometheus_client.core import Gauge from prometheus_client.core import Gauge
from synapse.config.cache import add_resizable_cache from synapse.config.cache import add_resizable_cache
from synapse.metrics import INSTANCE_LABEL_NAME
from synapse.util.metrics import DynamicCollectorRegistry from synapse.util.metrics import DynamicCollectorRegistry
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -46,50 +47,65 @@ CACHE_METRIC_REGISTRY = DynamicCollectorRegistry()
caches_by_name: Dict[str, Sized] = {} caches_by_name: Dict[str, Sized] = {}
cache_size = Gauge( cache_size = Gauge(
"synapse_util_caches_cache_size", "", ["name"], registry=CACHE_METRIC_REGISTRY "synapse_util_caches_cache_size",
"",
labelnames=["name", INSTANCE_LABEL_NAME],
registry=CACHE_METRIC_REGISTRY,
) )
cache_hits = Gauge( cache_hits = Gauge(
"synapse_util_caches_cache_hits", "", ["name"], registry=CACHE_METRIC_REGISTRY "synapse_util_caches_cache_hits",
"",
labelnames=["name", INSTANCE_LABEL_NAME],
registry=CACHE_METRIC_REGISTRY,
) )
cache_evicted = Gauge( cache_evicted = Gauge(
"synapse_util_caches_cache_evicted_size", "synapse_util_caches_cache_evicted_size",
"", "",
["name", "reason"], labelnames=["name", "reason", INSTANCE_LABEL_NAME],
registry=CACHE_METRIC_REGISTRY, registry=CACHE_METRIC_REGISTRY,
) )
cache_total = Gauge( cache_total = Gauge(
"synapse_util_caches_cache", "", ["name"], registry=CACHE_METRIC_REGISTRY "synapse_util_caches_cache",
"",
labelnames=["name", INSTANCE_LABEL_NAME],
registry=CACHE_METRIC_REGISTRY,
) )
cache_max_size = Gauge( cache_max_size = Gauge(
"synapse_util_caches_cache_max_size", "", ["name"], registry=CACHE_METRIC_REGISTRY "synapse_util_caches_cache_max_size",
"",
labelnames=["name", INSTANCE_LABEL_NAME],
registry=CACHE_METRIC_REGISTRY,
) )
cache_memory_usage = Gauge( cache_memory_usage = Gauge(
"synapse_util_caches_cache_size_bytes", "synapse_util_caches_cache_size_bytes",
"Estimated memory usage of the caches", "Estimated memory usage of the caches",
["name"], labelnames=["name", INSTANCE_LABEL_NAME],
registry=CACHE_METRIC_REGISTRY, registry=CACHE_METRIC_REGISTRY,
) )
response_cache_size = Gauge( response_cache_size = Gauge(
"synapse_util_caches_response_cache_size", "synapse_util_caches_response_cache_size",
"", "",
["name"], labelnames=["name", INSTANCE_LABEL_NAME],
registry=CACHE_METRIC_REGISTRY, registry=CACHE_METRIC_REGISTRY,
) )
response_cache_hits = Gauge( response_cache_hits = Gauge(
"synapse_util_caches_response_cache_hits", "synapse_util_caches_response_cache_hits",
"", "",
["name"], labelnames=["name", INSTANCE_LABEL_NAME],
registry=CACHE_METRIC_REGISTRY, registry=CACHE_METRIC_REGISTRY,
) )
response_cache_evicted = Gauge( response_cache_evicted = Gauge(
"synapse_util_caches_response_cache_evicted_size", "synapse_util_caches_response_cache_evicted_size",
"", "",
["name", "reason"], labelnames=["name", "reason", INSTANCE_LABEL_NAME],
registry=CACHE_METRIC_REGISTRY, registry=CACHE_METRIC_REGISTRY,
) )
response_cache_total = Gauge( response_cache_total = Gauge(
"synapse_util_caches_response_cache", "", ["name"], registry=CACHE_METRIC_REGISTRY "synapse_util_caches_response_cache",
"",
labelnames=["name", INSTANCE_LABEL_NAME],
registry=CACHE_METRIC_REGISTRY,
) )
@ -103,12 +119,17 @@ class EvictionReason(Enum):
invalidation = auto() invalidation = auto()
@attr.s(slots=True, auto_attribs=True) @attr.s(slots=True, auto_attribs=True, kw_only=True)
class CacheMetric: class CacheMetric:
"""
Used to track cache metrics
"""
_cache: Sized _cache: Sized
_cache_type: str _cache_type: str
_cache_name: str _cache_name: str
_collect_callback: Optional[Callable] _collect_callback: Optional[Callable]
_server_name: str
hits: int = 0 hits: int = 0
misses: int = 0 misses: int = 0
@ -145,34 +166,34 @@ class CacheMetric:
def collect(self) -> None: def collect(self) -> None:
try: try:
labels_base = {
"name": self._cache_name,
INSTANCE_LABEL_NAME: self._server_name,
}
if self._cache_type == "response_cache": if self._cache_type == "response_cache":
response_cache_size.labels(self._cache_name).set(len(self._cache)) response_cache_size.labels(**labels_base).set(len(self._cache))
response_cache_hits.labels(self._cache_name).set(self.hits) response_cache_hits.labels(**labels_base).set(self.hits)
for reason in EvictionReason: for reason in EvictionReason:
response_cache_evicted.labels(self._cache_name, reason.name).set( response_cache_evicted.labels(
self.eviction_size_by_reason[reason] **{**labels_base, "reason": reason.name}
) ).set(self.eviction_size_by_reason[reason])
response_cache_total.labels(self._cache_name).set( response_cache_total.labels(**labels_base).set(self.hits + self.misses)
self.hits + self.misses
)
else: else:
cache_size.labels(self._cache_name).set(len(self._cache)) cache_size.labels(**labels_base).set(len(self._cache))
cache_hits.labels(self._cache_name).set(self.hits) cache_hits.labels(**labels_base).set(self.hits)
for reason in EvictionReason: for reason in EvictionReason:
cache_evicted.labels(self._cache_name, reason.name).set( cache_evicted.labels(**{**labels_base, "reason": reason.name}).set(
self.eviction_size_by_reason[reason] self.eviction_size_by_reason[reason]
) )
cache_total.labels(self._cache_name).set(self.hits + self.misses) cache_total.labels(**labels_base).set(self.hits + self.misses)
max_size = getattr(self._cache, "max_size", None) max_size = getattr(self._cache, "max_size", None)
if max_size: if max_size:
cache_max_size.labels(self._cache_name).set(max_size) cache_max_size.labels(**labels_base).set(max_size)
if TRACK_MEMORY_USAGE: if TRACK_MEMORY_USAGE:
# self.memory_usage can be None if nothing has been inserted # self.memory_usage can be None if nothing has been inserted
# into the cache yet. # into the cache yet.
cache_memory_usage.labels(self._cache_name).set( cache_memory_usage.labels(**labels_base).set(self.memory_usage or 0)
self.memory_usage or 0
)
if self._collect_callback: if self._collect_callback:
self._collect_callback() self._collect_callback()
except Exception as e: except Exception as e:
@ -181,9 +202,11 @@ class CacheMetric:
def register_cache( def register_cache(
*,
cache_type: str, cache_type: str,
cache_name: str, cache_name: str,
cache: Sized, cache: Sized,
server_name: str,
collect_callback: Optional[Callable] = None, collect_callback: Optional[Callable] = None,
resizable: bool = True, resizable: bool = True,
resize_callback: Optional[Callable] = None, resize_callback: Optional[Callable] = None,
@ -196,6 +219,8 @@ def register_cache(
cache_name: name of the cache cache_name: name of the cache
cache: cache itself, which must implement __len__(), and may optionally implement cache: cache itself, which must implement __len__(), and may optionally implement
a max_size property a max_size property
server_name: The homeserver name that this cache is associated with
(used to label the metric) (`hs.hostname`).
collect_callback: If given, a function which is called during metric collect_callback: If given, a function which is called during metric
collection to update additional metrics. collection to update additional metrics.
resizable: Whether this cache supports being resized, in which case either resizable: Whether this cache supports being resized, in which case either
@ -210,7 +235,13 @@ def register_cache(
resize_callback = cache.set_cache_factor # type: ignore resize_callback = cache.set_cache_factor # type: ignore
add_resizable_cache(cache_name, resize_callback) add_resizable_cache(cache_name, resize_callback)
metric = CacheMetric(cache, cache_type, cache_name, collect_callback) metric = CacheMetric(
cache=cache,
cache_type=cache_type,
cache_name=cache_name,
server_name=server_name,
collect_callback=collect_callback,
)
metric_name = "cache_%s_%s" % (cache_type, cache_name) metric_name = "cache_%s_%s" % (cache_type, cache_name)
caches_by_name[cache_name] = cache caches_by_name[cache_name] = cache
CACHE_METRIC_REGISTRY.register_hook(metric_name, metric.collect) CACHE_METRIC_REGISTRY.register_hook(metric_name, metric.collect)

View File

@ -79,7 +79,9 @@ class DeferredCache(Generic[KT, VT]):
def __init__( def __init__(
self, self,
*,
name: str, name: str,
server_name: str,
max_entries: int = 1000, max_entries: int = 1000,
tree: bool = False, tree: bool = False,
iterable: bool = False, iterable: bool = False,
@ -89,6 +91,8 @@ class DeferredCache(Generic[KT, VT]):
""" """
Args: Args:
name: The name of the cache name: The name of the cache
server_name: server_name: The homeserver name that this cache is associated with
(used to label the metric) (`hs.hostname`).
max_entries: Maximum amount of entries that the cache will hold max_entries: Maximum amount of entries that the cache will hold
tree: Use a TreeCache instead of a dict as the underlying cache type tree: Use a TreeCache instead of a dict as the underlying cache type
iterable: If True, count each item in the cached object as an entry, iterable: If True, count each item in the cached object as an entry,
@ -113,6 +117,7 @@ class DeferredCache(Generic[KT, VT]):
# a Deferred. # a Deferred.
self.cache: LruCache[KT, VT] = LruCache( self.cache: LruCache[KT, VT] = LruCache(
max_size=max_entries, max_size=max_entries,
server_name=server_name,
cache_name=name, cache_name=name,
cache_type=cache_type, cache_type=cache_type,
size_callback=( size_callback=(

View File

@ -33,6 +33,7 @@ from typing import (
List, List,
Mapping, Mapping,
Optional, Optional,
Protocol,
Sequence, Sequence,
Tuple, Tuple,
Type, Type,
@ -153,6 +154,14 @@ class _CacheDescriptorBase:
) )
class HasServerName(Protocol):
server_name: str
"""
The homeserver name that this cache is associated with (used to label the metric)
(`hs.hostname`).
"""
class DeferredCacheDescriptor(_CacheDescriptorBase): class DeferredCacheDescriptor(_CacheDescriptorBase):
"""A method decorator that applies a memoizing cache around the function. """A method decorator that applies a memoizing cache around the function.
@ -200,6 +209,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
def __init__( def __init__(
self, self,
*,
orig: Callable[..., Any], orig: Callable[..., Any],
max_entries: int = 1000, max_entries: int = 1000,
num_args: Optional[int] = None, num_args: Optional[int] = None,
@ -229,10 +239,20 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
self.prune_unread_entries = prune_unread_entries self.prune_unread_entries = prune_unread_entries
def __get__( def __get__(
self, obj: Optional[Any], owner: Optional[Type] self, obj: Optional[HasServerName], owner: Optional[Type]
) -> Callable[..., "defer.Deferred[Any]"]: ) -> Callable[..., "defer.Deferred[Any]"]:
# We need access to instance-level `obj.server_name` attribute
assert obj is not None, (
"Cannot call cached method from class (❌ `MyClass.cached_method()`) "
"and must be called from an instance (✅ `MyClass().cached_method()`). "
)
assert obj.server_name is not None, (
"The `server_name` attribute must be set on the object where `@cached` decorator is used."
)
cache: DeferredCache[CacheKey, Any] = DeferredCache( cache: DeferredCache[CacheKey, Any] = DeferredCache(
name=self.name, name=self.name,
server_name=obj.server_name,
max_entries=self.max_entries, max_entries=self.max_entries,
tree=self.tree, tree=self.tree,
iterable=self.iterable, iterable=self.iterable,
@ -490,7 +510,7 @@ class _CachedFunctionDescriptor:
def __call__(self, orig: F) -> CachedFunction[F]: def __call__(self, orig: F) -> CachedFunction[F]:
d = DeferredCacheDescriptor( d = DeferredCacheDescriptor(
orig, orig=orig,
max_entries=self.max_entries, max_entries=self.max_entries,
num_args=self.num_args, num_args=self.num_args,
uncached_args=self.uncached_args, uncached_args=self.uncached_args,

View File

@ -127,7 +127,15 @@ class DictionaryCache(Generic[KT, DKT, DV]):
for the '2' dict key. for the '2' dict key.
""" """
def __init__(self, name: str, max_entries: int = 1000): def __init__(self, *, name: str, server_name: str, max_entries: int = 1000):
"""
Args:
name
server_name: The homeserver name that this cache is associated with
(used to label the metric) (`hs.hostname`).
max_entries
"""
# We use a single LruCache to store two different types of entries: # We use a single LruCache to store two different types of entries:
# 1. Map from (key, dict_key) -> dict value (or sentinel, indicating # 1. Map from (key, dict_key) -> dict value (or sentinel, indicating
# the key doesn't exist in the dict); and # the key doesn't exist in the dict); and
@ -152,6 +160,7 @@ class DictionaryCache(Generic[KT, DKT, DV]):
Union[_PerKeyValue, Dict[DKT, DV]], Union[_PerKeyValue, Dict[DKT, DV]],
] = LruCache( ] = LruCache(
max_size=max_entries, max_size=max_entries,
server_name=server_name,
cache_name=name, cache_name=name,
cache_type=TreeCache, cache_type=TreeCache,
size_callback=len, size_callback=len,

View File

@ -46,7 +46,9 @@ VT = TypeVar("VT")
class ExpiringCache(Generic[KT, VT]): class ExpiringCache(Generic[KT, VT]):
def __init__( def __init__(
self, self,
*,
cache_name: str, cache_name: str,
server_name: str,
clock: Clock, clock: Clock,
max_len: int = 0, max_len: int = 0,
expiry_ms: int = 0, expiry_ms: int = 0,
@ -56,6 +58,8 @@ class ExpiringCache(Generic[KT, VT]):
""" """
Args: Args:
cache_name: Name of this cache, used for logging. cache_name: Name of this cache, used for logging.
server_name: The homeserver name that this cache is associated
with (used to label the metric) (`hs.hostname`).
clock clock
max_len: Max size of dict. If the dict grows larger than this max_len: Max size of dict. If the dict grows larger than this
then the oldest items get automatically evicted. Default is 0, then the oldest items get automatically evicted. Default is 0,
@ -83,7 +87,12 @@ class ExpiringCache(Generic[KT, VT]):
self.iterable = iterable self.iterable = iterable
self.metrics = register_cache("expiring", cache_name, self) self.metrics = register_cache(
cache_type="expiring",
cache_name=cache_name,
cache=self,
server_name=server_name,
)
if not self._expiry_ms: if not self._expiry_ms:
# Don't bother starting the loop if things never expire # Don't bother starting the loop if things never expire

View File

@ -376,9 +376,43 @@ class LruCache(Generic[KT, VT]):
If cache_type=TreeCache, all keys must be tuples. If cache_type=TreeCache, all keys must be tuples.
""" """
@overload
def __init__( def __init__(
self, self,
*,
max_size: int, max_size: int,
server_name: str,
cache_name: str,
cache_type: Type[Union[dict, TreeCache]] = dict,
size_callback: Optional[Callable[[VT], int]] = None,
metrics_collection_callback: Optional[Callable[[], None]] = None,
apply_cache_factor_from_config: bool = True,
clock: Optional[Clock] = None,
prune_unread_entries: bool = True,
extra_index_cb: Optional[Callable[[KT, VT], KT]] = None,
): ...
@overload
def __init__(
self,
*,
max_size: int,
server_name: Literal[None] = None,
cache_name: Literal[None] = None,
cache_type: Type[Union[dict, TreeCache]] = dict,
size_callback: Optional[Callable[[VT], int]] = None,
metrics_collection_callback: Optional[Callable[[], None]] = None,
apply_cache_factor_from_config: bool = True,
clock: Optional[Clock] = None,
prune_unread_entries: bool = True,
extra_index_cb: Optional[Callable[[KT, VT], KT]] = None,
): ...
def __init__(
self,
*,
max_size: int,
server_name: Optional[str] = None,
cache_name: Optional[str] = None, cache_name: Optional[str] = None,
cache_type: Type[Union[dict, TreeCache]] = dict, cache_type: Type[Union[dict, TreeCache]] = dict,
size_callback: Optional[Callable[[VT], int]] = None, size_callback: Optional[Callable[[VT], int]] = None,
@ -392,8 +426,13 @@ class LruCache(Generic[KT, VT]):
Args: Args:
max_size: The maximum amount of entries the cache can hold max_size: The maximum amount of entries the cache can hold
cache_name: The name of this cache, for the prometheus metrics. If unset, server_name: The homeserver name that this cache is associated with
no metrics will be reported on this cache. (used to label the metric) (`hs.hostname`). Must be set if `cache_name` is
set. If unset, no metrics will be reported on this cache.
cache_name: The name of this cache, for the prometheus metrics. Must be set
if `server_name` is set. If unset, no metrics will be reported on this
cache.
cache_type: cache_type:
type of underlying cache to be used. Typically one of dict type of underlying cache to be used. Typically one of dict
@ -457,11 +496,12 @@ class LruCache(Generic[KT, VT]):
# do yet when we get resized. # do yet when we get resized.
self._on_resize: Optional[Callable[[], None]] = None self._on_resize: Optional[Callable[[], None]] = None
if cache_name is not None: if cache_name is not None and server_name is not None:
metrics: Optional[CacheMetric] = register_cache( metrics: Optional[CacheMetric] = register_cache(
"lru_cache", cache_type="lru_cache",
cache_name, cache_name=cache_name,
self, cache=self,
server_name=server_name,
collect_callback=metrics_collection_callback, collect_callback=metrics_collection_callback,
) )
else: else:

View File

@ -103,18 +103,35 @@ class ResponseCache(Generic[KV]):
def __init__( def __init__(
self, self,
*,
clock: Clock, clock: Clock,
name: str, name: str,
server_name: str,
timeout_ms: float = 0, timeout_ms: float = 0,
enable_logging: bool = True, enable_logging: bool = True,
): ):
"""
Args:
clock
name
server_name: The homeserver name that this cache is associated
with (used to label the metric) (`hs.hostname`).
timeout_ms
enable_logging
"""
self._result_cache: Dict[KV, ResponseCacheEntry] = {} self._result_cache: Dict[KV, ResponseCacheEntry] = {}
self.clock = clock self.clock = clock
self.timeout_sec = timeout_ms / 1000.0 self.timeout_sec = timeout_ms / 1000.0
self._name = name self._name = name
self._metrics = register_cache("response_cache", name, self, resizable=False) self._metrics = register_cache(
cache_type="response_cache",
cache_name=name,
cache=self,
server_name=server_name,
resizable=False,
)
self._enable_logging = enable_logging self._enable_logging = enable_logging
def size(self) -> int: def size(self) -> int:

View File

@ -73,11 +73,23 @@ class StreamChangeCache:
def __init__( def __init__(
self, self,
*,
name: str, name: str,
server_name: str,
current_stream_pos: int, current_stream_pos: int,
max_size: int = 10000, max_size: int = 10000,
prefilled_cache: Optional[Mapping[EntityType, int]] = None, prefilled_cache: Optional[Mapping[EntityType, int]] = None,
) -> None: ) -> None:
"""
Args:
name
server_name: The homeserver name that this cache is associated with
(used to label the metric) (`hs.hostname`).
current_stream_pos
max_size
prefilled_cache
"""
self._original_max_size: int = max_size self._original_max_size: int = max_size
self._max_size = math.floor(max_size) self._max_size = math.floor(max_size)
@ -96,7 +108,11 @@ class StreamChangeCache:
self.name = name self.name = name
self.metrics = caches.register_cache( self.metrics = caches.register_cache(
"cache", self.name, self._cache, resize_callback=self.set_cache_factor cache_type="cache",
cache_name=self.name,
server_name=server_name,
cache=self._cache,
resize_callback=self.set_cache_factor,
) )
if prefilled_cache: if prefilled_cache:

View File

@ -40,7 +40,21 @@ VT = TypeVar("VT")
class TTLCache(Generic[KT, VT]): class TTLCache(Generic[KT, VT]):
"""A key/value cache implementation where each entry has its own TTL""" """A key/value cache implementation where each entry has its own TTL"""
def __init__(self, cache_name: str, timer: Callable[[], float] = time.time): def __init__(
self,
*,
cache_name: str,
server_name: str,
timer: Callable[[], float] = time.time,
):
"""
Args:
cache_name
server_name: The homeserver name that this cache is associated with
(used to label the metric) (`hs.hostname`).
timer: Function used to get the current time in seconds since the epoch.
"""
# map from key to _CacheEntry # map from key to _CacheEntry
self._data: Dict[KT, _CacheEntry[KT, VT]] = {} self._data: Dict[KT, _CacheEntry[KT, VT]] = {}
@ -49,7 +63,13 @@ class TTLCache(Generic[KT, VT]):
self._timer = timer self._timer = timer
self._metrics = register_cache("ttl", cache_name, self, resizable=False) self._metrics = register_cache(
cache_type="ttl",
cache_name=cache_name,
cache=self,
server_name=server_name,
resizable=False,
)
def set(self, key: KT, value: VT, ttl: float) -> None: def set(self, key: KT, value: VT, ttl: float) -> None:
"""Add/update an entry in the cache """Add/update an entry in the cache

View File

@ -29,7 +29,7 @@ async def main(reactor: ISynapseReactor, loops: int) -> float:
""" """
Benchmark `loops` number of insertions into LruCache without eviction. Benchmark `loops` number of insertions into LruCache without eviction.
""" """
cache: LruCache[int, bool] = LruCache(loops) cache: LruCache[int, bool] = LruCache(max_size=loops)
start = perf_counter() start = perf_counter()

View File

@ -30,7 +30,7 @@ async def main(reactor: ISynapseReactor, loops: int) -> float:
Benchmark `loops` number of insertions into LruCache where half of them are Benchmark `loops` number of insertions into LruCache where half of them are
evicted. evicted.
""" """
cache: LruCache[int, bool] = LruCache(loops // 2) cache: LruCache[int, bool] = LruCache(max_size=loops // 2)
start = perf_counter() start = perf_counter()

View File

@ -60,7 +60,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
# modify its config instead of the hs' # modify its config instead of the hs'
self.auth_blocking = AuthBlocking(hs) self.auth_blocking = AuthBlocking(hs)
self.test_user = "@foo:bar" self.test_user_id = UserID.from_string("@foo:bar")
self.test_token = b"_test_token_" self.test_token = b"_test_token_"
# this is overridden for the appservice tests # this is overridden for the appservice tests
@ -71,7 +71,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
def test_get_user_by_req_user_valid_token(self) -> None: def test_get_user_by_req_user_valid_token(self) -> None:
user_info = TokenLookupResult( user_info = TokenLookupResult(
user_id=self.test_user, token_id=5, device_id="device" user_id=self.test_user_id.to_string(), token_id=5, device_id="device"
) )
self.store.get_user_by_access_token = AsyncMock(return_value=user_info) self.store.get_user_by_access_token = AsyncMock(return_value=user_info)
self.store.mark_access_token_as_used = AsyncMock(return_value=None) self.store.mark_access_token_as_used = AsyncMock(return_value=None)
@ -81,7 +81,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
request.args[b"access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = self.get_success(self.auth.get_user_by_req(request)) requester = self.get_success(self.auth.get_user_by_req(request))
self.assertEqual(requester.user.to_string(), self.test_user) self.assertEqual(requester.user, self.test_user_id)
def test_get_user_by_req_user_bad_token(self) -> None: def test_get_user_by_req_user_bad_token(self) -> None:
self.store.get_user_by_access_token = AsyncMock(return_value=None) self.store.get_user_by_access_token = AsyncMock(return_value=None)
@ -96,7 +96,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
def test_get_user_by_req_user_missing_token(self) -> None: def test_get_user_by_req_user_missing_token(self) -> None:
user_info = TokenLookupResult(user_id=self.test_user, token_id=5) user_info = TokenLookupResult(user_id=self.test_user_id.to_string(), token_id=5)
self.store.get_user_by_access_token = AsyncMock(return_value=user_info) self.store.get_user_by_access_token = AsyncMock(return_value=user_info)
request = Mock(args={}) request = Mock(args={})
@ -109,7 +109,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
def test_get_user_by_req_appservice_valid_token(self) -> None: def test_get_user_by_req_appservice_valid_token(self) -> None:
app_service = Mock( app_service = Mock(
token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None token="foobar",
url="a_url",
sender=self.test_user_id,
ip_range_whitelist=None,
) )
self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = AsyncMock(return_value=None) self.store.get_user_by_access_token = AsyncMock(return_value=None)
@ -119,7 +122,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
request.args[b"access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = self.get_success(self.auth.get_user_by_req(request)) requester = self.get_success(self.auth.get_user_by_req(request))
self.assertEqual(requester.user.to_string(), self.test_user) self.assertEqual(requester.user, self.test_user_id)
def test_get_user_by_req_appservice_valid_token_good_ip(self) -> None: def test_get_user_by_req_appservice_valid_token_good_ip(self) -> None:
from netaddr import IPSet from netaddr import IPSet
@ -127,7 +130,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
app_service = Mock( app_service = Mock(
token="foobar", token="foobar",
url="a_url", url="a_url",
sender=self.test_user, sender=self.test_user_id.to_string(),
ip_range_whitelist=IPSet(["192.168.0.0/16"]), ip_range_whitelist=IPSet(["192.168.0.0/16"]),
) )
self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_app_service_by_token = Mock(return_value=app_service)
@ -138,7 +141,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
request.args[b"access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = self.get_success(self.auth.get_user_by_req(request)) requester = self.get_success(self.auth.get_user_by_req(request))
self.assertEqual(requester.user.to_string(), self.test_user) self.assertEqual(requester.user, self.test_user_id)
def test_get_user_by_req_appservice_valid_token_bad_ip(self) -> None: def test_get_user_by_req_appservice_valid_token_bad_ip(self) -> None:
from netaddr import IPSet from netaddr import IPSet
@ -146,7 +149,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
app_service = Mock( app_service = Mock(
token="foobar", token="foobar",
url="a_url", url="a_url",
sender=self.test_user, sender=self.test_user_id,
ip_range_whitelist=IPSet(["192.168.0.0/16"]), ip_range_whitelist=IPSet(["192.168.0.0/16"]),
) )
self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_app_service_by_token = Mock(return_value=app_service)
@ -176,7 +179,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
def test_get_user_by_req_appservice_missing_token(self) -> None: def test_get_user_by_req_appservice_missing_token(self) -> None:
app_service = Mock(token="foobar", url="a_url", sender=self.test_user) app_service = Mock(token="foobar", url="a_url", sender=self.test_user_id)
self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = AsyncMock(return_value=None) self.store.get_user_by_access_token = AsyncMock(return_value=None)
@ -191,7 +194,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
def test_get_user_by_req_appservice_valid_token_valid_user_id(self) -> None: def test_get_user_by_req_appservice_valid_token_valid_user_id(self) -> None:
masquerading_user_id = b"@doppelganger:matrix.org" masquerading_user_id = b"@doppelganger:matrix.org"
app_service = Mock( app_service = Mock(
token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None token="foobar",
url="a_url",
sender=self.test_user_id,
ip_range_whitelist=None,
) )
app_service.is_interested_in_user = Mock(return_value=True) app_service.is_interested_in_user = Mock(return_value=True)
self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_app_service_by_token = Mock(return_value=app_service)
@ -215,7 +221,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
def test_get_user_by_req_appservice_valid_token_bad_user_id(self) -> None: def test_get_user_by_req_appservice_valid_token_bad_user_id(self) -> None:
masquerading_user_id = b"@doppelganger:matrix.org" masquerading_user_id = b"@doppelganger:matrix.org"
app_service = Mock( app_service = Mock(
token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None token="foobar",
url="a_url",
sender=self.test_user_id,
ip_range_whitelist=None,
) )
app_service.is_interested_in_user = Mock(return_value=False) app_service.is_interested_in_user = Mock(return_value=False)
self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_app_service_by_token = Mock(return_value=app_service)
@ -238,7 +247,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
masquerading_user_id = b"@doppelganger:matrix.org" masquerading_user_id = b"@doppelganger:matrix.org"
masquerading_device_id = b"DOPPELDEVICE" masquerading_device_id = b"DOPPELDEVICE"
app_service = Mock( app_service = Mock(
token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None token="foobar",
url="a_url",
sender=self.test_user_id,
ip_range_whitelist=None,
) )
app_service.is_interested_in_user = Mock(return_value=True) app_service.is_interested_in_user = Mock(return_value=True)
self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_app_service_by_token = Mock(return_value=app_service)
@ -270,7 +282,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
masquerading_user_id = b"@doppelganger:matrix.org" masquerading_user_id = b"@doppelganger:matrix.org"
masquerading_device_id = b"NOT_A_REAL_DEVICE_ID" masquerading_device_id = b"NOT_A_REAL_DEVICE_ID"
app_service = Mock( app_service = Mock(
token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None token="foobar",
url="a_url",
sender=self.test_user_id,
ip_range_whitelist=None,
) )
app_service.is_interested_in_user = Mock(return_value=True) app_service.is_interested_in_user = Mock(return_value=True)
self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_app_service_by_token = Mock(return_value=app_service)
@ -436,7 +451,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
namespaces={ namespaces={
"users": [{"regex": "@_appservice.*:sender", "exclusive": True}] "users": [{"regex": "@_appservice.*:sender", "exclusive": True}]
}, },
sender="@appservice:sender", sender=UserID.from_string("@appservice:server"),
) )
requester = Requester( requester = Requester(
user=UserID.from_string("@appservice:server"), user=UserID.from_string("@appservice:server"),
@ -467,7 +482,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
namespaces={ namespaces={
"users": [{"regex": "@_appservice.*:sender", "exclusive": True}] "users": [{"regex": "@_appservice.*:sender", "exclusive": True}]
}, },
sender="@appservice:sender", sender=UserID.from_string("@appservice:server"),
) )
requester = Requester( requester = Requester(
user=UserID.from_string("@appservice:server"), user=UserID.from_string("@appservice:server"),

View File

@ -5,7 +5,7 @@ from synapse.appservice import ApplicationService
from synapse.config.ratelimiting import RatelimitSettings from synapse.config.ratelimiting import RatelimitSettings
from synapse.module_api import RatelimitOverride from synapse.module_api import RatelimitOverride
from synapse.module_api.callbacks.ratelimit_callbacks import RatelimitModuleApiCallbacks from synapse.module_api.callbacks.ratelimit_callbacks import RatelimitModuleApiCallbacks
from synapse.types import create_requester from synapse.types import UserID, create_requester
from tests import unittest from tests import unittest
@ -40,7 +40,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
token="fake_token", token="fake_token",
id="foo", id="foo",
rate_limited=True, rate_limited=True,
sender="@as:example.com", sender=UserID.from_string("@as:example.com"),
) )
as_requester = create_requester("@user:example.com", app_service=appservice) as_requester = create_requester("@user:example.com", app_service=appservice)
@ -76,7 +76,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
token="fake_token", token="fake_token",
id="foo", id="foo",
rate_limited=False, rate_limited=False,
sender="@as:example.com", sender=UserID.from_string("@as:example.com"),
) )
as_requester = create_requester("@user:example.com", app_service=appservice) as_requester = create_requester("@user:example.com", app_service=appservice)

View File

@ -25,7 +25,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import JsonDict from synapse.types import JsonDict, UserID
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
@ -41,7 +41,7 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase):
self.api = hs.get_application_service_api() self.api = hs.get_application_service_api()
self.service = ApplicationService( self.service = ApplicationService(
id="unique_identifier", id="unique_identifier",
sender="@as:test", sender=UserID.from_string("@as:test"),
url=URL, url=URL,
token="unused", token="unused",
hs_token=TOKEN, hs_token=TOKEN,

View File

@ -25,6 +25,7 @@ from unittest.mock import AsyncMock, Mock
from twisted.internet import defer from twisted.internet import defer
from synapse.appservice import ApplicationService, Namespace from synapse.appservice import ApplicationService, Namespace
from synapse.types import UserID
from tests import unittest from tests import unittest
@ -37,7 +38,7 @@ class ApplicationServiceTestCase(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.service = ApplicationService( self.service = ApplicationService(
id="unique_identifier", id="unique_identifier",
sender="@as:test", sender=UserID.from_string("@as:test"),
url="some_url", url="some_url",
token="some_token", token="some_token",
) )
@ -226,11 +227,11 @@ class ApplicationServiceTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_interested_in_self(self) -> Generator["defer.Deferred[Any]", object, None]: def test_interested_in_self(self) -> Generator["defer.Deferred[Any]", object, None]:
# make sure invites get through # make sure invites get through
self.service.sender = "@appservice:name" self.service.sender = UserID.from_string("@appservice:name")
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
self.event.type = "m.room.member" self.event.type = "m.room.member"
self.event.content = {"membership": "invite"} self.event.content = {"membership": "invite"}
self.event.state_key = self.service.sender self.event.state_key = self.service.sender.to_string()
self.assertTrue( self.assertTrue(
( (
yield self.service.is_interested_in_event( yield self.service.is_interested_in_event(

View File

@ -75,7 +75,7 @@ class CacheConfigTests(TestCase):
the default cache size in the interim, and then resized once the config the default cache size in the interim, and then resized once the config
is loaded. is loaded.
""" """
cache: LruCache = LruCache(100) cache: LruCache = LruCache(max_size=100)
add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
self.assertEqual(cache.max_size, 50) self.assertEqual(cache.max_size, 50)
@ -96,7 +96,7 @@ class CacheConfigTests(TestCase):
self.config.read_config(config, config_dir_path="", data_dir_path="") self.config.read_config(config, config_dir_path="", data_dir_path="")
self.config.resize_all_caches() self.config.resize_all_caches()
cache: LruCache = LruCache(100) cache: LruCache = LruCache(max_size=100)
add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
self.assertEqual(cache.max_size, 200) self.assertEqual(cache.max_size, 200)
@ -106,7 +106,7 @@ class CacheConfigTests(TestCase):
the default cache size in the interim, and then resized to the new the default cache size in the interim, and then resized to the new
default cache size once the config is loaded. default cache size once the config is loaded.
""" """
cache: LruCache = LruCache(100) cache: LruCache = LruCache(max_size=100)
add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
self.assertEqual(cache.max_size, 50) self.assertEqual(cache.max_size, 50)
@ -126,7 +126,7 @@ class CacheConfigTests(TestCase):
self.config.read_config(config, config_dir_path="", data_dir_path="") self.config.read_config(config, config_dir_path="", data_dir_path="")
self.config.resize_all_caches() self.config.resize_all_caches()
cache: LruCache = LruCache(100) cache: LruCache = LruCache(max_size=100)
add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
self.assertEqual(cache.max_size, 150) self.assertEqual(cache.max_size, 150)
@ -145,15 +145,15 @@ class CacheConfigTests(TestCase):
self.config.read_config(config, config_dir_path="", data_dir_path="") self.config.read_config(config, config_dir_path="", data_dir_path="")
self.config.resize_all_caches() self.config.resize_all_caches()
cache_a: LruCache = LruCache(100) cache_a: LruCache = LruCache(max_size=100)
add_resizable_cache("*cache_a*", cache_resize_callback=cache_a.set_cache_factor) add_resizable_cache("*cache_a*", cache_resize_callback=cache_a.set_cache_factor)
self.assertEqual(cache_a.max_size, 200) self.assertEqual(cache_a.max_size, 200)
cache_b: LruCache = LruCache(100) cache_b: LruCache = LruCache(max_size=100)
add_resizable_cache("*Cache_b*", cache_resize_callback=cache_b.set_cache_factor) add_resizable_cache("*Cache_b*", cache_resize_callback=cache_b.set_cache_factor)
self.assertEqual(cache_b.max_size, 300) self.assertEqual(cache_b.max_size, 300)
cache_c: LruCache = LruCache(100) cache_c: LruCache = LruCache(max_size=100)
add_resizable_cache("*cache_c*", cache_resize_callback=cache_c.set_cache_factor) add_resizable_cache("*cache_c*", cache_resize_callback=cache_c.set_cache_factor)
self.assertEqual(cache_c.max_size, 200) self.assertEqual(cache_c.max_size, 200)

View File

@ -43,6 +43,7 @@ from synapse.types import (
MultiWriterStreamToken, MultiWriterStreamToken,
RoomStreamToken, RoomStreamToken,
StreamKeyType, StreamKeyType,
UserID,
) )
from synapse.util import Clock from synapse.util import Clock
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
@ -1009,7 +1010,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
appservice = ApplicationService( appservice = ApplicationService(
token=random_string(10), token=random_string(10),
id=random_string(10), id=random_string(10),
sender="@as:example.com", sender=UserID.from_string("@as:example.com"),
rate_limited=False, rate_limited=False,
namespaces=namespaces, namespaces=namespaces,
supports_ephemeral=True, supports_ephemeral=True,
@ -1087,7 +1088,7 @@ class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase)
appservice = ApplicationService( appservice = ApplicationService(
token=random_string(10), token=random_string(10),
id=random_string(10), id=random_string(10),
sender="@as:example.com", sender=UserID.from_string("@as:example.com"),
rate_limited=False, rate_limited=False,
namespaces={ namespaces={
ApplicationService.NS_USERS: [ ApplicationService.NS_USERS: [
@ -1151,9 +1152,9 @@ class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase):
# Define an application service for the tests # Define an application service for the tests
self._service_token = "VERYSECRET" self._service_token = "VERYSECRET"
self._service = ApplicationService( self._service = ApplicationService(
self._service_token, token=self._service_token,
"as1", id="as1",
"@as.sender:test", sender=UserID.from_string("@as.sender:test"),
namespaces={ namespaces={
"users": [ "users": [
{"regex": "@_as_.*:test", "exclusive": True}, {"regex": "@_as_.*:test", "exclusive": True},

View File

@ -34,7 +34,7 @@ from synapse.rest import admin
from synapse.rest.client import devices, login, register from synapse.rest.client import devices, login, register
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.databases.main.appservice import _make_exclusive_regex from synapse.storage.databases.main.appservice import _make_exclusive_regex
from synapse.types import JsonDict, create_requester from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock from synapse.util import Clock
from synapse.util.task_scheduler import TaskScheduler from synapse.util.task_scheduler import TaskScheduler
@ -419,7 +419,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
id="1234", id="1234",
namespaces={"users": [{"regex": r"@boris:.+", "exclusive": True}]}, namespaces={"users": [{"regex": r"@boris:.+", "exclusive": True}]},
# Note: this user does not have to match the regex above # Note: this user does not have to match the regex above
sender="@as_main:test", sender=UserID.from_string("@as_main:test"),
) )
self.hs.get_datastores().main.services_cache = [appservice] self.hs.get_datastores().main.services_cache = [appservice]
self.hs.get_datastores().main.exclusive_user_regex = _make_exclusive_regex( self.hs.get_datastores().main.exclusive_user_regex = _make_exclusive_regex(

View File

@ -1457,7 +1457,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
id="1234", id="1234",
namespaces={"users": [{"regex": r"@boris:.+", "exclusive": True}]}, namespaces={"users": [{"regex": r"@boris:.+", "exclusive": True}]},
# Note: this user does not have to match the regex above # Note: this user does not have to match the regex above
sender="@as_main:test", sender=UserID.from_string("@as_main:test"),
) )
self.hs.get_datastores().main.services_cache = [appservice] self.hs.get_datastores().main.services_cache = [appservice]
self.hs.get_datastores().main.exclusive_user_regex = _make_exclusive_regex( self.hs.get_datastores().main.exclusive_user_regex = _make_exclusive_regex(
@ -1525,7 +1525,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
id="1234", id="1234",
namespaces={"users": [{"regex": r"@boris:.+", "exclusive": True}]}, namespaces={"users": [{"regex": r"@boris:.+", "exclusive": True}]},
# Note: this user does not have to match the regex above # Note: this user does not have to match the regex above
sender="@as_main:test", sender=UserID.from_string("@as_main:test"),
) )
self.hs.get_datastores().main.services_cache = [appservice] self.hs.get_datastores().main.services_cache = [appservice]
self.hs.get_datastores().main.exclusive_user_regex = _make_exclusive_regex( self.hs.get_datastores().main.exclusive_user_regex = _make_exclusive_regex(
@ -1751,7 +1751,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
id="1234", id="1234",
namespaces={"users": [{"regex": r"@boris:.+", "exclusive": True}]}, namespaces={"users": [{"regex": r"@boris:.+", "exclusive": True}]},
# Note: this user does not have to match the regex above # Note: this user does not have to match the regex above
sender="@as_main:test", sender=UserID.from_string("@as_main:test"),
) )
self.hs.get_datastores().main.services_cache = [appservice] self.hs.get_datastores().main.services_cache = [appservice]
self.hs.get_datastores().main.exclusive_user_regex = _make_exclusive_regex( self.hs.get_datastores().main.exclusive_user_regex = _make_exclusive_regex(

View File

@ -726,7 +726,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
token="i_am_an_app_service", token="i_am_an_app_service",
id="1234", id="1234",
namespaces={"users": [{"regex": r"@alice:.+", "exclusive": True}]}, namespaces={"users": [{"regex": r"@alice:.+", "exclusive": True}]},
sender="@as_main:test", sender=UserID.from_string("@as_main:test"),
) )
self.hs.get_datastores().main.services_cache = [appservice] self.hs.get_datastores().main.services_cache = [appservice]

View File

@ -91,6 +91,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
user_agent=b"SynapseInTrialTest/0.0.0", user_agent=b"SynapseInTrialTest/0.0.0",
ip_allowlist=None, ip_allowlist=None,
ip_blocklist=IPSet(), ip_blocklist=IPSet(),
server_name="test_server",
) )
# the tests assume that we are starting at unix time 1000 # the tests assume that we are starting at unix time 1000

View File

@ -31,7 +31,7 @@ from synapse.appservice import ApplicationService
from synapse.rest.client import login, register, room, user_directory from synapse.rest.client import login, register, room, user_directory
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.roommember import ProfileInfo from synapse.storage.roommember import ProfileInfo
from synapse.types import JsonDict, UserProfile, create_requester from synapse.types import JsonDict, UserID, UserProfile, create_requester
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
@ -78,7 +78,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
# Note: this user does not match the regex above, so that tests # Note: this user does not match the regex above, so that tests
# can distinguish the sender from the AS user. # can distinguish the sender from the AS user.
sender="@as_main:test", sender=UserID.from_string("@as_main:test"),
) )
mock_load_appservices = Mock(return_value=[self.appservice]) mock_load_appservices = Mock(return_value=[self.appservice])
@ -196,7 +196,9 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
user = self.register_user("user", "pass") user = self.register_user("user", "pass")
token = self.login(user, "pass") token = self.login(user, "pass")
room = self.helper.create_room_as(user, is_public=True, tok=token) room = self.helper.create_room_as(user, is_public=True, tok=token)
self.helper.join(room, self.appservice.sender, tok=self.appservice.token) self.helper.join(
room, self.appservice.sender.to_string(), tok=self.appservice.token
)
self._check_only_one_user_in_directory(user, room) self._check_only_one_user_in_directory(user, room)
def test_search_term_with_colon_in_it_does_not_raise(self) -> None: def test_search_term_with_colon_in_it_does_not_raise(self) -> None:
@ -433,7 +435,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
def test_handle_local_profile_change_with_appservice_sender(self) -> None: def test_handle_local_profile_change_with_appservice_sender(self) -> None:
# profile is not in directory # profile is not in directory
profile = self.get_success( profile = self.get_success(
self.store._get_user_in_directory(self.appservice.sender) self.store._get_user_in_directory(self.appservice.sender.to_string())
) )
self.assertIsNone(profile) self.assertIsNone(profile)
@ -441,13 +443,13 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
profile_info = ProfileInfo(avatar_url="avatar_url", display_name="4L1c3") profile_info = ProfileInfo(avatar_url="avatar_url", display_name="4L1c3")
self.get_success( self.get_success(
self.handler.handle_local_profile_change( self.handler.handle_local_profile_change(
self.appservice.sender, profile_info self.appservice.sender.to_string(), profile_info
) )
) )
# profile is still not in directory # profile is still not in directory
profile = self.get_success( profile = self.get_success(
self.store._get_user_in_directory(self.appservice.sender) self.store._get_user_in_directory(self.appservice.sender.to_string())
) )
self.assertIsNone(profile) self.assertIsNone(profile)

View File

@ -85,15 +85,20 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.tls_factory = FederationPolicyForHTTPS(config) self.tls_factory = FederationPolicyForHTTPS(config)
self.well_known_cache: TTLCache[bytes, Optional[bytes]] = TTLCache( self.well_known_cache: TTLCache[bytes, Optional[bytes]] = TTLCache(
"test_cache", timer=self.reactor.seconds cache_name="test_cache",
server_name="test_server",
timer=self.reactor.seconds,
) )
self.had_well_known_cache: TTLCache[bytes, bool] = TTLCache( self.had_well_known_cache: TTLCache[bytes, bool] = TTLCache(
"test_cache", timer=self.reactor.seconds cache_name="test_cache",
server_name="test_server",
timer=self.reactor.seconds,
) )
self.well_known_resolver = WellKnownResolver( self.well_known_resolver = WellKnownResolver(
self.reactor, self.reactor,
Agent(self.reactor, contextFactory=self.tls_factory), Agent(self.reactor, contextFactory=self.tls_factory),
b"test-agent", b"test-agent",
server_name="test_server",
well_known_cache=self.well_known_cache, well_known_cache=self.well_known_cache,
had_well_known_cache=self.had_well_known_cache, had_well_known_cache=self.had_well_known_cache,
) )
@ -274,6 +279,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
user_agent=b"test-agent", # Note that this is unused since _well_known_resolver is provided. user_agent=b"test-agent", # Note that this is unused since _well_known_resolver is provided.
ip_allowlist=IPSet(), ip_allowlist=IPSet(),
ip_blocklist=IPSet(), ip_blocklist=IPSet(),
server_name="test_server",
_srv_resolver=self.mock_resolver, _srv_resolver=self.mock_resolver,
_well_known_resolver=self.well_known_resolver, _well_known_resolver=self.well_known_resolver,
) )
@ -1016,11 +1022,13 @@ class MatrixFederationAgentTests(unittest.TestCase):
user_agent=b"test-agent", # This is unused since _well_known_resolver is passed below. user_agent=b"test-agent", # This is unused since _well_known_resolver is passed below.
ip_allowlist=IPSet(), ip_allowlist=IPSet(),
ip_blocklist=IPSet(), ip_blocklist=IPSet(),
server_name="test_server",
_srv_resolver=self.mock_resolver, _srv_resolver=self.mock_resolver,
_well_known_resolver=WellKnownResolver( _well_known_resolver=WellKnownResolver(
cast(ISynapseReactor, self.reactor), cast(ISynapseReactor, self.reactor),
Agent(self.reactor, contextFactory=tls_factory), Agent(self.reactor, contextFactory=tls_factory),
b"test-agent", b"test-agent",
server_name="test_server",
well_known_cache=self.well_known_cache, well_known_cache=self.well_known_cache,
had_well_known_cache=self.had_well_known_cache, had_well_known_cache=self.had_well_known_cache,
), ),

View File

@ -159,7 +159,9 @@ class CacheMetricsTests(unittest.HomeserverTestCase):
Caches produce metrics reflecting their state when scraped. Caches produce metrics reflecting their state when scraped.
""" """
CACHE_NAME = "cache_metrics_test_fgjkbdfg" CACHE_NAME = "cache_metrics_test_fgjkbdfg"
cache: DeferredCache[str, str] = DeferredCache(CACHE_NAME, max_entries=777) cache: DeferredCache[str, str] = DeferredCache(
name=CACHE_NAME, server_name=self.hs.hostname, max_entries=777
)
items = { items = {
x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii") x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii")

View File

@ -823,9 +823,9 @@ class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase):
# Define an application service so that we can register appservice users # Define an application service so that we can register appservice users
self._service_token = "some_token" self._service_token = "some_token"
self._service = ApplicationService( self._service = ApplicationService(
self._service_token, token=self._service_token,
"as1", id="as1",
"@as.sender:test", sender=UserID.from_string("@as.sender:test"),
namespaces={ namespaces={
"users": [ "users": [
{"regex": "@_as_.*:test", "exclusive": True}, {"regex": "@_as_.*:test", "exclusive": True},

View File

@ -139,7 +139,9 @@ class TypingStreamTestCase(BaseStreamTestCase):
self.hs.get_replication_command_handler()._streams["typing"].last_token = 0 self.hs.get_replication_command_handler()._streams["typing"].last_token = 0
typing._latest_room_serial = 0 typing._latest_room_serial = 0
typing._typing_stream_change_cache = StreamChangeCache( typing._typing_stream_change_cache = StreamChangeCache(
"TypingStreamChangeCache", typing._latest_room_serial name="TypingStreamChangeCache",
server_name=self.hs.hostname,
current_stream_pos=typing._latest_room_serial,
) )
typing._reset() typing._reset()

View File

@ -73,6 +73,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
user_agent=b"SynapseInTrialTest/0.0.0", user_agent=b"SynapseInTrialTest/0.0.0",
ip_allowlist=None, ip_allowlist=None,
ip_blocklist=IPSet(), ip_blocklist=IPSet(),
server_name="test_server",
) )
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:

View File

@ -35,6 +35,7 @@ KEY = "mykey"
class TestCache: class TestCache:
current_value = FIRST_VALUE current_value = FIRST_VALUE
server_name = "test_server"
@cached() @cached()
async def cached_function(self, user_id: str) -> str: async def cached_function(self, user_id: str) -> str:

View File

@ -764,10 +764,10 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
as_token = "i_am_an_app_service" as_token = "i_am_an_app_service"
appservice = ApplicationService( appservice = ApplicationService(
as_token, token=as_token,
id="1234", id="1234",
namespaces={"users": [{"regex": user_id, "exclusive": True}]}, namespaces={"users": [{"regex": user_id, "exclusive": True}]},
sender=user_id, sender=UserID.from_string(user_id),
) )
self.hs.get_datastores().main.services_cache.append(appservice) self.hs.get_datastores().main.services_cache.append(appservice)

View File

@ -472,7 +472,7 @@ class MSC4190AppserviceDevicesTestCase(unittest.HomeserverTestCase):
id="msc4190", id="msc4190",
token="some_token", token="some_token",
hs_token="some_token", hs_token="some_token",
sender="@as:example.com", sender=UserID.from_string("@as:example.com"),
namespaces={ namespaces={
ApplicationService.NS_USERS: [{"regex": "@.*", "exclusive": False}] ApplicationService.NS_USERS: [{"regex": "@.*", "exclusive": False}]
}, },
@ -483,7 +483,7 @@ class MSC4190AppserviceDevicesTestCase(unittest.HomeserverTestCase):
id="regular", id="regular",
token="other_token", token="other_token",
hs_token="other_token", hs_token="other_token",
sender="@as2:example.com", sender=UserID.from_string("@as2:example.com"),
namespaces={ namespaces={
ApplicationService.NS_USERS: [{"regex": "@.*", "exclusive": False}] ApplicationService.NS_USERS: [{"regex": "@.*", "exclusive": False}]
}, },

View File

@ -25,7 +25,7 @@ from synapse.appservice import ApplicationService
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import directory, login, room from synapse.rest.client import directory, login, room
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import RoomAlias from synapse.types import RoomAlias, UserID
from synapse.util import Clock from synapse.util import Clock
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
@ -140,7 +140,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
as_token, as_token,
id="1234", id="1234",
namespaces={"aliases": [{"regex": "#asns-*", "exclusive": True}]}, namespaces={"aliases": [{"regex": "#asns-*", "exclusive": True}]},
sender=user_id, sender=UserID.from_string(user_id),
) )
self.hs.get_datastores().main.services_cache.append(appservice) self.hs.get_datastores().main.services_cache.append(appservice)

View File

@ -51,7 +51,7 @@ from synapse.rest.client import account, devices, login, logout, profile, regist
from synapse.rest.client.account import WhoamiRestServlet from synapse.rest.client.account import WhoamiRestServlet
from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.rest.synapse.client import build_synapse_client_resource_tree
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import JsonDict, create_requester from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
@ -1484,7 +1484,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
self.service = ApplicationService( self.service = ApplicationService(
id="unique_identifier", id="unique_identifier",
token="some_token", token="some_token",
sender="@asbot:example.com", sender=UserID.from_string("@asbot:example.com"),
namespaces={ namespaces={
ApplicationService.NS_USERS: [ ApplicationService.NS_USERS: [
{"regex": r"@as_user.*", "exclusive": False} {"regex": r"@as_user.*", "exclusive": False}
@ -1496,7 +1496,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
self.another_service = ApplicationService( self.another_service = ApplicationService(
id="another__identifier", id="another__identifier",
token="another_token", token="another_token",
sender="@as2bot:example.com", sender=UserID.from_string("@as2bot:example.com"),
namespaces={ namespaces={
ApplicationService.NS_USERS: [ ApplicationService.NS_USERS: [
{"regex": r"@as2_user.*", "exclusive": False} {"regex": r"@as2_user.*", "exclusive": False}
@ -1530,7 +1530,10 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
params = { params = {
"type": login.LoginRestServlet.APPSERVICE_TYPE, "type": login.LoginRestServlet.APPSERVICE_TYPE,
"identifier": {"type": "m.id.user", "user": self.service.sender}, "identifier": {
"type": "m.id.user",
"user": self.service.sender.to_string(),
},
} }
channel = self.make_request( channel = self.make_request(
b"POST", LOGIN_URL, params, access_token=self.service.token b"POST", LOGIN_URL, params, access_token=self.service.token

View File

@ -39,7 +39,7 @@ from synapse.appservice import ApplicationService
from synapse.rest.client import account, account_validity, login, logout, register, sync from synapse.rest.client import account, account_validity, login, logout, register, sync
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage._base import db_to_json from synapse.storage._base import db_to_json
from synapse.types import JsonDict from synapse.types import JsonDict, UserID
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
@ -75,7 +75,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
as_token, as_token,
id="1234", id="1234",
namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
sender="@as:test", sender=UserID.from_string("@as:test"),
) )
self.hs.get_datastores().main.services_cache.append(appservice) self.hs.get_datastores().main.services_cache.append(appservice)
@ -99,7 +99,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
as_token, as_token,
id="1234", id="1234",
namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
sender="@as:test", sender=UserID.from_string("@as:test"),
) )
self.hs.get_datastores().main.services_cache.append(appservice) self.hs.get_datastores().main.services_cache.append(appservice)
@ -129,7 +129,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
as_token, as_token,
id="1234", id="1234",
namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
sender="@as:test", sender=UserID.from_string("@as:test"),
msc4190_device_management=True, msc4190_device_management=True,
) )

View File

@ -1426,7 +1426,7 @@ class RoomAppserviceTsParamTestCase(unittest.HomeserverTestCase):
id="1234", id="1234",
namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
# Note: this user does not have to match the regex above # Note: this user does not have to match the regex above
sender="@as_main:test", sender=UserID.from_string("@as_main:test"),
) )
mock_load_appservices = Mock(return_value=[self.appservice]) mock_load_appservices = Mock(return_value=[self.appservice])

View File

@ -38,6 +38,7 @@ from synapse.storage.databases.main.user_directory import (
_parse_words_with_regex, _parse_words_with_regex,
) )
from synapse.storage.roommember import ProfileInfo from synapse.storage.roommember import ProfileInfo
from synapse.types import UserID
from synapse.util import Clock from synapse.util import Clock
from tests.server import ThreadedMemoryReactorClock from tests.server import ThreadedMemoryReactorClock
@ -161,7 +162,7 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
token="i_am_an_app_service", token="i_am_an_app_service",
id="1234", id="1234",
namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
sender="@as:test", sender=UserID.from_string("@as:test"),
) )
mock_load_appservices = Mock(return_value=[self.appservice]) mock_load_appservices = Mock(return_value=[self.appservice])
@ -386,7 +387,7 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
# Join the AS sender to rooms owned by the normal user. # Join the AS sender to rooms owned by the normal user.
public, private = self._create_rooms_and_inject_memberships( public, private = self._create_rooms_and_inject_memberships(
user, token, self.appservice.sender user, token, self.appservice.sender.to_string()
) )
# Rebuild the directory. # Rebuild the directory.

View File

@ -29,7 +29,7 @@ from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.rest.client import register, sync from synapse.rest.client import register, sync
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import JsonDict from synapse.types import JsonDict, UserID
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
@ -118,7 +118,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
ApplicationService( ApplicationService(
token=as_token, token=as_token,
id="SomeASID", id="SomeASID",
sender="@as_sender:test", sender=UserID.from_string("@as_sender:test"),
namespaces={"users": [{"regex": "@as_*", "exclusive": True}]}, namespaces={"users": [{"regex": "@as_*", "exclusive": True}]},
) )
) )
@ -263,7 +263,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
ApplicationService( ApplicationService(
token=as_token_1, token=as_token_1,
id="SomeASID", id="SomeASID",
sender="@as_sender_1:test", sender=UserID.from_string("@as_sender_1:test"),
namespaces={"users": [{"regex": "@as_1.*", "exclusive": True}]}, namespaces={"users": [{"regex": "@as_1.*", "exclusive": True}]},
) )
) )
@ -273,7 +273,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
ApplicationService( ApplicationService(
token=as_token_2, token=as_token_2,
id="AnotherASID", id="AnotherASID",
sender="@as_sender_2:test", sender=UserID.from_string("@as_sender_2:test"),
namespaces={"users": [{"regex": "@as_2.*", "exclusive": True}]}, namespaces={"users": [{"regex": "@as_2.*", "exclusive": True}]},
) )
) )

View File

@ -31,18 +31,24 @@ from tests.unittest import TestCase
class DeferredCacheTestCase(TestCase): class DeferredCacheTestCase(TestCase):
def test_empty(self) -> None: def test_empty(self) -> None:
cache: DeferredCache[str, int] = DeferredCache("test") cache: DeferredCache[str, int] = DeferredCache(
name="test", server_name="test_server"
)
with self.assertRaises(KeyError): with self.assertRaises(KeyError):
cache.get("foo") cache.get("foo")
def test_hit(self) -> None: def test_hit(self) -> None:
cache: DeferredCache[str, int] = DeferredCache("test") cache: DeferredCache[str, int] = DeferredCache(
name="test", server_name="test_server"
)
cache.prefill("foo", 123) cache.prefill("foo", 123)
self.assertEqual(self.successResultOf(cache.get("foo")), 123) self.assertEqual(self.successResultOf(cache.get("foo")), 123)
def test_hit_deferred(self) -> None: def test_hit_deferred(self) -> None:
cache: DeferredCache[str, int] = DeferredCache("test") cache: DeferredCache[str, int] = DeferredCache(
name="test", server_name="test_server"
)
origin_d: "defer.Deferred[int]" = defer.Deferred() origin_d: "defer.Deferred[int]" = defer.Deferred()
set_d = cache.set("k1", origin_d) set_d = cache.set("k1", origin_d)
@ -65,7 +71,9 @@ class DeferredCacheTestCase(TestCase):
def test_callbacks(self) -> None: def test_callbacks(self) -> None:
"""Invalidation callbacks are called at the right time""" """Invalidation callbacks are called at the right time"""
cache: DeferredCache[str, int] = DeferredCache("test") cache: DeferredCache[str, int] = DeferredCache(
name="test", server_name="test_server"
)
callbacks = set() callbacks = set()
# start with an entry, with a callback # start with an entry, with a callback
@ -98,7 +106,9 @@ class DeferredCacheTestCase(TestCase):
self.assertEqual(callbacks, {"set", "get"}) self.assertEqual(callbacks, {"set", "get"})
def test_set_fail(self) -> None: def test_set_fail(self) -> None:
cache: DeferredCache[str, int] = DeferredCache("test") cache: DeferredCache[str, int] = DeferredCache(
name="test", server_name="test_server"
)
callbacks = set() callbacks = set()
# start with an entry, with a callback # start with an entry, with a callback
@ -135,7 +145,9 @@ class DeferredCacheTestCase(TestCase):
self.assertEqual(callbacks, {"prefill", "get2"}) self.assertEqual(callbacks, {"prefill", "get2"})
def test_get_immediate(self) -> None: def test_get_immediate(self) -> None:
cache: DeferredCache[str, int] = DeferredCache("test") cache: DeferredCache[str, int] = DeferredCache(
name="test", server_name="test_server"
)
d1: "defer.Deferred[int]" = defer.Deferred() d1: "defer.Deferred[int]" = defer.Deferred()
cache.set("key1", d1) cache.set("key1", d1)
@ -151,7 +163,9 @@ class DeferredCacheTestCase(TestCase):
self.assertEqual(v, 2) self.assertEqual(v, 2)
def test_invalidate(self) -> None: def test_invalidate(self) -> None:
cache: DeferredCache[Tuple[str], int] = DeferredCache("test") cache: DeferredCache[Tuple[str], int] = DeferredCache(
name="test", server_name="test_server"
)
cache.prefill(("foo",), 123) cache.prefill(("foo",), 123)
cache.invalidate(("foo",)) cache.invalidate(("foo",))
@ -159,7 +173,9 @@ class DeferredCacheTestCase(TestCase):
cache.get(("foo",)) cache.get(("foo",))
def test_invalidate_all(self) -> None: def test_invalidate_all(self) -> None:
cache: DeferredCache[str, str] = DeferredCache("testcache") cache: DeferredCache[str, str] = DeferredCache(
name="testcache", server_name="test_server"
)
callback_record = [False, False] callback_record = [False, False]
@ -203,7 +219,10 @@ class DeferredCacheTestCase(TestCase):
def test_eviction(self) -> None: def test_eviction(self) -> None:
cache: DeferredCache[int, str] = DeferredCache( cache: DeferredCache[int, str] = DeferredCache(
"test", max_entries=2, apply_cache_factor_from_config=False name="test",
server_name="test_server",
max_entries=2,
apply_cache_factor_from_config=False,
) )
cache.prefill(1, "one") cache.prefill(1, "one")
@ -218,7 +237,10 @@ class DeferredCacheTestCase(TestCase):
def test_eviction_lru(self) -> None: def test_eviction_lru(self) -> None:
cache: DeferredCache[int, str] = DeferredCache( cache: DeferredCache[int, str] = DeferredCache(
"test", max_entries=2, apply_cache_factor_from_config=False name="test",
server_name="test_server",
max_entries=2,
apply_cache_factor_from_config=False,
) )
cache.prefill(1, "one") cache.prefill(1, "one")
@ -237,7 +259,8 @@ class DeferredCacheTestCase(TestCase):
def test_eviction_iterable(self) -> None: def test_eviction_iterable(self) -> None:
cache: DeferredCache[int, List[str]] = DeferredCache( cache: DeferredCache[int, List[str]] = DeferredCache(
"test", name="test",
server_name="test_server",
max_entries=3, max_entries=3,
apply_cache_factor_from_config=False, apply_cache_factor_from_config=False,
iterable=True, iterable=True,

View File

@ -66,6 +66,7 @@ class DescriptorTestCase(unittest.TestCase):
class Cls: class Cls:
def __init__(self) -> None: def __init__(self) -> None:
self.mock = mock.Mock() self.mock = mock.Mock()
self.server_name = "test_server"
@descriptors.cached() @descriptors.cached()
def fn(self, arg1: int, arg2: int) -> str: def fn(self, arg1: int, arg2: int) -> str:
@ -100,6 +101,7 @@ class DescriptorTestCase(unittest.TestCase):
class Cls: class Cls:
def __init__(self) -> None: def __init__(self) -> None:
self.mock = mock.Mock() self.mock = mock.Mock()
self.server_name = "test_server"
@descriptors.cached(num_args=1) @descriptors.cached(num_args=1)
def fn(self, arg1: int, arg2: int) -> str: def fn(self, arg1: int, arg2: int) -> str:
@ -145,6 +147,7 @@ class DescriptorTestCase(unittest.TestCase):
def __init__(self) -> None: def __init__(self) -> None:
self.mock = mock.Mock() self.mock = mock.Mock()
self.server_name = "test_server"
obj = Cls() obj = Cls()
obj.mock.return_value = "fish" obj.mock.return_value = "fish"
@ -175,6 +178,7 @@ class DescriptorTestCase(unittest.TestCase):
class Cls: class Cls:
def __init__(self) -> None: def __init__(self) -> None:
self.mock = mock.Mock() self.mock = mock.Mock()
self.server_name = "test_server"
@descriptors.cached() @descriptors.cached()
def fn(self, arg1: int, kwarg1: int = 2) -> str: def fn(self, arg1: int, kwarg1: int = 2) -> str:
@ -209,6 +213,8 @@ class DescriptorTestCase(unittest.TestCase):
"""If the wrapped function throws synchronously, things should continue to work""" """If the wrapped function throws synchronously, things should continue to work"""
class Cls: class Cls:
server_name = "test_server"
@cached() @cached()
def fn(self, arg1: int) -> NoReturn: def fn(self, arg1: int) -> NoReturn:
raise SynapseError(100, "mai spoon iz too big!!1") raise SynapseError(100, "mai spoon iz too big!!1")
@ -232,6 +238,7 @@ class DescriptorTestCase(unittest.TestCase):
class Cls: class Cls:
result: Optional[Deferred] = None result: Optional[Deferred] = None
call_count = 0 call_count = 0
server_name = "test_server"
@cached() @cached()
def fn(self, arg1: int) -> Deferred: def fn(self, arg1: int) -> Deferred:
@ -285,6 +292,8 @@ class DescriptorTestCase(unittest.TestCase):
complete_lookup: Deferred = Deferred() complete_lookup: Deferred = Deferred()
class Cls: class Cls:
server_name = "test_server"
@descriptors.cached() @descriptors.cached()
def fn(self, arg1: int) -> "Deferred[int]": def fn(self, arg1: int) -> "Deferred[int]":
@defer.inlineCallbacks @defer.inlineCallbacks
@ -327,6 +336,8 @@ class DescriptorTestCase(unittest.TestCase):
the lookup function throws an exception""" the lookup function throws an exception"""
class Cls: class Cls:
server_name = "test_server"
@descriptors.cached() @descriptors.cached()
def fn(self, arg1: int) -> Deferred: def fn(self, arg1: int) -> Deferred:
@defer.inlineCallbacks @defer.inlineCallbacks
@ -369,6 +380,7 @@ class DescriptorTestCase(unittest.TestCase):
class Cls: class Cls:
def __init__(self) -> None: def __init__(self) -> None:
self.mock = mock.Mock() self.mock = mock.Mock()
self.server_name = "test_server"
@descriptors.cached() @descriptors.cached()
def fn(self, arg1: int, arg2: int = 2, arg3: int = 3) -> str: def fn(self, arg1: int, arg2: int = 2, arg3: int = 3) -> str:
@ -406,6 +418,7 @@ class DescriptorTestCase(unittest.TestCase):
class Cls: class Cls:
def __init__(self) -> None: def __init__(self) -> None:
self.mock = mock.Mock() self.mock = mock.Mock()
self.server_name = "test_server"
@descriptors.cached(iterable=True) @descriptors.cached(iterable=True)
def fn(self, arg1: int, arg2: int) -> Tuple[str, ...]: def fn(self, arg1: int, arg2: int) -> Tuple[str, ...]:
@ -439,6 +452,8 @@ class DescriptorTestCase(unittest.TestCase):
"""If the wrapped function throws synchronously, things should continue to work""" """If the wrapped function throws synchronously, things should continue to work"""
class Cls: class Cls:
server_name = "test_server"
@descriptors.cached(iterable=True) @descriptors.cached(iterable=True)
def fn(self, arg1: int) -> NoReturn: def fn(self, arg1: int) -> NoReturn:
raise SynapseError(100, "mai spoon iz too big!!1") raise SynapseError(100, "mai spoon iz too big!!1")
@ -460,6 +475,8 @@ class DescriptorTestCase(unittest.TestCase):
"""Invalidations should cascade up through cache contexts""" """Invalidations should cascade up through cache contexts"""
class Cls: class Cls:
server_name = "test_server"
@cached(cache_context=True) @cached(cache_context=True)
async def func1(self, key: str, cache_context: _CacheContext) -> int: async def func1(self, key: str, cache_context: _CacheContext) -> int:
return await self.func2(key, on_invalidate=cache_context.invalidate) return await self.func2(key, on_invalidate=cache_context.invalidate)
@ -486,6 +503,8 @@ class DescriptorTestCase(unittest.TestCase):
complete_lookup: "Deferred[None]" = Deferred() complete_lookup: "Deferred[None]" = Deferred()
class Cls: class Cls:
server_name = "test_server"
@cached() @cached()
async def fn(self, arg1: int) -> str: async def fn(self, arg1: int) -> str:
await complete_lookup await complete_lookup
@ -517,6 +536,7 @@ class DescriptorTestCase(unittest.TestCase):
class Cls: class Cls:
inner_context_was_finished = False inner_context_was_finished = False
server_name = "test_server"
@cached() @cached()
async def fn(self, arg1: int) -> str: async def fn(self, arg1: int) -> str:
@ -562,6 +582,8 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_passthrough(self) -> Generator["Deferred[Any]", object, None]: def test_passthrough(self) -> Generator["Deferred[Any]", object, None]:
class A: class A:
server_name = "test_server"
@cached() @cached()
def func(self, key: str) -> str: def func(self, key: str) -> str:
return key return key
@ -576,6 +598,8 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
callcount = [0] callcount = [0]
class A: class A:
server_name = "test_server"
@cached() @cached()
def func(self, key: str) -> str: def func(self, key: str) -> str:
callcount[0] += 1 callcount[0] += 1
@ -594,6 +618,8 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
callcount = [0] callcount = [0]
class A: class A:
server_name = "test_server"
@cached() @cached()
def func(self, key: str) -> str: def func(self, key: str) -> str:
callcount[0] += 1 callcount[0] += 1
@ -612,6 +638,8 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
def test_invalidate_missing(self) -> None: def test_invalidate_missing(self) -> None:
class A: class A:
server_name = "test_server"
@cached() @cached()
def func(self, key: str) -> str: def func(self, key: str) -> str:
return key return key
@ -623,6 +651,8 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
callcount = [0] callcount = [0]
class A: class A:
server_name = "test_server"
@cached(max_entries=10) @cached(max_entries=10)
def func(self, key: int) -> int: def func(self, key: int) -> int:
callcount[0] += 1 callcount[0] += 1
@ -650,6 +680,8 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
d = defer.succeed(123) d = defer.succeed(123)
class A: class A:
server_name = "test_server"
@cached() @cached()
def func(self, key: str) -> "Deferred[int]": def func(self, key: str) -> "Deferred[int]":
callcount[0] += 1 callcount[0] += 1
@ -668,6 +700,8 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
callcount2 = [0] callcount2 = [0]
class A: class A:
server_name = "test_server"
@cached() @cached()
def func(self, key: str) -> str: def func(self, key: str) -> str:
callcount[0] += 1 callcount[0] += 1
@ -701,6 +735,8 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
callcount2 = [0] callcount2 = [0]
class A: class A:
server_name = "test_server"
@cached(max_entries=2) @cached(max_entries=2)
def func(self, key: str) -> str: def func(self, key: str) -> str:
callcount[0] += 1 callcount[0] += 1
@ -738,6 +774,8 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
callcount2 = [0] callcount2 = [0]
class A: class A:
server_name = "test_server"
@cached() @cached()
def func(self, key: str) -> str: def func(self, key: str) -> str:
callcount[0] += 1 callcount[0] += 1
@ -785,6 +823,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
class Cls: class Cls:
def __init__(self) -> None: def __init__(self) -> None:
self.mock = mock.Mock() self.mock = mock.Mock()
self.server_name = "test_server"
@descriptors.cached() @descriptors.cached()
def fn(self, arg1: int, arg2: int) -> None: def fn(self, arg1: int, arg2: int) -> None:
@ -850,6 +889,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
class Cls: class Cls:
def __init__(self) -> None: def __init__(self) -> None:
self.mock = mock.Mock() self.mock = mock.Mock()
self.server_name = "test_server"
@descriptors.cached() @descriptors.cached()
def fn(self, arg1: int) -> None: def fn(self, arg1: int) -> None:
@ -893,6 +933,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
class Cls: class Cls:
def __init__(self) -> None: def __init__(self) -> None:
self.mock = mock.Mock() self.mock = mock.Mock()
self.server_name = "test_server"
@descriptors.cached() @descriptors.cached()
def fn(self, arg1: int, arg2: int) -> None: def fn(self, arg1: int, arg2: int) -> None:
@ -933,6 +974,8 @@ class CachedListDescriptorTestCase(unittest.TestCase):
complete_lookup: "Deferred[None]" = Deferred() complete_lookup: "Deferred[None]" = Deferred()
class Cls: class Cls:
server_name = "test_server"
@cached() @cached()
def fn(self, arg1: int) -> None: def fn(self, arg1: int) -> None:
pass pass
@ -967,6 +1010,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
class Cls: class Cls:
inner_context_was_finished = False inner_context_was_finished = False
server_name = "test_server"
@cached() @cached()
def fn(self, arg1: int) -> None: def fn(self, arg1: int) -> None:
@ -1010,6 +1054,8 @@ class CachedListDescriptorTestCase(unittest.TestCase):
""" """
class Cls: class Cls:
server_name = "test_server"
@descriptors.cached(tree=True) @descriptors.cached(tree=True)
def fn(self, room_id: str, event_id: str) -> None: def fn(self, room_id: str, event_id: str) -> None:
pass pass

View File

@ -46,7 +46,9 @@ class ResponseCacheTestCase(TestCase):
self.reactor, self.clock = get_clock() self.reactor, self.clock = get_clock()
def with_cache(self, name: str, ms: int = 0) -> ResponseCache: def with_cache(self, name: str, ms: int = 0) -> ResponseCache:
return ResponseCache(self.clock, name, timeout_ms=ms) return ResponseCache(
clock=self.clock, name=name, server_name="test_server", timeout_ms=ms
)
@staticmethod @staticmethod
async def instant_return(o: str) -> str: async def instant_return(o: str) -> str:

View File

@ -28,7 +28,9 @@ from tests import unittest
class CacheTestCase(unittest.TestCase): class CacheTestCase(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.mock_timer = Mock(side_effect=lambda: 100.0) self.mock_timer = Mock(side_effect=lambda: 100.0)
self.cache: TTLCache[str, str] = TTLCache("test_cache", self.mock_timer) self.cache: TTLCache[str, str] = TTLCache(
cache_name="test_cache", server_name="test_server", timer=self.mock_timer
)
def test_get(self) -> None: def test_get(self) -> None:
"""simple set/get tests""" """simple set/get tests"""

View File

@ -28,7 +28,7 @@ from tests import unittest
class DictCacheTestCase(unittest.TestCase): class DictCacheTestCase(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.cache: DictionaryCache[str, str, str] = DictionaryCache( self.cache: DictionaryCache[str, str, str] = DictionaryCache(
"foobar", max_entries=10 name="foobar", server_name="test_server", max_entries=10
) )
def test_simple_cache_hit_full(self) -> None: def test_simple_cache_hit_full(self) -> None:

View File

@ -33,7 +33,10 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase):
def test_get_set(self) -> None: def test_get_set(self) -> None:
clock = MockClock() clock = MockClock()
cache: ExpiringCache[str, str] = ExpiringCache( cache: ExpiringCache[str, str] = ExpiringCache(
"test", cast(Clock, clock), max_len=1 cache_name="test",
server_name="testserver",
clock=cast(Clock, clock),
max_len=1,
) )
cache["key"] = "value" cache["key"] = "value"
@ -43,7 +46,10 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase):
def test_eviction(self) -> None: def test_eviction(self) -> None:
clock = MockClock() clock = MockClock()
cache: ExpiringCache[str, str] = ExpiringCache( cache: ExpiringCache[str, str] = ExpiringCache(
"test", cast(Clock, clock), max_len=2 cache_name="test",
server_name="testserver",
clock=cast(Clock, clock),
max_len=2,
) )
cache["key"] = "value" cache["key"] = "value"
@ -59,7 +65,11 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase):
def test_iterable_eviction(self) -> None: def test_iterable_eviction(self) -> None:
clock = MockClock() clock = MockClock()
cache: ExpiringCache[str, List[int]] = ExpiringCache( cache: ExpiringCache[str, List[int]] = ExpiringCache(
"test", cast(Clock, clock), max_len=5, iterable=True cache_name="test",
server_name="testserver",
clock=cast(Clock, clock),
max_len=5,
iterable=True,
) )
cache["key"] = [1] cache["key"] = [1]
@ -79,7 +89,10 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase):
def test_time_eviction(self) -> None: def test_time_eviction(self) -> None:
clock = MockClock() clock = MockClock()
cache: ExpiringCache[str, int] = ExpiringCache( cache: ExpiringCache[str, int] = ExpiringCache(
"test", cast(Clock, clock), expiry_ms=1000 cache_name="test",
server_name="testserver",
clock=cast(Clock, clock),
expiry_ms=1000,
) )
cache["key"] = 1 cache["key"] = 1

View File

@ -34,13 +34,13 @@ from tests.unittest import override_config
class LruCacheTestCase(unittest.HomeserverTestCase): class LruCacheTestCase(unittest.HomeserverTestCase):
def test_get_set(self) -> None: def test_get_set(self) -> None:
cache: LruCache[str, str] = LruCache(1) cache: LruCache[str, str] = LruCache(max_size=1)
cache["key"] = "value" cache["key"] = "value"
self.assertEqual(cache.get("key"), "value") self.assertEqual(cache.get("key"), "value")
self.assertEqual(cache["key"], "value") self.assertEqual(cache["key"], "value")
def test_eviction(self) -> None: def test_eviction(self) -> None:
cache: LruCache[int, int] = LruCache(2) cache: LruCache[int, int] = LruCache(max_size=2)
cache[1] = 1 cache[1] = 1
cache[2] = 2 cache[2] = 2
@ -54,7 +54,7 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(cache.get(3), 3) self.assertEqual(cache.get(3), 3)
def test_setdefault(self) -> None: def test_setdefault(self) -> None:
cache: LruCache[str, int] = LruCache(1) cache: LruCache[str, int] = LruCache(max_size=1)
self.assertEqual(cache.setdefault("key", 1), 1) self.assertEqual(cache.setdefault("key", 1), 1)
self.assertEqual(cache.get("key"), 1) self.assertEqual(cache.get("key"), 1)
self.assertEqual(cache.setdefault("key", 2), 1) self.assertEqual(cache.setdefault("key", 2), 1)
@ -63,14 +63,16 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(cache.get("key"), 2) self.assertEqual(cache.get("key"), 2)
def test_pop(self) -> None: def test_pop(self) -> None:
cache: LruCache[str, int] = LruCache(1) cache: LruCache[str, int] = LruCache(max_size=1)
cache["key"] = 1 cache["key"] = 1
self.assertEqual(cache.pop("key"), 1) self.assertEqual(cache.pop("key"), 1)
self.assertEqual(cache.pop("key"), None) self.assertEqual(cache.pop("key"), None)
def test_del_multi(self) -> None: def test_del_multi(self) -> None:
# The type here isn't quite correct as they don't handle TreeCache well. # The type here isn't quite correct as they don't handle TreeCache well.
cache: LruCache[Tuple[str, str], str] = LruCache(4, cache_type=TreeCache) cache: LruCache[Tuple[str, str], str] = LruCache(
max_size=4, cache_type=TreeCache
)
cache[("animal", "cat")] = "mew" cache[("animal", "cat")] = "mew"
cache[("animal", "dog")] = "woof" cache[("animal", "dog")] = "woof"
cache[("vehicles", "car")] = "vroom" cache[("vehicles", "car")] = "vroom"
@ -89,21 +91,23 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
# Man from del_multi say "Yes". # Man from del_multi say "Yes".
def test_clear(self) -> None: def test_clear(self) -> None:
cache: LruCache[str, int] = LruCache(1) cache: LruCache[str, int] = LruCache(max_size=1)
cache["key"] = 1 cache["key"] = 1
cache.clear() cache.clear()
self.assertEqual(len(cache), 0) self.assertEqual(len(cache), 0)
@override_config({"caches": {"per_cache_factors": {"mycache": 10}}}) @override_config({"caches": {"per_cache_factors": {"mycache": 10}}})
def test_special_size(self) -> None: def test_special_size(self) -> None:
cache: LruCache = LruCache(10, "mycache") cache: LruCache = LruCache(
max_size=10, server_name="test_server", cache_name="mycache"
)
self.assertEqual(cache.max_size, 100) self.assertEqual(cache.max_size, 100)
class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
def test_get(self) -> None: def test_get(self) -> None:
m = Mock() m = Mock()
cache: LruCache[str, str] = LruCache(1) cache: LruCache[str, str] = LruCache(max_size=1)
cache.set("key", "value") cache.set("key", "value")
self.assertFalse(m.called) self.assertFalse(m.called)
@ -122,7 +126,7 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
def test_multi_get(self) -> None: def test_multi_get(self) -> None:
m = Mock() m = Mock()
cache: LruCache[str, str] = LruCache(1) cache: LruCache[str, str] = LruCache(max_size=1)
cache.set("key", "value") cache.set("key", "value")
self.assertFalse(m.called) self.assertFalse(m.called)
@ -141,7 +145,7 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
def test_set(self) -> None: def test_set(self) -> None:
m = Mock() m = Mock()
cache: LruCache[str, str] = LruCache(1) cache: LruCache[str, str] = LruCache(max_size=1)
cache.set("key", "value", callbacks=[m]) cache.set("key", "value", callbacks=[m])
self.assertFalse(m.called) self.assertFalse(m.called)
@ -157,7 +161,7 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
def test_pop(self) -> None: def test_pop(self) -> None:
m = Mock() m = Mock()
cache: LruCache[str, str] = LruCache(1) cache: LruCache[str, str] = LruCache(max_size=1)
cache.set("key", "value", callbacks=[m]) cache.set("key", "value", callbacks=[m])
self.assertFalse(m.called) self.assertFalse(m.called)
@ -177,7 +181,9 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
m3 = Mock() m3 = Mock()
m4 = Mock() m4 = Mock()
# The type here isn't quite correct as they don't handle TreeCache well. # The type here isn't quite correct as they don't handle TreeCache well.
cache: LruCache[Tuple[str, str], str] = LruCache(4, cache_type=TreeCache) cache: LruCache[Tuple[str, str], str] = LruCache(
max_size=4, cache_type=TreeCache
)
cache.set(("a", "1"), "value", callbacks=[m1]) cache.set(("a", "1"), "value", callbacks=[m1])
cache.set(("a", "2"), "value", callbacks=[m2]) cache.set(("a", "2"), "value", callbacks=[m2])
@ -199,7 +205,7 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
def test_clear(self) -> None: def test_clear(self) -> None:
m1 = Mock() m1 = Mock()
m2 = Mock() m2 = Mock()
cache: LruCache[str, str] = LruCache(5) cache: LruCache[str, str] = LruCache(max_size=5)
cache.set("key1", "value", callbacks=[m1]) cache.set("key1", "value", callbacks=[m1])
cache.set("key2", "value", callbacks=[m2]) cache.set("key2", "value", callbacks=[m2])
@ -216,7 +222,7 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
m1 = Mock(name="m1") m1 = Mock(name="m1")
m2 = Mock(name="m2") m2 = Mock(name="m2")
m3 = Mock(name="m3") m3 = Mock(name="m3")
cache: LruCache[str, str] = LruCache(2) cache: LruCache[str, str] = LruCache(max_size=2)
cache.set("key1", "value", callbacks=[m1]) cache.set("key1", "value", callbacks=[m1])
cache.set("key2", "value", callbacks=[m2]) cache.set("key2", "value", callbacks=[m2])
@ -252,7 +258,7 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
class LruCacheSizedTestCase(unittest.HomeserverTestCase): class LruCacheSizedTestCase(unittest.HomeserverTestCase):
def test_evict(self) -> None: def test_evict(self) -> None:
cache: LruCache[str, List[int]] = LruCache(5, size_callback=len) cache: LruCache[str, List[int]] = LruCache(max_size=5, size_callback=len)
cache["key1"] = [0] cache["key1"] = [0]
cache["key2"] = [1, 2] cache["key2"] = [1, 2]
cache["key3"] = [3] cache["key3"] = [3]
@ -275,7 +281,9 @@ class LruCacheSizedTestCase(unittest.HomeserverTestCase):
def test_zero_size_drop_from_cache(self) -> None: def test_zero_size_drop_from_cache(self) -> None:
"""Test that `drop_from_cache` works correctly with 0-sized entries.""" """Test that `drop_from_cache` works correctly with 0-sized entries."""
cache: LruCache[str, List[int]] = LruCache(5, size_callback=lambda x: 0) cache: LruCache[str, List[int]] = LruCache(
max_size=5, size_callback=lambda x: 0
)
cache["key1"] = [] cache["key1"] = []
self.assertEqual(len(cache), 0) self.assertEqual(len(cache), 0)
@ -299,7 +307,7 @@ class TimeEvictionTestCase(unittest.HomeserverTestCase):
def test_evict(self) -> None: def test_evict(self) -> None:
setup_expire_lru_cache_entries(self.hs) setup_expire_lru_cache_entries(self.hs)
cache: LruCache[str, int] = LruCache(5, clock=self.hs.get_clock()) cache: LruCache[str, int] = LruCache(max_size=5, clock=self.hs.get_clock())
# Check that we evict entries we haven't accessed for 30 minutes. # Check that we evict entries we haven't accessed for 30 minutes.
cache["key1"] = 1 cache["key1"] = 1
@ -351,7 +359,7 @@ class MemoryEvictionTestCase(unittest.HomeserverTestCase):
mock_jemalloc_class.get_stat.return_value = 924288000 mock_jemalloc_class.get_stat.return_value = 924288000
setup_expire_lru_cache_entries(self.hs) setup_expire_lru_cache_entries(self.hs)
cache: LruCache[str, int] = LruCache(4, clock=self.hs.get_clock()) cache: LruCache[str, int] = LruCache(max_size=4, clock=self.hs.get_clock())
cache["key1"] = 1 cache["key1"] = 1
cache["key2"] = 2 cache["key2"] = 2
@ -387,7 +395,9 @@ class MemoryEvictionTestCase(unittest.HomeserverTestCase):
class ExtraIndexLruCacheTestCase(unittest.HomeserverTestCase): class ExtraIndexLruCacheTestCase(unittest.HomeserverTestCase):
def test_invalidate_simple(self) -> None: def test_invalidate_simple(self) -> None:
cache: LruCache[str, int] = LruCache(10, extra_index_cb=lambda k, v: str(v)) cache: LruCache[str, int] = LruCache(
max_size=10, extra_index_cb=lambda k, v: str(v)
)
cache["key1"] = 1 cache["key1"] = 1
cache["key2"] = 2 cache["key2"] = 2
@ -400,7 +410,9 @@ class ExtraIndexLruCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(cache.get("key2"), 2) self.assertEqual(cache.get("key2"), 2)
def test_invalidate_multi(self) -> None: def test_invalidate_multi(self) -> None:
cache: LruCache[str, int] = LruCache(10, extra_index_cb=lambda k, v: str(v)) cache: LruCache[str, int] = LruCache(
max_size=10, extra_index_cb=lambda k, v: str(v)
)
cache["key1"] = 1 cache["key1"] = 1
cache["key2"] = 1 cache["key2"] = 1
cache["key3"] = 2 cache["key3"] = 2

View File

@ -15,7 +15,12 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
Providing a prefilled cache to StreamChangeCache will result in a cache Providing a prefilled cache to StreamChangeCache will result in a cache
with the prefilled-cache entered in. with the prefilled-cache entered in.
""" """
cache = StreamChangeCache("#test", 1, prefilled_cache={"user@foo.com": 2}) cache = StreamChangeCache(
name="#test",
server_name=self.hs.hostname,
current_stream_pos=1,
prefilled_cache={"user@foo.com": 2},
)
self.assertTrue(cache.has_entity_changed("user@foo.com", 1)) self.assertTrue(cache.has_entity_changed("user@foo.com", 1))
def test_has_entity_changed(self) -> None: def test_has_entity_changed(self) -> None:
@ -23,7 +28,9 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
StreamChangeCache.entity_has_changed will mark entities as changed, and StreamChangeCache.entity_has_changed will mark entities as changed, and
has_entity_changed will observe the changed entities. has_entity_changed will observe the changed entities.
""" """
cache = StreamChangeCache("#test", 3) cache = StreamChangeCache(
name="#test", server_name=self.hs.hostname, current_stream_pos=3
)
cache.entity_has_changed("user@foo.com", 6) cache.entity_has_changed("user@foo.com", 6)
cache.entity_has_changed("bar@baz.net", 7) cache.entity_has_changed("bar@baz.net", 7)
@ -61,7 +68,9 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
StreamChangeCache.entity_has_changed will respect the max size and StreamChangeCache.entity_has_changed will respect the max size and
purge the oldest items upon reaching that max size. purge the oldest items upon reaching that max size.
""" """
cache = StreamChangeCache("#test", 1, max_size=2) cache = StreamChangeCache(
name="#test", server_name=self.hs.hostname, current_stream_pos=1, max_size=2
)
cache.entity_has_changed("user@foo.com", 2) cache.entity_has_changed("user@foo.com", 2)
cache.entity_has_changed("bar@baz.net", 3) cache.entity_has_changed("bar@baz.net", 3)
@ -100,7 +109,9 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
entities since the given position. If the position is before the start entities since the given position. If the position is before the start
of the known stream, it returns None instead. of the known stream, it returns None instead.
""" """
cache = StreamChangeCache("#test", 1) cache = StreamChangeCache(
name="#test", server_name=self.hs.hostname, current_stream_pos=1
)
cache.entity_has_changed("user@foo.com", 2) cache.entity_has_changed("user@foo.com", 2)
cache.entity_has_changed("bar@baz.net", 3) cache.entity_has_changed("bar@baz.net", 3)
@ -148,7 +159,9 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
stream position is before it, it will return True, otherwise False if stream position is before it, it will return True, otherwise False if
the cache has no entries. the cache has no entries.
""" """
cache = StreamChangeCache("#test", 1) cache = StreamChangeCache(
name="#test", server_name=self.hs.hostname, current_stream_pos=1
)
# With no entities, it returns True for the past, present, and False for # With no entities, it returns True for the past, present, and False for
# the future. # the future.
@ -175,7 +188,9 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
stream position is earlier than the earliest known position, it will stream position is earlier than the earliest known position, it will
return all of the entities queried for. return all of the entities queried for.
""" """
cache = StreamChangeCache("#test", 1) cache = StreamChangeCache(
name="#test", server_name=self.hs.hostname, current_stream_pos=1
)
cache.entity_has_changed("user@foo.com", 2) cache.entity_has_changed("user@foo.com", 2)
cache.entity_has_changed("bar@baz.net", 3) cache.entity_has_changed("bar@baz.net", 3)
@ -242,7 +257,9 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
recent point where the entity could have changed. If the entity is not recent point where the entity could have changed. If the entity is not
known, the stream start is provided instead. known, the stream start is provided instead.
""" """
cache = StreamChangeCache("#test", 1) cache = StreamChangeCache(
name="#test", server_name=self.hs.hostname, current_stream_pos=1
)
cache.entity_has_changed("user@foo.com", 2) cache.entity_has_changed("user@foo.com", 2)
cache.entity_has_changed("bar@baz.net", 3) cache.entity_has_changed("bar@baz.net", 3)
@ -260,7 +277,12 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
""" """
`StreamChangeCache.all_entities_changed(...)` will mark all entites as changed. `StreamChangeCache.all_entities_changed(...)` will mark all entites as changed.
""" """
cache = StreamChangeCache("#test", 1, max_size=10) cache = StreamChangeCache(
name="#test",
server_name=self.hs.hostname,
current_stream_pos=1,
max_size=10,
)
cache.entity_has_changed("user@foo.com", 2) cache.entity_has_changed("user@foo.com", 2)
cache.entity_has_changed("bar@baz.net", 3) cache.entity_has_changed("bar@baz.net", 3)