mirror of
https://github.com/element-hq/synapse.git
synced 2025-11-15 00:02:05 -05:00
Move device changes off the main process (#18581)
The main goal of this PR is to handle device list changes onto multiple writers, off the main process, so that we can have logins happening whilst Synapse is rolling-restarting. This is quite an intrusive change, so I would advise to review this commit by commit; I tried to keep the history as clean as possible. There are a few things to consider: - the `device_list_key` in stream tokens becomes a `MultiWriterStreamToken`, which has a few implications in sync and on the storage layer - we had a split between `DeviceHandler` and `DeviceWorkerHandler` for master vs. worker process. I've kept this split, but making it rather writer vs. non-writer worker, using method overrides for doing replication calls when needed - there are a few operations that need to happen on a single worker at a time. Instead of using cross-worker locks, for now I made them run on the first writer on the list --------- Co-authored-by: Eric Eastwood <erice@element.io>
This commit is contained in:
parent
66504d1144
commit
5ea2cf2484
1
changelog.d/18581.feature
Normal file
1
changelog.d/18581.feature
Normal file
@ -0,0 +1 @@
|
||||
Enable workers to write directly to the device lists stream and handle device list updates, reducing load on the main process.
|
||||
@ -54,7 +54,6 @@ if [[ -n "$SYNAPSE_COMPLEMENT_USE_WORKERS" ]]; then
|
||||
export SYNAPSE_WORKER_TYPES="\
|
||||
event_persister:2, \
|
||||
background_worker, \
|
||||
frontend_proxy, \
|
||||
event_creator, \
|
||||
user_dir, \
|
||||
media_repository, \
|
||||
@ -65,6 +64,7 @@ if [[ -n "$SYNAPSE_COMPLEMENT_USE_WORKERS" ]]; then
|
||||
client_reader, \
|
||||
appservice, \
|
||||
pusher, \
|
||||
device_lists:2, \
|
||||
stream_writers=account_data+presence+receipts+to_device+typing"
|
||||
|
||||
fi
|
||||
|
||||
@ -178,6 +178,8 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = {
|
||||
"^/_matrix/client/(api/v1|r0|v3|unstable)/login$",
|
||||
"^/_matrix/client/(api/v1|r0|v3|unstable)/account/3pid$",
|
||||
"^/_matrix/client/(api/v1|r0|v3|unstable)/account/whoami$",
|
||||
"^/_matrix/client/(api/v1|r0|v3|unstable)/devices(/|$)",
|
||||
"^/_matrix/client/(r0|v3)/delete_devices$",
|
||||
"^/_matrix/client/versions$",
|
||||
"^/_matrix/client/(api/v1|r0|v3|unstable)/voip/turnServer$",
|
||||
"^/_matrix/client/(r0|v3|unstable)/register$",
|
||||
@ -194,6 +196,9 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = {
|
||||
"^/_matrix/client/(api/v1|r0|v3|unstable)/directory/room/.*$",
|
||||
"^/_matrix/client/(r0|v3|unstable)/capabilities$",
|
||||
"^/_matrix/client/(r0|v3|unstable)/notifications$",
|
||||
"^/_matrix/client/(api/v1|r0|v3|unstable)/keys/upload",
|
||||
"^/_matrix/client/(api/v1|r0|v3|unstable)/keys/device_signing/upload$",
|
||||
"^/_matrix/client/(api/v1|r0|v3|unstable)/keys/signatures/upload$",
|
||||
],
|
||||
"shared_extra_conf": {},
|
||||
"worker_extra_conf": "",
|
||||
@ -265,13 +270,6 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = {
|
||||
"shared_extra_conf": {},
|
||||
"worker_extra_conf": "",
|
||||
},
|
||||
"frontend_proxy": {
|
||||
"app": "synapse.app.generic_worker",
|
||||
"listener_resources": ["client", "replication"],
|
||||
"endpoint_patterns": ["^/_matrix/client/(api/v1|r0|v3|unstable)/keys/upload"],
|
||||
"shared_extra_conf": {},
|
||||
"worker_extra_conf": "",
|
||||
},
|
||||
"account_data": {
|
||||
"app": "synapse.app.generic_worker",
|
||||
"listener_resources": ["client", "replication"],
|
||||
@ -306,6 +304,13 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = {
|
||||
"shared_extra_conf": {},
|
||||
"worker_extra_conf": "",
|
||||
},
|
||||
"device_lists": {
|
||||
"app": "synapse.app.generic_worker",
|
||||
"listener_resources": ["client", "replication"],
|
||||
"endpoint_patterns": [],
|
||||
"shared_extra_conf": {},
|
||||
"worker_extra_conf": "",
|
||||
},
|
||||
"typing": {
|
||||
"app": "synapse.app.generic_worker",
|
||||
"listener_resources": ["client", "replication"],
|
||||
@ -412,16 +417,17 @@ def add_worker_roles_to_shared_config(
|
||||
# streams
|
||||
instance_map = shared_config.setdefault("instance_map", {})
|
||||
|
||||
# This is a list of the stream_writers that there can be only one of. Events can be
|
||||
# sharded, and therefore doesn't belong here.
|
||||
singular_stream_writers = [
|
||||
# This is a list of the stream_writers.
|
||||
stream_writers = {
|
||||
"account_data",
|
||||
"events",
|
||||
"device_lists",
|
||||
"presence",
|
||||
"receipts",
|
||||
"to_device",
|
||||
"typing",
|
||||
"push_rules",
|
||||
]
|
||||
}
|
||||
|
||||
# Worker-type specific sharding config. Now a single worker can fulfill multiple
|
||||
# roles, check each.
|
||||
@ -431,28 +437,11 @@ def add_worker_roles_to_shared_config(
|
||||
if "federation_sender" in worker_types_set:
|
||||
shared_config.setdefault("federation_sender_instances", []).append(worker_name)
|
||||
|
||||
if "event_persister" in worker_types_set:
|
||||
# Event persisters write to the events stream, so we need to update
|
||||
# the list of event stream writers
|
||||
shared_config.setdefault("stream_writers", {}).setdefault("events", []).append(
|
||||
worker_name
|
||||
)
|
||||
|
||||
# Map of stream writer instance names to host/ports combos
|
||||
if os.environ.get("SYNAPSE_USE_UNIX_SOCKET", False):
|
||||
instance_map[worker_name] = {
|
||||
"path": f"/run/worker.{worker_port}",
|
||||
}
|
||||
else:
|
||||
instance_map[worker_name] = {
|
||||
"host": "localhost",
|
||||
"port": worker_port,
|
||||
}
|
||||
# Update the list of stream writers. It's convenient that the name of the worker
|
||||
# type is the same as the stream to write. Iterate over the whole list in case there
|
||||
# is more than one.
|
||||
for worker in worker_types_set:
|
||||
if worker in singular_stream_writers:
|
||||
if worker in stream_writers:
|
||||
shared_config.setdefault("stream_writers", {}).setdefault(
|
||||
worker, []
|
||||
).append(worker_name)
|
||||
@ -876,6 +865,13 @@ def generate_worker_files(
|
||||
else:
|
||||
healthcheck_urls.append("http://localhost:%d/health" % (worker_port,))
|
||||
|
||||
# Special case for event_persister: those are just workers that write to
|
||||
# the `events` stream. For other workers, the worker name is the same
|
||||
# name of the stream they write to, but for some reason it is not the
|
||||
# case for event_persister.
|
||||
if "event_persister" in worker_types_set:
|
||||
worker_types_set.add("events")
|
||||
|
||||
# Update the shared config with sharding-related options if necessary
|
||||
add_worker_roles_to_shared_config(
|
||||
shared_config, worker_types_set, worker_name, worker_port
|
||||
|
||||
@ -4341,6 +4341,8 @@ This setting has the following sub-options:
|
||||
|
||||
* `push_rules` (string): Name of a worker assigned to the `push_rules` stream.
|
||||
|
||||
* `device_lists` (string): Name of a worker assigned to the `device_lists` stream.
|
||||
|
||||
Example configuration:
|
||||
```yaml
|
||||
stream_writers:
|
||||
|
||||
@ -238,7 +238,8 @@ information.
|
||||
^/_matrix/client/unstable/im.nheko.summary/summary/.*$
|
||||
^/_matrix/client/(r0|v3|unstable)/account/3pid$
|
||||
^/_matrix/client/(r0|v3|unstable)/account/whoami$
|
||||
^/_matrix/client/(r0|v3|unstable)/devices$
|
||||
^/_matrix/client/(r0|v3)/delete_devices$
|
||||
^/_matrix/client/(api/v1|r0|v3|unstable)/devices(/|$)
|
||||
^/_matrix/client/versions$
|
||||
^/_matrix/client/(api/v1|r0|v3|unstable)/voip/turnServer$
|
||||
^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/event/
|
||||
@ -257,7 +258,9 @@ information.
|
||||
^/_matrix/client/(r0|v3|unstable)/keys/changes$
|
||||
^/_matrix/client/(r0|v3|unstable)/keys/claim$
|
||||
^/_matrix/client/(r0|v3|unstable)/room_keys/
|
||||
^/_matrix/client/(r0|v3|unstable)/keys/upload$
|
||||
^/_matrix/client/(r0|v3|unstable)/keys/upload
|
||||
^/_matrix/client/(api/v1|r0|v3|unstable/keys/device_signing/upload$
|
||||
^/_matrix/client/(api/v1|r0|v3|unstable)/keys/signatures/upload$
|
||||
|
||||
# Registration/login requests
|
||||
^/_matrix/client/(api/v1|r0|v3|unstable)/login$
|
||||
@ -282,7 +285,6 @@ Additionally, the following REST endpoints can be handled for GET requests:
|
||||
|
||||
^/_matrix/client/(api/v1|r0|v3|unstable)/pushrules/
|
||||
^/_matrix/client/unstable/org.matrix.msc4140/delayed_events
|
||||
^/_matrix/client/(api/v1|r0|v3|unstable)/devices/
|
||||
|
||||
# Account data requests
|
||||
^/_matrix/client/(r0|v3|unstable)/.*/tags
|
||||
@ -329,7 +331,6 @@ set to `true`), the following endpoints can be handled by the worker:
|
||||
^/_synapse/admin/v2/users/[^/]+$
|
||||
^/_synapse/admin/v1/username_available$
|
||||
^/_synapse/admin/v1/users/[^/]+/_allow_cross_signing_replacement_without_uia$
|
||||
# Only the GET method:
|
||||
^/_synapse/admin/v1/users/[^/]+/devices$
|
||||
|
||||
Note that a [HTTP listener](usage/configuration/config_documentation.md#listeners)
|
||||
@ -550,6 +551,18 @@ the stream writer for the `push_rules` stream:
|
||||
|
||||
^/_matrix/client/(api/v1|r0|v3|unstable)/pushrules/
|
||||
|
||||
##### The `device_lists` stream
|
||||
|
||||
The `device_lists` stream supports multiple writers. The following endpoints
|
||||
can be handled by any worker, but should be routed directly one of the workers
|
||||
configured as stream writer for the `device_lists` stream:
|
||||
|
||||
^/_matrix/client/(r0|v3)/delete_devices$
|
||||
^/_matrix/client/(api/v1|r0|v3|unstable)/devices/
|
||||
^/_matrix/client/(r0|v3|unstable)/keys/upload
|
||||
^/_matrix/client/(api/v1|r0|v3|unstable/keys/device_signing/upload$
|
||||
^/_matrix/client/(api/v1|r0|v3|unstable)/keys/signatures/upload$
|
||||
|
||||
#### Restrict outbound federation traffic to a specific set of workers
|
||||
|
||||
The
|
||||
|
||||
@ -5383,6 +5383,9 @@ properties:
|
||||
push_rules:
|
||||
type: string
|
||||
description: Name of a worker assigned to the `push_rules` stream.
|
||||
device_lists:
|
||||
type: string
|
||||
description: Name of a worker assigned to the `device_lists` stream.
|
||||
default: {}
|
||||
examples:
|
||||
- events: worker1
|
||||
|
||||
@ -134,40 +134,44 @@ class WriterLocations:
|
||||
can only be a single instance.
|
||||
account_data: The instances that write to the account data streams. Currently
|
||||
can only be a single instance.
|
||||
receipts: The instances that write to the receipts stream. Currently
|
||||
can only be a single instance.
|
||||
receipts: The instances that write to the receipts stream.
|
||||
presence: The instances that write to the presence stream. Currently
|
||||
can only be a single instance.
|
||||
push_rules: The instances that write to the push stream. Currently
|
||||
can only be a single instance.
|
||||
device_lists: The instances that write to the device list stream.
|
||||
"""
|
||||
|
||||
events: List[str] = attr.ib(
|
||||
default=["master"],
|
||||
default=[MAIN_PROCESS_INSTANCE_NAME],
|
||||
converter=_instance_to_list_converter,
|
||||
)
|
||||
typing: List[str] = attr.ib(
|
||||
default=["master"],
|
||||
default=[MAIN_PROCESS_INSTANCE_NAME],
|
||||
converter=_instance_to_list_converter,
|
||||
)
|
||||
to_device: List[str] = attr.ib(
|
||||
default=["master"],
|
||||
default=[MAIN_PROCESS_INSTANCE_NAME],
|
||||
converter=_instance_to_list_converter,
|
||||
)
|
||||
account_data: List[str] = attr.ib(
|
||||
default=["master"],
|
||||
default=[MAIN_PROCESS_INSTANCE_NAME],
|
||||
converter=_instance_to_list_converter,
|
||||
)
|
||||
receipts: List[str] = attr.ib(
|
||||
default=["master"],
|
||||
default=[MAIN_PROCESS_INSTANCE_NAME],
|
||||
converter=_instance_to_list_converter,
|
||||
)
|
||||
presence: List[str] = attr.ib(
|
||||
default=["master"],
|
||||
default=[MAIN_PROCESS_INSTANCE_NAME],
|
||||
converter=_instance_to_list_converter,
|
||||
)
|
||||
push_rules: List[str] = attr.ib(
|
||||
default=["master"],
|
||||
default=[MAIN_PROCESS_INSTANCE_NAME],
|
||||
converter=_instance_to_list_converter,
|
||||
)
|
||||
device_lists: List[str] = attr.ib(
|
||||
default=[MAIN_PROCESS_INSTANCE_NAME],
|
||||
converter=_instance_to_list_converter,
|
||||
)
|
||||
|
||||
@ -358,7 +362,10 @@ class WorkerConfig(Config):
|
||||
):
|
||||
instances = _instance_to_list_converter(getattr(self.writers, stream))
|
||||
for instance in instances:
|
||||
if instance != "master" and instance not in self.instance_map:
|
||||
if (
|
||||
instance != MAIN_PROCESS_INSTANCE_NAME
|
||||
and instance not in self.instance_map
|
||||
):
|
||||
raise ConfigError(
|
||||
"Instance %r is configured to write %s but does not appear in `instance_map` config."
|
||||
% (instance, stream)
|
||||
@ -397,6 +404,11 @@ class WorkerConfig(Config):
|
||||
"Must only specify one instance to handle `push` messages."
|
||||
)
|
||||
|
||||
if len(self.writers.device_lists) == 0:
|
||||
raise ConfigError(
|
||||
"Must specify at least one instance to handle `device_lists` messages."
|
||||
)
|
||||
|
||||
self.events_shard_config = RoutableShardedWorkerHandlingConfig(
|
||||
self.writers.events
|
||||
)
|
||||
@ -419,9 +431,12 @@ class WorkerConfig(Config):
|
||||
#
|
||||
# No effort is made to ensure only a single instance of these tasks is
|
||||
# running.
|
||||
background_tasks_instance = config.get("run_background_tasks_on") or "master"
|
||||
background_tasks_instance = (
|
||||
config.get("run_background_tasks_on") or MAIN_PROCESS_INSTANCE_NAME
|
||||
)
|
||||
self.run_background_tasks = (
|
||||
self.worker_name is None and background_tasks_instance == "master"
|
||||
self.worker_name is None
|
||||
and background_tasks_instance == MAIN_PROCESS_INSTANCE_NAME
|
||||
) or self.worker_name == background_tasks_instance
|
||||
|
||||
self.should_notify_appservices = self._should_this_worker_perform_duty(
|
||||
@ -493,9 +508,10 @@ class WorkerConfig(Config):
|
||||
# 'don't run here'.
|
||||
new_option_should_run_here = None
|
||||
if new_option_name in config:
|
||||
designated_worker = config[new_option_name] or "master"
|
||||
designated_worker = config[new_option_name] or MAIN_PROCESS_INSTANCE_NAME
|
||||
new_option_should_run_here = (
|
||||
designated_worker == "master" and self.worker_name is None
|
||||
designated_worker == MAIN_PROCESS_INSTANCE_NAME
|
||||
and self.worker_name is None
|
||||
) or designated_worker == self.worker_name
|
||||
|
||||
legacy_option_should_run_here = None
|
||||
@ -592,7 +608,7 @@ class WorkerConfig(Config):
|
||||
# If no worker instances are set we check if the legacy option
|
||||
# is set, which means use the main process.
|
||||
if legacy_option:
|
||||
worker_instances = ["master"]
|
||||
worker_instances = [MAIN_PROCESS_INSTANCE_NAME]
|
||||
|
||||
if self.worker_app == legacy_app_name:
|
||||
if legacy_option:
|
||||
|
||||
@ -638,7 +638,8 @@ class ApplicationServicesHandler:
|
||||
|
||||
# Fetch the users who have modified their device list since then.
|
||||
users_with_changed_device_lists = await self.store.get_all_devices_changed(
|
||||
from_key, to_key=new_key
|
||||
MultiWriterStreamToken(stream=from_key),
|
||||
to_key=MultiWriterStreamToken(stream=new_key),
|
||||
)
|
||||
|
||||
# Filter out any users the application service is not interested in
|
||||
|
||||
@ -24,7 +24,6 @@ from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from synapse.api.constants import Membership
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.handlers.device import DeviceHandler
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.types import Codes, Requester, UserID, create_requester
|
||||
|
||||
@ -84,10 +83,6 @@ class DeactivateAccountHandler:
|
||||
Returns:
|
||||
True if identity server supports removing threepids, otherwise False.
|
||||
"""
|
||||
|
||||
# This can only be called on the main process.
|
||||
assert isinstance(self._device_handler, DeviceHandler)
|
||||
|
||||
# Check if this user can be deactivated
|
||||
if not await self._third_party_rules.check_can_deactivate_user(
|
||||
user_id, by_admin
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -33,9 +33,6 @@ from synapse.logging.opentracing import (
|
||||
log_kv,
|
||||
set_tag,
|
||||
)
|
||||
from synapse.replication.http.devices import (
|
||||
ReplicationMultiUserDevicesResyncRestServlet,
|
||||
)
|
||||
from synapse.types import JsonDict, Requester, StreamKeyType, UserID, get_domain_from_id
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.stringutils import random_string
|
||||
@ -56,9 +53,9 @@ class DeviceMessageHandler:
|
||||
self.store = hs.get_datastores().main
|
||||
self.notifier = hs.get_notifier()
|
||||
self.is_mine = hs.is_mine
|
||||
self.device_handler = hs.get_device_handler()
|
||||
if hs.config.experimental.msc3814_enabled:
|
||||
self.event_sources = hs.get_event_sources()
|
||||
self.device_handler = hs.get_device_handler()
|
||||
|
||||
# We only need to poke the federation sender explicitly if its on the
|
||||
# same instance. Other federation sender instances will get notified by
|
||||
@ -80,18 +77,6 @@ class DeviceMessageHandler:
|
||||
hs.config.worker.writers.to_device,
|
||||
)
|
||||
|
||||
# The handler to call when we think a user's device list might be out of
|
||||
# sync. We do all device list resyncing on the master instance, so if
|
||||
# we're on a worker we hit the device resync replication API.
|
||||
if hs.config.worker.worker_app is None:
|
||||
self._multi_user_device_resync = (
|
||||
hs.get_device_handler().device_list_updater.multi_user_device_resync
|
||||
)
|
||||
else:
|
||||
self._multi_user_device_resync = (
|
||||
ReplicationMultiUserDevicesResyncRestServlet.make_client(hs)
|
||||
)
|
||||
|
||||
# a rate limiter for room key requests. The keys are
|
||||
# (sending_user_id, sending_device_id).
|
||||
self._ratelimiter = Ratelimiter(
|
||||
@ -213,7 +198,10 @@ class DeviceMessageHandler:
|
||||
await self.store.mark_remote_users_device_caches_as_stale((sender_user_id,))
|
||||
|
||||
# Immediately attempt a resync in the background
|
||||
run_in_background(self._multi_user_device_resync, user_ids=[sender_user_id])
|
||||
run_in_background(
|
||||
self.device_handler.device_list_updater.multi_user_device_resync,
|
||||
user_ids=[sender_user_id],
|
||||
)
|
||||
|
||||
async def send_device_message(
|
||||
self,
|
||||
|
||||
@ -32,10 +32,9 @@ from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EduTypes
|
||||
from synapse.api.errors import CodeMessageException, Codes, NotFoundError, SynapseError
|
||||
from synapse.handlers.device import DeviceHandler
|
||||
from synapse.handlers.device import DeviceWriterHandler
|
||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||
from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
|
||||
from synapse.replication.http.devices import ReplicationUploadKeysForUserRestServlet
|
||||
from synapse.types import (
|
||||
JsonDict,
|
||||
JsonMapping,
|
||||
@ -76,8 +75,10 @@ class E2eKeysHandler:
|
||||
|
||||
federation_registry = hs.get_federation_registry()
|
||||
|
||||
is_master = hs.config.worker.worker_app is None
|
||||
if is_master:
|
||||
# Only the first writer in the list should handle EDUs for signing key
|
||||
# updates, so that we can use an in-memory linearizer instead of worker locks.
|
||||
edu_writer = hs.config.worker.writers.device_lists[0]
|
||||
if hs.get_instance_name() == edu_writer:
|
||||
edu_updater = SigningKeyEduUpdater(hs)
|
||||
|
||||
# Only register this edu handler on master as it requires writing
|
||||
@ -92,11 +93,14 @@ class E2eKeysHandler:
|
||||
EduTypes.UNSTABLE_SIGNING_KEY_UPDATE,
|
||||
edu_updater.incoming_signing_key_update,
|
||||
)
|
||||
|
||||
self.device_key_uploader = self.upload_device_keys_for_user
|
||||
else:
|
||||
self.device_key_uploader = (
|
||||
ReplicationUploadKeysForUserRestServlet.make_client(hs)
|
||||
federation_registry.register_instances_for_edu(
|
||||
EduTypes.SIGNING_KEY_UPDATE,
|
||||
[edu_writer],
|
||||
)
|
||||
federation_registry.register_instances_for_edu(
|
||||
EduTypes.UNSTABLE_SIGNING_KEY_UPDATE,
|
||||
[edu_writer],
|
||||
)
|
||||
|
||||
# doesn't really work as part of the generic query API, because the
|
||||
@ -847,7 +851,7 @@ class E2eKeysHandler:
|
||||
# TODO: Validate the JSON to make sure it has the right keys.
|
||||
device_keys = keys.get("device_keys", None)
|
||||
if device_keys:
|
||||
await self.device_key_uploader(
|
||||
await self.upload_device_keys_for_user(
|
||||
user_id=user_id,
|
||||
device_id=device_id,
|
||||
keys={"device_keys": device_keys},
|
||||
@ -904,9 +908,6 @@ class E2eKeysHandler:
|
||||
device_keys: the `device_keys` of an /keys/upload request.
|
||||
|
||||
"""
|
||||
# This can only be called from the main process.
|
||||
assert isinstance(self.device_handler, DeviceHandler)
|
||||
|
||||
time_now = self.clock.time_msec()
|
||||
|
||||
device_keys = keys["device_keys"]
|
||||
@ -998,9 +999,6 @@ class E2eKeysHandler:
|
||||
user_id: the user uploading the keys
|
||||
keys: the signing keys
|
||||
"""
|
||||
# This can only be called from the main process.
|
||||
assert isinstance(self.device_handler, DeviceHandler)
|
||||
|
||||
# if a master key is uploaded, then check it. Otherwise, load the
|
||||
# stored master key, to check signatures on other keys
|
||||
if "master_key" in keys:
|
||||
@ -1091,9 +1089,6 @@ class E2eKeysHandler:
|
||||
Raises:
|
||||
SynapseError: if the signatures dict is not valid.
|
||||
"""
|
||||
# This can only be called from the main process.
|
||||
assert isinstance(self.device_handler, DeviceHandler)
|
||||
|
||||
failures = {}
|
||||
|
||||
# signatures to be stored. Each item will be a SignatureListItem
|
||||
@ -1467,9 +1462,6 @@ class E2eKeysHandler:
|
||||
A tuple of the retrieved key content, the key's ID and the matching VerifyKey.
|
||||
If the key cannot be retrieved, all values in the tuple will instead be None.
|
||||
"""
|
||||
# This can only be called from the main process.
|
||||
assert isinstance(self.device_handler, DeviceHandler)
|
||||
|
||||
try:
|
||||
remote_result = await self.federation.query_user_devices(
|
||||
user.domain, user.to_string()
|
||||
@ -1770,7 +1762,7 @@ class SigningKeyEduUpdater:
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
device_handler = hs.get_device_handler()
|
||||
assert isinstance(device_handler, DeviceHandler)
|
||||
assert isinstance(device_handler, DeviceWriterHandler)
|
||||
self._device_handler = device_handler
|
||||
|
||||
self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
|
||||
|
||||
@ -698,10 +698,19 @@ class FederationHandler:
|
||||
# We may want to reset the partial state info if it's from an
|
||||
# old, failed partial state join.
|
||||
# https://github.com/matrix-org/synapse/issues/13000
|
||||
|
||||
# FIXME: Ideally, we would store the full stream token here
|
||||
# not just the minimum stream ID, so that we can compute an
|
||||
# accurate list of device changes when un-partial-ing the
|
||||
# room. The only side effect of this is that we may send
|
||||
# extra unecessary device list outbound pokes through
|
||||
# federation, which is harmless.
|
||||
device_lists_stream_id = self.store.get_device_stream_token().stream
|
||||
|
||||
await self.store.store_partial_state_room(
|
||||
room_id=room_id,
|
||||
servers=ret.servers_in_room,
|
||||
device_lists_stream_id=self.store.get_device_stream_token(),
|
||||
device_lists_stream_id=device_lists_stream_id,
|
||||
joined_via=origin,
|
||||
)
|
||||
|
||||
|
||||
@ -77,9 +77,6 @@ from synapse.logging.opentracing import (
|
||||
trace,
|
||||
)
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.replication.http.devices import (
|
||||
ReplicationMultiUserDevicesResyncRestServlet,
|
||||
)
|
||||
from synapse.replication.http.federation import (
|
||||
ReplicationFederationSendEventsRestServlet,
|
||||
)
|
||||
@ -180,11 +177,6 @@ class FederationEventHandler:
|
||||
self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages
|
||||
|
||||
self._send_events = ReplicationFederationSendEventsRestServlet.make_client(hs)
|
||||
if hs.config.worker.worker_app:
|
||||
self._multi_user_device_resync = (
|
||||
ReplicationMultiUserDevicesResyncRestServlet.make_client(hs)
|
||||
)
|
||||
else:
|
||||
self._device_list_updater = hs.get_device_handler().device_list_updater
|
||||
|
||||
# When joining a room we need to queue any events for that room up.
|
||||
@ -1544,12 +1536,7 @@ class FederationEventHandler:
|
||||
await self._store.mark_remote_users_device_caches_as_stale((sender,))
|
||||
|
||||
# Immediately attempt a resync in the background
|
||||
if self._config.worker.worker_app:
|
||||
await self._multi_user_device_resync(user_ids=[sender])
|
||||
else:
|
||||
await self._device_list_updater.multi_user_device_resync(
|
||||
user_ids=[sender]
|
||||
)
|
||||
await self._device_list_updater.multi_user_device_resync(user_ids=[sender])
|
||||
except Exception:
|
||||
logger.exception("Failed to resync device for %s", sender)
|
||||
|
||||
|
||||
@ -44,7 +44,6 @@ from synapse.api.errors import (
|
||||
)
|
||||
from synapse.appservice import ApplicationService
|
||||
from synapse.config.server import is_threepid_reserved
|
||||
from synapse.handlers.device import DeviceHandler
|
||||
from synapse.http.servlet import assert_params_in_dict
|
||||
from synapse.replication.http.login import RegisterDeviceReplicationServlet
|
||||
from synapse.replication.http.register import (
|
||||
@ -840,9 +839,6 @@ class RegistrationHandler:
|
||||
refresh_token = None
|
||||
refresh_token_id = None
|
||||
|
||||
# This can only run on the main process.
|
||||
assert isinstance(self.device_handler, DeviceHandler)
|
||||
|
||||
registered_device_id = await self.device_handler.check_device_registered(
|
||||
user_id,
|
||||
device_id,
|
||||
|
||||
@ -21,7 +21,6 @@ import logging
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from synapse.api.errors import Codes, StoreError, SynapseError
|
||||
from synapse.handlers.device import DeviceHandler
|
||||
from synapse.types import Requester
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -36,17 +35,7 @@ class SetPasswordHandler:
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastores().main
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
|
||||
# We don't need the device handler if password changing is disabled.
|
||||
# This allows us to instantiate the SetPasswordHandler on the workers
|
||||
# that have admin APIs for MAS
|
||||
if self._auth_handler.can_change_password():
|
||||
# This can only be instantiated on the main process.
|
||||
device_handler = hs.get_device_handler()
|
||||
assert isinstance(device_handler, DeviceHandler)
|
||||
self._device_handler: Optional[DeviceHandler] = device_handler
|
||||
else:
|
||||
self._device_handler = None
|
||||
self._device_handler = hs.get_device_handler()
|
||||
|
||||
async def set_password(
|
||||
self,
|
||||
@ -58,9 +47,6 @@ class SetPasswordHandler:
|
||||
if not self._auth_handler.can_change_password():
|
||||
raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN)
|
||||
|
||||
# We should have this available only if password changing is enabled.
|
||||
assert self._device_handler is not None
|
||||
|
||||
try:
|
||||
await self.store.user_set_password_hash(user_id, password_hash)
|
||||
except StoreError as e:
|
||||
|
||||
@ -46,7 +46,6 @@ from twisted.web.server import Request
|
||||
from synapse.api.constants import LoginType, ProfileFields
|
||||
from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
|
||||
from synapse.config.sso import SsoAttributeRequirement
|
||||
from synapse.handlers.device import DeviceHandler
|
||||
from synapse.handlers.register import init_counters_for_auth_provider
|
||||
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
|
||||
from synapse.http import get_request_user_agent
|
||||
@ -1181,8 +1180,6 @@ class SsoHandler:
|
||||
) -> None:
|
||||
"""Revoke any devices and in-flight logins tied to a provider session.
|
||||
|
||||
Can only be called from the main process.
|
||||
|
||||
Args:
|
||||
auth_provider_id: A unique identifier for this SSO provider, e.g.
|
||||
"oidc" or "saml".
|
||||
@ -1191,11 +1188,6 @@ class SsoHandler:
|
||||
sessions belonging to other users and log an error.
|
||||
"""
|
||||
|
||||
# It is expected that this is the main process.
|
||||
assert isinstance(self._device_handler, DeviceHandler), (
|
||||
"revoking SSO sessions can only be called on the main process"
|
||||
)
|
||||
|
||||
# Invalidate any running user-mapping sessions
|
||||
to_delete = []
|
||||
for session_id, session in self._username_mapping_sessions.items():
|
||||
|
||||
@ -66,7 +66,6 @@ from synapse.handlers.auth import (
|
||||
ON_LOGGED_OUT_CALLBACK,
|
||||
AuthHandler,
|
||||
)
|
||||
from synapse.handlers.device import DeviceHandler
|
||||
from synapse.handlers.push_rules import RuleSpec, check_actions
|
||||
from synapse.http.client import SimpleHttpClient
|
||||
from synapse.http.server import (
|
||||
@ -925,8 +924,6 @@ class ModuleApi:
|
||||
) -> Generator["defer.Deferred[Any]", Any, None]:
|
||||
"""Invalidate an access token for a user
|
||||
|
||||
Can only be called from the main process.
|
||||
|
||||
Added in Synapse v0.25.0.
|
||||
|
||||
Args:
|
||||
@ -939,10 +936,6 @@ class ModuleApi:
|
||||
Raises:
|
||||
synapse.api.errors.AuthError: the access token is invalid
|
||||
"""
|
||||
assert isinstance(self._device_handler, DeviceHandler), (
|
||||
"invalidate_access_token can only be called on the main process"
|
||||
)
|
||||
|
||||
# see if the access token corresponds to a device
|
||||
user_info = yield defer.ensureDeferred(
|
||||
self._auth.get_user_by_access_token(access_token)
|
||||
|
||||
@ -59,10 +59,10 @@ class ReplicationRestResource(JsonResource):
|
||||
account_data.register_servlets(hs, self)
|
||||
push.register_servlets(hs, self)
|
||||
state.register_servlets(hs, self)
|
||||
devices.register_servlets(hs, self)
|
||||
|
||||
# The following can't currently be instantiated on workers.
|
||||
if hs.config.worker.worker_app is None:
|
||||
login.register_servlets(hs, self)
|
||||
register.register_servlets(hs, self)
|
||||
devices.register_servlets(hs, self)
|
||||
delayed_events.register_servlets(hs, self)
|
||||
|
||||
@ -34,6 +34,92 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReplicationNotifyDeviceUpdateRestServlet(ReplicationEndpoint):
|
||||
"""Notify a device writer that a user's device list has changed.
|
||||
|
||||
Request format:
|
||||
|
||||
POST /_synapse/replication/notify_device_update/:user_id
|
||||
|
||||
{
|
||||
"device_ids": ["JLAFKJWSCS", "JLAFKJWSCS"]
|
||||
}
|
||||
"""
|
||||
|
||||
NAME = "notify_device_update"
|
||||
PATH_ARGS = ("user_id",)
|
||||
CACHE = False
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
|
||||
self.device_handler = hs.get_device_handler()
|
||||
self.store = hs.get_datastores().main
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
@staticmethod
|
||||
async def _serialize_payload( # type: ignore[override]
|
||||
user_id: str, device_ids: List[str]
|
||||
) -> JsonDict:
|
||||
return {"device_ids": device_ids}
|
||||
|
||||
async def _handle_request( # type: ignore[override]
|
||||
self, request: Request, content: JsonDict, user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
device_ids = content["device_ids"]
|
||||
|
||||
span = active_span()
|
||||
if span:
|
||||
span.set_tag("user_id", user_id)
|
||||
span.set_tag("device_ids", f"{device_ids!r}")
|
||||
|
||||
await self.device_handler.notify_device_update(user_id, device_ids)
|
||||
|
||||
return 200, {}
|
||||
|
||||
|
||||
class ReplicationNotifyUserSignatureUpdateRestServlet(ReplicationEndpoint):
|
||||
"""Notify a device writer that a user have made new signatures of other users.
|
||||
|
||||
Request format:
|
||||
|
||||
POST /_synapse/replication/notify_user_signature_update/:from_user_id
|
||||
|
||||
{
|
||||
"user_ids": ["@alice:example.org", "@bob:example.org", ...]
|
||||
}
|
||||
"""
|
||||
|
||||
NAME = "notify_user_signature_update"
|
||||
PATH_ARGS = ("from_user_id",)
|
||||
CACHE = False
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
|
||||
self.device_handler = hs.get_device_handler()
|
||||
self.store = hs.get_datastores().main
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
@staticmethod
|
||||
async def _serialize_payload(from_user_id: str, user_ids: List[str]) -> JsonDict: # type: ignore[override]
|
||||
return {"user_ids": user_ids}
|
||||
|
||||
async def _handle_request( # type: ignore[override]
|
||||
self, request: Request, content: JsonDict, from_user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
user_ids = content["user_ids"]
|
||||
|
||||
span = active_span()
|
||||
if span:
|
||||
span.set_tag("from_user_id", from_user_id)
|
||||
span.set_tag("user_ids", f"{user_ids!r}")
|
||||
|
||||
await self.device_handler.notify_user_signature_update(from_user_id, user_ids)
|
||||
|
||||
return 200, {}
|
||||
|
||||
|
||||
class ReplicationMultiUserDevicesResyncRestServlet(ReplicationEndpoint):
|
||||
"""Ask master to resync the device list for multiple users from the same
|
||||
remote server by contacting their server.
|
||||
@ -73,11 +159,7 @@ class ReplicationMultiUserDevicesResyncRestServlet(ReplicationEndpoint):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
|
||||
from synapse.handlers.device import DeviceHandler
|
||||
|
||||
handler = hs.get_device_handler()
|
||||
assert isinstance(handler, DeviceHandler)
|
||||
self.device_list_updater = handler.device_list_updater
|
||||
self.device_list_updater = hs.get_device_handler().device_list_updater
|
||||
|
||||
self.store = hs.get_datastores().main
|
||||
self.clock = hs.get_clock()
|
||||
@ -103,32 +185,10 @@ class ReplicationMultiUserDevicesResyncRestServlet(ReplicationEndpoint):
|
||||
return 200, multi_user_devices
|
||||
|
||||
|
||||
# FIXME(2025-07-22): Remove this on the next release, this will only get used
|
||||
# during rollout to Synapse 1.135 and can be removed after that release.
|
||||
class ReplicationUploadKeysForUserRestServlet(ReplicationEndpoint):
|
||||
"""Ask master to upload keys for the user and send them out over federation to
|
||||
update other servers.
|
||||
|
||||
For now, only the master is permitted to handle key upload requests;
|
||||
any worker can handle key query requests (since they're read-only).
|
||||
|
||||
Calls to e2e_keys_handler.upload_keys_for_user(user_id, device_id, keys) on
|
||||
the main process to accomplish this.
|
||||
|
||||
Request format for this endpoint (borrowed and expanded from KeyUploadServlet):
|
||||
|
||||
POST /_synapse/replication/upload_keys_for_user
|
||||
|
||||
{
|
||||
"user_id": "<user_id>",
|
||||
"device_id": "<device_id>",
|
||||
"keys": {
|
||||
....this part can be found in KeyUploadServlet in rest/client/keys.py....
|
||||
or as defined in https://spec.matrix.org/v1.4/client-server-api/#post_matrixclientv3keysupload
|
||||
}
|
||||
}
|
||||
|
||||
Response is equivalent to ` /_matrix/client/v3/keys/upload` found in KeyUploadServlet
|
||||
|
||||
"""
|
||||
"""Unused endpoint, kept for backwards compatibility during rollout."""
|
||||
|
||||
NAME = "upload_keys_for_user"
|
||||
PATH_ARGS = ()
|
||||
@ -165,6 +225,71 @@ class ReplicationUploadKeysForUserRestServlet(ReplicationEndpoint):
|
||||
return 200, results
|
||||
|
||||
|
||||
class ReplicationHandleNewDeviceUpdateRestServlet(ReplicationEndpoint):
|
||||
"""Wake up a device writer to send local device list changes as federation outbound pokes.
|
||||
|
||||
Request format:
|
||||
|
||||
POST /_synapse/replication/handle_new_device_update
|
||||
|
||||
{}
|
||||
"""
|
||||
|
||||
NAME = "handle_new_device_update"
|
||||
PATH_ARGS = ()
|
||||
CACHE = False
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
|
||||
self.device_handler = hs.get_device_handler()
|
||||
|
||||
@staticmethod
|
||||
async def _serialize_payload() -> JsonDict: # type: ignore[override]
|
||||
return {}
|
||||
|
||||
async def _handle_request( # type: ignore[override]
|
||||
self, request: Request, content: JsonDict
|
||||
) -> Tuple[int, JsonDict]:
|
||||
await self.device_handler.handle_new_device_update()
|
||||
return 200, {}
|
||||
|
||||
|
||||
class ReplicationDeviceHandleRoomUnPartialStated(ReplicationEndpoint):
|
||||
"""Handles sending appropriate device list updates in a room that has
|
||||
gone from partial to full state.
|
||||
|
||||
Request format:
|
||||
|
||||
POST /_synapse/replication/device_handle_room_un_partial_stated/:room_id
|
||||
|
||||
{}
|
||||
"""
|
||||
|
||||
NAME = "device_handle_room_un_partial_stated"
|
||||
PATH_ARGS = ("room_id",)
|
||||
CACHE = True
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
|
||||
self.device_handler = hs.get_device_handler()
|
||||
|
||||
@staticmethod
|
||||
async def _serialize_payload(room_id: str) -> JsonDict: # type: ignore[override]
|
||||
return {}
|
||||
|
||||
async def _handle_request( # type: ignore[override]
|
||||
self, request: Request, content: JsonDict, room_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
await self.device_handler.handle_room_un_partial_stated(room_id)
|
||||
return 200, {}
|
||||
|
||||
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
ReplicationNotifyDeviceUpdateRestServlet(hs).register(http_server)
|
||||
ReplicationNotifyUserSignatureUpdateRestServlet(hs).register(http_server)
|
||||
ReplicationMultiUserDevicesResyncRestServlet(hs).register(http_server)
|
||||
ReplicationHandleNewDeviceUpdateRestServlet(hs).register(http_server)
|
||||
ReplicationUploadKeysForUserRestServlet(hs).register(http_server)
|
||||
ReplicationDeviceHandleRoomUnPartialStated(hs).register(http_server)
|
||||
|
||||
@ -116,7 +116,11 @@ class ReplicationDataHandler:
|
||||
all_room_ids: Set[str] = set()
|
||||
if stream_name == DeviceListsStream.NAME:
|
||||
if any(not row.is_signature and not row.hosts_calculated for row in rows):
|
||||
prev_token = self.store.get_device_stream_token()
|
||||
# This only uses the minimum stream position on the device lists
|
||||
# stream, which means that we may process a device list change
|
||||
# twice in case of concurrent writes. This is fine, as this only
|
||||
# triggers cache invalidation, which is harmless if done twice.
|
||||
prev_token = self.store.get_device_stream_token().stream
|
||||
all_room_ids = await self.store.get_all_device_list_changes(
|
||||
prev_token, token
|
||||
)
|
||||
|
||||
@ -72,6 +72,7 @@ from synapse.replication.tcp.streams import (
|
||||
ToDeviceStream,
|
||||
TypingStream,
|
||||
)
|
||||
from synapse.replication.tcp.streams._base import DeviceListsStream
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
@ -185,6 +186,12 @@ class ReplicationCommandHandler:
|
||||
|
||||
continue
|
||||
|
||||
if isinstance(stream, DeviceListsStream):
|
||||
if hs.get_instance_name() in hs.config.worker.writers.device_lists:
|
||||
self._streams_to_replicate.append(stream)
|
||||
|
||||
continue
|
||||
|
||||
# Only add any other streams if we're on master.
|
||||
if hs.config.worker.worker_app is not None:
|
||||
continue
|
||||
|
||||
@ -51,7 +51,6 @@ from synapse.rest.admin.background_updates import (
|
||||
from synapse.rest.admin.devices import (
|
||||
DeleteDevicesRestServlet,
|
||||
DeviceRestServlet,
|
||||
DevicesGetRestServlet,
|
||||
DevicesRestServlet,
|
||||
)
|
||||
from synapse.rest.admin.event_reports import (
|
||||
@ -375,4 +374,5 @@ def register_servlets_for_msc3861_delegation(
|
||||
UserRestServletV2(hs).register(http_server)
|
||||
UsernameAvailableRestServlet(hs).register(http_server)
|
||||
UserReplaceMasterCrossSigningKeyRestServlet(hs).register(http_server)
|
||||
DevicesGetRestServlet(hs).register(http_server)
|
||||
DeviceRestServlet(hs).register(http_server)
|
||||
DevicesRestServlet(hs).register(http_server)
|
||||
|
||||
@ -23,7 +23,6 @@ from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.api.errors import NotFoundError, SynapseError
|
||||
from synapse.handlers.device import DeviceHandler
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
assert_params_in_dict,
|
||||
@ -51,9 +50,7 @@ class DeviceRestServlet(RestServlet):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
handler = hs.get_device_handler()
|
||||
assert isinstance(handler, DeviceHandler)
|
||||
self.device_handler = handler
|
||||
self.device_handler = hs.get_device_handler()
|
||||
self.store = hs.get_datastores().main
|
||||
self.is_mine = hs.is_mine
|
||||
|
||||
@ -113,7 +110,7 @@ class DeviceRestServlet(RestServlet):
|
||||
return HTTPStatus.OK, {}
|
||||
|
||||
|
||||
class DevicesGetRestServlet(RestServlet):
|
||||
class DevicesRestServlet(RestServlet):
|
||||
"""
|
||||
Retrieve the given user's devices
|
||||
|
||||
@ -158,19 +155,6 @@ class DevicesGetRestServlet(RestServlet):
|
||||
|
||||
return HTTPStatus.OK, {"devices": devices, "total": len(devices)}
|
||||
|
||||
|
||||
class DevicesRestServlet(DevicesGetRestServlet):
|
||||
"""
|
||||
Retrieve the given user's devices
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/devices$", "v2")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
assert isinstance(self.device_worker_handler, DeviceHandler)
|
||||
self.device_handler = self.device_worker_handler
|
||||
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
@ -194,7 +178,7 @@ class DevicesRestServlet(DevicesGetRestServlet):
|
||||
if not isinstance(device_id, str):
|
||||
raise SynapseError(HTTPStatus.BAD_REQUEST, "device_id must be a string")
|
||||
|
||||
await self.device_handler.check_device_registered(
|
||||
await self.device_worker_handler.check_device_registered(
|
||||
user_id=user_id, device_id=device_id
|
||||
)
|
||||
|
||||
@ -211,9 +195,7 @@ class DeleteDevicesRestServlet(RestServlet):
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.auth = hs.get_auth()
|
||||
handler = hs.get_device_handler()
|
||||
assert isinstance(handler, DeviceHandler)
|
||||
self.device_handler = handler
|
||||
self.device_handler = hs.get_device_handler()
|
||||
self.store = hs.get_datastores().main
|
||||
self.is_mine = hs.is_mine
|
||||
|
||||
|
||||
@ -27,7 +27,6 @@ from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
from synapse._pydantic_compat import Extra, StrictStr
|
||||
from synapse.api import errors
|
||||
from synapse.api.errors import NotFoundError, SynapseError, UnrecognizedRequestError
|
||||
from synapse.handlers.device import DeviceHandler
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
@ -91,7 +90,6 @@ class DeleteDevicesRestServlet(RestServlet):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
handler = hs.get_device_handler()
|
||||
assert isinstance(handler, DeviceHandler)
|
||||
self.device_handler = handler
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
|
||||
@ -147,7 +145,6 @@ class DeviceRestServlet(RestServlet):
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
self._msc3852_enabled = hs.config.experimental.msc3852_enabled
|
||||
self._msc3861_oauth_delegation_enabled = hs.config.experimental.msc3861.enabled
|
||||
self._is_main_process = hs.config.worker.worker_app is None
|
||||
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, device_id: str
|
||||
@ -179,14 +176,6 @@ class DeviceRestServlet(RestServlet):
|
||||
async def on_DELETE(
|
||||
self, request: SynapseRequest, device_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
# Can only be run on main process, as changes to device lists must
|
||||
# happen on main.
|
||||
if not self._is_main_process:
|
||||
error_message = "DELETE on /devices/ must be routed to main process"
|
||||
logger.error(error_message)
|
||||
raise SynapseError(500, error_message)
|
||||
assert isinstance(self.device_handler, DeviceHandler)
|
||||
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
|
||||
try:
|
||||
@ -231,14 +220,6 @@ class DeviceRestServlet(RestServlet):
|
||||
async def on_PUT(
|
||||
self, request: SynapseRequest, device_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
# Can only be run on main process, as changes to device lists must
|
||||
# happen on main.
|
||||
if not self._is_main_process:
|
||||
error_message = "PUT on /devices/ must be routed to main process"
|
||||
logger.error(error_message)
|
||||
raise SynapseError(500, error_message)
|
||||
assert isinstance(self.device_handler, DeviceHandler)
|
||||
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
||||
body = parse_and_validate_json_object_from_request(request, self.PutBody)
|
||||
@ -317,7 +298,6 @@ class DehydratedDeviceServlet(RestServlet):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
handler = hs.get_device_handler()
|
||||
assert isinstance(handler, DeviceHandler)
|
||||
self.device_handler = handler
|
||||
|
||||
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
@ -377,7 +357,6 @@ class ClaimDehydratedDeviceServlet(RestServlet):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
handler = hs.get_device_handler()
|
||||
assert isinstance(handler, DeviceHandler)
|
||||
self.device_handler = handler
|
||||
|
||||
class PostBody(RequestBodyModel):
|
||||
@ -517,7 +496,6 @@ class DehydratedDeviceV2Servlet(RestServlet):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
handler = hs.get_device_handler()
|
||||
assert isinstance(handler, DeviceHandler)
|
||||
self.e2e_keys_handler = hs.get_e2e_keys_handler()
|
||||
self.device_handler = handler
|
||||
|
||||
@ -595,15 +573,11 @@ class DehydratedDeviceV2Servlet(RestServlet):
|
||||
|
||||
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
if (
|
||||
hs.config.worker.worker_app is None
|
||||
and not hs.config.experimental.msc3861.enabled
|
||||
):
|
||||
if not hs.config.experimental.msc3861.enabled:
|
||||
DeleteDevicesRestServlet(hs).register(http_server)
|
||||
DevicesRestServlet(hs).register(http_server)
|
||||
DeviceRestServlet(hs).register(http_server)
|
||||
|
||||
if hs.config.worker.worker_app is None:
|
||||
if hs.config.experimental.msc2697_enabled:
|
||||
DehydratedDeviceServlet(hs).register(http_server)
|
||||
ClaimDehydratedDeviceServlet(hs).register(http_server)
|
||||
|
||||
@ -504,6 +504,5 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
OneTimeKeyServlet(hs).register(http_server)
|
||||
if hs.config.experimental.msc3983_appservice_otk_claims:
|
||||
UnstableOneTimeKeyServlet(hs).register(http_server)
|
||||
if hs.config.worker.worker_app is None:
|
||||
SigningKeyUploadServlet(hs).register(http_server)
|
||||
SignaturesUploadServlet(hs).register(http_server)
|
||||
|
||||
@ -22,7 +22,6 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.handlers.device import DeviceHandler
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import RestServlet
|
||||
from synapse.http.site import SynapseRequest
|
||||
@ -42,9 +41,7 @@ class LogoutRestServlet(RestServlet):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
handler = hs.get_device_handler()
|
||||
assert isinstance(handler, DeviceHandler)
|
||||
self._device_handler = handler
|
||||
self._device_handler = hs.get_device_handler()
|
||||
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(
|
||||
@ -71,9 +68,7 @@ class LogoutAllRestServlet(RestServlet):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
handler = hs.get_device_handler()
|
||||
assert isinstance(handler, DeviceHandler)
|
||||
self._device_handler = handler
|
||||
self._device_handler = hs.get_device_handler()
|
||||
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(
|
||||
|
||||
@ -69,7 +69,7 @@ from synapse.handlers.auth import AuthHandler, PasswordAuthProvider
|
||||
from synapse.handlers.cas import CasHandler
|
||||
from synapse.handlers.deactivate_account import DeactivateAccountHandler
|
||||
from synapse.handlers.delayed_events import DelayedEventsHandler
|
||||
from synapse.handlers.device import DeviceHandler, DeviceWorkerHandler
|
||||
from synapse.handlers.device import DeviceHandler, DeviceWriterHandler
|
||||
from synapse.handlers.devicemessage import DeviceMessageHandler
|
||||
from synapse.handlers.directory import DirectoryHandler
|
||||
from synapse.handlers.e2e_keys import E2eKeysHandler
|
||||
@ -586,10 +586,10 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||
)
|
||||
|
||||
@cache_in_self
|
||||
def get_device_handler(self) -> DeviceWorkerHandler:
|
||||
if self.config.worker.worker_app:
|
||||
return DeviceWorkerHandler(self)
|
||||
else:
|
||||
def get_device_handler(self) -> DeviceHandler:
|
||||
if self.get_instance_name() in self.config.worker.writers.device_lists:
|
||||
return DeviceWriterHandler(self)
|
||||
|
||||
return DeviceHandler(self)
|
||||
|
||||
@cache_in_self
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -59,7 +59,7 @@ from synapse.storage.database import (
|
||||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||
from synapse.types import JsonDict, JsonMapping
|
||||
from synapse.types import JsonDict, JsonMapping, MultiWriterStreamToken
|
||||
from synapse.util import json_decoder, json_encoder
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
from synapse.util.cancellation import cancellable
|
||||
@ -120,6 +120,20 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||
self.hs.config.federation.allow_device_name_lookup_over_federation
|
||||
)
|
||||
|
||||
self._cross_signing_id_gen = MultiWriterIdGenerator(
|
||||
db_conn=db_conn,
|
||||
db=database,
|
||||
notifier=hs.get_replication_notifier(),
|
||||
stream_name="e2e_cross_signing_keys",
|
||||
instance_name=self._instance_name,
|
||||
tables=[
|
||||
("e2e_cross_signing_keys", "instance_name", "stream_id"),
|
||||
],
|
||||
sequence_name="e2e_cross_signing_keys_sequence",
|
||||
# No one reads the stream positions, so we're allowed to have an empty list of writers
|
||||
writers=[],
|
||||
)
|
||||
|
||||
def process_replication_rows(
|
||||
self,
|
||||
stream_name: str,
|
||||
@ -145,7 +159,12 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||
Returns:
|
||||
(stream_id, devices)
|
||||
"""
|
||||
now_stream_id = self.get_device_stream_token()
|
||||
# Here, we don't use the individual instances positions, as we *need* to
|
||||
# give out the stream_id as an integer in the federation API.
|
||||
# This means that we'll potentially return the same data twice with a
|
||||
# different stream_id, and invalidate cache more often than necessary,
|
||||
# which is fine overall.
|
||||
now_stream_id = self.get_device_stream_token().stream
|
||||
|
||||
# We need to be careful with the caching here, as we need to always
|
||||
# return *all* persisted devices, however there may be a lag between a
|
||||
@ -164,8 +183,10 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||
# have to check for potential invalidations after the
|
||||
# `now_stream_id`.
|
||||
sql = """
|
||||
SELECT user_id FROM device_lists_stream
|
||||
SELECT 1
|
||||
FROM device_lists_stream
|
||||
WHERE stream_id >= ? AND user_id = ?
|
||||
LIMIT 1
|
||||
"""
|
||||
rows = await self.db_pool.execute(
|
||||
"get_e2e_device_keys_for_federation_query_check",
|
||||
@ -1117,7 +1138,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_device_stream_token(self) -> int:
|
||||
def get_device_stream_token(self) -> MultiWriterStreamToken:
|
||||
"""Get the current stream id from the _device_list_id_gen"""
|
||||
...
|
||||
|
||||
@ -1540,27 +1561,44 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||
impl,
|
||||
)
|
||||
|
||||
async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
|
||||
def delete_e2e_keys_by_device_txn(txn: LoggingTransaction) -> None:
|
||||
log_kv(
|
||||
{
|
||||
"message": "Deleting keys for device",
|
||||
"device_id": device_id,
|
||||
"user_id": user_id,
|
||||
}
|
||||
)
|
||||
self.db_pool.simple_delete_txn(
|
||||
txn,
|
||||
table="e2e_device_keys_json",
|
||||
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||
)
|
||||
self.db_pool.simple_delete_txn(
|
||||
txn,
|
||||
table="e2e_one_time_keys_json",
|
||||
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||
)
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.count_e2e_one_time_keys, (user_id, device_id)
|
||||
)
|
||||
self.db_pool.simple_delete_txn(
|
||||
txn,
|
||||
table="dehydrated_devices",
|
||||
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||
)
|
||||
self.db_pool.simple_delete_txn(
|
||||
txn,
|
||||
table="e2e_fallback_keys_json",
|
||||
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||
)
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
|
||||
)
|
||||
|
||||
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||
def __init__(
|
||||
self,
|
||||
database: DatabasePool,
|
||||
db_conn: LoggingDatabaseConnection,
|
||||
hs: "HomeServer",
|
||||
):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self._cross_signing_id_gen = MultiWriterIdGenerator(
|
||||
db_conn=db_conn,
|
||||
db=database,
|
||||
notifier=hs.get_replication_notifier(),
|
||||
stream_name="e2e_cross_signing_keys",
|
||||
instance_name=self._instance_name,
|
||||
tables=[
|
||||
("e2e_cross_signing_keys", "instance_name", "stream_id"),
|
||||
],
|
||||
sequence_name="e2e_cross_signing_keys_sequence",
|
||||
writers=["master"],
|
||||
await self.db_pool.runInteraction(
|
||||
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
|
||||
)
|
||||
|
||||
async def set_e2e_device_keys(
|
||||
@ -1754,3 +1792,13 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||
],
|
||||
desc="add_e2e_signing_key",
|
||||
)
|
||||
|
||||
|
||||
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||
def __init__(
|
||||
self,
|
||||
database: DatabasePool,
|
||||
db_conn: LoggingDatabaseConnection,
|
||||
hs: "HomeServer",
|
||||
):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
@ -36,7 +36,6 @@ from typing import (
|
||||
)
|
||||
|
||||
import attr
|
||||
from immutabledict import immutabledict
|
||||
|
||||
from synapse.api.constants import EduTypes
|
||||
from synapse.replication.tcp.streams import ReceiptsStream
|
||||
@ -167,25 +166,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
def get_max_receipt_stream_id(self) -> MultiWriterStreamToken:
|
||||
"""Get the current max stream ID for receipts stream"""
|
||||
|
||||
min_pos = self._receipts_id_gen.get_current_token()
|
||||
|
||||
positions = {}
|
||||
if isinstance(self._receipts_id_gen, MultiWriterIdGenerator):
|
||||
# The `min_pos` is the minimum position that we know all instances
|
||||
# have finished persisting to, so we only care about instances whose
|
||||
# positions are ahead of that. (Instance positions can be behind the
|
||||
# min position as there are times we can work out that the minimum
|
||||
# position is ahead of the naive minimum across all current
|
||||
# positions. See MultiWriterIdGenerator for details)
|
||||
positions = {
|
||||
i: p
|
||||
for i, p in self._receipts_id_gen.get_positions().items()
|
||||
if p > min_pos
|
||||
}
|
||||
|
||||
return MultiWriterStreamToken(
|
||||
stream=min_pos, instance_map=immutabledict(positions)
|
||||
)
|
||||
return MultiWriterStreamToken.from_generator(self._receipts_id_gen)
|
||||
|
||||
def get_receipt_stream_id_for_instance(self, instance_name: str) -> int:
|
||||
return self._receipts_id_gen.get_current_token_for_writer(instance_name)
|
||||
|
||||
@ -2093,6 +2093,58 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore):
|
||||
"replace_refresh_token", _replace_refresh_token_txn
|
||||
)
|
||||
|
||||
async def set_device_for_refresh_token(
|
||||
self, user_id: str, old_device_id: str, device_id: str
|
||||
) -> None:
|
||||
"""Moves refresh tokens from old device to current device
|
||||
|
||||
Args:
|
||||
user_id: The user of the devices.
|
||||
old_device_id: The old device.
|
||||
device_id: The new device ID.
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
await self.db_pool.simple_update(
|
||||
"refresh_tokens",
|
||||
keyvalues={"user_id": user_id, "device_id": old_device_id},
|
||||
updatevalues={"device_id": device_id},
|
||||
desc="set_device_for_refresh_token",
|
||||
)
|
||||
|
||||
def _set_device_for_access_token_txn(
|
||||
self, txn: LoggingTransaction, token: str, device_id: str
|
||||
) -> str:
|
||||
old_device_id = self.db_pool.simple_select_one_onecol_txn(
|
||||
txn, "access_tokens", {"token": token}, "device_id"
|
||||
)
|
||||
|
||||
self.db_pool.simple_update_txn(
|
||||
txn, "access_tokens", {"token": token}, {"device_id": device_id}
|
||||
)
|
||||
|
||||
self._invalidate_cache_and_stream(txn, self.get_user_by_access_token, (token,))
|
||||
|
||||
return old_device_id
|
||||
|
||||
async def set_device_for_access_token(self, token: str, device_id: str) -> str:
|
||||
"""Sets the device ID associated with an access token.
|
||||
|
||||
Args:
|
||||
token: The access token to modify.
|
||||
device_id: The new device ID.
|
||||
Returns:
|
||||
The old device ID associated with the access token.
|
||||
"""
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"set_device_for_access_token",
|
||||
self._set_device_for_access_token_txn,
|
||||
token,
|
||||
device_id,
|
||||
)
|
||||
|
||||
async def add_login_token_to_user(
|
||||
self,
|
||||
user_id: str,
|
||||
@ -2396,6 +2448,154 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore):
|
||||
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
|
||||
self._invalidate_cache_and_stream(txn, self.is_user_approved, (user_id,))
|
||||
|
||||
async def user_delete_access_tokens(
|
||||
self,
|
||||
user_id: str,
|
||||
except_token_id: Optional[int] = None,
|
||||
device_id: Optional[str] = None,
|
||||
) -> List[Tuple[str, int, Optional[str]]]:
|
||||
"""
|
||||
Invalidate access and refresh tokens belonging to a user
|
||||
|
||||
Args:
|
||||
user_id: ID of user the tokens belong to
|
||||
except_token_id: access_tokens ID which should *not* be deleted
|
||||
device_id: ID of device the tokens are associated with.
|
||||
If None, tokens associated with any device (or no device) will
|
||||
be deleted
|
||||
Returns:
|
||||
A tuple of (token, token id, device id) for each of the deleted tokens
|
||||
"""
|
||||
|
||||
def f(txn: LoggingTransaction) -> List[Tuple[str, int, Optional[str]]]:
|
||||
keyvalues = {"user_id": user_id}
|
||||
if device_id is not None:
|
||||
keyvalues["device_id"] = device_id
|
||||
|
||||
items = keyvalues.items()
|
||||
where_clause = " AND ".join(k + " = ?" for k, _ in items)
|
||||
values: List[Union[str, int]] = [v for _, v in items]
|
||||
# Conveniently, refresh_tokens and access_tokens both use the user_id and device_id fields. Only caveat
|
||||
# is the `except_token_id` param that is tricky to get right, so for now we're just using the same where
|
||||
# clause and values before we handle that. This seems to be only used in the "set password" handler.
|
||||
refresh_where_clause = where_clause
|
||||
refresh_values = values.copy()
|
||||
if except_token_id:
|
||||
# TODO: support that for refresh tokens
|
||||
where_clause += " AND id != ?"
|
||||
values.append(except_token_id)
|
||||
|
||||
txn.execute(
|
||||
"SELECT token, id, device_id FROM access_tokens WHERE %s"
|
||||
% where_clause,
|
||||
values,
|
||||
)
|
||||
tokens_and_devices = [(r[0], r[1], r[2]) for r in txn]
|
||||
|
||||
self._invalidate_cache_and_stream_bulk(
|
||||
txn,
|
||||
self.get_user_by_access_token,
|
||||
[(token,) for token, _, _ in tokens_and_devices],
|
||||
)
|
||||
|
||||
txn.execute("DELETE FROM access_tokens WHERE %s" % where_clause, values)
|
||||
|
||||
txn.execute(
|
||||
"DELETE FROM refresh_tokens WHERE %s" % refresh_where_clause,
|
||||
refresh_values,
|
||||
)
|
||||
|
||||
return tokens_and_devices
|
||||
|
||||
return await self.db_pool.runInteraction("user_delete_access_tokens", f)
|
||||
|
||||
async def user_delete_access_tokens_for_devices(
|
||||
self,
|
||||
user_id: str,
|
||||
device_ids: StrCollection,
|
||||
) -> List[Tuple[str, int, Optional[str]]]:
|
||||
"""
|
||||
Invalidate access and refresh tokens belonging to a user
|
||||
|
||||
Args:
|
||||
user_id: ID of user the tokens belong to
|
||||
device_ids: The devices to delete tokens for.
|
||||
Returns:
|
||||
A tuple of (token, token id, device id) for each of the deleted tokens
|
||||
"""
|
||||
|
||||
def user_delete_access_tokens_for_devices_txn(
|
||||
txn: LoggingTransaction, batch_device_ids: StrCollection
|
||||
) -> List[Tuple[str, int, Optional[str]]]:
|
||||
self.db_pool.simple_delete_many_txn(
|
||||
txn,
|
||||
table="refresh_tokens",
|
||||
keyvalues={"user_id": user_id},
|
||||
column="device_id",
|
||||
values=batch_device_ids,
|
||||
)
|
||||
|
||||
clause, args = make_in_list_sql_clause(
|
||||
txn.database_engine, "device_id", batch_device_ids
|
||||
)
|
||||
args.append(user_id)
|
||||
|
||||
if self.database_engine.supports_returning:
|
||||
sql = f"""
|
||||
DELETE FROM access_tokens
|
||||
WHERE {clause} AND user_id = ?
|
||||
RETURNING token, id, device_id
|
||||
"""
|
||||
txn.execute(sql, args)
|
||||
tokens_and_devices = txn.fetchall()
|
||||
else:
|
||||
tokens_and_devices = self.db_pool.simple_select_many_txn(
|
||||
txn,
|
||||
table="access_tokens",
|
||||
column="device_id",
|
||||
iterable=batch_device_ids,
|
||||
keyvalues={"user_id": user_id},
|
||||
retcols=("token", "id", "device_id"),
|
||||
)
|
||||
|
||||
self.db_pool.simple_delete_many_txn(
|
||||
txn,
|
||||
table="access_tokens",
|
||||
keyvalues={"user_id": user_id},
|
||||
column="device_id",
|
||||
values=batch_device_ids,
|
||||
)
|
||||
|
||||
self._invalidate_cache_and_stream_bulk(
|
||||
txn,
|
||||
self.get_user_by_access_token,
|
||||
[(t[0],) for t in tokens_and_devices],
|
||||
)
|
||||
return tokens_and_devices
|
||||
|
||||
results = []
|
||||
for batch_device_ids in batch_iter(device_ids, 1000):
|
||||
tokens_and_devices = await self.db_pool.runInteraction(
|
||||
"user_delete_access_tokens_for_devices",
|
||||
user_delete_access_tokens_for_devices_txn,
|
||||
batch_device_ids,
|
||||
)
|
||||
results.extend(tokens_and_devices)
|
||||
|
||||
return results
|
||||
|
||||
async def delete_access_token(self, access_token: str) -> None:
|
||||
def f(txn: LoggingTransaction) -> None:
|
||||
self.db_pool.simple_delete_one_txn(
|
||||
txn, table="access_tokens", keyvalues={"token": access_token}
|
||||
)
|
||||
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_user_by_access_token, (access_token,)
|
||||
)
|
||||
|
||||
await self.db_pool.runInteraction("delete_access_token", f)
|
||||
|
||||
|
||||
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
|
||||
def __init__(
|
||||
@ -2620,58 +2820,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
||||
|
||||
return next_id
|
||||
|
||||
async def set_device_for_refresh_token(
|
||||
self, user_id: str, old_device_id: str, device_id: str
|
||||
) -> None:
|
||||
"""Moves refresh tokens from old device to current device
|
||||
|
||||
Args:
|
||||
user_id: The user of the devices.
|
||||
old_device_id: The old device.
|
||||
device_id: The new device ID.
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
await self.db_pool.simple_update(
|
||||
"refresh_tokens",
|
||||
keyvalues={"user_id": user_id, "device_id": old_device_id},
|
||||
updatevalues={"device_id": device_id},
|
||||
desc="set_device_for_refresh_token",
|
||||
)
|
||||
|
||||
def _set_device_for_access_token_txn(
|
||||
self, txn: LoggingTransaction, token: str, device_id: str
|
||||
) -> str:
|
||||
old_device_id = self.db_pool.simple_select_one_onecol_txn(
|
||||
txn, "access_tokens", {"token": token}, "device_id"
|
||||
)
|
||||
|
||||
self.db_pool.simple_update_txn(
|
||||
txn, "access_tokens", {"token": token}, {"device_id": device_id}
|
||||
)
|
||||
|
||||
self._invalidate_cache_and_stream(txn, self.get_user_by_access_token, (token,))
|
||||
|
||||
return old_device_id
|
||||
|
||||
async def set_device_for_access_token(self, token: str, device_id: str) -> str:
|
||||
"""Sets the device ID associated with an access token.
|
||||
|
||||
Args:
|
||||
token: The access token to modify.
|
||||
device_id: The new device ID.
|
||||
Returns:
|
||||
The old device ID associated with the access token.
|
||||
"""
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"set_device_for_access_token",
|
||||
self._set_device_for_access_token_txn,
|
||||
token,
|
||||
device_id,
|
||||
)
|
||||
|
||||
async def user_set_password_hash(
|
||||
self, user_id: str, password_hash: Optional[str]
|
||||
) -> None:
|
||||
@ -2743,162 +2891,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
||||
|
||||
await self.db_pool.runInteraction("user_set_consent_server_notice_sent", f)
|
||||
|
||||
async def user_delete_access_tokens(
|
||||
self,
|
||||
user_id: str,
|
||||
except_token_id: Optional[int] = None,
|
||||
device_id: Optional[str] = None,
|
||||
) -> List[Tuple[str, int, Optional[str]]]:
|
||||
"""
|
||||
Invalidate access and refresh tokens belonging to a user
|
||||
|
||||
Args:
|
||||
user_id: ID of user the tokens belong to
|
||||
except_token_id: access_tokens ID which should *not* be deleted
|
||||
device_id: ID of device the tokens are associated with.
|
||||
If None, tokens associated with any device (or no device) will
|
||||
be deleted
|
||||
Returns:
|
||||
A tuple of (token, token id, device id) for each of the deleted tokens
|
||||
"""
|
||||
|
||||
def f(txn: LoggingTransaction) -> List[Tuple[str, int, Optional[str]]]:
|
||||
keyvalues = {"user_id": user_id}
|
||||
if device_id is not None:
|
||||
keyvalues["device_id"] = device_id
|
||||
|
||||
items = keyvalues.items()
|
||||
where_clause = " AND ".join(k + " = ?" for k, _ in items)
|
||||
values: List[Union[str, int]] = [v for _, v in items]
|
||||
# Conveniently, refresh_tokens and access_tokens both use the user_id and device_id fields. Only caveat
|
||||
# is the `except_token_id` param that is tricky to get right, so for now we're just using the same where
|
||||
# clause and values before we handle that. This seems to be only used in the "set password" handler.
|
||||
refresh_where_clause = where_clause
|
||||
refresh_values = values.copy()
|
||||
if except_token_id:
|
||||
# TODO: support that for refresh tokens
|
||||
where_clause += " AND id != ?"
|
||||
values.append(except_token_id)
|
||||
|
||||
txn.execute(
|
||||
"SELECT token, id, device_id FROM access_tokens WHERE %s"
|
||||
% where_clause,
|
||||
values,
|
||||
)
|
||||
tokens_and_devices = [(r[0], r[1], r[2]) for r in txn]
|
||||
|
||||
self._invalidate_cache_and_stream_bulk(
|
||||
txn,
|
||||
self.get_user_by_access_token,
|
||||
[(token,) for token, _, _ in tokens_and_devices],
|
||||
)
|
||||
|
||||
txn.execute("DELETE FROM access_tokens WHERE %s" % where_clause, values)
|
||||
|
||||
txn.execute(
|
||||
"DELETE FROM refresh_tokens WHERE %s" % refresh_where_clause,
|
||||
refresh_values,
|
||||
)
|
||||
|
||||
return tokens_and_devices
|
||||
|
||||
return await self.db_pool.runInteraction("user_delete_access_tokens", f)
|
||||
|
||||
async def user_delete_access_tokens_for_devices(
|
||||
self,
|
||||
user_id: str,
|
||||
device_ids: StrCollection,
|
||||
) -> List[Tuple[str, int, Optional[str]]]:
|
||||
"""
|
||||
Invalidate access and refresh tokens belonging to a user
|
||||
|
||||
Args:
|
||||
user_id: ID of user the tokens belong to
|
||||
device_ids: The devices to delete tokens for.
|
||||
Returns:
|
||||
A tuple of (token, token id, device id) for each of the deleted tokens
|
||||
"""
|
||||
|
||||
def user_delete_access_tokens_for_devices_txn(
|
||||
txn: LoggingTransaction, batch_device_ids: StrCollection
|
||||
) -> List[Tuple[str, int, Optional[str]]]:
|
||||
self.db_pool.simple_delete_many_txn(
|
||||
txn,
|
||||
table="refresh_tokens",
|
||||
keyvalues={"user_id": user_id},
|
||||
column="device_id",
|
||||
values=batch_device_ids,
|
||||
)
|
||||
|
||||
clause, args = make_in_list_sql_clause(
|
||||
txn.database_engine, "device_id", batch_device_ids
|
||||
)
|
||||
args.append(user_id)
|
||||
|
||||
if self.database_engine.supports_returning:
|
||||
sql = f"""
|
||||
DELETE FROM access_tokens
|
||||
WHERE {clause} AND user_id = ?
|
||||
RETURNING token, id, device_id
|
||||
"""
|
||||
txn.execute(sql, args)
|
||||
tokens_and_devices = txn.fetchall()
|
||||
else:
|
||||
tokens_and_devices = self.db_pool.simple_select_many_txn(
|
||||
txn,
|
||||
table="access_tokens",
|
||||
column="device_id",
|
||||
iterable=batch_device_ids,
|
||||
keyvalues={"user_id": user_id},
|
||||
retcols=("token", "id", "device_id"),
|
||||
)
|
||||
|
||||
self.db_pool.simple_delete_many_txn(
|
||||
txn,
|
||||
table="access_tokens",
|
||||
keyvalues={"user_id": user_id},
|
||||
column="device_id",
|
||||
values=batch_device_ids,
|
||||
)
|
||||
|
||||
self._invalidate_cache_and_stream_bulk(
|
||||
txn,
|
||||
self.get_user_by_access_token,
|
||||
[(t[0],) for t in tokens_and_devices],
|
||||
)
|
||||
return tokens_and_devices
|
||||
|
||||
results = []
|
||||
for batch_device_ids in batch_iter(device_ids, 1000):
|
||||
tokens_and_devices = await self.db_pool.runInteraction(
|
||||
"user_delete_access_tokens_for_devices",
|
||||
user_delete_access_tokens_for_devices_txn,
|
||||
batch_device_ids,
|
||||
)
|
||||
results.extend(tokens_and_devices)
|
||||
|
||||
return results
|
||||
|
||||
async def delete_access_token(self, access_token: str) -> None:
|
||||
def f(txn: LoggingTransaction) -> None:
|
||||
self.db_pool.simple_delete_one_txn(
|
||||
txn, table="access_tokens", keyvalues={"token": access_token}
|
||||
)
|
||||
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_user_by_access_token, (access_token,)
|
||||
)
|
||||
|
||||
await self.db_pool.runInteraction("delete_access_token", f)
|
||||
|
||||
async def delete_refresh_token(self, refresh_token: str) -> None:
|
||||
def f(txn: LoggingTransaction) -> None:
|
||||
self.db_pool.simple_delete_one_txn(
|
||||
txn, table="refresh_tokens", keyvalues={"token": refresh_token}
|
||||
)
|
||||
|
||||
await self.db_pool.runInteraction("delete_refresh_token", f)
|
||||
|
||||
async def add_user_pending_deactivation(self, user_id: str) -> None:
|
||||
"""
|
||||
Adds a user to the table of users who need to be parted from all the rooms they're
|
||||
|
||||
@ -324,7 +324,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||
account_data_key=0,
|
||||
push_rules_key=0,
|
||||
to_device_key=0,
|
||||
device_list_key=0,
|
||||
device_list_key=MultiWriterStreamToken(stream=0),
|
||||
groups_key=0,
|
||||
un_partial_stated_rooms_key=0,
|
||||
)
|
||||
|
||||
@ -1574,6 +1574,11 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
||||
"""Get the event ID of the initial join that started the partial
|
||||
join, and the device list stream ID at the point we started the partial
|
||||
join.
|
||||
|
||||
This only returns the minimum device list stream ID at the time of
|
||||
joining, not the full device list stream token. The only impact of this
|
||||
is that we may be sending again device list updates that we've already
|
||||
sent to some destinations, which is harmless.
|
||||
"""
|
||||
|
||||
return cast(
|
||||
|
||||
@ -61,7 +61,6 @@ from typing import (
|
||||
)
|
||||
|
||||
import attr
|
||||
from immutabledict import immutabledict
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from twisted.internet import defer
|
||||
@ -657,23 +656,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
component.
|
||||
"""
|
||||
|
||||
min_pos = self._stream_id_gen.get_current_token()
|
||||
|
||||
positions = {}
|
||||
if isinstance(self._stream_id_gen, MultiWriterIdGenerator):
|
||||
# The `min_pos` is the minimum position that we know all instances
|
||||
# have finished persisting to, so we only care about instances whose
|
||||
# positions are ahead of that. (Instance positions can be behind the
|
||||
# min position as there are times we can work out that the minimum
|
||||
# position is ahead of the naive minimum across all current
|
||||
# positions. See MultiWriterIdGenerator for details)
|
||||
positions = {
|
||||
i: p
|
||||
for i, p in self._stream_id_gen.get_positions().items()
|
||||
if p > min_pos
|
||||
}
|
||||
|
||||
return RoomStreamToken(stream=min_pos, instance_map=immutabledict(positions))
|
||||
return RoomStreamToken.from_generator(self._stream_id_gen)
|
||||
|
||||
def get_events_stream_id_generator(self) -> MultiWriterIdGenerator:
|
||||
return self._stream_id_gen
|
||||
|
||||
@ -203,7 +203,7 @@ class EventSources:
|
||||
account_data_key=0,
|
||||
push_rules_key=0,
|
||||
to_device_key=0,
|
||||
device_list_key=0,
|
||||
device_list_key=MultiWriterStreamToken(stream=0),
|
||||
groups_key=0,
|
||||
un_partial_stated_rooms_key=0,
|
||||
)
|
||||
|
||||
@ -75,6 +75,7 @@ if TYPE_CHECKING:
|
||||
from synapse.appservice.api import ApplicationService
|
||||
from synapse.storage.databases.main import DataStore, PurgeEventsStore
|
||||
from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -570,6 +571,25 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_generator(cls, generator: "MultiWriterIdGenerator") -> Self:
|
||||
"""Get the current token out of a MultiWriterIdGenerator"""
|
||||
|
||||
# The `min_pos` is the minimum position that we know all instances
|
||||
# have finished persisting to, so we only care about instances whose
|
||||
# positions are ahead of that. (Instance positions can be behind the
|
||||
# min position as there are times we can work out that the minimum
|
||||
# position is ahead of the naive minimum across all current
|
||||
# positions. See MultiWriterIdGenerator for details)
|
||||
min_pos = generator.get_current_token()
|
||||
positions = {
|
||||
instance: position
|
||||
for instance, position in generator.get_positions().items()
|
||||
if position > min_pos
|
||||
}
|
||||
|
||||
return cls(stream=min_pos, instance_map=immutabledict(positions))
|
||||
|
||||
|
||||
@attr.s(frozen=True, slots=True, order=False)
|
||||
class RoomStreamToken(AbstractMultiWriterStreamToken):
|
||||
@ -980,7 +1000,9 @@ class StreamToken:
|
||||
account_data_key: int
|
||||
push_rules_key: int
|
||||
to_device_key: int
|
||||
device_list_key: int
|
||||
device_list_key: MultiWriterStreamToken = attr.ib(
|
||||
validator=attr.validators.instance_of(MultiWriterStreamToken)
|
||||
)
|
||||
# Note that the groups key is no longer used and may have bogus values.
|
||||
groups_key: int
|
||||
un_partial_stated_rooms_key: int
|
||||
@ -1021,7 +1043,9 @@ class StreamToken:
|
||||
account_data_key=int(account_data_key),
|
||||
push_rules_key=int(push_rules_key),
|
||||
to_device_key=int(to_device_key),
|
||||
device_list_key=int(device_list_key),
|
||||
device_list_key=await MultiWriterStreamToken.parse(
|
||||
store, device_list_key
|
||||
),
|
||||
groups_key=int(groups_key),
|
||||
un_partial_stated_rooms_key=int(un_partial_stated_rooms_key),
|
||||
)
|
||||
@ -1040,7 +1064,7 @@ class StreamToken:
|
||||
str(self.account_data_key),
|
||||
str(self.push_rules_key),
|
||||
str(self.to_device_key),
|
||||
str(self.device_list_key),
|
||||
await self.device_list_key.to_string(store),
|
||||
# Note that the groups key is no longer used, but it is still
|
||||
# serialized so that there will not be confusion in the future
|
||||
# if additional tokens are added.
|
||||
@ -1069,6 +1093,12 @@ class StreamToken:
|
||||
StreamKeyType.RECEIPT, self.receipt_key.copy_and_advance(new_value)
|
||||
)
|
||||
return new_token
|
||||
elif key == StreamKeyType.DEVICE_LIST:
|
||||
new_token = self.copy_and_replace(
|
||||
StreamKeyType.DEVICE_LIST,
|
||||
self.device_list_key.copy_and_advance(new_value),
|
||||
)
|
||||
return new_token
|
||||
|
||||
new_token = self.copy_and_replace(key, new_value)
|
||||
new_id = new_token.get_field(key)
|
||||
@ -1087,7 +1117,11 @@ class StreamToken:
|
||||
|
||||
@overload
|
||||
def get_field(
|
||||
self, key: Literal[StreamKeyType.RECEIPT]
|
||||
self,
|
||||
key: Literal[
|
||||
StreamKeyType.RECEIPT,
|
||||
StreamKeyType.DEVICE_LIST,
|
||||
],
|
||||
) -> MultiWriterStreamToken: ...
|
||||
|
||||
@overload
|
||||
@ -1095,7 +1129,6 @@ class StreamToken:
|
||||
self,
|
||||
key: Literal[
|
||||
StreamKeyType.ACCOUNT_DATA,
|
||||
StreamKeyType.DEVICE_LIST,
|
||||
StreamKeyType.PRESENCE,
|
||||
StreamKeyType.PUSH_RULES,
|
||||
StreamKeyType.TO_DEVICE,
|
||||
@ -1161,7 +1194,16 @@ class StreamToken:
|
||||
|
||||
|
||||
StreamToken.START = StreamToken(
|
||||
RoomStreamToken(stream=0), 0, 0, MultiWriterStreamToken(stream=0), 0, 0, 0, 0, 0, 0
|
||||
room_key=RoomStreamToken(stream=0),
|
||||
presence_key=0,
|
||||
typing_key=0,
|
||||
receipt_key=MultiWriterStreamToken(stream=0),
|
||||
account_data_key=0,
|
||||
push_rules_key=0,
|
||||
to_device_key=0,
|
||||
device_list_key=MultiWriterStreamToken(stream=0),
|
||||
groups_key=0,
|
||||
un_partial_stated_rooms_key=0,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -30,7 +30,7 @@ from synapse.api.constants import EduTypes, RoomEncryptionAlgorithms
|
||||
from synapse.api.presence import UserPresenceState
|
||||
from synapse.federation.sender.per_destination_queue import MAX_PRESENCE_STATES_PER_EDU
|
||||
from synapse.federation.units import Transaction
|
||||
from synapse.handlers.device import DeviceHandler
|
||||
from synapse.handlers.device import DeviceListUpdater, DeviceWriterHandler
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client import login
|
||||
from synapse.server import HomeServer
|
||||
@ -500,7 +500,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
|
||||
hs.get_datastores().main.get_current_hosts_in_room = get_current_hosts_in_room # type: ignore[assignment]
|
||||
|
||||
device_handler = hs.get_device_handler()
|
||||
assert isinstance(device_handler, DeviceHandler)
|
||||
assert isinstance(device_handler, DeviceWriterHandler)
|
||||
self.device_handler = device_handler
|
||||
|
||||
# whenever send_transaction is called, record the edu data
|
||||
@ -554,6 +554,8 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
|
||||
"devices": [{"device_id": "D1"}],
|
||||
}
|
||||
|
||||
assert isinstance(self.device_handler.device_list_updater, DeviceListUpdater)
|
||||
|
||||
self.get_success(
|
||||
self.device_handler.device_list_updater.incoming_device_list_update(
|
||||
"host2",
|
||||
|
||||
@ -29,7 +29,7 @@ from twisted.test.proto_helpers import MemoryReactor
|
||||
from synapse.api.constants import RoomEncryptionAlgorithms
|
||||
from synapse.api.errors import NotFoundError, SynapseError
|
||||
from synapse.appservice import ApplicationService
|
||||
from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler
|
||||
from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceWriterHandler
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client import devices, login, register
|
||||
from synapse.server import HomeServer
|
||||
@ -53,7 +53,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
|
||||
application_service_api=self.appservice_api,
|
||||
)
|
||||
handler = hs.get_device_handler()
|
||||
assert isinstance(handler, DeviceHandler)
|
||||
assert isinstance(handler, DeviceWriterHandler)
|
||||
self.handler = handler
|
||||
self.store = hs.get_datastores().main
|
||||
self.device_message_handler = hs.get_device_message_handler()
|
||||
@ -229,7 +229,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
# queue a bunch of messages in the inbox
|
||||
requester = create_requester(sender, device_id=DEVICE_ID)
|
||||
for i in range(DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT + 10):
|
||||
for i in range(DeviceWriterHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT + 10):
|
||||
self.get_success(
|
||||
self.device_message_handler.send_device_message(
|
||||
requester, "message_type", {receiver: {"*": {"val": i}}}
|
||||
@ -462,7 +462,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
hs = self.setup_test_homeserver("server")
|
||||
handler = hs.get_device_handler()
|
||||
assert isinstance(handler, DeviceHandler)
|
||||
assert isinstance(handler, DeviceWriterHandler)
|
||||
self.handler = handler
|
||||
self.message_handler = hs.get_device_message_handler()
|
||||
self.registration = hs.get_registration_handler()
|
||||
|
||||
@ -31,7 +31,7 @@ from twisted.test.proto_helpers import MemoryReactor
|
||||
from synapse.api.constants import RoomEncryptionAlgorithms
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.appservice import ApplicationService
|
||||
from synapse.handlers.device import DeviceHandler
|
||||
from synapse.handlers.device import DeviceWriterHandler
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.databases.main.appservice import _make_exclusive_regex
|
||||
from synapse.types import JsonDict, UserID
|
||||
@ -856,7 +856,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||
self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))
|
||||
|
||||
device_handler = self.hs.get_device_handler()
|
||||
assert isinstance(device_handler, DeviceHandler)
|
||||
assert isinstance(device_handler, DeviceWriterHandler)
|
||||
e = self.get_failure(
|
||||
device_handler.check_device_registered(
|
||||
user_id=local_user,
|
||||
|
||||
@ -28,7 +28,7 @@ from synapse.api.constants import EduTypes, EventTypes
|
||||
from synapse.api.errors import NotFoundError
|
||||
from synapse.events import EventBase
|
||||
from synapse.federation.units import Transaction
|
||||
from synapse.handlers.device import DeviceHandler
|
||||
from synapse.handlers.device import DeviceWriterHandler
|
||||
from synapse.handlers.presence import UserPresenceState
|
||||
from synapse.handlers.push_rules import InvalidRuleException
|
||||
from synapse.module_api import ModuleApi
|
||||
@ -819,7 +819,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
|
||||
|
||||
# Delete the device.
|
||||
device_handler = self.hs.get_device_handler()
|
||||
assert isinstance(device_handler, DeviceHandler)
|
||||
assert isinstance(device_handler, DeviceWriterHandler)
|
||||
self.get_success(device_handler.delete_devices(user_id, [device_id]))
|
||||
|
||||
# Check that the callback was called and the pushers still existed.
|
||||
|
||||
@ -26,7 +26,7 @@ from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
import synapse.rest.admin
|
||||
from synapse.api.errors import Codes
|
||||
from synapse.handlers.device import DeviceHandler
|
||||
from synapse.handlers.device import DeviceWriterHandler
|
||||
from synapse.rest.client import devices, login
|
||||
from synapse.server import HomeServer
|
||||
from synapse.util import Clock
|
||||
@ -42,7 +42,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
handler = hs.get_device_handler()
|
||||
assert isinstance(handler, DeviceHandler)
|
||||
assert isinstance(handler, DeviceWriterHandler)
|
||||
self.handler = handler
|
||||
|
||||
self.admin_user = self.register_user("admin", "pass", admin=True)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user