Move device changes off the main process (#18581)

The main goal of this PR is to handle device list changes onto multiple
writers, off the main process, so that we can have logins happening
whilst Synapse is rolling-restarting.

This is quite an intrusive change, so I would advise to review this
commit by commit; I tried to keep the history as clean as possible.

There are a few things to consider:

- the `device_list_key` in stream tokens becomes a
`MultiWriterStreamToken`, which has a few implications in sync and on
the storage layer
- we had a split between `DeviceHandler` and `DeviceWorkerHandler` for
master vs. worker process. I've kept this split, but making it rather
writer vs. non-writer worker, using method overrides for doing
replication calls when needed
- there are a few operations that need to happen on a single worker at a
time. Instead of using cross-worker locks, for now I made them run on
the first writer on the list

---------

Co-authored-by: Eric Eastwood <erice@element.io>
This commit is contained in:
Quentin Gliech 2025-07-18 09:06:14 +02:00 committed by GitHub
parent 66504d1144
commit 5ea2cf2484
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
42 changed files with 1753 additions and 1422 deletions

View File

@ -0,0 +1 @@
Enable workers to write directly to the device lists stream and handle device list updates, reducing load on the main process.

View File

@ -54,7 +54,6 @@ if [[ -n "$SYNAPSE_COMPLEMENT_USE_WORKERS" ]]; then
export SYNAPSE_WORKER_TYPES="\
event_persister:2, \
background_worker, \
frontend_proxy, \
event_creator, \
user_dir, \
media_repository, \
@ -65,6 +64,7 @@ if [[ -n "$SYNAPSE_COMPLEMENT_USE_WORKERS" ]]; then
client_reader, \
appservice, \
pusher, \
device_lists:2, \
stream_writers=account_data+presence+receipts+to_device+typing"
fi

View File

@ -178,6 +178,8 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = {
"^/_matrix/client/(api/v1|r0|v3|unstable)/login$",
"^/_matrix/client/(api/v1|r0|v3|unstable)/account/3pid$",
"^/_matrix/client/(api/v1|r0|v3|unstable)/account/whoami$",
"^/_matrix/client/(api/v1|r0|v3|unstable)/devices(/|$)",
"^/_matrix/client/(r0|v3)/delete_devices$",
"^/_matrix/client/versions$",
"^/_matrix/client/(api/v1|r0|v3|unstable)/voip/turnServer$",
"^/_matrix/client/(r0|v3|unstable)/register$",
@ -194,6 +196,9 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = {
"^/_matrix/client/(api/v1|r0|v3|unstable)/directory/room/.*$",
"^/_matrix/client/(r0|v3|unstable)/capabilities$",
"^/_matrix/client/(r0|v3|unstable)/notifications$",
"^/_matrix/client/(api/v1|r0|v3|unstable)/keys/upload",
"^/_matrix/client/(api/v1|r0|v3|unstable)/keys/device_signing/upload$",
"^/_matrix/client/(api/v1|r0|v3|unstable)/keys/signatures/upload$",
],
"shared_extra_conf": {},
"worker_extra_conf": "",
@ -265,13 +270,6 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = {
"shared_extra_conf": {},
"worker_extra_conf": "",
},
"frontend_proxy": {
"app": "synapse.app.generic_worker",
"listener_resources": ["client", "replication"],
"endpoint_patterns": ["^/_matrix/client/(api/v1|r0|v3|unstable)/keys/upload"],
"shared_extra_conf": {},
"worker_extra_conf": "",
},
"account_data": {
"app": "synapse.app.generic_worker",
"listener_resources": ["client", "replication"],
@ -306,6 +304,13 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = {
"shared_extra_conf": {},
"worker_extra_conf": "",
},
"device_lists": {
"app": "synapse.app.generic_worker",
"listener_resources": ["client", "replication"],
"endpoint_patterns": [],
"shared_extra_conf": {},
"worker_extra_conf": "",
},
"typing": {
"app": "synapse.app.generic_worker",
"listener_resources": ["client", "replication"],
@ -412,16 +417,17 @@ def add_worker_roles_to_shared_config(
# streams
instance_map = shared_config.setdefault("instance_map", {})
# This is a list of the stream_writers that there can be only one of. Events can be
# sharded, and therefore doesn't belong here.
singular_stream_writers = [
# This is a list of the stream_writers.
stream_writers = {
"account_data",
"events",
"device_lists",
"presence",
"receipts",
"to_device",
"typing",
"push_rules",
]
}
# Worker-type specific sharding config. Now a single worker can fulfill multiple
# roles, check each.
@ -431,28 +437,11 @@ def add_worker_roles_to_shared_config(
if "federation_sender" in worker_types_set:
shared_config.setdefault("federation_sender_instances", []).append(worker_name)
if "event_persister" in worker_types_set:
# Event persisters write to the events stream, so we need to update
# the list of event stream writers
shared_config.setdefault("stream_writers", {}).setdefault("events", []).append(
worker_name
)
# Map of stream writer instance names to host/ports combos
if os.environ.get("SYNAPSE_USE_UNIX_SOCKET", False):
instance_map[worker_name] = {
"path": f"/run/worker.{worker_port}",
}
else:
instance_map[worker_name] = {
"host": "localhost",
"port": worker_port,
}
# Update the list of stream writers. It's convenient that the name of the worker
# type is the same as the stream to write. Iterate over the whole list in case there
# is more than one.
for worker in worker_types_set:
if worker in singular_stream_writers:
if worker in stream_writers:
shared_config.setdefault("stream_writers", {}).setdefault(
worker, []
).append(worker_name)
@ -876,6 +865,13 @@ def generate_worker_files(
else:
healthcheck_urls.append("http://localhost:%d/health" % (worker_port,))
# Special case for event_persister: those are just workers that write to
# the `events` stream. For other workers, the worker name is the same
# name of the stream they write to, but for some reason it is not the
# case for event_persister.
if "event_persister" in worker_types_set:
worker_types_set.add("events")
# Update the shared config with sharding-related options if necessary
add_worker_roles_to_shared_config(
shared_config, worker_types_set, worker_name, worker_port

View File

@ -4341,6 +4341,8 @@ This setting has the following sub-options:
* `push_rules` (string): Name of a worker assigned to the `push_rules` stream.
* `device_lists` (string): Name of a worker assigned to the `device_lists` stream.
Example configuration:
```yaml
stream_writers:

View File

@ -238,7 +238,8 @@ information.
^/_matrix/client/unstable/im.nheko.summary/summary/.*$
^/_matrix/client/(r0|v3|unstable)/account/3pid$
^/_matrix/client/(r0|v3|unstable)/account/whoami$
^/_matrix/client/(r0|v3|unstable)/devices$
^/_matrix/client/(r0|v3)/delete_devices$
^/_matrix/client/(api/v1|r0|v3|unstable)/devices(/|$)
^/_matrix/client/versions$
^/_matrix/client/(api/v1|r0|v3|unstable)/voip/turnServer$
^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/event/
@ -257,7 +258,9 @@ information.
^/_matrix/client/(r0|v3|unstable)/keys/changes$
^/_matrix/client/(r0|v3|unstable)/keys/claim$
^/_matrix/client/(r0|v3|unstable)/room_keys/
^/_matrix/client/(r0|v3|unstable)/keys/upload$
^/_matrix/client/(r0|v3|unstable)/keys/upload
^/_matrix/client/(api/v1|r0|v3|unstable/keys/device_signing/upload$
^/_matrix/client/(api/v1|r0|v3|unstable)/keys/signatures/upload$
# Registration/login requests
^/_matrix/client/(api/v1|r0|v3|unstable)/login$
@ -282,7 +285,6 @@ Additionally, the following REST endpoints can be handled for GET requests:
^/_matrix/client/(api/v1|r0|v3|unstable)/pushrules/
^/_matrix/client/unstable/org.matrix.msc4140/delayed_events
^/_matrix/client/(api/v1|r0|v3|unstable)/devices/
# Account data requests
^/_matrix/client/(r0|v3|unstable)/.*/tags
@ -329,7 +331,6 @@ set to `true`), the following endpoints can be handled by the worker:
^/_synapse/admin/v2/users/[^/]+$
^/_synapse/admin/v1/username_available$
^/_synapse/admin/v1/users/[^/]+/_allow_cross_signing_replacement_without_uia$
# Only the GET method:
^/_synapse/admin/v1/users/[^/]+/devices$
Note that a [HTTP listener](usage/configuration/config_documentation.md#listeners)
@ -550,6 +551,18 @@ the stream writer for the `push_rules` stream:
^/_matrix/client/(api/v1|r0|v3|unstable)/pushrules/
##### The `device_lists` stream
The `device_lists` stream supports multiple writers. The following endpoints
can be handled by any worker, but should be routed directly one of the workers
configured as stream writer for the `device_lists` stream:
^/_matrix/client/(r0|v3)/delete_devices$
^/_matrix/client/(api/v1|r0|v3|unstable)/devices/
^/_matrix/client/(r0|v3|unstable)/keys/upload
^/_matrix/client/(api/v1|r0|v3|unstable/keys/device_signing/upload$
^/_matrix/client/(api/v1|r0|v3|unstable)/keys/signatures/upload$
#### Restrict outbound federation traffic to a specific set of workers
The

View File

@ -5383,6 +5383,9 @@ properties:
push_rules:
type: string
description: Name of a worker assigned to the `push_rules` stream.
device_lists:
type: string
description: Name of a worker assigned to the `device_lists` stream.
default: {}
examples:
- events: worker1

View File

@ -134,40 +134,44 @@ class WriterLocations:
can only be a single instance.
account_data: The instances that write to the account data streams. Currently
can only be a single instance.
receipts: The instances that write to the receipts stream. Currently
can only be a single instance.
receipts: The instances that write to the receipts stream.
presence: The instances that write to the presence stream. Currently
can only be a single instance.
push_rules: The instances that write to the push stream. Currently
can only be a single instance.
device_lists: The instances that write to the device list stream.
"""
events: List[str] = attr.ib(
default=["master"],
default=[MAIN_PROCESS_INSTANCE_NAME],
converter=_instance_to_list_converter,
)
typing: List[str] = attr.ib(
default=["master"],
default=[MAIN_PROCESS_INSTANCE_NAME],
converter=_instance_to_list_converter,
)
to_device: List[str] = attr.ib(
default=["master"],
default=[MAIN_PROCESS_INSTANCE_NAME],
converter=_instance_to_list_converter,
)
account_data: List[str] = attr.ib(
default=["master"],
default=[MAIN_PROCESS_INSTANCE_NAME],
converter=_instance_to_list_converter,
)
receipts: List[str] = attr.ib(
default=["master"],
default=[MAIN_PROCESS_INSTANCE_NAME],
converter=_instance_to_list_converter,
)
presence: List[str] = attr.ib(
default=["master"],
default=[MAIN_PROCESS_INSTANCE_NAME],
converter=_instance_to_list_converter,
)
push_rules: List[str] = attr.ib(
default=["master"],
default=[MAIN_PROCESS_INSTANCE_NAME],
converter=_instance_to_list_converter,
)
device_lists: List[str] = attr.ib(
default=[MAIN_PROCESS_INSTANCE_NAME],
converter=_instance_to_list_converter,
)
@ -358,7 +362,10 @@ class WorkerConfig(Config):
):
instances = _instance_to_list_converter(getattr(self.writers, stream))
for instance in instances:
if instance != "master" and instance not in self.instance_map:
if (
instance != MAIN_PROCESS_INSTANCE_NAME
and instance not in self.instance_map
):
raise ConfigError(
"Instance %r is configured to write %s but does not appear in `instance_map` config."
% (instance, stream)
@ -397,6 +404,11 @@ class WorkerConfig(Config):
"Must only specify one instance to handle `push` messages."
)
if len(self.writers.device_lists) == 0:
raise ConfigError(
"Must specify at least one instance to handle `device_lists` messages."
)
self.events_shard_config = RoutableShardedWorkerHandlingConfig(
self.writers.events
)
@ -419,9 +431,12 @@ class WorkerConfig(Config):
#
# No effort is made to ensure only a single instance of these tasks is
# running.
background_tasks_instance = config.get("run_background_tasks_on") or "master"
background_tasks_instance = (
config.get("run_background_tasks_on") or MAIN_PROCESS_INSTANCE_NAME
)
self.run_background_tasks = (
self.worker_name is None and background_tasks_instance == "master"
self.worker_name is None
and background_tasks_instance == MAIN_PROCESS_INSTANCE_NAME
) or self.worker_name == background_tasks_instance
self.should_notify_appservices = self._should_this_worker_perform_duty(
@ -493,9 +508,10 @@ class WorkerConfig(Config):
# 'don't run here'.
new_option_should_run_here = None
if new_option_name in config:
designated_worker = config[new_option_name] or "master"
designated_worker = config[new_option_name] or MAIN_PROCESS_INSTANCE_NAME
new_option_should_run_here = (
designated_worker == "master" and self.worker_name is None
designated_worker == MAIN_PROCESS_INSTANCE_NAME
and self.worker_name is None
) or designated_worker == self.worker_name
legacy_option_should_run_here = None
@ -592,7 +608,7 @@ class WorkerConfig(Config):
# If no worker instances are set we check if the legacy option
# is set, which means use the main process.
if legacy_option:
worker_instances = ["master"]
worker_instances = [MAIN_PROCESS_INSTANCE_NAME]
if self.worker_app == legacy_app_name:
if legacy_option:

View File

@ -638,7 +638,8 @@ class ApplicationServicesHandler:
# Fetch the users who have modified their device list since then.
users_with_changed_device_lists = await self.store.get_all_devices_changed(
from_key, to_key=new_key
MultiWriterStreamToken(stream=from_key),
to_key=MultiWriterStreamToken(stream=new_key),
)
# Filter out any users the application service is not interested in

View File

@ -24,7 +24,6 @@ from typing import TYPE_CHECKING, Optional
from synapse.api.constants import Membership
from synapse.api.errors import SynapseError
from synapse.handlers.device import DeviceHandler
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import Codes, Requester, UserID, create_requester
@ -84,10 +83,6 @@ class DeactivateAccountHandler:
Returns:
True if identity server supports removing threepids, otherwise False.
"""
# This can only be called on the main process.
assert isinstance(self._device_handler, DeviceHandler)
# Check if this user can be deactivated
if not await self._third_party_rules.check_can_deactivate_user(
user_id, by_admin

File diff suppressed because it is too large Load Diff

View File

@ -33,9 +33,6 @@ from synapse.logging.opentracing import (
log_kv,
set_tag,
)
from synapse.replication.http.devices import (
ReplicationMultiUserDevicesResyncRestServlet,
)
from synapse.types import JsonDict, Requester, StreamKeyType, UserID, get_domain_from_id
from synapse.util import json_encoder
from synapse.util.stringutils import random_string
@ -56,9 +53,9 @@ class DeviceMessageHandler:
self.store = hs.get_datastores().main
self.notifier = hs.get_notifier()
self.is_mine = hs.is_mine
self.device_handler = hs.get_device_handler()
if hs.config.experimental.msc3814_enabled:
self.event_sources = hs.get_event_sources()
self.device_handler = hs.get_device_handler()
# We only need to poke the federation sender explicitly if its on the
# same instance. Other federation sender instances will get notified by
@ -80,18 +77,6 @@ class DeviceMessageHandler:
hs.config.worker.writers.to_device,
)
# The handler to call when we think a user's device list might be out of
# sync. We do all device list resyncing on the master instance, so if
# we're on a worker we hit the device resync replication API.
if hs.config.worker.worker_app is None:
self._multi_user_device_resync = (
hs.get_device_handler().device_list_updater.multi_user_device_resync
)
else:
self._multi_user_device_resync = (
ReplicationMultiUserDevicesResyncRestServlet.make_client(hs)
)
# a rate limiter for room key requests. The keys are
# (sending_user_id, sending_device_id).
self._ratelimiter = Ratelimiter(
@ -213,7 +198,10 @@ class DeviceMessageHandler:
await self.store.mark_remote_users_device_caches_as_stale((sender_user_id,))
# Immediately attempt a resync in the background
run_in_background(self._multi_user_device_resync, user_ids=[sender_user_id])
run_in_background(
self.device_handler.device_list_updater.multi_user_device_resync,
user_ids=[sender_user_id],
)
async def send_device_message(
self,

View File

@ -32,10 +32,9 @@ from twisted.internet import defer
from synapse.api.constants import EduTypes
from synapse.api.errors import CodeMessageException, Codes, NotFoundError, SynapseError
from synapse.handlers.device import DeviceHandler
from synapse.handlers.device import DeviceWriterHandler
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
from synapse.replication.http.devices import ReplicationUploadKeysForUserRestServlet
from synapse.types import (
JsonDict,
JsonMapping,
@ -76,8 +75,10 @@ class E2eKeysHandler:
federation_registry = hs.get_federation_registry()
is_master = hs.config.worker.worker_app is None
if is_master:
# Only the first writer in the list should handle EDUs for signing key
# updates, so that we can use an in-memory linearizer instead of worker locks.
edu_writer = hs.config.worker.writers.device_lists[0]
if hs.get_instance_name() == edu_writer:
edu_updater = SigningKeyEduUpdater(hs)
# Only register this edu handler on master as it requires writing
@ -92,11 +93,14 @@ class E2eKeysHandler:
EduTypes.UNSTABLE_SIGNING_KEY_UPDATE,
edu_updater.incoming_signing_key_update,
)
self.device_key_uploader = self.upload_device_keys_for_user
else:
self.device_key_uploader = (
ReplicationUploadKeysForUserRestServlet.make_client(hs)
federation_registry.register_instances_for_edu(
EduTypes.SIGNING_KEY_UPDATE,
[edu_writer],
)
federation_registry.register_instances_for_edu(
EduTypes.UNSTABLE_SIGNING_KEY_UPDATE,
[edu_writer],
)
# doesn't really work as part of the generic query API, because the
@ -847,7 +851,7 @@ class E2eKeysHandler:
# TODO: Validate the JSON to make sure it has the right keys.
device_keys = keys.get("device_keys", None)
if device_keys:
await self.device_key_uploader(
await self.upload_device_keys_for_user(
user_id=user_id,
device_id=device_id,
keys={"device_keys": device_keys},
@ -904,9 +908,6 @@ class E2eKeysHandler:
device_keys: the `device_keys` of an /keys/upload request.
"""
# This can only be called from the main process.
assert isinstance(self.device_handler, DeviceHandler)
time_now = self.clock.time_msec()
device_keys = keys["device_keys"]
@ -998,9 +999,6 @@ class E2eKeysHandler:
user_id: the user uploading the keys
keys: the signing keys
"""
# This can only be called from the main process.
assert isinstance(self.device_handler, DeviceHandler)
# if a master key is uploaded, then check it. Otherwise, load the
# stored master key, to check signatures on other keys
if "master_key" in keys:
@ -1091,9 +1089,6 @@ class E2eKeysHandler:
Raises:
SynapseError: if the signatures dict is not valid.
"""
# This can only be called from the main process.
assert isinstance(self.device_handler, DeviceHandler)
failures = {}
# signatures to be stored. Each item will be a SignatureListItem
@ -1467,9 +1462,6 @@ class E2eKeysHandler:
A tuple of the retrieved key content, the key's ID and the matching VerifyKey.
If the key cannot be retrieved, all values in the tuple will instead be None.
"""
# This can only be called from the main process.
assert isinstance(self.device_handler, DeviceHandler)
try:
remote_result = await self.federation.query_user_devices(
user.domain, user.to_string()
@ -1770,7 +1762,7 @@ class SigningKeyEduUpdater:
self.clock = hs.get_clock()
device_handler = hs.get_device_handler()
assert isinstance(device_handler, DeviceHandler)
assert isinstance(device_handler, DeviceWriterHandler)
self._device_handler = device_handler
self._remote_edu_linearizer = Linearizer(name="remote_signing_key")

View File

@ -698,10 +698,19 @@ class FederationHandler:
# We may want to reset the partial state info if it's from an
# old, failed partial state join.
# https://github.com/matrix-org/synapse/issues/13000
# FIXME: Ideally, we would store the full stream token here
# not just the minimum stream ID, so that we can compute an
# accurate list of device changes when un-partial-ing the
# room. The only side effect of this is that we may send
# extra unecessary device list outbound pokes through
# federation, which is harmless.
device_lists_stream_id = self.store.get_device_stream_token().stream
await self.store.store_partial_state_room(
room_id=room_id,
servers=ret.servers_in_room,
device_lists_stream_id=self.store.get_device_stream_token(),
device_lists_stream_id=device_lists_stream_id,
joined_via=origin,
)

View File

@ -77,9 +77,6 @@ from synapse.logging.opentracing import (
trace,
)
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.devices import (
ReplicationMultiUserDevicesResyncRestServlet,
)
from synapse.replication.http.federation import (
ReplicationFederationSendEventsRestServlet,
)
@ -180,12 +177,7 @@ class FederationEventHandler:
self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages
self._send_events = ReplicationFederationSendEventsRestServlet.make_client(hs)
if hs.config.worker.worker_app:
self._multi_user_device_resync = (
ReplicationMultiUserDevicesResyncRestServlet.make_client(hs)
)
else:
self._device_list_updater = hs.get_device_handler().device_list_updater
self._device_list_updater = hs.get_device_handler().device_list_updater
# When joining a room we need to queue any events for that room up.
# For each room, a list of (pdu, origin) tuples.
@ -1544,12 +1536,7 @@ class FederationEventHandler:
await self._store.mark_remote_users_device_caches_as_stale((sender,))
# Immediately attempt a resync in the background
if self._config.worker.worker_app:
await self._multi_user_device_resync(user_ids=[sender])
else:
await self._device_list_updater.multi_user_device_resync(
user_ids=[sender]
)
await self._device_list_updater.multi_user_device_resync(user_ids=[sender])
except Exception:
logger.exception("Failed to resync device for %s", sender)

View File

@ -44,7 +44,6 @@ from synapse.api.errors import (
)
from synapse.appservice import ApplicationService
from synapse.config.server import is_threepid_reserved
from synapse.handlers.device import DeviceHandler
from synapse.http.servlet import assert_params_in_dict
from synapse.replication.http.login import RegisterDeviceReplicationServlet
from synapse.replication.http.register import (
@ -840,9 +839,6 @@ class RegistrationHandler:
refresh_token = None
refresh_token_id = None
# This can only run on the main process.
assert isinstance(self.device_handler, DeviceHandler)
registered_device_id = await self.device_handler.check_device_registered(
user_id,
device_id,

View File

@ -21,7 +21,6 @@ import logging
from typing import TYPE_CHECKING, Optional
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.handlers.device import DeviceHandler
from synapse.types import Requester
if TYPE_CHECKING:
@ -36,17 +35,7 @@ class SetPasswordHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self._auth_handler = hs.get_auth_handler()
# We don't need the device handler if password changing is disabled.
# This allows us to instantiate the SetPasswordHandler on the workers
# that have admin APIs for MAS
if self._auth_handler.can_change_password():
# This can only be instantiated on the main process.
device_handler = hs.get_device_handler()
assert isinstance(device_handler, DeviceHandler)
self._device_handler: Optional[DeviceHandler] = device_handler
else:
self._device_handler = None
self._device_handler = hs.get_device_handler()
async def set_password(
self,
@ -58,9 +47,6 @@ class SetPasswordHandler:
if not self._auth_handler.can_change_password():
raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN)
# We should have this available only if password changing is enabled.
assert self._device_handler is not None
try:
await self.store.user_set_password_hash(user_id, password_hash)
except StoreError as e:

View File

@ -46,7 +46,6 @@ from twisted.web.server import Request
from synapse.api.constants import LoginType, ProfileFields
from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
from synapse.config.sso import SsoAttributeRequirement
from synapse.handlers.device import DeviceHandler
from synapse.handlers.register import init_counters_for_auth_provider
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http import get_request_user_agent
@ -1181,8 +1180,6 @@ class SsoHandler:
) -> None:
"""Revoke any devices and in-flight logins tied to a provider session.
Can only be called from the main process.
Args:
auth_provider_id: A unique identifier for this SSO provider, e.g.
"oidc" or "saml".
@ -1191,11 +1188,6 @@ class SsoHandler:
sessions belonging to other users and log an error.
"""
# It is expected that this is the main process.
assert isinstance(self._device_handler, DeviceHandler), (
"revoking SSO sessions can only be called on the main process"
)
# Invalidate any running user-mapping sessions
to_delete = []
for session_id, session in self._username_mapping_sessions.items():

View File

@ -66,7 +66,6 @@ from synapse.handlers.auth import (
ON_LOGGED_OUT_CALLBACK,
AuthHandler,
)
from synapse.handlers.device import DeviceHandler
from synapse.handlers.push_rules import RuleSpec, check_actions
from synapse.http.client import SimpleHttpClient
from synapse.http.server import (
@ -925,8 +924,6 @@ class ModuleApi:
) -> Generator["defer.Deferred[Any]", Any, None]:
"""Invalidate an access token for a user
Can only be called from the main process.
Added in Synapse v0.25.0.
Args:
@ -939,10 +936,6 @@ class ModuleApi:
Raises:
synapse.api.errors.AuthError: the access token is invalid
"""
assert isinstance(self._device_handler, DeviceHandler), (
"invalidate_access_token can only be called on the main process"
)
# see if the access token corresponds to a device
user_info = yield defer.ensureDeferred(
self._auth.get_user_by_access_token(access_token)

View File

@ -59,10 +59,10 @@ class ReplicationRestResource(JsonResource):
account_data.register_servlets(hs, self)
push.register_servlets(hs, self)
state.register_servlets(hs, self)
devices.register_servlets(hs, self)
# The following can't currently be instantiated on workers.
if hs.config.worker.worker_app is None:
login.register_servlets(hs, self)
register.register_servlets(hs, self)
devices.register_servlets(hs, self)
delayed_events.register_servlets(hs, self)

View File

@ -34,6 +34,92 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class ReplicationNotifyDeviceUpdateRestServlet(ReplicationEndpoint):
"""Notify a device writer that a user's device list has changed.
Request format:
POST /_synapse/replication/notify_device_update/:user_id
{
"device_ids": ["JLAFKJWSCS", "JLAFKJWSCS"]
}
"""
NAME = "notify_device_update"
PATH_ARGS = ("user_id",)
CACHE = False
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.device_handler = hs.get_device_handler()
self.store = hs.get_datastores().main
self.clock = hs.get_clock()
@staticmethod
async def _serialize_payload( # type: ignore[override]
user_id: str, device_ids: List[str]
) -> JsonDict:
return {"device_ids": device_ids}
async def _handle_request( # type: ignore[override]
self, request: Request, content: JsonDict, user_id: str
) -> Tuple[int, JsonDict]:
device_ids = content["device_ids"]
span = active_span()
if span:
span.set_tag("user_id", user_id)
span.set_tag("device_ids", f"{device_ids!r}")
await self.device_handler.notify_device_update(user_id, device_ids)
return 200, {}
class ReplicationNotifyUserSignatureUpdateRestServlet(ReplicationEndpoint):
"""Notify a device writer that a user have made new signatures of other users.
Request format:
POST /_synapse/replication/notify_user_signature_update/:from_user_id
{
"user_ids": ["@alice:example.org", "@bob:example.org", ...]
}
"""
NAME = "notify_user_signature_update"
PATH_ARGS = ("from_user_id",)
CACHE = False
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.device_handler = hs.get_device_handler()
self.store = hs.get_datastores().main
self.clock = hs.get_clock()
@staticmethod
async def _serialize_payload(from_user_id: str, user_ids: List[str]) -> JsonDict: # type: ignore[override]
return {"user_ids": user_ids}
async def _handle_request( # type: ignore[override]
self, request: Request, content: JsonDict, from_user_id: str
) -> Tuple[int, JsonDict]:
user_ids = content["user_ids"]
span = active_span()
if span:
span.set_tag("from_user_id", from_user_id)
span.set_tag("user_ids", f"{user_ids!r}")
await self.device_handler.notify_user_signature_update(from_user_id, user_ids)
return 200, {}
class ReplicationMultiUserDevicesResyncRestServlet(ReplicationEndpoint):
"""Ask master to resync the device list for multiple users from the same
remote server by contacting their server.
@ -73,11 +159,7 @@ class ReplicationMultiUserDevicesResyncRestServlet(ReplicationEndpoint):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
from synapse.handlers.device import DeviceHandler
handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler)
self.device_list_updater = handler.device_list_updater
self.device_list_updater = hs.get_device_handler().device_list_updater
self.store = hs.get_datastores().main
self.clock = hs.get_clock()
@ -103,32 +185,10 @@ class ReplicationMultiUserDevicesResyncRestServlet(ReplicationEndpoint):
return 200, multi_user_devices
# FIXME(2025-07-22): Remove this on the next release, this will only get used
# during rollout to Synapse 1.135 and can be removed after that release.
class ReplicationUploadKeysForUserRestServlet(ReplicationEndpoint):
"""Ask master to upload keys for the user and send them out over federation to
update other servers.
For now, only the master is permitted to handle key upload requests;
any worker can handle key query requests (since they're read-only).
Calls to e2e_keys_handler.upload_keys_for_user(user_id, device_id, keys) on
the main process to accomplish this.
Request format for this endpoint (borrowed and expanded from KeyUploadServlet):
POST /_synapse/replication/upload_keys_for_user
{
"user_id": "<user_id>",
"device_id": "<device_id>",
"keys": {
....this part can be found in KeyUploadServlet in rest/client/keys.py....
or as defined in https://spec.matrix.org/v1.4/client-server-api/#post_matrixclientv3keysupload
}
}
Response is equivalent to ` /_matrix/client/v3/keys/upload` found in KeyUploadServlet
"""
"""Unused endpoint, kept for backwards compatibility during rollout."""
NAME = "upload_keys_for_user"
PATH_ARGS = ()
@ -165,6 +225,71 @@ class ReplicationUploadKeysForUserRestServlet(ReplicationEndpoint):
return 200, results
class ReplicationHandleNewDeviceUpdateRestServlet(ReplicationEndpoint):
"""Wake up a device writer to send local device list changes as federation outbound pokes.
Request format:
POST /_synapse/replication/handle_new_device_update
{}
"""
NAME = "handle_new_device_update"
PATH_ARGS = ()
CACHE = False
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.device_handler = hs.get_device_handler()
@staticmethod
async def _serialize_payload() -> JsonDict: # type: ignore[override]
return {}
async def _handle_request( # type: ignore[override]
self, request: Request, content: JsonDict
) -> Tuple[int, JsonDict]:
await self.device_handler.handle_new_device_update()
return 200, {}
class ReplicationDeviceHandleRoomUnPartialStated(ReplicationEndpoint):
"""Handles sending appropriate device list updates in a room that has
gone from partial to full state.
Request format:
POST /_synapse/replication/device_handle_room_un_partial_stated/:room_id
{}
"""
NAME = "device_handle_room_un_partial_stated"
PATH_ARGS = ("room_id",)
CACHE = True
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.device_handler = hs.get_device_handler()
@staticmethod
async def _serialize_payload(room_id: str) -> JsonDict: # type: ignore[override]
return {}
async def _handle_request( # type: ignore[override]
self, request: Request, content: JsonDict, room_id: str
) -> Tuple[int, JsonDict]:
await self.device_handler.handle_room_un_partial_stated(room_id)
return 200, {}
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReplicationNotifyDeviceUpdateRestServlet(hs).register(http_server)
ReplicationNotifyUserSignatureUpdateRestServlet(hs).register(http_server)
ReplicationMultiUserDevicesResyncRestServlet(hs).register(http_server)
ReplicationHandleNewDeviceUpdateRestServlet(hs).register(http_server)
ReplicationUploadKeysForUserRestServlet(hs).register(http_server)
ReplicationDeviceHandleRoomUnPartialStated(hs).register(http_server)

View File

@ -116,7 +116,11 @@ class ReplicationDataHandler:
all_room_ids: Set[str] = set()
if stream_name == DeviceListsStream.NAME:
if any(not row.is_signature and not row.hosts_calculated for row in rows):
prev_token = self.store.get_device_stream_token()
# This only uses the minimum stream position on the device lists
# stream, which means that we may process a device list change
# twice in case of concurrent writes. This is fine, as this only
# triggers cache invalidation, which is harmless if done twice.
prev_token = self.store.get_device_stream_token().stream
all_room_ids = await self.store.get_all_device_list_changes(
prev_token, token
)

View File

@ -72,6 +72,7 @@ from synapse.replication.tcp.streams import (
ToDeviceStream,
TypingStream,
)
from synapse.replication.tcp.streams._base import DeviceListsStream
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -185,6 +186,12 @@ class ReplicationCommandHandler:
continue
if isinstance(stream, DeviceListsStream):
if hs.get_instance_name() in hs.config.worker.writers.device_lists:
self._streams_to_replicate.append(stream)
continue
# Only add any other streams if we're on master.
if hs.config.worker.worker_app is not None:
continue

View File

@ -51,7 +51,6 @@ from synapse.rest.admin.background_updates import (
from synapse.rest.admin.devices import (
DeleteDevicesRestServlet,
DeviceRestServlet,
DevicesGetRestServlet,
DevicesRestServlet,
)
from synapse.rest.admin.event_reports import (
@ -375,4 +374,5 @@ def register_servlets_for_msc3861_delegation(
UserRestServletV2(hs).register(http_server)
UsernameAvailableRestServlet(hs).register(http_server)
UserReplaceMasterCrossSigningKeyRestServlet(hs).register(http_server)
DevicesGetRestServlet(hs).register(http_server)
DeviceRestServlet(hs).register(http_server)
DevicesRestServlet(hs).register(http_server)

View File

@ -23,7 +23,6 @@ from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import NotFoundError, SynapseError
from synapse.handlers.device import DeviceHandler
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
@ -51,9 +50,7 @@ class DeviceRestServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler)
self.device_handler = handler
self.device_handler = hs.get_device_handler()
self.store = hs.get_datastores().main
self.is_mine = hs.is_mine
@ -113,7 +110,7 @@ class DeviceRestServlet(RestServlet):
return HTTPStatus.OK, {}
class DevicesGetRestServlet(RestServlet):
class DevicesRestServlet(RestServlet):
"""
Retrieve the given user's devices
@ -158,19 +155,6 @@ class DevicesGetRestServlet(RestServlet):
return HTTPStatus.OK, {"devices": devices, "total": len(devices)}
class DevicesRestServlet(DevicesGetRestServlet):
"""
Retrieve the given user's devices
"""
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/devices$", "v2")
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
assert isinstance(self.device_worker_handler, DeviceHandler)
self.device_handler = self.device_worker_handler
async def on_POST(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
@ -194,7 +178,7 @@ class DevicesRestServlet(DevicesGetRestServlet):
if not isinstance(device_id, str):
raise SynapseError(HTTPStatus.BAD_REQUEST, "device_id must be a string")
await self.device_handler.check_device_registered(
await self.device_worker_handler.check_device_registered(
user_id=user_id, device_id=device_id
)
@ -211,9 +195,7 @@ class DeleteDevicesRestServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
self.auth = hs.get_auth()
handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler)
self.device_handler = handler
self.device_handler = hs.get_device_handler()
self.store = hs.get_datastores().main
self.is_mine = hs.is_mine

View File

@ -27,7 +27,6 @@ from typing import TYPE_CHECKING, List, Optional, Tuple
from synapse._pydantic_compat import Extra, StrictStr
from synapse.api import errors
from synapse.api.errors import NotFoundError, SynapseError, UnrecognizedRequestError
from synapse.handlers.device import DeviceHandler
from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
@ -91,7 +90,6 @@ class DeleteDevicesRestServlet(RestServlet):
self.hs = hs
self.auth = hs.get_auth()
handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler)
self.device_handler = handler
self.auth_handler = hs.get_auth_handler()
@ -147,7 +145,6 @@ class DeviceRestServlet(RestServlet):
self.auth_handler = hs.get_auth_handler()
self._msc3852_enabled = hs.config.experimental.msc3852_enabled
self._msc3861_oauth_delegation_enabled = hs.config.experimental.msc3861.enabled
self._is_main_process = hs.config.worker.worker_app is None
async def on_GET(
self, request: SynapseRequest, device_id: str
@ -179,14 +176,6 @@ class DeviceRestServlet(RestServlet):
async def on_DELETE(
self, request: SynapseRequest, device_id: str
) -> Tuple[int, JsonDict]:
# Can only be run on main process, as changes to device lists must
# happen on main.
if not self._is_main_process:
error_message = "DELETE on /devices/ must be routed to main process"
logger.error(error_message)
raise SynapseError(500, error_message)
assert isinstance(self.device_handler, DeviceHandler)
requester = await self.auth.get_user_by_req(request)
try:
@ -231,14 +220,6 @@ class DeviceRestServlet(RestServlet):
async def on_PUT(
self, request: SynapseRequest, device_id: str
) -> Tuple[int, JsonDict]:
# Can only be run on main process, as changes to device lists must
# happen on main.
if not self._is_main_process:
error_message = "PUT on /devices/ must be routed to main process"
logger.error(error_message)
raise SynapseError(500, error_message)
assert isinstance(self.device_handler, DeviceHandler)
requester = await self.auth.get_user_by_req(request, allow_guest=True)
body = parse_and_validate_json_object_from_request(request, self.PutBody)
@ -317,7 +298,6 @@ class DehydratedDeviceServlet(RestServlet):
self.hs = hs
self.auth = hs.get_auth()
handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler)
self.device_handler = handler
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
@ -377,7 +357,6 @@ class ClaimDehydratedDeviceServlet(RestServlet):
self.hs = hs
self.auth = hs.get_auth()
handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler)
self.device_handler = handler
class PostBody(RequestBodyModel):
@ -517,7 +496,6 @@ class DehydratedDeviceV2Servlet(RestServlet):
self.hs = hs
self.auth = hs.get_auth()
handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler)
self.e2e_keys_handler = hs.get_e2e_keys_handler()
self.device_handler = handler
@ -595,18 +573,14 @@ class DehydratedDeviceV2Servlet(RestServlet):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
if (
hs.config.worker.worker_app is None
and not hs.config.experimental.msc3861.enabled
):
if not hs.config.experimental.msc3861.enabled:
DeleteDevicesRestServlet(hs).register(http_server)
DevicesRestServlet(hs).register(http_server)
DeviceRestServlet(hs).register(http_server)
if hs.config.worker.worker_app is None:
if hs.config.experimental.msc2697_enabled:
DehydratedDeviceServlet(hs).register(http_server)
ClaimDehydratedDeviceServlet(hs).register(http_server)
if hs.config.experimental.msc3814_enabled:
DehydratedDeviceV2Servlet(hs).register(http_server)
DehydratedDeviceEventsServlet(hs).register(http_server)
if hs.config.experimental.msc2697_enabled:
DehydratedDeviceServlet(hs).register(http_server)
ClaimDehydratedDeviceServlet(hs).register(http_server)
if hs.config.experimental.msc3814_enabled:
DehydratedDeviceV2Servlet(hs).register(http_server)
DehydratedDeviceEventsServlet(hs).register(http_server)

View File

@ -504,6 +504,5 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
OneTimeKeyServlet(hs).register(http_server)
if hs.config.experimental.msc3983_appservice_otk_claims:
UnstableOneTimeKeyServlet(hs).register(http_server)
if hs.config.worker.worker_app is None:
SigningKeyUploadServlet(hs).register(http_server)
SignaturesUploadServlet(hs).register(http_server)
SigningKeyUploadServlet(hs).register(http_server)
SignaturesUploadServlet(hs).register(http_server)

View File

@ -22,7 +22,6 @@
import logging
from typing import TYPE_CHECKING, Tuple
from synapse.handlers.device import DeviceHandler
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet
from synapse.http.site import SynapseRequest
@ -42,9 +41,7 @@ class LogoutRestServlet(RestServlet):
super().__init__()
self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler)
self._device_handler = handler
self._device_handler = hs.get_device_handler()
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(
@ -71,9 +68,7 @@ class LogoutAllRestServlet(RestServlet):
super().__init__()
self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler)
self._device_handler = handler
self._device_handler = hs.get_device_handler()
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(

View File

@ -69,7 +69,7 @@ from synapse.handlers.auth import AuthHandler, PasswordAuthProvider
from synapse.handlers.cas import CasHandler
from synapse.handlers.deactivate_account import DeactivateAccountHandler
from synapse.handlers.delayed_events import DelayedEventsHandler
from synapse.handlers.device import DeviceHandler, DeviceWorkerHandler
from synapse.handlers.device import DeviceHandler, DeviceWriterHandler
from synapse.handlers.devicemessage import DeviceMessageHandler
from synapse.handlers.directory import DirectoryHandler
from synapse.handlers.e2e_keys import E2eKeysHandler
@ -586,11 +586,11 @@ class HomeServer(metaclass=abc.ABCMeta):
)
@cache_in_self
def get_device_handler(self) -> DeviceWorkerHandler:
if self.config.worker.worker_app:
return DeviceWorkerHandler(self)
else:
return DeviceHandler(self)
def get_device_handler(self) -> DeviceHandler:
if self.get_instance_name() in self.config.worker.writers.device_lists:
return DeviceWriterHandler(self)
return DeviceHandler(self)
@cache_in_self
def get_device_message_handler(self) -> DeviceMessageHandler:

File diff suppressed because it is too large Load Diff

View File

@ -59,7 +59,7 @@ from synapse.storage.database import (
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import JsonDict, JsonMapping
from synapse.types import JsonDict, JsonMapping, MultiWriterStreamToken
from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.cancellation import cancellable
@ -120,6 +120,20 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
self.hs.config.federation.allow_device_name_lookup_over_federation
)
self._cross_signing_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
notifier=hs.get_replication_notifier(),
stream_name="e2e_cross_signing_keys",
instance_name=self._instance_name,
tables=[
("e2e_cross_signing_keys", "instance_name", "stream_id"),
],
sequence_name="e2e_cross_signing_keys_sequence",
# No one reads the stream positions, so we're allowed to have an empty list of writers
writers=[],
)
def process_replication_rows(
self,
stream_name: str,
@ -145,7 +159,12 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
Returns:
(stream_id, devices)
"""
now_stream_id = self.get_device_stream_token()
# Here, we don't use the individual instances positions, as we *need* to
# give out the stream_id as an integer in the federation API.
# This means that we'll potentially return the same data twice with a
# different stream_id, and invalidate cache more often than necessary,
# which is fine overall.
now_stream_id = self.get_device_stream_token().stream
# We need to be careful with the caching here, as we need to always
# return *all* persisted devices, however there may be a lag between a
@ -164,8 +183,10 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
# have to check for potential invalidations after the
# `now_stream_id`.
sql = """
SELECT user_id FROM device_lists_stream
SELECT 1
FROM device_lists_stream
WHERE stream_id >= ? AND user_id = ?
LIMIT 1
"""
rows = await self.db_pool.execute(
"get_e2e_device_keys_for_federation_query_check",
@ -1117,7 +1138,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
)
@abc.abstractmethod
def get_device_stream_token(self) -> int:
def get_device_stream_token(self) -> MultiWriterStreamToken:
"""Get the current stream id from the _device_list_id_gen"""
...
@ -1540,27 +1561,44 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
impl,
)
async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
def delete_e2e_keys_by_device_txn(txn: LoggingTransaction) -> None:
log_kv(
{
"message": "Deleting keys for device",
"device_id": device_id,
"user_id": user_id,
}
)
self.db_pool.simple_delete_txn(
txn,
table="e2e_device_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
)
self.db_pool.simple_delete_txn(
txn,
table="e2e_one_time_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
)
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
self.db_pool.simple_delete_txn(
txn,
table="dehydrated_devices",
keyvalues={"user_id": user_id, "device_id": device_id},
)
self.db_pool.simple_delete_txn(
txn,
table="e2e_fallback_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
)
self._invalidate_cache_and_stream(
txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
)
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self._cross_signing_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
notifier=hs.get_replication_notifier(),
stream_name="e2e_cross_signing_keys",
instance_name=self._instance_name,
tables=[
("e2e_cross_signing_keys", "instance_name", "stream_id"),
],
sequence_name="e2e_cross_signing_keys_sequence",
writers=["master"],
await self.db_pool.runInteraction(
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
)
async def set_e2e_device_keys(
@ -1754,3 +1792,13 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
],
desc="add_e2e_signing_key",
)
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)

View File

@ -36,7 +36,6 @@ from typing import (
)
import attr
from immutabledict import immutabledict
from synapse.api.constants import EduTypes
from synapse.replication.tcp.streams import ReceiptsStream
@ -167,25 +166,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
def get_max_receipt_stream_id(self) -> MultiWriterStreamToken:
"""Get the current max stream ID for receipts stream"""
min_pos = self._receipts_id_gen.get_current_token()
positions = {}
if isinstance(self._receipts_id_gen, MultiWriterIdGenerator):
# The `min_pos` is the minimum position that we know all instances
# have finished persisting to, so we only care about instances whose
# positions are ahead of that. (Instance positions can be behind the
# min position as there are times we can work out that the minimum
# position is ahead of the naive minimum across all current
# positions. See MultiWriterIdGenerator for details)
positions = {
i: p
for i, p in self._receipts_id_gen.get_positions().items()
if p > min_pos
}
return MultiWriterStreamToken(
stream=min_pos, instance_map=immutabledict(positions)
)
return MultiWriterStreamToken.from_generator(self._receipts_id_gen)
def get_receipt_stream_id_for_instance(self, instance_name: str) -> int:
return self._receipts_id_gen.get_current_token_for_writer(instance_name)

View File

@ -2093,6 +2093,58 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore):
"replace_refresh_token", _replace_refresh_token_txn
)
async def set_device_for_refresh_token(
self, user_id: str, old_device_id: str, device_id: str
) -> None:
"""Moves refresh tokens from old device to current device
Args:
user_id: The user of the devices.
old_device_id: The old device.
device_id: The new device ID.
Returns:
None
"""
await self.db_pool.simple_update(
"refresh_tokens",
keyvalues={"user_id": user_id, "device_id": old_device_id},
updatevalues={"device_id": device_id},
desc="set_device_for_refresh_token",
)
def _set_device_for_access_token_txn(
self, txn: LoggingTransaction, token: str, device_id: str
) -> str:
old_device_id = self.db_pool.simple_select_one_onecol_txn(
txn, "access_tokens", {"token": token}, "device_id"
)
self.db_pool.simple_update_txn(
txn, "access_tokens", {"token": token}, {"device_id": device_id}
)
self._invalidate_cache_and_stream(txn, self.get_user_by_access_token, (token,))
return old_device_id
async def set_device_for_access_token(self, token: str, device_id: str) -> str:
"""Sets the device ID associated with an access token.
Args:
token: The access token to modify.
device_id: The new device ID.
Returns:
The old device ID associated with the access token.
"""
return await self.db_pool.runInteraction(
"set_device_for_access_token",
self._set_device_for_access_token_txn,
token,
device_id,
)
async def add_login_token_to_user(
self,
user_id: str,
@ -2396,6 +2448,154 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore):
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
self._invalidate_cache_and_stream(txn, self.is_user_approved, (user_id,))
async def user_delete_access_tokens(
self,
user_id: str,
except_token_id: Optional[int] = None,
device_id: Optional[str] = None,
) -> List[Tuple[str, int, Optional[str]]]:
"""
Invalidate access and refresh tokens belonging to a user
Args:
user_id: ID of user the tokens belong to
except_token_id: access_tokens ID which should *not* be deleted
device_id: ID of device the tokens are associated with.
If None, tokens associated with any device (or no device) will
be deleted
Returns:
A tuple of (token, token id, device id) for each of the deleted tokens
"""
def f(txn: LoggingTransaction) -> List[Tuple[str, int, Optional[str]]]:
keyvalues = {"user_id": user_id}
if device_id is not None:
keyvalues["device_id"] = device_id
items = keyvalues.items()
where_clause = " AND ".join(k + " = ?" for k, _ in items)
values: List[Union[str, int]] = [v for _, v in items]
# Conveniently, refresh_tokens and access_tokens both use the user_id and device_id fields. Only caveat
# is the `except_token_id` param that is tricky to get right, so for now we're just using the same where
# clause and values before we handle that. This seems to be only used in the "set password" handler.
refresh_where_clause = where_clause
refresh_values = values.copy()
if except_token_id:
# TODO: support that for refresh tokens
where_clause += " AND id != ?"
values.append(except_token_id)
txn.execute(
"SELECT token, id, device_id FROM access_tokens WHERE %s"
% where_clause,
values,
)
tokens_and_devices = [(r[0], r[1], r[2]) for r in txn]
self._invalidate_cache_and_stream_bulk(
txn,
self.get_user_by_access_token,
[(token,) for token, _, _ in tokens_and_devices],
)
txn.execute("DELETE FROM access_tokens WHERE %s" % where_clause, values)
txn.execute(
"DELETE FROM refresh_tokens WHERE %s" % refresh_where_clause,
refresh_values,
)
return tokens_and_devices
return await self.db_pool.runInteraction("user_delete_access_tokens", f)
async def user_delete_access_tokens_for_devices(
self,
user_id: str,
device_ids: StrCollection,
) -> List[Tuple[str, int, Optional[str]]]:
"""
Invalidate access and refresh tokens belonging to a user
Args:
user_id: ID of user the tokens belong to
device_ids: The devices to delete tokens for.
Returns:
A tuple of (token, token id, device id) for each of the deleted tokens
"""
def user_delete_access_tokens_for_devices_txn(
txn: LoggingTransaction, batch_device_ids: StrCollection
) -> List[Tuple[str, int, Optional[str]]]:
self.db_pool.simple_delete_many_txn(
txn,
table="refresh_tokens",
keyvalues={"user_id": user_id},
column="device_id",
values=batch_device_ids,
)
clause, args = make_in_list_sql_clause(
txn.database_engine, "device_id", batch_device_ids
)
args.append(user_id)
if self.database_engine.supports_returning:
sql = f"""
DELETE FROM access_tokens
WHERE {clause} AND user_id = ?
RETURNING token, id, device_id
"""
txn.execute(sql, args)
tokens_and_devices = txn.fetchall()
else:
tokens_and_devices = self.db_pool.simple_select_many_txn(
txn,
table="access_tokens",
column="device_id",
iterable=batch_device_ids,
keyvalues={"user_id": user_id},
retcols=("token", "id", "device_id"),
)
self.db_pool.simple_delete_many_txn(
txn,
table="access_tokens",
keyvalues={"user_id": user_id},
column="device_id",
values=batch_device_ids,
)
self._invalidate_cache_and_stream_bulk(
txn,
self.get_user_by_access_token,
[(t[0],) for t in tokens_and_devices],
)
return tokens_and_devices
results = []
for batch_device_ids in batch_iter(device_ids, 1000):
tokens_and_devices = await self.db_pool.runInteraction(
"user_delete_access_tokens_for_devices",
user_delete_access_tokens_for_devices_txn,
batch_device_ids,
)
results.extend(tokens_and_devices)
return results
async def delete_access_token(self, access_token: str) -> None:
def f(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_one_txn(
txn, table="access_tokens", keyvalues={"token": access_token}
)
self._invalidate_cache_and_stream(
txn, self.get_user_by_access_token, (access_token,)
)
await self.db_pool.runInteraction("delete_access_token", f)
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
def __init__(
@ -2620,58 +2820,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
return next_id
async def set_device_for_refresh_token(
self, user_id: str, old_device_id: str, device_id: str
) -> None:
"""Moves refresh tokens from old device to current device
Args:
user_id: The user of the devices.
old_device_id: The old device.
device_id: The new device ID.
Returns:
None
"""
await self.db_pool.simple_update(
"refresh_tokens",
keyvalues={"user_id": user_id, "device_id": old_device_id},
updatevalues={"device_id": device_id},
desc="set_device_for_refresh_token",
)
def _set_device_for_access_token_txn(
self, txn: LoggingTransaction, token: str, device_id: str
) -> str:
old_device_id = self.db_pool.simple_select_one_onecol_txn(
txn, "access_tokens", {"token": token}, "device_id"
)
self.db_pool.simple_update_txn(
txn, "access_tokens", {"token": token}, {"device_id": device_id}
)
self._invalidate_cache_and_stream(txn, self.get_user_by_access_token, (token,))
return old_device_id
async def set_device_for_access_token(self, token: str, device_id: str) -> str:
"""Sets the device ID associated with an access token.
Args:
token: The access token to modify.
device_id: The new device ID.
Returns:
The old device ID associated with the access token.
"""
return await self.db_pool.runInteraction(
"set_device_for_access_token",
self._set_device_for_access_token_txn,
token,
device_id,
)
async def user_set_password_hash(
self, user_id: str, password_hash: Optional[str]
) -> None:
@ -2743,162 +2891,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
await self.db_pool.runInteraction("user_set_consent_server_notice_sent", f)
async def user_delete_access_tokens(
self,
user_id: str,
except_token_id: Optional[int] = None,
device_id: Optional[str] = None,
) -> List[Tuple[str, int, Optional[str]]]:
"""
Invalidate access and refresh tokens belonging to a user
Args:
user_id: ID of user the tokens belong to
except_token_id: access_tokens ID which should *not* be deleted
device_id: ID of device the tokens are associated with.
If None, tokens associated with any device (or no device) will
be deleted
Returns:
A tuple of (token, token id, device id) for each of the deleted tokens
"""
def f(txn: LoggingTransaction) -> List[Tuple[str, int, Optional[str]]]:
keyvalues = {"user_id": user_id}
if device_id is not None:
keyvalues["device_id"] = device_id
items = keyvalues.items()
where_clause = " AND ".join(k + " = ?" for k, _ in items)
values: List[Union[str, int]] = [v for _, v in items]
# Conveniently, refresh_tokens and access_tokens both use the user_id and device_id fields. Only caveat
# is the `except_token_id` param that is tricky to get right, so for now we're just using the same where
# clause and values before we handle that. This seems to be only used in the "set password" handler.
refresh_where_clause = where_clause
refresh_values = values.copy()
if except_token_id:
# TODO: support that for refresh tokens
where_clause += " AND id != ?"
values.append(except_token_id)
txn.execute(
"SELECT token, id, device_id FROM access_tokens WHERE %s"
% where_clause,
values,
)
tokens_and_devices = [(r[0], r[1], r[2]) for r in txn]
self._invalidate_cache_and_stream_bulk(
txn,
self.get_user_by_access_token,
[(token,) for token, _, _ in tokens_and_devices],
)
txn.execute("DELETE FROM access_tokens WHERE %s" % where_clause, values)
txn.execute(
"DELETE FROM refresh_tokens WHERE %s" % refresh_where_clause,
refresh_values,
)
return tokens_and_devices
return await self.db_pool.runInteraction("user_delete_access_tokens", f)
async def user_delete_access_tokens_for_devices(
self,
user_id: str,
device_ids: StrCollection,
) -> List[Tuple[str, int, Optional[str]]]:
"""
Invalidate access and refresh tokens belonging to a user
Args:
user_id: ID of user the tokens belong to
device_ids: The devices to delete tokens for.
Returns:
A tuple of (token, token id, device id) for each of the deleted tokens
"""
def user_delete_access_tokens_for_devices_txn(
txn: LoggingTransaction, batch_device_ids: StrCollection
) -> List[Tuple[str, int, Optional[str]]]:
self.db_pool.simple_delete_many_txn(
txn,
table="refresh_tokens",
keyvalues={"user_id": user_id},
column="device_id",
values=batch_device_ids,
)
clause, args = make_in_list_sql_clause(
txn.database_engine, "device_id", batch_device_ids
)
args.append(user_id)
if self.database_engine.supports_returning:
sql = f"""
DELETE FROM access_tokens
WHERE {clause} AND user_id = ?
RETURNING token, id, device_id
"""
txn.execute(sql, args)
tokens_and_devices = txn.fetchall()
else:
tokens_and_devices = self.db_pool.simple_select_many_txn(
txn,
table="access_tokens",
column="device_id",
iterable=batch_device_ids,
keyvalues={"user_id": user_id},
retcols=("token", "id", "device_id"),
)
self.db_pool.simple_delete_many_txn(
txn,
table="access_tokens",
keyvalues={"user_id": user_id},
column="device_id",
values=batch_device_ids,
)
self._invalidate_cache_and_stream_bulk(
txn,
self.get_user_by_access_token,
[(t[0],) for t in tokens_and_devices],
)
return tokens_and_devices
results = []
for batch_device_ids in batch_iter(device_ids, 1000):
tokens_and_devices = await self.db_pool.runInteraction(
"user_delete_access_tokens_for_devices",
user_delete_access_tokens_for_devices_txn,
batch_device_ids,
)
results.extend(tokens_and_devices)
return results
async def delete_access_token(self, access_token: str) -> None:
def f(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_one_txn(
txn, table="access_tokens", keyvalues={"token": access_token}
)
self._invalidate_cache_and_stream(
txn, self.get_user_by_access_token, (access_token,)
)
await self.db_pool.runInteraction("delete_access_token", f)
async def delete_refresh_token(self, refresh_token: str) -> None:
def f(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_one_txn(
txn, table="refresh_tokens", keyvalues={"token": refresh_token}
)
await self.db_pool.runInteraction("delete_refresh_token", f)
async def add_user_pending_deactivation(self, user_id: str) -> None:
"""
Adds a user to the table of users who need to be parted from all the rooms they're

View File

@ -324,7 +324,7 @@ class RelationsWorkerStore(SQLBaseStore):
account_data_key=0,
push_rules_key=0,
to_device_key=0,
device_list_key=0,
device_list_key=MultiWriterStreamToken(stream=0),
groups_key=0,
un_partial_stated_rooms_key=0,
)

View File

@ -1574,6 +1574,11 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
"""Get the event ID of the initial join that started the partial
join, and the device list stream ID at the point we started the partial
join.
This only returns the minimum device list stream ID at the time of
joining, not the full device list stream token. The only impact of this
is that we may be sending again device list updates that we've already
sent to some destinations, which is harmless.
"""
return cast(

View File

@ -61,7 +61,6 @@ from typing import (
)
import attr
from immutabledict import immutabledict
from typing_extensions import assert_never
from twisted.internet import defer
@ -657,23 +656,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
component.
"""
min_pos = self._stream_id_gen.get_current_token()
positions = {}
if isinstance(self._stream_id_gen, MultiWriterIdGenerator):
# The `min_pos` is the minimum position that we know all instances
# have finished persisting to, so we only care about instances whose
# positions are ahead of that. (Instance positions can be behind the
# min position as there are times we can work out that the minimum
# position is ahead of the naive minimum across all current
# positions. See MultiWriterIdGenerator for details)
positions = {
i: p
for i, p in self._stream_id_gen.get_positions().items()
if p > min_pos
}
return RoomStreamToken(stream=min_pos, instance_map=immutabledict(positions))
return RoomStreamToken.from_generator(self._stream_id_gen)
def get_events_stream_id_generator(self) -> MultiWriterIdGenerator:
return self._stream_id_gen

View File

@ -203,7 +203,7 @@ class EventSources:
account_data_key=0,
push_rules_key=0,
to_device_key=0,
device_list_key=0,
device_list_key=MultiWriterStreamToken(stream=0),
groups_key=0,
un_partial_stated_rooms_key=0,
)

View File

@ -75,6 +75,7 @@ if TYPE_CHECKING:
from synapse.appservice.api import ApplicationService
from synapse.storage.databases.main import DataStore, PurgeEventsStore
from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
from synapse.storage.util.id_generators import MultiWriterIdGenerator
logger = logging.getLogger(__name__)
@ -570,6 +571,25 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
),
)
@classmethod
def from_generator(cls, generator: "MultiWriterIdGenerator") -> Self:
"""Get the current token out of a MultiWriterIdGenerator"""
# The `min_pos` is the minimum position that we know all instances
# have finished persisting to, so we only care about instances whose
# positions are ahead of that. (Instance positions can be behind the
# min position as there are times we can work out that the minimum
# position is ahead of the naive minimum across all current
# positions. See MultiWriterIdGenerator for details)
min_pos = generator.get_current_token()
positions = {
instance: position
for instance, position in generator.get_positions().items()
if position > min_pos
}
return cls(stream=min_pos, instance_map=immutabledict(positions))
@attr.s(frozen=True, slots=True, order=False)
class RoomStreamToken(AbstractMultiWriterStreamToken):
@ -980,7 +1000,9 @@ class StreamToken:
account_data_key: int
push_rules_key: int
to_device_key: int
device_list_key: int
device_list_key: MultiWriterStreamToken = attr.ib(
validator=attr.validators.instance_of(MultiWriterStreamToken)
)
# Note that the groups key is no longer used and may have bogus values.
groups_key: int
un_partial_stated_rooms_key: int
@ -1021,7 +1043,9 @@ class StreamToken:
account_data_key=int(account_data_key),
push_rules_key=int(push_rules_key),
to_device_key=int(to_device_key),
device_list_key=int(device_list_key),
device_list_key=await MultiWriterStreamToken.parse(
store, device_list_key
),
groups_key=int(groups_key),
un_partial_stated_rooms_key=int(un_partial_stated_rooms_key),
)
@ -1040,7 +1064,7 @@ class StreamToken:
str(self.account_data_key),
str(self.push_rules_key),
str(self.to_device_key),
str(self.device_list_key),
await self.device_list_key.to_string(store),
# Note that the groups key is no longer used, but it is still
# serialized so that there will not be confusion in the future
# if additional tokens are added.
@ -1069,6 +1093,12 @@ class StreamToken:
StreamKeyType.RECEIPT, self.receipt_key.copy_and_advance(new_value)
)
return new_token
elif key == StreamKeyType.DEVICE_LIST:
new_token = self.copy_and_replace(
StreamKeyType.DEVICE_LIST,
self.device_list_key.copy_and_advance(new_value),
)
return new_token
new_token = self.copy_and_replace(key, new_value)
new_id = new_token.get_field(key)
@ -1087,7 +1117,11 @@ class StreamToken:
@overload
def get_field(
self, key: Literal[StreamKeyType.RECEIPT]
self,
key: Literal[
StreamKeyType.RECEIPT,
StreamKeyType.DEVICE_LIST,
],
) -> MultiWriterStreamToken: ...
@overload
@ -1095,7 +1129,6 @@ class StreamToken:
self,
key: Literal[
StreamKeyType.ACCOUNT_DATA,
StreamKeyType.DEVICE_LIST,
StreamKeyType.PRESENCE,
StreamKeyType.PUSH_RULES,
StreamKeyType.TO_DEVICE,
@ -1161,7 +1194,16 @@ class StreamToken:
StreamToken.START = StreamToken(
RoomStreamToken(stream=0), 0, 0, MultiWriterStreamToken(stream=0), 0, 0, 0, 0, 0, 0
room_key=RoomStreamToken(stream=0),
presence_key=0,
typing_key=0,
receipt_key=MultiWriterStreamToken(stream=0),
account_data_key=0,
push_rules_key=0,
to_device_key=0,
device_list_key=MultiWriterStreamToken(stream=0),
groups_key=0,
un_partial_stated_rooms_key=0,
)

View File

@ -30,7 +30,7 @@ from synapse.api.constants import EduTypes, RoomEncryptionAlgorithms
from synapse.api.presence import UserPresenceState
from synapse.federation.sender.per_destination_queue import MAX_PRESENCE_STATES_PER_EDU
from synapse.federation.units import Transaction
from synapse.handlers.device import DeviceHandler
from synapse.handlers.device import DeviceListUpdater, DeviceWriterHandler
from synapse.rest import admin
from synapse.rest.client import login
from synapse.server import HomeServer
@ -500,7 +500,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
hs.get_datastores().main.get_current_hosts_in_room = get_current_hosts_in_room # type: ignore[assignment]
device_handler = hs.get_device_handler()
assert isinstance(device_handler, DeviceHandler)
assert isinstance(device_handler, DeviceWriterHandler)
self.device_handler = device_handler
# whenever send_transaction is called, record the edu data
@ -554,6 +554,8 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
"devices": [{"device_id": "D1"}],
}
assert isinstance(self.device_handler.device_list_updater, DeviceListUpdater)
self.get_success(
self.device_handler.device_list_updater.incoming_device_list_update(
"host2",

View File

@ -29,7 +29,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import RoomEncryptionAlgorithms
from synapse.api.errors import NotFoundError, SynapseError
from synapse.appservice import ApplicationService
from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler
from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceWriterHandler
from synapse.rest import admin
from synapse.rest.client import devices, login, register
from synapse.server import HomeServer
@ -53,7 +53,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
application_service_api=self.appservice_api,
)
handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler)
assert isinstance(handler, DeviceWriterHandler)
self.handler = handler
self.store = hs.get_datastores().main
self.device_message_handler = hs.get_device_message_handler()
@ -229,7 +229,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
# queue a bunch of messages in the inbox
requester = create_requester(sender, device_id=DEVICE_ID)
for i in range(DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT + 10):
for i in range(DeviceWriterHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT + 10):
self.get_success(
self.device_message_handler.send_device_message(
requester, "message_type", {receiver: {"*": {"val": i}}}
@ -462,7 +462,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver("server")
handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler)
assert isinstance(handler, DeviceWriterHandler)
self.handler = handler
self.message_handler = hs.get_device_message_handler()
self.registration = hs.get_registration_handler()

View File

@ -31,7 +31,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import RoomEncryptionAlgorithms
from synapse.api.errors import Codes, SynapseError
from synapse.appservice import ApplicationService
from synapse.handlers.device import DeviceHandler
from synapse.handlers.device import DeviceWriterHandler
from synapse.server import HomeServer
from synapse.storage.databases.main.appservice import _make_exclusive_regex
from synapse.types import JsonDict, UserID
@ -856,7 +856,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))
device_handler = self.hs.get_device_handler()
assert isinstance(device_handler, DeviceHandler)
assert isinstance(device_handler, DeviceWriterHandler)
e = self.get_failure(
device_handler.check_device_registered(
user_id=local_user,

View File

@ -28,7 +28,7 @@ from synapse.api.constants import EduTypes, EventTypes
from synapse.api.errors import NotFoundError
from synapse.events import EventBase
from synapse.federation.units import Transaction
from synapse.handlers.device import DeviceHandler
from synapse.handlers.device import DeviceWriterHandler
from synapse.handlers.presence import UserPresenceState
from synapse.handlers.push_rules import InvalidRuleException
from synapse.module_api import ModuleApi
@ -819,7 +819,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
# Delete the device.
device_handler = self.hs.get_device_handler()
assert isinstance(device_handler, DeviceHandler)
assert isinstance(device_handler, DeviceWriterHandler)
self.get_success(device_handler.delete_devices(user_id, [device_id]))
# Check that the callback was called and the pushers still existed.

View File

@ -26,7 +26,7 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.api.errors import Codes
from synapse.handlers.device import DeviceHandler
from synapse.handlers.device import DeviceWriterHandler
from synapse.rest.client import devices, login
from synapse.server import HomeServer
from synapse.util import Clock
@ -42,7 +42,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler)
assert isinstance(handler, DeviceWriterHandler)
self.handler = handler
self.admin_user = self.register_user("admin", "pass", admin=True)