mirror of
				https://github.com/element-hq/synapse.git
				synced 2025-11-03 00:03:21 -05:00 
			
		
		
		
	Move device changes off the main process (#18581)
The main goal of this PR is to handle device list changes onto multiple writers, off the main process, so that we can have logins happening whilst Synapse is rolling-restarting. This is quite an intrusive change, so I would advise to review this commit by commit; I tried to keep the history as clean as possible. There are a few things to consider: - the `device_list_key` in stream tokens becomes a `MultiWriterStreamToken`, which has a few implications in sync and on the storage layer - we had a split between `DeviceHandler` and `DeviceWorkerHandler` for master vs. worker process. I've kept this split, but making it rather writer vs. non-writer worker, using method overrides for doing replication calls when needed - there are a few operations that need to happen on a single worker at a time. Instead of using cross-worker locks, for now I made them run on the first writer on the list --------- Co-authored-by: Eric Eastwood <erice@element.io>
This commit is contained in:
		
							parent
							
								
									66504d1144
								
							
						
					
					
						commit
						5ea2cf2484
					
				
							
								
								
									
										1
									
								
								changelog.d/18581.feature
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								changelog.d/18581.feature
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1 @@
 | 
			
		||||
Enable workers to write directly to the device lists stream and handle device list updates, reducing load on the main process.
 | 
			
		||||
@ -54,7 +54,6 @@ if [[ -n "$SYNAPSE_COMPLEMENT_USE_WORKERS" ]]; then
 | 
			
		||||
    export SYNAPSE_WORKER_TYPES="\
 | 
			
		||||
      event_persister:2, \
 | 
			
		||||
      background_worker, \
 | 
			
		||||
      frontend_proxy, \
 | 
			
		||||
      event_creator, \
 | 
			
		||||
      user_dir, \
 | 
			
		||||
      media_repository, \
 | 
			
		||||
@ -65,6 +64,7 @@ if [[ -n "$SYNAPSE_COMPLEMENT_USE_WORKERS" ]]; then
 | 
			
		||||
      client_reader, \
 | 
			
		||||
      appservice, \
 | 
			
		||||
      pusher, \
 | 
			
		||||
      device_lists:2, \
 | 
			
		||||
      stream_writers=account_data+presence+receipts+to_device+typing"
 | 
			
		||||
 | 
			
		||||
  fi
 | 
			
		||||
 | 
			
		||||
@ -178,6 +178,8 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = {
 | 
			
		||||
            "^/_matrix/client/(api/v1|r0|v3|unstable)/login$",
 | 
			
		||||
            "^/_matrix/client/(api/v1|r0|v3|unstable)/account/3pid$",
 | 
			
		||||
            "^/_matrix/client/(api/v1|r0|v3|unstable)/account/whoami$",
 | 
			
		||||
            "^/_matrix/client/(api/v1|r0|v3|unstable)/devices(/|$)",
 | 
			
		||||
            "^/_matrix/client/(r0|v3)/delete_devices$",
 | 
			
		||||
            "^/_matrix/client/versions$",
 | 
			
		||||
            "^/_matrix/client/(api/v1|r0|v3|unstable)/voip/turnServer$",
 | 
			
		||||
            "^/_matrix/client/(r0|v3|unstable)/register$",
 | 
			
		||||
@ -194,6 +196,9 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = {
 | 
			
		||||
            "^/_matrix/client/(api/v1|r0|v3|unstable)/directory/room/.*$",
 | 
			
		||||
            "^/_matrix/client/(r0|v3|unstable)/capabilities$",
 | 
			
		||||
            "^/_matrix/client/(r0|v3|unstable)/notifications$",
 | 
			
		||||
            "^/_matrix/client/(api/v1|r0|v3|unstable)/keys/upload",
 | 
			
		||||
            "^/_matrix/client/(api/v1|r0|v3|unstable)/keys/device_signing/upload$",
 | 
			
		||||
            "^/_matrix/client/(api/v1|r0|v3|unstable)/keys/signatures/upload$",
 | 
			
		||||
        ],
 | 
			
		||||
        "shared_extra_conf": {},
 | 
			
		||||
        "worker_extra_conf": "",
 | 
			
		||||
@ -265,13 +270,6 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = {
 | 
			
		||||
        "shared_extra_conf": {},
 | 
			
		||||
        "worker_extra_conf": "",
 | 
			
		||||
    },
 | 
			
		||||
    "frontend_proxy": {
 | 
			
		||||
        "app": "synapse.app.generic_worker",
 | 
			
		||||
        "listener_resources": ["client", "replication"],
 | 
			
		||||
        "endpoint_patterns": ["^/_matrix/client/(api/v1|r0|v3|unstable)/keys/upload"],
 | 
			
		||||
        "shared_extra_conf": {},
 | 
			
		||||
        "worker_extra_conf": "",
 | 
			
		||||
    },
 | 
			
		||||
    "account_data": {
 | 
			
		||||
        "app": "synapse.app.generic_worker",
 | 
			
		||||
        "listener_resources": ["client", "replication"],
 | 
			
		||||
@ -306,6 +304,13 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = {
 | 
			
		||||
        "shared_extra_conf": {},
 | 
			
		||||
        "worker_extra_conf": "",
 | 
			
		||||
    },
 | 
			
		||||
    "device_lists": {
 | 
			
		||||
        "app": "synapse.app.generic_worker",
 | 
			
		||||
        "listener_resources": ["client", "replication"],
 | 
			
		||||
        "endpoint_patterns": [],
 | 
			
		||||
        "shared_extra_conf": {},
 | 
			
		||||
        "worker_extra_conf": "",
 | 
			
		||||
    },
 | 
			
		||||
    "typing": {
 | 
			
		||||
        "app": "synapse.app.generic_worker",
 | 
			
		||||
        "listener_resources": ["client", "replication"],
 | 
			
		||||
@ -412,16 +417,17 @@ def add_worker_roles_to_shared_config(
 | 
			
		||||
    # streams
 | 
			
		||||
    instance_map = shared_config.setdefault("instance_map", {})
 | 
			
		||||
 | 
			
		||||
    # This is a list of the stream_writers that there can be only one of. Events can be
 | 
			
		||||
    # sharded, and therefore doesn't belong here.
 | 
			
		||||
    singular_stream_writers = [
 | 
			
		||||
    # This is a list of the stream_writers.
 | 
			
		||||
    stream_writers = {
 | 
			
		||||
        "account_data",
 | 
			
		||||
        "events",
 | 
			
		||||
        "device_lists",
 | 
			
		||||
        "presence",
 | 
			
		||||
        "receipts",
 | 
			
		||||
        "to_device",
 | 
			
		||||
        "typing",
 | 
			
		||||
        "push_rules",
 | 
			
		||||
    ]
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    # Worker-type specific sharding config. Now a single worker can fulfill multiple
 | 
			
		||||
    # roles, check each.
 | 
			
		||||
@ -431,28 +437,11 @@ def add_worker_roles_to_shared_config(
 | 
			
		||||
    if "federation_sender" in worker_types_set:
 | 
			
		||||
        shared_config.setdefault("federation_sender_instances", []).append(worker_name)
 | 
			
		||||
 | 
			
		||||
    if "event_persister" in worker_types_set:
 | 
			
		||||
        # Event persisters write to the events stream, so we need to update
 | 
			
		||||
        # the list of event stream writers
 | 
			
		||||
        shared_config.setdefault("stream_writers", {}).setdefault("events", []).append(
 | 
			
		||||
            worker_name
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Map of stream writer instance names to host/ports combos
 | 
			
		||||
        if os.environ.get("SYNAPSE_USE_UNIX_SOCKET", False):
 | 
			
		||||
            instance_map[worker_name] = {
 | 
			
		||||
                "path": f"/run/worker.{worker_port}",
 | 
			
		||||
            }
 | 
			
		||||
        else:
 | 
			
		||||
            instance_map[worker_name] = {
 | 
			
		||||
                "host": "localhost",
 | 
			
		||||
                "port": worker_port,
 | 
			
		||||
            }
 | 
			
		||||
    # Update the list of stream writers. It's convenient that the name of the worker
 | 
			
		||||
    # type is the same as the stream to write. Iterate over the whole list in case there
 | 
			
		||||
    # is more than one.
 | 
			
		||||
    for worker in worker_types_set:
 | 
			
		||||
        if worker in singular_stream_writers:
 | 
			
		||||
        if worker in stream_writers:
 | 
			
		||||
            shared_config.setdefault("stream_writers", {}).setdefault(
 | 
			
		||||
                worker, []
 | 
			
		||||
            ).append(worker_name)
 | 
			
		||||
@ -876,6 +865,13 @@ def generate_worker_files(
 | 
			
		||||
        else:
 | 
			
		||||
            healthcheck_urls.append("http://localhost:%d/health" % (worker_port,))
 | 
			
		||||
 | 
			
		||||
        # Special case for event_persister: those are just workers that write to
 | 
			
		||||
        # the `events` stream. For other workers, the worker name is the same
 | 
			
		||||
        # name of the stream they write to, but for some reason it is not the
 | 
			
		||||
        # case for event_persister.
 | 
			
		||||
        if "event_persister" in worker_types_set:
 | 
			
		||||
            worker_types_set.add("events")
 | 
			
		||||
 | 
			
		||||
        # Update the shared config with sharding-related options if necessary
 | 
			
		||||
        add_worker_roles_to_shared_config(
 | 
			
		||||
            shared_config, worker_types_set, worker_name, worker_port
 | 
			
		||||
 | 
			
		||||
@ -4341,6 +4341,8 @@ This setting has the following sub-options:
 | 
			
		||||
 | 
			
		||||
* `push_rules` (string): Name of a worker assigned to the `push_rules` stream.
 | 
			
		||||
 | 
			
		||||
* `device_lists` (string): Name of a worker assigned to the `device_lists` stream.
 | 
			
		||||
 | 
			
		||||
Example configuration:
 | 
			
		||||
```yaml
 | 
			
		||||
stream_writers:
 | 
			
		||||
 | 
			
		||||
@ -238,7 +238,8 @@ information.
 | 
			
		||||
    ^/_matrix/client/unstable/im.nheko.summary/summary/.*$
 | 
			
		||||
    ^/_matrix/client/(r0|v3|unstable)/account/3pid$
 | 
			
		||||
    ^/_matrix/client/(r0|v3|unstable)/account/whoami$
 | 
			
		||||
    ^/_matrix/client/(r0|v3|unstable)/devices$
 | 
			
		||||
    ^/_matrix/client/(r0|v3)/delete_devices$
 | 
			
		||||
    ^/_matrix/client/(api/v1|r0|v3|unstable)/devices(/|$)
 | 
			
		||||
    ^/_matrix/client/versions$
 | 
			
		||||
    ^/_matrix/client/(api/v1|r0|v3|unstable)/voip/turnServer$
 | 
			
		||||
    ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/event/
 | 
			
		||||
@ -257,7 +258,9 @@ information.
 | 
			
		||||
    ^/_matrix/client/(r0|v3|unstable)/keys/changes$
 | 
			
		||||
    ^/_matrix/client/(r0|v3|unstable)/keys/claim$
 | 
			
		||||
    ^/_matrix/client/(r0|v3|unstable)/room_keys/
 | 
			
		||||
    ^/_matrix/client/(r0|v3|unstable)/keys/upload$
 | 
			
		||||
    ^/_matrix/client/(r0|v3|unstable)/keys/upload
 | 
			
		||||
    ^/_matrix/client/(api/v1|r0|v3|unstable/keys/device_signing/upload$
 | 
			
		||||
    ^/_matrix/client/(api/v1|r0|v3|unstable)/keys/signatures/upload$
 | 
			
		||||
 | 
			
		||||
    # Registration/login requests
 | 
			
		||||
    ^/_matrix/client/(api/v1|r0|v3|unstable)/login$
 | 
			
		||||
@ -282,7 +285,6 @@ Additionally, the following REST endpoints can be handled for GET requests:
 | 
			
		||||
 | 
			
		||||
    ^/_matrix/client/(api/v1|r0|v3|unstable)/pushrules/
 | 
			
		||||
    ^/_matrix/client/unstable/org.matrix.msc4140/delayed_events
 | 
			
		||||
    ^/_matrix/client/(api/v1|r0|v3|unstable)/devices/
 | 
			
		||||
 | 
			
		||||
    # Account data requests
 | 
			
		||||
    ^/_matrix/client/(r0|v3|unstable)/.*/tags
 | 
			
		||||
@ -329,7 +331,6 @@ set to `true`), the following endpoints can be handled by the worker:
 | 
			
		||||
    ^/_synapse/admin/v2/users/[^/]+$
 | 
			
		||||
    ^/_synapse/admin/v1/username_available$
 | 
			
		||||
    ^/_synapse/admin/v1/users/[^/]+/_allow_cross_signing_replacement_without_uia$
 | 
			
		||||
    # Only the GET method:
 | 
			
		||||
    ^/_synapse/admin/v1/users/[^/]+/devices$
 | 
			
		||||
 | 
			
		||||
Note that a [HTTP listener](usage/configuration/config_documentation.md#listeners)
 | 
			
		||||
@ -550,6 +551,18 @@ the stream writer for the `push_rules` stream:
 | 
			
		||||
 | 
			
		||||
    ^/_matrix/client/(api/v1|r0|v3|unstable)/pushrules/
 | 
			
		||||
 | 
			
		||||
##### The `device_lists` stream
 | 
			
		||||
 | 
			
		||||
The `device_lists` stream supports multiple writers. The following endpoints
 | 
			
		||||
can be handled by any worker, but should be routed directly one of the workers
 | 
			
		||||
configured as stream writer for the `device_lists` stream:
 | 
			
		||||
 | 
			
		||||
    ^/_matrix/client/(r0|v3)/delete_devices$
 | 
			
		||||
    ^/_matrix/client/(api/v1|r0|v3|unstable)/devices/
 | 
			
		||||
    ^/_matrix/client/(r0|v3|unstable)/keys/upload
 | 
			
		||||
    ^/_matrix/client/(api/v1|r0|v3|unstable/keys/device_signing/upload$
 | 
			
		||||
    ^/_matrix/client/(api/v1|r0|v3|unstable)/keys/signatures/upload$
 | 
			
		||||
 | 
			
		||||
#### Restrict outbound federation traffic to a specific set of workers
 | 
			
		||||
 | 
			
		||||
The
 | 
			
		||||
 | 
			
		||||
@ -5383,6 +5383,9 @@ properties:
 | 
			
		||||
      push_rules:
 | 
			
		||||
        type: string
 | 
			
		||||
        description: Name of a worker assigned to the `push_rules` stream.
 | 
			
		||||
      device_lists:
 | 
			
		||||
        type: string
 | 
			
		||||
        description: Name of a worker assigned to the `device_lists` stream.
 | 
			
		||||
    default: {}
 | 
			
		||||
    examples:
 | 
			
		||||
      - events: worker1
 | 
			
		||||
 | 
			
		||||
@ -134,40 +134,44 @@ class WriterLocations:
 | 
			
		||||
            can only be a single instance.
 | 
			
		||||
        account_data: The instances that write to the account data streams. Currently
 | 
			
		||||
            can only be a single instance.
 | 
			
		||||
        receipts: The instances that write to the receipts stream. Currently
 | 
			
		||||
            can only be a single instance.
 | 
			
		||||
        receipts: The instances that write to the receipts stream.
 | 
			
		||||
        presence: The instances that write to the presence stream. Currently
 | 
			
		||||
            can only be a single instance.
 | 
			
		||||
        push_rules: The instances that write to the push stream. Currently
 | 
			
		||||
            can only be a single instance.
 | 
			
		||||
        device_lists: The instances that write to the device list stream.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    events: List[str] = attr.ib(
 | 
			
		||||
        default=["master"],
 | 
			
		||||
        default=[MAIN_PROCESS_INSTANCE_NAME],
 | 
			
		||||
        converter=_instance_to_list_converter,
 | 
			
		||||
    )
 | 
			
		||||
    typing: List[str] = attr.ib(
 | 
			
		||||
        default=["master"],
 | 
			
		||||
        default=[MAIN_PROCESS_INSTANCE_NAME],
 | 
			
		||||
        converter=_instance_to_list_converter,
 | 
			
		||||
    )
 | 
			
		||||
    to_device: List[str] = attr.ib(
 | 
			
		||||
        default=["master"],
 | 
			
		||||
        default=[MAIN_PROCESS_INSTANCE_NAME],
 | 
			
		||||
        converter=_instance_to_list_converter,
 | 
			
		||||
    )
 | 
			
		||||
    account_data: List[str] = attr.ib(
 | 
			
		||||
        default=["master"],
 | 
			
		||||
        default=[MAIN_PROCESS_INSTANCE_NAME],
 | 
			
		||||
        converter=_instance_to_list_converter,
 | 
			
		||||
    )
 | 
			
		||||
    receipts: List[str] = attr.ib(
 | 
			
		||||
        default=["master"],
 | 
			
		||||
        default=[MAIN_PROCESS_INSTANCE_NAME],
 | 
			
		||||
        converter=_instance_to_list_converter,
 | 
			
		||||
    )
 | 
			
		||||
    presence: List[str] = attr.ib(
 | 
			
		||||
        default=["master"],
 | 
			
		||||
        default=[MAIN_PROCESS_INSTANCE_NAME],
 | 
			
		||||
        converter=_instance_to_list_converter,
 | 
			
		||||
    )
 | 
			
		||||
    push_rules: List[str] = attr.ib(
 | 
			
		||||
        default=["master"],
 | 
			
		||||
        default=[MAIN_PROCESS_INSTANCE_NAME],
 | 
			
		||||
        converter=_instance_to_list_converter,
 | 
			
		||||
    )
 | 
			
		||||
    device_lists: List[str] = attr.ib(
 | 
			
		||||
        default=[MAIN_PROCESS_INSTANCE_NAME],
 | 
			
		||||
        converter=_instance_to_list_converter,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@ -358,7 +362,10 @@ class WorkerConfig(Config):
 | 
			
		||||
        ):
 | 
			
		||||
            instances = _instance_to_list_converter(getattr(self.writers, stream))
 | 
			
		||||
            for instance in instances:
 | 
			
		||||
                if instance != "master" and instance not in self.instance_map:
 | 
			
		||||
                if (
 | 
			
		||||
                    instance != MAIN_PROCESS_INSTANCE_NAME
 | 
			
		||||
                    and instance not in self.instance_map
 | 
			
		||||
                ):
 | 
			
		||||
                    raise ConfigError(
 | 
			
		||||
                        "Instance %r is configured to write %s but does not appear in `instance_map` config."
 | 
			
		||||
                        % (instance, stream)
 | 
			
		||||
@ -397,6 +404,11 @@ class WorkerConfig(Config):
 | 
			
		||||
                "Must only specify one instance to handle `push` messages."
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        if len(self.writers.device_lists) == 0:
 | 
			
		||||
            raise ConfigError(
 | 
			
		||||
                "Must specify at least one instance to handle `device_lists` messages."
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        self.events_shard_config = RoutableShardedWorkerHandlingConfig(
 | 
			
		||||
            self.writers.events
 | 
			
		||||
        )
 | 
			
		||||
@ -419,9 +431,12 @@ class WorkerConfig(Config):
 | 
			
		||||
        #
 | 
			
		||||
        # No effort is made to ensure only a single instance of these tasks is
 | 
			
		||||
        # running.
 | 
			
		||||
        background_tasks_instance = config.get("run_background_tasks_on") or "master"
 | 
			
		||||
        background_tasks_instance = (
 | 
			
		||||
            config.get("run_background_tasks_on") or MAIN_PROCESS_INSTANCE_NAME
 | 
			
		||||
        )
 | 
			
		||||
        self.run_background_tasks = (
 | 
			
		||||
            self.worker_name is None and background_tasks_instance == "master"
 | 
			
		||||
            self.worker_name is None
 | 
			
		||||
            and background_tasks_instance == MAIN_PROCESS_INSTANCE_NAME
 | 
			
		||||
        ) or self.worker_name == background_tasks_instance
 | 
			
		||||
 | 
			
		||||
        self.should_notify_appservices = self._should_this_worker_perform_duty(
 | 
			
		||||
@ -493,9 +508,10 @@ class WorkerConfig(Config):
 | 
			
		||||
        # 'don't run here'.
 | 
			
		||||
        new_option_should_run_here = None
 | 
			
		||||
        if new_option_name in config:
 | 
			
		||||
            designated_worker = config[new_option_name] or "master"
 | 
			
		||||
            designated_worker = config[new_option_name] or MAIN_PROCESS_INSTANCE_NAME
 | 
			
		||||
            new_option_should_run_here = (
 | 
			
		||||
                designated_worker == "master" and self.worker_name is None
 | 
			
		||||
                designated_worker == MAIN_PROCESS_INSTANCE_NAME
 | 
			
		||||
                and self.worker_name is None
 | 
			
		||||
            ) or designated_worker == self.worker_name
 | 
			
		||||
 | 
			
		||||
        legacy_option_should_run_here = None
 | 
			
		||||
@ -592,7 +608,7 @@ class WorkerConfig(Config):
 | 
			
		||||
            # If no worker instances are set we check if the legacy option
 | 
			
		||||
            # is set, which means use the main process.
 | 
			
		||||
            if legacy_option:
 | 
			
		||||
                worker_instances = ["master"]
 | 
			
		||||
                worker_instances = [MAIN_PROCESS_INSTANCE_NAME]
 | 
			
		||||
 | 
			
		||||
            if self.worker_app == legacy_app_name:
 | 
			
		||||
                if legacy_option:
 | 
			
		||||
 | 
			
		||||
@ -638,7 +638,8 @@ class ApplicationServicesHandler:
 | 
			
		||||
 | 
			
		||||
        # Fetch the users who have modified their device list since then.
 | 
			
		||||
        users_with_changed_device_lists = await self.store.get_all_devices_changed(
 | 
			
		||||
            from_key, to_key=new_key
 | 
			
		||||
            MultiWriterStreamToken(stream=from_key),
 | 
			
		||||
            to_key=MultiWriterStreamToken(stream=new_key),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Filter out any users the application service is not interested in
 | 
			
		||||
 | 
			
		||||
@ -24,7 +24,6 @@ from typing import TYPE_CHECKING, Optional
 | 
			
		||||
 | 
			
		||||
from synapse.api.constants import Membership
 | 
			
		||||
from synapse.api.errors import SynapseError
 | 
			
		||||
from synapse.handlers.device import DeviceHandler
 | 
			
		||||
from synapse.metrics.background_process_metrics import run_as_background_process
 | 
			
		||||
from synapse.types import Codes, Requester, UserID, create_requester
 | 
			
		||||
 | 
			
		||||
@ -84,10 +83,6 @@ class DeactivateAccountHandler:
 | 
			
		||||
        Returns:
 | 
			
		||||
            True if identity server supports removing threepids, otherwise False.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        # This can only be called on the main process.
 | 
			
		||||
        assert isinstance(self._device_handler, DeviceHandler)
 | 
			
		||||
 | 
			
		||||
        # Check if this user can be deactivated
 | 
			
		||||
        if not await self._third_party_rules.check_can_deactivate_user(
 | 
			
		||||
            user_id, by_admin
 | 
			
		||||
 | 
			
		||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							@ -33,9 +33,6 @@ from synapse.logging.opentracing import (
 | 
			
		||||
    log_kv,
 | 
			
		||||
    set_tag,
 | 
			
		||||
)
 | 
			
		||||
from synapse.replication.http.devices import (
 | 
			
		||||
    ReplicationMultiUserDevicesResyncRestServlet,
 | 
			
		||||
)
 | 
			
		||||
from synapse.types import JsonDict, Requester, StreamKeyType, UserID, get_domain_from_id
 | 
			
		||||
from synapse.util import json_encoder
 | 
			
		||||
from synapse.util.stringutils import random_string
 | 
			
		||||
@ -56,9 +53,9 @@ class DeviceMessageHandler:
 | 
			
		||||
        self.store = hs.get_datastores().main
 | 
			
		||||
        self.notifier = hs.get_notifier()
 | 
			
		||||
        self.is_mine = hs.is_mine
 | 
			
		||||
        self.device_handler = hs.get_device_handler()
 | 
			
		||||
        if hs.config.experimental.msc3814_enabled:
 | 
			
		||||
            self.event_sources = hs.get_event_sources()
 | 
			
		||||
            self.device_handler = hs.get_device_handler()
 | 
			
		||||
 | 
			
		||||
        # We only need to poke the federation sender explicitly if its on the
 | 
			
		||||
        # same instance. Other federation sender instances will get notified by
 | 
			
		||||
@ -80,18 +77,6 @@ class DeviceMessageHandler:
 | 
			
		||||
                hs.config.worker.writers.to_device,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # The handler to call when we think a user's device list might be out of
 | 
			
		||||
        # sync. We do all device list resyncing on the master instance, so if
 | 
			
		||||
        # we're on a worker we hit the device resync replication API.
 | 
			
		||||
        if hs.config.worker.worker_app is None:
 | 
			
		||||
            self._multi_user_device_resync = (
 | 
			
		||||
                hs.get_device_handler().device_list_updater.multi_user_device_resync
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            self._multi_user_device_resync = (
 | 
			
		||||
                ReplicationMultiUserDevicesResyncRestServlet.make_client(hs)
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # a rate limiter for room key requests.  The keys are
 | 
			
		||||
        # (sending_user_id, sending_device_id).
 | 
			
		||||
        self._ratelimiter = Ratelimiter(
 | 
			
		||||
@ -213,7 +198,10 @@ class DeviceMessageHandler:
 | 
			
		||||
            await self.store.mark_remote_users_device_caches_as_stale((sender_user_id,))
 | 
			
		||||
 | 
			
		||||
            # Immediately attempt a resync in the background
 | 
			
		||||
            run_in_background(self._multi_user_device_resync, user_ids=[sender_user_id])
 | 
			
		||||
            run_in_background(
 | 
			
		||||
                self.device_handler.device_list_updater.multi_user_device_resync,
 | 
			
		||||
                user_ids=[sender_user_id],
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    async def send_device_message(
 | 
			
		||||
        self,
 | 
			
		||||
 | 
			
		||||
@ -32,10 +32,9 @@ from twisted.internet import defer
 | 
			
		||||
 | 
			
		||||
from synapse.api.constants import EduTypes
 | 
			
		||||
from synapse.api.errors import CodeMessageException, Codes, NotFoundError, SynapseError
 | 
			
		||||
from synapse.handlers.device import DeviceHandler
 | 
			
		||||
from synapse.handlers.device import DeviceWriterHandler
 | 
			
		||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
 | 
			
		||||
from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
 | 
			
		||||
from synapse.replication.http.devices import ReplicationUploadKeysForUserRestServlet
 | 
			
		||||
from synapse.types import (
 | 
			
		||||
    JsonDict,
 | 
			
		||||
    JsonMapping,
 | 
			
		||||
@ -76,8 +75,10 @@ class E2eKeysHandler:
 | 
			
		||||
 | 
			
		||||
        federation_registry = hs.get_federation_registry()
 | 
			
		||||
 | 
			
		||||
        is_master = hs.config.worker.worker_app is None
 | 
			
		||||
        if is_master:
 | 
			
		||||
        # Only the first writer in the list should handle EDUs for signing key
 | 
			
		||||
        # updates, so that we can use an in-memory linearizer instead of worker locks.
 | 
			
		||||
        edu_writer = hs.config.worker.writers.device_lists[0]
 | 
			
		||||
        if hs.get_instance_name() == edu_writer:
 | 
			
		||||
            edu_updater = SigningKeyEduUpdater(hs)
 | 
			
		||||
 | 
			
		||||
            # Only register this edu handler on master as it requires writing
 | 
			
		||||
@ -92,11 +93,14 @@ class E2eKeysHandler:
 | 
			
		||||
                EduTypes.UNSTABLE_SIGNING_KEY_UPDATE,
 | 
			
		||||
                edu_updater.incoming_signing_key_update,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            self.device_key_uploader = self.upload_device_keys_for_user
 | 
			
		||||
        else:
 | 
			
		||||
            self.device_key_uploader = (
 | 
			
		||||
                ReplicationUploadKeysForUserRestServlet.make_client(hs)
 | 
			
		||||
            federation_registry.register_instances_for_edu(
 | 
			
		||||
                EduTypes.SIGNING_KEY_UPDATE,
 | 
			
		||||
                [edu_writer],
 | 
			
		||||
            )
 | 
			
		||||
            federation_registry.register_instances_for_edu(
 | 
			
		||||
                EduTypes.UNSTABLE_SIGNING_KEY_UPDATE,
 | 
			
		||||
                [edu_writer],
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # doesn't really work as part of the generic query API, because the
 | 
			
		||||
@ -847,7 +851,7 @@ class E2eKeysHandler:
 | 
			
		||||
        # TODO: Validate the JSON to make sure it has the right keys.
 | 
			
		||||
        device_keys = keys.get("device_keys", None)
 | 
			
		||||
        if device_keys:
 | 
			
		||||
            await self.device_key_uploader(
 | 
			
		||||
            await self.upload_device_keys_for_user(
 | 
			
		||||
                user_id=user_id,
 | 
			
		||||
                device_id=device_id,
 | 
			
		||||
                keys={"device_keys": device_keys},
 | 
			
		||||
@ -904,9 +908,6 @@ class E2eKeysHandler:
 | 
			
		||||
            device_keys: the `device_keys` of an /keys/upload request.
 | 
			
		||||
 | 
			
		||||
        """
 | 
			
		||||
        # This can only be called from the main process.
 | 
			
		||||
        assert isinstance(self.device_handler, DeviceHandler)
 | 
			
		||||
 | 
			
		||||
        time_now = self.clock.time_msec()
 | 
			
		||||
 | 
			
		||||
        device_keys = keys["device_keys"]
 | 
			
		||||
@ -998,9 +999,6 @@ class E2eKeysHandler:
 | 
			
		||||
            user_id: the user uploading the keys
 | 
			
		||||
            keys: the signing keys
 | 
			
		||||
        """
 | 
			
		||||
        # This can only be called from the main process.
 | 
			
		||||
        assert isinstance(self.device_handler, DeviceHandler)
 | 
			
		||||
 | 
			
		||||
        # if a master key is uploaded, then check it.  Otherwise, load the
 | 
			
		||||
        # stored master key, to check signatures on other keys
 | 
			
		||||
        if "master_key" in keys:
 | 
			
		||||
@ -1091,9 +1089,6 @@ class E2eKeysHandler:
 | 
			
		||||
        Raises:
 | 
			
		||||
            SynapseError: if the signatures dict is not valid.
 | 
			
		||||
        """
 | 
			
		||||
        # This can only be called from the main process.
 | 
			
		||||
        assert isinstance(self.device_handler, DeviceHandler)
 | 
			
		||||
 | 
			
		||||
        failures = {}
 | 
			
		||||
 | 
			
		||||
        # signatures to be stored.  Each item will be a SignatureListItem
 | 
			
		||||
@ -1467,9 +1462,6 @@ class E2eKeysHandler:
 | 
			
		||||
            A tuple of the retrieved key content, the key's ID and the matching VerifyKey.
 | 
			
		||||
            If the key cannot be retrieved, all values in the tuple will instead be None.
 | 
			
		||||
        """
 | 
			
		||||
        # This can only be called from the main process.
 | 
			
		||||
        assert isinstance(self.device_handler, DeviceHandler)
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            remote_result = await self.federation.query_user_devices(
 | 
			
		||||
                user.domain, user.to_string()
 | 
			
		||||
@ -1770,7 +1762,7 @@ class SigningKeyEduUpdater:
 | 
			
		||||
        self.clock = hs.get_clock()
 | 
			
		||||
 | 
			
		||||
        device_handler = hs.get_device_handler()
 | 
			
		||||
        assert isinstance(device_handler, DeviceHandler)
 | 
			
		||||
        assert isinstance(device_handler, DeviceWriterHandler)
 | 
			
		||||
        self._device_handler = device_handler
 | 
			
		||||
 | 
			
		||||
        self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
 | 
			
		||||
 | 
			
		||||
@ -698,10 +698,19 @@ class FederationHandler:
 | 
			
		||||
                    #     We may want to reset the partial state info if it's from an
 | 
			
		||||
                    #     old, failed partial state join.
 | 
			
		||||
                    #     https://github.com/matrix-org/synapse/issues/13000
 | 
			
		||||
 | 
			
		||||
                    # FIXME: Ideally, we would store the full stream token here
 | 
			
		||||
                    # not just the minimum stream ID, so that we can compute an
 | 
			
		||||
                    # accurate list of device changes when un-partial-ing the
 | 
			
		||||
                    # room. The only side effect of this is that we may send
 | 
			
		||||
                    # extra unecessary device list outbound pokes through
 | 
			
		||||
                    # federation, which is harmless.
 | 
			
		||||
                    device_lists_stream_id = self.store.get_device_stream_token().stream
 | 
			
		||||
 | 
			
		||||
                    await self.store.store_partial_state_room(
 | 
			
		||||
                        room_id=room_id,
 | 
			
		||||
                        servers=ret.servers_in_room,
 | 
			
		||||
                        device_lists_stream_id=self.store.get_device_stream_token(),
 | 
			
		||||
                        device_lists_stream_id=device_lists_stream_id,
 | 
			
		||||
                        joined_via=origin,
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -77,9 +77,6 @@ from synapse.logging.opentracing import (
 | 
			
		||||
    trace,
 | 
			
		||||
)
 | 
			
		||||
from synapse.metrics.background_process_metrics import run_as_background_process
 | 
			
		||||
from synapse.replication.http.devices import (
 | 
			
		||||
    ReplicationMultiUserDevicesResyncRestServlet,
 | 
			
		||||
)
 | 
			
		||||
from synapse.replication.http.federation import (
 | 
			
		||||
    ReplicationFederationSendEventsRestServlet,
 | 
			
		||||
)
 | 
			
		||||
@ -180,12 +177,7 @@ class FederationEventHandler:
 | 
			
		||||
        self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages
 | 
			
		||||
 | 
			
		||||
        self._send_events = ReplicationFederationSendEventsRestServlet.make_client(hs)
 | 
			
		||||
        if hs.config.worker.worker_app:
 | 
			
		||||
            self._multi_user_device_resync = (
 | 
			
		||||
                ReplicationMultiUserDevicesResyncRestServlet.make_client(hs)
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            self._device_list_updater = hs.get_device_handler().device_list_updater
 | 
			
		||||
        self._device_list_updater = hs.get_device_handler().device_list_updater
 | 
			
		||||
 | 
			
		||||
        # When joining a room we need to queue any events for that room up.
 | 
			
		||||
        # For each room, a list of (pdu, origin) tuples.
 | 
			
		||||
@ -1544,12 +1536,7 @@ class FederationEventHandler:
 | 
			
		||||
            await self._store.mark_remote_users_device_caches_as_stale((sender,))
 | 
			
		||||
 | 
			
		||||
            # Immediately attempt a resync in the background
 | 
			
		||||
            if self._config.worker.worker_app:
 | 
			
		||||
                await self._multi_user_device_resync(user_ids=[sender])
 | 
			
		||||
            else:
 | 
			
		||||
                await self._device_list_updater.multi_user_device_resync(
 | 
			
		||||
                    user_ids=[sender]
 | 
			
		||||
                )
 | 
			
		||||
            await self._device_list_updater.multi_user_device_resync(user_ids=[sender])
 | 
			
		||||
        except Exception:
 | 
			
		||||
            logger.exception("Failed to resync device for %s", sender)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -44,7 +44,6 @@ from synapse.api.errors import (
 | 
			
		||||
)
 | 
			
		||||
from synapse.appservice import ApplicationService
 | 
			
		||||
from synapse.config.server import is_threepid_reserved
 | 
			
		||||
from synapse.handlers.device import DeviceHandler
 | 
			
		||||
from synapse.http.servlet import assert_params_in_dict
 | 
			
		||||
from synapse.replication.http.login import RegisterDeviceReplicationServlet
 | 
			
		||||
from synapse.replication.http.register import (
 | 
			
		||||
@ -840,9 +839,6 @@ class RegistrationHandler:
 | 
			
		||||
        refresh_token = None
 | 
			
		||||
        refresh_token_id = None
 | 
			
		||||
 | 
			
		||||
        # This can only run on the main process.
 | 
			
		||||
        assert isinstance(self.device_handler, DeviceHandler)
 | 
			
		||||
 | 
			
		||||
        registered_device_id = await self.device_handler.check_device_registered(
 | 
			
		||||
            user_id,
 | 
			
		||||
            device_id,
 | 
			
		||||
 | 
			
		||||
@ -21,7 +21,6 @@ import logging
 | 
			
		||||
from typing import TYPE_CHECKING, Optional
 | 
			
		||||
 | 
			
		||||
from synapse.api.errors import Codes, StoreError, SynapseError
 | 
			
		||||
from synapse.handlers.device import DeviceHandler
 | 
			
		||||
from synapse.types import Requester
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
@ -36,17 +35,7 @@ class SetPasswordHandler:
 | 
			
		||||
    def __init__(self, hs: "HomeServer"):
 | 
			
		||||
        self.store = hs.get_datastores().main
 | 
			
		||||
        self._auth_handler = hs.get_auth_handler()
 | 
			
		||||
 | 
			
		||||
        # We don't need the device handler if password changing is disabled.
 | 
			
		||||
        # This allows us to instantiate the SetPasswordHandler on the workers
 | 
			
		||||
        # that have admin APIs for MAS
 | 
			
		||||
        if self._auth_handler.can_change_password():
 | 
			
		||||
            # This can only be instantiated on the main process.
 | 
			
		||||
            device_handler = hs.get_device_handler()
 | 
			
		||||
            assert isinstance(device_handler, DeviceHandler)
 | 
			
		||||
            self._device_handler: Optional[DeviceHandler] = device_handler
 | 
			
		||||
        else:
 | 
			
		||||
            self._device_handler = None
 | 
			
		||||
        self._device_handler = hs.get_device_handler()
 | 
			
		||||
 | 
			
		||||
    async def set_password(
 | 
			
		||||
        self,
 | 
			
		||||
@ -58,9 +47,6 @@ class SetPasswordHandler:
 | 
			
		||||
        if not self._auth_handler.can_change_password():
 | 
			
		||||
            raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN)
 | 
			
		||||
 | 
			
		||||
        # We should have this available only if password changing is enabled.
 | 
			
		||||
        assert self._device_handler is not None
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            await self.store.user_set_password_hash(user_id, password_hash)
 | 
			
		||||
        except StoreError as e:
 | 
			
		||||
 | 
			
		||||
@ -46,7 +46,6 @@ from twisted.web.server import Request
 | 
			
		||||
from synapse.api.constants import LoginType, ProfileFields
 | 
			
		||||
from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
 | 
			
		||||
from synapse.config.sso import SsoAttributeRequirement
 | 
			
		||||
from synapse.handlers.device import DeviceHandler
 | 
			
		||||
from synapse.handlers.register import init_counters_for_auth_provider
 | 
			
		||||
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
 | 
			
		||||
from synapse.http import get_request_user_agent
 | 
			
		||||
@ -1181,8 +1180,6 @@ class SsoHandler:
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        """Revoke any devices and in-flight logins tied to a provider session.
 | 
			
		||||
 | 
			
		||||
        Can only be called from the main process.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            auth_provider_id: A unique identifier for this SSO provider, e.g.
 | 
			
		||||
                "oidc" or "saml".
 | 
			
		||||
@ -1191,11 +1188,6 @@ class SsoHandler:
 | 
			
		||||
                sessions belonging to other users and log an error.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        # It is expected that this is the main process.
 | 
			
		||||
        assert isinstance(self._device_handler, DeviceHandler), (
 | 
			
		||||
            "revoking SSO sessions can only be called on the main process"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Invalidate any running user-mapping sessions
 | 
			
		||||
        to_delete = []
 | 
			
		||||
        for session_id, session in self._username_mapping_sessions.items():
 | 
			
		||||
 | 
			
		||||
@ -66,7 +66,6 @@ from synapse.handlers.auth import (
 | 
			
		||||
    ON_LOGGED_OUT_CALLBACK,
 | 
			
		||||
    AuthHandler,
 | 
			
		||||
)
 | 
			
		||||
from synapse.handlers.device import DeviceHandler
 | 
			
		||||
from synapse.handlers.push_rules import RuleSpec, check_actions
 | 
			
		||||
from synapse.http.client import SimpleHttpClient
 | 
			
		||||
from synapse.http.server import (
 | 
			
		||||
@ -925,8 +924,6 @@ class ModuleApi:
 | 
			
		||||
    ) -> Generator["defer.Deferred[Any]", Any, None]:
 | 
			
		||||
        """Invalidate an access token for a user
 | 
			
		||||
 | 
			
		||||
        Can only be called from the main process.
 | 
			
		||||
 | 
			
		||||
        Added in Synapse v0.25.0.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
@ -939,10 +936,6 @@ class ModuleApi:
 | 
			
		||||
        Raises:
 | 
			
		||||
            synapse.api.errors.AuthError: the access token is invalid
 | 
			
		||||
        """
 | 
			
		||||
        assert isinstance(self._device_handler, DeviceHandler), (
 | 
			
		||||
            "invalidate_access_token can only be called on the main process"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # see if the access token corresponds to a device
 | 
			
		||||
        user_info = yield defer.ensureDeferred(
 | 
			
		||||
            self._auth.get_user_by_access_token(access_token)
 | 
			
		||||
 | 
			
		||||
@ -59,10 +59,10 @@ class ReplicationRestResource(JsonResource):
 | 
			
		||||
        account_data.register_servlets(hs, self)
 | 
			
		||||
        push.register_servlets(hs, self)
 | 
			
		||||
        state.register_servlets(hs, self)
 | 
			
		||||
        devices.register_servlets(hs, self)
 | 
			
		||||
 | 
			
		||||
        # The following can't currently be instantiated on workers.
 | 
			
		||||
        if hs.config.worker.worker_app is None:
 | 
			
		||||
            login.register_servlets(hs, self)
 | 
			
		||||
            register.register_servlets(hs, self)
 | 
			
		||||
            devices.register_servlets(hs, self)
 | 
			
		||||
            delayed_events.register_servlets(hs, self)
 | 
			
		||||
 | 
			
		||||
@ -34,6 +34,92 @@ if TYPE_CHECKING:
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ReplicationNotifyDeviceUpdateRestServlet(ReplicationEndpoint):
 | 
			
		||||
    """Notify a device writer that a user's device list has changed.
 | 
			
		||||
 | 
			
		||||
    Request format:
 | 
			
		||||
 | 
			
		||||
        POST /_synapse/replication/notify_device_update/:user_id
 | 
			
		||||
 | 
			
		||||
        {
 | 
			
		||||
            "device_ids": ["JLAFKJWSCS", "JLAFKJWSCS"]
 | 
			
		||||
        }
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    NAME = "notify_device_update"
 | 
			
		||||
    PATH_ARGS = ("user_id",)
 | 
			
		||||
    CACHE = False
 | 
			
		||||
 | 
			
		||||
    def __init__(self, hs: "HomeServer"):
 | 
			
		||||
        super().__init__(hs)
 | 
			
		||||
 | 
			
		||||
        self.device_handler = hs.get_device_handler()
 | 
			
		||||
        self.store = hs.get_datastores().main
 | 
			
		||||
        self.clock = hs.get_clock()
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    async def _serialize_payload(  # type: ignore[override]
 | 
			
		||||
        user_id: str, device_ids: List[str]
 | 
			
		||||
    ) -> JsonDict:
 | 
			
		||||
        return {"device_ids": device_ids}
 | 
			
		||||
 | 
			
		||||
    async def _handle_request(  # type: ignore[override]
 | 
			
		||||
        self, request: Request, content: JsonDict, user_id: str
 | 
			
		||||
    ) -> Tuple[int, JsonDict]:
 | 
			
		||||
        device_ids = content["device_ids"]
 | 
			
		||||
 | 
			
		||||
        span = active_span()
 | 
			
		||||
        if span:
 | 
			
		||||
            span.set_tag("user_id", user_id)
 | 
			
		||||
            span.set_tag("device_ids", f"{device_ids!r}")
 | 
			
		||||
 | 
			
		||||
        await self.device_handler.notify_device_update(user_id, device_ids)
 | 
			
		||||
 | 
			
		||||
        return 200, {}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ReplicationNotifyUserSignatureUpdateRestServlet(ReplicationEndpoint):
 | 
			
		||||
    """Notify a device writer that a user have made new signatures of other users.
 | 
			
		||||
 | 
			
		||||
    Request format:
 | 
			
		||||
 | 
			
		||||
        POST /_synapse/replication/notify_user_signature_update/:from_user_id
 | 
			
		||||
 | 
			
		||||
        {
 | 
			
		||||
            "user_ids": ["@alice:example.org", "@bob:example.org", ...]
 | 
			
		||||
        }
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    NAME = "notify_user_signature_update"
 | 
			
		||||
    PATH_ARGS = ("from_user_id",)
 | 
			
		||||
    CACHE = False
 | 
			
		||||
 | 
			
		||||
    def __init__(self, hs: "HomeServer"):
 | 
			
		||||
        super().__init__(hs)
 | 
			
		||||
 | 
			
		||||
        self.device_handler = hs.get_device_handler()
 | 
			
		||||
        self.store = hs.get_datastores().main
 | 
			
		||||
        self.clock = hs.get_clock()
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    async def _serialize_payload(from_user_id: str, user_ids: List[str]) -> JsonDict:  # type: ignore[override]
 | 
			
		||||
        return {"user_ids": user_ids}
 | 
			
		||||
 | 
			
		||||
    async def _handle_request(  # type: ignore[override]
 | 
			
		||||
        self, request: Request, content: JsonDict, from_user_id: str
 | 
			
		||||
    ) -> Tuple[int, JsonDict]:
 | 
			
		||||
        user_ids = content["user_ids"]
 | 
			
		||||
 | 
			
		||||
        span = active_span()
 | 
			
		||||
        if span:
 | 
			
		||||
            span.set_tag("from_user_id", from_user_id)
 | 
			
		||||
            span.set_tag("user_ids", f"{user_ids!r}")
 | 
			
		||||
 | 
			
		||||
        await self.device_handler.notify_user_signature_update(from_user_id, user_ids)
 | 
			
		||||
 | 
			
		||||
        return 200, {}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ReplicationMultiUserDevicesResyncRestServlet(ReplicationEndpoint):
 | 
			
		||||
    """Ask master to resync the device list for multiple users from the same
 | 
			
		||||
    remote server by contacting their server.
 | 
			
		||||
@ -73,11 +159,7 @@ class ReplicationMultiUserDevicesResyncRestServlet(ReplicationEndpoint):
 | 
			
		||||
    def __init__(self, hs: "HomeServer"):
 | 
			
		||||
        super().__init__(hs)
 | 
			
		||||
 | 
			
		||||
        from synapse.handlers.device import DeviceHandler
 | 
			
		||||
 | 
			
		||||
        handler = hs.get_device_handler()
 | 
			
		||||
        assert isinstance(handler, DeviceHandler)
 | 
			
		||||
        self.device_list_updater = handler.device_list_updater
 | 
			
		||||
        self.device_list_updater = hs.get_device_handler().device_list_updater
 | 
			
		||||
 | 
			
		||||
        self.store = hs.get_datastores().main
 | 
			
		||||
        self.clock = hs.get_clock()
 | 
			
		||||
@ -103,32 +185,10 @@ class ReplicationMultiUserDevicesResyncRestServlet(ReplicationEndpoint):
 | 
			
		||||
        return 200, multi_user_devices
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# FIXME(2025-07-22): Remove this on the next release, this will only get used
 | 
			
		||||
# during rollout to Synapse 1.135 and can be removed after that release.
 | 
			
		||||
class ReplicationUploadKeysForUserRestServlet(ReplicationEndpoint):
 | 
			
		||||
    """Ask master to upload keys for the user and send them out over federation to
 | 
			
		||||
    update other servers.
 | 
			
		||||
 | 
			
		||||
    For now, only the master is permitted to handle key upload requests;
 | 
			
		||||
    any worker can handle key query requests (since they're read-only).
 | 
			
		||||
 | 
			
		||||
    Calls to e2e_keys_handler.upload_keys_for_user(user_id, device_id, keys) on
 | 
			
		||||
    the main process to accomplish this.
 | 
			
		||||
 | 
			
		||||
    Request format for this endpoint (borrowed and expanded from KeyUploadServlet):
 | 
			
		||||
 | 
			
		||||
        POST /_synapse/replication/upload_keys_for_user
 | 
			
		||||
 | 
			
		||||
    {
 | 
			
		||||
        "user_id": "<user_id>",
 | 
			
		||||
        "device_id": "<device_id>",
 | 
			
		||||
        "keys": {
 | 
			
		||||
            ....this part can be found in KeyUploadServlet in rest/client/keys.py....
 | 
			
		||||
            or as defined in https://spec.matrix.org/v1.4/client-server-api/#post_matrixclientv3keysupload
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Response is equivalent to ` /_matrix/client/v3/keys/upload` found in KeyUploadServlet
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
    """Unused endpoint, kept for backwards compatibility during rollout."""
 | 
			
		||||
 | 
			
		||||
    NAME = "upload_keys_for_user"
 | 
			
		||||
    PATH_ARGS = ()
 | 
			
		||||
@ -165,6 +225,71 @@ class ReplicationUploadKeysForUserRestServlet(ReplicationEndpoint):
 | 
			
		||||
        return 200, results
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ReplicationHandleNewDeviceUpdateRestServlet(ReplicationEndpoint):
 | 
			
		||||
    """Wake up a device writer to send local device list changes as federation outbound pokes.
 | 
			
		||||
 | 
			
		||||
    Request format:
 | 
			
		||||
 | 
			
		||||
        POST /_synapse/replication/handle_new_device_update
 | 
			
		||||
 | 
			
		||||
        {}
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    NAME = "handle_new_device_update"
 | 
			
		||||
    PATH_ARGS = ()
 | 
			
		||||
    CACHE = False
 | 
			
		||||
 | 
			
		||||
    def __init__(self, hs: "HomeServer"):
 | 
			
		||||
        super().__init__(hs)
 | 
			
		||||
 | 
			
		||||
        self.device_handler = hs.get_device_handler()
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    async def _serialize_payload() -> JsonDict:  # type: ignore[override]
 | 
			
		||||
        return {}
 | 
			
		||||
 | 
			
		||||
    async def _handle_request(  # type: ignore[override]
 | 
			
		||||
        self, request: Request, content: JsonDict
 | 
			
		||||
    ) -> Tuple[int, JsonDict]:
 | 
			
		||||
        await self.device_handler.handle_new_device_update()
 | 
			
		||||
        return 200, {}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ReplicationDeviceHandleRoomUnPartialStated(ReplicationEndpoint):
 | 
			
		||||
    """Handles sending appropriate device list updates in a room that has
 | 
			
		||||
    gone from partial to full state.
 | 
			
		||||
 | 
			
		||||
    Request format:
 | 
			
		||||
 | 
			
		||||
        POST /_synapse/replication/device_handle_room_un_partial_stated/:room_id
 | 
			
		||||
 | 
			
		||||
        {}
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    NAME = "device_handle_room_un_partial_stated"
 | 
			
		||||
    PATH_ARGS = ("room_id",)
 | 
			
		||||
    CACHE = True
 | 
			
		||||
 | 
			
		||||
    def __init__(self, hs: "HomeServer"):
 | 
			
		||||
        super().__init__(hs)
 | 
			
		||||
 | 
			
		||||
        self.device_handler = hs.get_device_handler()
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    async def _serialize_payload(room_id: str) -> JsonDict:  # type: ignore[override]
 | 
			
		||||
        return {}
 | 
			
		||||
 | 
			
		||||
    async def _handle_request(  # type: ignore[override]
 | 
			
		||||
        self, request: Request, content: JsonDict, room_id: str
 | 
			
		||||
    ) -> Tuple[int, JsonDict]:
 | 
			
		||||
        await self.device_handler.handle_room_un_partial_stated(room_id)
 | 
			
		||||
        return 200, {}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
 | 
			
		||||
    ReplicationNotifyDeviceUpdateRestServlet(hs).register(http_server)
 | 
			
		||||
    ReplicationNotifyUserSignatureUpdateRestServlet(hs).register(http_server)
 | 
			
		||||
    ReplicationMultiUserDevicesResyncRestServlet(hs).register(http_server)
 | 
			
		||||
    ReplicationHandleNewDeviceUpdateRestServlet(hs).register(http_server)
 | 
			
		||||
    ReplicationUploadKeysForUserRestServlet(hs).register(http_server)
 | 
			
		||||
    ReplicationDeviceHandleRoomUnPartialStated(hs).register(http_server)
 | 
			
		||||
 | 
			
		||||
@ -116,7 +116,11 @@ class ReplicationDataHandler:
 | 
			
		||||
        all_room_ids: Set[str] = set()
 | 
			
		||||
        if stream_name == DeviceListsStream.NAME:
 | 
			
		||||
            if any(not row.is_signature and not row.hosts_calculated for row in rows):
 | 
			
		||||
                prev_token = self.store.get_device_stream_token()
 | 
			
		||||
                # This only uses the minimum stream position on the device lists
 | 
			
		||||
                # stream, which means that we may process a device list change
 | 
			
		||||
                # twice in case of concurrent writes. This is fine, as this only
 | 
			
		||||
                # triggers cache invalidation, which is harmless if done twice.
 | 
			
		||||
                prev_token = self.store.get_device_stream_token().stream
 | 
			
		||||
                all_room_ids = await self.store.get_all_device_list_changes(
 | 
			
		||||
                    prev_token, token
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
@ -72,6 +72,7 @@ from synapse.replication.tcp.streams import (
 | 
			
		||||
    ToDeviceStream,
 | 
			
		||||
    TypingStream,
 | 
			
		||||
)
 | 
			
		||||
from synapse.replication.tcp.streams._base import DeviceListsStream
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from synapse.server import HomeServer
 | 
			
		||||
@ -185,6 +186,12 @@ class ReplicationCommandHandler:
 | 
			
		||||
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            if isinstance(stream, DeviceListsStream):
 | 
			
		||||
                if hs.get_instance_name() in hs.config.worker.writers.device_lists:
 | 
			
		||||
                    self._streams_to_replicate.append(stream)
 | 
			
		||||
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            # Only add any other streams if we're on master.
 | 
			
		||||
            if hs.config.worker.worker_app is not None:
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
@ -51,7 +51,6 @@ from synapse.rest.admin.background_updates import (
 | 
			
		||||
from synapse.rest.admin.devices import (
 | 
			
		||||
    DeleteDevicesRestServlet,
 | 
			
		||||
    DeviceRestServlet,
 | 
			
		||||
    DevicesGetRestServlet,
 | 
			
		||||
    DevicesRestServlet,
 | 
			
		||||
)
 | 
			
		||||
from synapse.rest.admin.event_reports import (
 | 
			
		||||
@ -375,4 +374,5 @@ def register_servlets_for_msc3861_delegation(
 | 
			
		||||
    UserRestServletV2(hs).register(http_server)
 | 
			
		||||
    UsernameAvailableRestServlet(hs).register(http_server)
 | 
			
		||||
    UserReplaceMasterCrossSigningKeyRestServlet(hs).register(http_server)
 | 
			
		||||
    DevicesGetRestServlet(hs).register(http_server)
 | 
			
		||||
    DeviceRestServlet(hs).register(http_server)
 | 
			
		||||
    DevicesRestServlet(hs).register(http_server)
 | 
			
		||||
 | 
			
		||||
@ -23,7 +23,6 @@ from http import HTTPStatus
 | 
			
		||||
from typing import TYPE_CHECKING, Tuple
 | 
			
		||||
 | 
			
		||||
from synapse.api.errors import NotFoundError, SynapseError
 | 
			
		||||
from synapse.handlers.device import DeviceHandler
 | 
			
		||||
from synapse.http.servlet import (
 | 
			
		||||
    RestServlet,
 | 
			
		||||
    assert_params_in_dict,
 | 
			
		||||
@ -51,9 +50,7 @@ class DeviceRestServlet(RestServlet):
 | 
			
		||||
    def __init__(self, hs: "HomeServer"):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.auth = hs.get_auth()
 | 
			
		||||
        handler = hs.get_device_handler()
 | 
			
		||||
        assert isinstance(handler, DeviceHandler)
 | 
			
		||||
        self.device_handler = handler
 | 
			
		||||
        self.device_handler = hs.get_device_handler()
 | 
			
		||||
        self.store = hs.get_datastores().main
 | 
			
		||||
        self.is_mine = hs.is_mine
 | 
			
		||||
 | 
			
		||||
@ -113,7 +110,7 @@ class DeviceRestServlet(RestServlet):
 | 
			
		||||
        return HTTPStatus.OK, {}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DevicesGetRestServlet(RestServlet):
 | 
			
		||||
class DevicesRestServlet(RestServlet):
 | 
			
		||||
    """
 | 
			
		||||
    Retrieve the given user's devices
 | 
			
		||||
 | 
			
		||||
@ -158,19 +155,6 @@ class DevicesGetRestServlet(RestServlet):
 | 
			
		||||
 | 
			
		||||
        return HTTPStatus.OK, {"devices": devices, "total": len(devices)}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DevicesRestServlet(DevicesGetRestServlet):
 | 
			
		||||
    """
 | 
			
		||||
    Retrieve the given user's devices
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/devices$", "v2")
 | 
			
		||||
 | 
			
		||||
    def __init__(self, hs: "HomeServer"):
 | 
			
		||||
        super().__init__(hs)
 | 
			
		||||
        assert isinstance(self.device_worker_handler, DeviceHandler)
 | 
			
		||||
        self.device_handler = self.device_worker_handler
 | 
			
		||||
 | 
			
		||||
    async def on_POST(
 | 
			
		||||
        self, request: SynapseRequest, user_id: str
 | 
			
		||||
    ) -> Tuple[int, JsonDict]:
 | 
			
		||||
@ -194,7 +178,7 @@ class DevicesRestServlet(DevicesGetRestServlet):
 | 
			
		||||
        if not isinstance(device_id, str):
 | 
			
		||||
            raise SynapseError(HTTPStatus.BAD_REQUEST, "device_id must be a string")
 | 
			
		||||
 | 
			
		||||
        await self.device_handler.check_device_registered(
 | 
			
		||||
        await self.device_worker_handler.check_device_registered(
 | 
			
		||||
            user_id=user_id, device_id=device_id
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
@ -211,9 +195,7 @@ class DeleteDevicesRestServlet(RestServlet):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, hs: "HomeServer"):
 | 
			
		||||
        self.auth = hs.get_auth()
 | 
			
		||||
        handler = hs.get_device_handler()
 | 
			
		||||
        assert isinstance(handler, DeviceHandler)
 | 
			
		||||
        self.device_handler = handler
 | 
			
		||||
        self.device_handler = hs.get_device_handler()
 | 
			
		||||
        self.store = hs.get_datastores().main
 | 
			
		||||
        self.is_mine = hs.is_mine
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -27,7 +27,6 @@ from typing import TYPE_CHECKING, List, Optional, Tuple
 | 
			
		||||
from synapse._pydantic_compat import Extra, StrictStr
 | 
			
		||||
from synapse.api import errors
 | 
			
		||||
from synapse.api.errors import NotFoundError, SynapseError, UnrecognizedRequestError
 | 
			
		||||
from synapse.handlers.device import DeviceHandler
 | 
			
		||||
from synapse.http.server import HttpServer
 | 
			
		||||
from synapse.http.servlet import (
 | 
			
		||||
    RestServlet,
 | 
			
		||||
@ -91,7 +90,6 @@ class DeleteDevicesRestServlet(RestServlet):
 | 
			
		||||
        self.hs = hs
 | 
			
		||||
        self.auth = hs.get_auth()
 | 
			
		||||
        handler = hs.get_device_handler()
 | 
			
		||||
        assert isinstance(handler, DeviceHandler)
 | 
			
		||||
        self.device_handler = handler
 | 
			
		||||
        self.auth_handler = hs.get_auth_handler()
 | 
			
		||||
 | 
			
		||||
@ -147,7 +145,6 @@ class DeviceRestServlet(RestServlet):
 | 
			
		||||
        self.auth_handler = hs.get_auth_handler()
 | 
			
		||||
        self._msc3852_enabled = hs.config.experimental.msc3852_enabled
 | 
			
		||||
        self._msc3861_oauth_delegation_enabled = hs.config.experimental.msc3861.enabled
 | 
			
		||||
        self._is_main_process = hs.config.worker.worker_app is None
 | 
			
		||||
 | 
			
		||||
    async def on_GET(
 | 
			
		||||
        self, request: SynapseRequest, device_id: str
 | 
			
		||||
@ -179,14 +176,6 @@ class DeviceRestServlet(RestServlet):
 | 
			
		||||
    async def on_DELETE(
 | 
			
		||||
        self, request: SynapseRequest, device_id: str
 | 
			
		||||
    ) -> Tuple[int, JsonDict]:
 | 
			
		||||
        # Can only be run on main process, as changes to device lists must
 | 
			
		||||
        # happen on main.
 | 
			
		||||
        if not self._is_main_process:
 | 
			
		||||
            error_message = "DELETE on /devices/ must be routed to main process"
 | 
			
		||||
            logger.error(error_message)
 | 
			
		||||
            raise SynapseError(500, error_message)
 | 
			
		||||
        assert isinstance(self.device_handler, DeviceHandler)
 | 
			
		||||
 | 
			
		||||
        requester = await self.auth.get_user_by_req(request)
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
@ -231,14 +220,6 @@ class DeviceRestServlet(RestServlet):
 | 
			
		||||
    async def on_PUT(
 | 
			
		||||
        self, request: SynapseRequest, device_id: str
 | 
			
		||||
    ) -> Tuple[int, JsonDict]:
 | 
			
		||||
        # Can only be run on main process, as changes to device lists must
 | 
			
		||||
        # happen on main.
 | 
			
		||||
        if not self._is_main_process:
 | 
			
		||||
            error_message = "PUT on /devices/ must be routed to main process"
 | 
			
		||||
            logger.error(error_message)
 | 
			
		||||
            raise SynapseError(500, error_message)
 | 
			
		||||
        assert isinstance(self.device_handler, DeviceHandler)
 | 
			
		||||
 | 
			
		||||
        requester = await self.auth.get_user_by_req(request, allow_guest=True)
 | 
			
		||||
 | 
			
		||||
        body = parse_and_validate_json_object_from_request(request, self.PutBody)
 | 
			
		||||
@ -317,7 +298,6 @@ class DehydratedDeviceServlet(RestServlet):
 | 
			
		||||
        self.hs = hs
 | 
			
		||||
        self.auth = hs.get_auth()
 | 
			
		||||
        handler = hs.get_device_handler()
 | 
			
		||||
        assert isinstance(handler, DeviceHandler)
 | 
			
		||||
        self.device_handler = handler
 | 
			
		||||
 | 
			
		||||
    async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
 | 
			
		||||
@ -377,7 +357,6 @@ class ClaimDehydratedDeviceServlet(RestServlet):
 | 
			
		||||
        self.hs = hs
 | 
			
		||||
        self.auth = hs.get_auth()
 | 
			
		||||
        handler = hs.get_device_handler()
 | 
			
		||||
        assert isinstance(handler, DeviceHandler)
 | 
			
		||||
        self.device_handler = handler
 | 
			
		||||
 | 
			
		||||
    class PostBody(RequestBodyModel):
 | 
			
		||||
@ -517,7 +496,6 @@ class DehydratedDeviceV2Servlet(RestServlet):
 | 
			
		||||
        self.hs = hs
 | 
			
		||||
        self.auth = hs.get_auth()
 | 
			
		||||
        handler = hs.get_device_handler()
 | 
			
		||||
        assert isinstance(handler, DeviceHandler)
 | 
			
		||||
        self.e2e_keys_handler = hs.get_e2e_keys_handler()
 | 
			
		||||
        self.device_handler = handler
 | 
			
		||||
 | 
			
		||||
@ -595,18 +573,14 @@ class DehydratedDeviceV2Servlet(RestServlet):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
 | 
			
		||||
    if (
 | 
			
		||||
        hs.config.worker.worker_app is None
 | 
			
		||||
        and not hs.config.experimental.msc3861.enabled
 | 
			
		||||
    ):
 | 
			
		||||
    if not hs.config.experimental.msc3861.enabled:
 | 
			
		||||
        DeleteDevicesRestServlet(hs).register(http_server)
 | 
			
		||||
    DevicesRestServlet(hs).register(http_server)
 | 
			
		||||
    DeviceRestServlet(hs).register(http_server)
 | 
			
		||||
 | 
			
		||||
    if hs.config.worker.worker_app is None:
 | 
			
		||||
        if hs.config.experimental.msc2697_enabled:
 | 
			
		||||
            DehydratedDeviceServlet(hs).register(http_server)
 | 
			
		||||
            ClaimDehydratedDeviceServlet(hs).register(http_server)
 | 
			
		||||
        if hs.config.experimental.msc3814_enabled:
 | 
			
		||||
            DehydratedDeviceV2Servlet(hs).register(http_server)
 | 
			
		||||
            DehydratedDeviceEventsServlet(hs).register(http_server)
 | 
			
		||||
    if hs.config.experimental.msc2697_enabled:
 | 
			
		||||
        DehydratedDeviceServlet(hs).register(http_server)
 | 
			
		||||
        ClaimDehydratedDeviceServlet(hs).register(http_server)
 | 
			
		||||
    if hs.config.experimental.msc3814_enabled:
 | 
			
		||||
        DehydratedDeviceV2Servlet(hs).register(http_server)
 | 
			
		||||
        DehydratedDeviceEventsServlet(hs).register(http_server)
 | 
			
		||||
 | 
			
		||||
@ -504,6 +504,5 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
 | 
			
		||||
    OneTimeKeyServlet(hs).register(http_server)
 | 
			
		||||
    if hs.config.experimental.msc3983_appservice_otk_claims:
 | 
			
		||||
        UnstableOneTimeKeyServlet(hs).register(http_server)
 | 
			
		||||
    if hs.config.worker.worker_app is None:
 | 
			
		||||
        SigningKeyUploadServlet(hs).register(http_server)
 | 
			
		||||
        SignaturesUploadServlet(hs).register(http_server)
 | 
			
		||||
    SigningKeyUploadServlet(hs).register(http_server)
 | 
			
		||||
    SignaturesUploadServlet(hs).register(http_server)
 | 
			
		||||
 | 
			
		||||
@ -22,7 +22,6 @@
 | 
			
		||||
import logging
 | 
			
		||||
from typing import TYPE_CHECKING, Tuple
 | 
			
		||||
 | 
			
		||||
from synapse.handlers.device import DeviceHandler
 | 
			
		||||
from synapse.http.server import HttpServer
 | 
			
		||||
from synapse.http.servlet import RestServlet
 | 
			
		||||
from synapse.http.site import SynapseRequest
 | 
			
		||||
@ -42,9 +41,7 @@ class LogoutRestServlet(RestServlet):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.auth = hs.get_auth()
 | 
			
		||||
        self._auth_handler = hs.get_auth_handler()
 | 
			
		||||
        handler = hs.get_device_handler()
 | 
			
		||||
        assert isinstance(handler, DeviceHandler)
 | 
			
		||||
        self._device_handler = handler
 | 
			
		||||
        self._device_handler = hs.get_device_handler()
 | 
			
		||||
 | 
			
		||||
    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
 | 
			
		||||
        requester = await self.auth.get_user_by_req(
 | 
			
		||||
@ -71,9 +68,7 @@ class LogoutAllRestServlet(RestServlet):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.auth = hs.get_auth()
 | 
			
		||||
        self._auth_handler = hs.get_auth_handler()
 | 
			
		||||
        handler = hs.get_device_handler()
 | 
			
		||||
        assert isinstance(handler, DeviceHandler)
 | 
			
		||||
        self._device_handler = handler
 | 
			
		||||
        self._device_handler = hs.get_device_handler()
 | 
			
		||||
 | 
			
		||||
    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
 | 
			
		||||
        requester = await self.auth.get_user_by_req(
 | 
			
		||||
 | 
			
		||||
@ -69,7 +69,7 @@ from synapse.handlers.auth import AuthHandler, PasswordAuthProvider
 | 
			
		||||
from synapse.handlers.cas import CasHandler
 | 
			
		||||
from synapse.handlers.deactivate_account import DeactivateAccountHandler
 | 
			
		||||
from synapse.handlers.delayed_events import DelayedEventsHandler
 | 
			
		||||
from synapse.handlers.device import DeviceHandler, DeviceWorkerHandler
 | 
			
		||||
from synapse.handlers.device import DeviceHandler, DeviceWriterHandler
 | 
			
		||||
from synapse.handlers.devicemessage import DeviceMessageHandler
 | 
			
		||||
from synapse.handlers.directory import DirectoryHandler
 | 
			
		||||
from synapse.handlers.e2e_keys import E2eKeysHandler
 | 
			
		||||
@ -586,11 +586,11 @@ class HomeServer(metaclass=abc.ABCMeta):
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @cache_in_self
 | 
			
		||||
    def get_device_handler(self) -> DeviceWorkerHandler:
 | 
			
		||||
        if self.config.worker.worker_app:
 | 
			
		||||
            return DeviceWorkerHandler(self)
 | 
			
		||||
        else:
 | 
			
		||||
            return DeviceHandler(self)
 | 
			
		||||
    def get_device_handler(self) -> DeviceHandler:
 | 
			
		||||
        if self.get_instance_name() in self.config.worker.writers.device_lists:
 | 
			
		||||
            return DeviceWriterHandler(self)
 | 
			
		||||
 | 
			
		||||
        return DeviceHandler(self)
 | 
			
		||||
 | 
			
		||||
    @cache_in_self
 | 
			
		||||
    def get_device_message_handler(self) -> DeviceMessageHandler:
 | 
			
		||||
 | 
			
		||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							@ -59,7 +59,7 @@ from synapse.storage.database import (
 | 
			
		||||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 | 
			
		||||
from synapse.storage.engines import PostgresEngine
 | 
			
		||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator
 | 
			
		||||
from synapse.types import JsonDict, JsonMapping
 | 
			
		||||
from synapse.types import JsonDict, JsonMapping, MultiWriterStreamToken
 | 
			
		||||
from synapse.util import json_decoder, json_encoder
 | 
			
		||||
from synapse.util.caches.descriptors import cached, cachedList
 | 
			
		||||
from synapse.util.cancellation import cancellable
 | 
			
		||||
@ -120,6 +120,20 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
 | 
			
		||||
            self.hs.config.federation.allow_device_name_lookup_over_federation
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self._cross_signing_id_gen = MultiWriterIdGenerator(
 | 
			
		||||
            db_conn=db_conn,
 | 
			
		||||
            db=database,
 | 
			
		||||
            notifier=hs.get_replication_notifier(),
 | 
			
		||||
            stream_name="e2e_cross_signing_keys",
 | 
			
		||||
            instance_name=self._instance_name,
 | 
			
		||||
            tables=[
 | 
			
		||||
                ("e2e_cross_signing_keys", "instance_name", "stream_id"),
 | 
			
		||||
            ],
 | 
			
		||||
            sequence_name="e2e_cross_signing_keys_sequence",
 | 
			
		||||
            # No one reads the stream positions, so we're allowed to have an empty list of writers
 | 
			
		||||
            writers=[],
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def process_replication_rows(
 | 
			
		||||
        self,
 | 
			
		||||
        stream_name: str,
 | 
			
		||||
@ -145,7 +159,12 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
 | 
			
		||||
        Returns:
 | 
			
		||||
            (stream_id, devices)
 | 
			
		||||
        """
 | 
			
		||||
        now_stream_id = self.get_device_stream_token()
 | 
			
		||||
        # Here, we don't use the individual instances positions, as we *need* to
 | 
			
		||||
        # give out the stream_id as an integer in the federation API.
 | 
			
		||||
        # This means that we'll potentially return the same data twice with a
 | 
			
		||||
        # different stream_id, and invalidate cache more often than necessary,
 | 
			
		||||
        # which is fine overall.
 | 
			
		||||
        now_stream_id = self.get_device_stream_token().stream
 | 
			
		||||
 | 
			
		||||
        # We need to be careful with the caching here, as we need to always
 | 
			
		||||
        # return *all* persisted devices, however there may be a lag between a
 | 
			
		||||
@ -164,8 +183,10 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
 | 
			
		||||
            # have to check for potential invalidations after the
 | 
			
		||||
            # `now_stream_id`.
 | 
			
		||||
            sql = """
 | 
			
		||||
                SELECT user_id FROM device_lists_stream
 | 
			
		||||
                SELECT 1
 | 
			
		||||
                FROM device_lists_stream
 | 
			
		||||
                WHERE stream_id >= ? AND user_id = ?
 | 
			
		||||
                LIMIT 1
 | 
			
		||||
            """
 | 
			
		||||
            rows = await self.db_pool.execute(
 | 
			
		||||
                "get_e2e_device_keys_for_federation_query_check",
 | 
			
		||||
@ -1117,7 +1138,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @abc.abstractmethod
 | 
			
		||||
    def get_device_stream_token(self) -> int:
 | 
			
		||||
    def get_device_stream_token(self) -> MultiWriterStreamToken:
 | 
			
		||||
        """Get the current stream id from the _device_list_id_gen"""
 | 
			
		||||
        ...
 | 
			
		||||
 | 
			
		||||
@ -1540,27 +1561,44 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
 | 
			
		||||
            impl,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
 | 
			
		||||
        def delete_e2e_keys_by_device_txn(txn: LoggingTransaction) -> None:
 | 
			
		||||
            log_kv(
 | 
			
		||||
                {
 | 
			
		||||
                    "message": "Deleting keys for device",
 | 
			
		||||
                    "device_id": device_id,
 | 
			
		||||
                    "user_id": user_id,
 | 
			
		||||
                }
 | 
			
		||||
            )
 | 
			
		||||
            self.db_pool.simple_delete_txn(
 | 
			
		||||
                txn,
 | 
			
		||||
                table="e2e_device_keys_json",
 | 
			
		||||
                keyvalues={"user_id": user_id, "device_id": device_id},
 | 
			
		||||
            )
 | 
			
		||||
            self.db_pool.simple_delete_txn(
 | 
			
		||||
                txn,
 | 
			
		||||
                table="e2e_one_time_keys_json",
 | 
			
		||||
                keyvalues={"user_id": user_id, "device_id": device_id},
 | 
			
		||||
            )
 | 
			
		||||
            self._invalidate_cache_and_stream(
 | 
			
		||||
                txn, self.count_e2e_one_time_keys, (user_id, device_id)
 | 
			
		||||
            )
 | 
			
		||||
            self.db_pool.simple_delete_txn(
 | 
			
		||||
                txn,
 | 
			
		||||
                table="dehydrated_devices",
 | 
			
		||||
                keyvalues={"user_id": user_id, "device_id": device_id},
 | 
			
		||||
            )
 | 
			
		||||
            self.db_pool.simple_delete_txn(
 | 
			
		||||
                txn,
 | 
			
		||||
                table="e2e_fallback_keys_json",
 | 
			
		||||
                keyvalues={"user_id": user_id, "device_id": device_id},
 | 
			
		||||
            )
 | 
			
		||||
            self._invalidate_cache_and_stream(
 | 
			
		||||
                txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        database: DatabasePool,
 | 
			
		||||
        db_conn: LoggingDatabaseConnection,
 | 
			
		||||
        hs: "HomeServer",
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__(database, db_conn, hs)
 | 
			
		||||
 | 
			
		||||
        self._cross_signing_id_gen = MultiWriterIdGenerator(
 | 
			
		||||
            db_conn=db_conn,
 | 
			
		||||
            db=database,
 | 
			
		||||
            notifier=hs.get_replication_notifier(),
 | 
			
		||||
            stream_name="e2e_cross_signing_keys",
 | 
			
		||||
            instance_name=self._instance_name,
 | 
			
		||||
            tables=[
 | 
			
		||||
                ("e2e_cross_signing_keys", "instance_name", "stream_id"),
 | 
			
		||||
            ],
 | 
			
		||||
            sequence_name="e2e_cross_signing_keys_sequence",
 | 
			
		||||
            writers=["master"],
 | 
			
		||||
        await self.db_pool.runInteraction(
 | 
			
		||||
            "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    async def set_e2e_device_keys(
 | 
			
		||||
@ -1754,3 +1792,13 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
 | 
			
		||||
            ],
 | 
			
		||||
            desc="add_e2e_signing_key",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        database: DatabasePool,
 | 
			
		||||
        db_conn: LoggingDatabaseConnection,
 | 
			
		||||
        hs: "HomeServer",
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__(database, db_conn, hs)
 | 
			
		||||
 | 
			
		||||
@ -36,7 +36,6 @@ from typing import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
import attr
 | 
			
		||||
from immutabledict import immutabledict
 | 
			
		||||
 | 
			
		||||
from synapse.api.constants import EduTypes
 | 
			
		||||
from synapse.replication.tcp.streams import ReceiptsStream
 | 
			
		||||
@ -167,25 +166,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
 | 
			
		||||
    def get_max_receipt_stream_id(self) -> MultiWriterStreamToken:
 | 
			
		||||
        """Get the current max stream ID for receipts stream"""
 | 
			
		||||
 | 
			
		||||
        min_pos = self._receipts_id_gen.get_current_token()
 | 
			
		||||
 | 
			
		||||
        positions = {}
 | 
			
		||||
        if isinstance(self._receipts_id_gen, MultiWriterIdGenerator):
 | 
			
		||||
            # The `min_pos` is the minimum position that we know all instances
 | 
			
		||||
            # have finished persisting to, so we only care about instances whose
 | 
			
		||||
            # positions are ahead of that. (Instance positions can be behind the
 | 
			
		||||
            # min position as there are times we can work out that the minimum
 | 
			
		||||
            # position is ahead of the naive minimum across all current
 | 
			
		||||
            # positions. See MultiWriterIdGenerator for details)
 | 
			
		||||
            positions = {
 | 
			
		||||
                i: p
 | 
			
		||||
                for i, p in self._receipts_id_gen.get_positions().items()
 | 
			
		||||
                if p > min_pos
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
        return MultiWriterStreamToken(
 | 
			
		||||
            stream=min_pos, instance_map=immutabledict(positions)
 | 
			
		||||
        )
 | 
			
		||||
        return MultiWriterStreamToken.from_generator(self._receipts_id_gen)
 | 
			
		||||
 | 
			
		||||
    def get_receipt_stream_id_for_instance(self, instance_name: str) -> int:
 | 
			
		||||
        return self._receipts_id_gen.get_current_token_for_writer(instance_name)
 | 
			
		||||
 | 
			
		||||
@ -2093,6 +2093,58 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore):
 | 
			
		||||
            "replace_refresh_token", _replace_refresh_token_txn
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    async def set_device_for_refresh_token(
 | 
			
		||||
        self, user_id: str, old_device_id: str, device_id: str
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        """Moves refresh tokens from old device to current device
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            user_id: The user of the devices.
 | 
			
		||||
            old_device_id: The old device.
 | 
			
		||||
            device_id: The new device ID.
 | 
			
		||||
        Returns:
 | 
			
		||||
            None
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        await self.db_pool.simple_update(
 | 
			
		||||
            "refresh_tokens",
 | 
			
		||||
            keyvalues={"user_id": user_id, "device_id": old_device_id},
 | 
			
		||||
            updatevalues={"device_id": device_id},
 | 
			
		||||
            desc="set_device_for_refresh_token",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _set_device_for_access_token_txn(
 | 
			
		||||
        self, txn: LoggingTransaction, token: str, device_id: str
 | 
			
		||||
    ) -> str:
 | 
			
		||||
        old_device_id = self.db_pool.simple_select_one_onecol_txn(
 | 
			
		||||
            txn, "access_tokens", {"token": token}, "device_id"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.db_pool.simple_update_txn(
 | 
			
		||||
            txn, "access_tokens", {"token": token}, {"device_id": device_id}
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self._invalidate_cache_and_stream(txn, self.get_user_by_access_token, (token,))
 | 
			
		||||
 | 
			
		||||
        return old_device_id
 | 
			
		||||
 | 
			
		||||
    async def set_device_for_access_token(self, token: str, device_id: str) -> str:
 | 
			
		||||
        """Sets the device ID associated with an access token.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            token: The access token to modify.
 | 
			
		||||
            device_id: The new device ID.
 | 
			
		||||
        Returns:
 | 
			
		||||
            The old device ID associated with the access token.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        return await self.db_pool.runInteraction(
 | 
			
		||||
            "set_device_for_access_token",
 | 
			
		||||
            self._set_device_for_access_token_txn,
 | 
			
		||||
            token,
 | 
			
		||||
            device_id,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    async def add_login_token_to_user(
 | 
			
		||||
        self,
 | 
			
		||||
        user_id: str,
 | 
			
		||||
@ -2396,6 +2448,154 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore):
 | 
			
		||||
        self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
 | 
			
		||||
        self._invalidate_cache_and_stream(txn, self.is_user_approved, (user_id,))
 | 
			
		||||
 | 
			
		||||
    async def user_delete_access_tokens(
 | 
			
		||||
        self,
 | 
			
		||||
        user_id: str,
 | 
			
		||||
        except_token_id: Optional[int] = None,
 | 
			
		||||
        device_id: Optional[str] = None,
 | 
			
		||||
    ) -> List[Tuple[str, int, Optional[str]]]:
 | 
			
		||||
        """
 | 
			
		||||
        Invalidate access and refresh tokens belonging to a user
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            user_id: ID of user the tokens belong to
 | 
			
		||||
            except_token_id: access_tokens ID which should *not* be deleted
 | 
			
		||||
            device_id: ID of device the tokens are associated with.
 | 
			
		||||
                If None, tokens associated with any device (or no device) will
 | 
			
		||||
                be deleted
 | 
			
		||||
        Returns:
 | 
			
		||||
            A tuple of (token, token id, device id) for each of the deleted tokens
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        def f(txn: LoggingTransaction) -> List[Tuple[str, int, Optional[str]]]:
 | 
			
		||||
            keyvalues = {"user_id": user_id}
 | 
			
		||||
            if device_id is not None:
 | 
			
		||||
                keyvalues["device_id"] = device_id
 | 
			
		||||
 | 
			
		||||
            items = keyvalues.items()
 | 
			
		||||
            where_clause = " AND ".join(k + " = ?" for k, _ in items)
 | 
			
		||||
            values: List[Union[str, int]] = [v for _, v in items]
 | 
			
		||||
            # Conveniently, refresh_tokens and access_tokens both use the user_id and device_id fields. Only caveat
 | 
			
		||||
            # is the `except_token_id` param that is tricky to get right, so for now we're just using the same where
 | 
			
		||||
            # clause and values before we handle that. This seems to be only used in the "set password" handler.
 | 
			
		||||
            refresh_where_clause = where_clause
 | 
			
		||||
            refresh_values = values.copy()
 | 
			
		||||
            if except_token_id:
 | 
			
		||||
                # TODO: support that for refresh tokens
 | 
			
		||||
                where_clause += " AND id != ?"
 | 
			
		||||
                values.append(except_token_id)
 | 
			
		||||
 | 
			
		||||
            txn.execute(
 | 
			
		||||
                "SELECT token, id, device_id FROM access_tokens WHERE %s"
 | 
			
		||||
                % where_clause,
 | 
			
		||||
                values,
 | 
			
		||||
            )
 | 
			
		||||
            tokens_and_devices = [(r[0], r[1], r[2]) for r in txn]
 | 
			
		||||
 | 
			
		||||
            self._invalidate_cache_and_stream_bulk(
 | 
			
		||||
                txn,
 | 
			
		||||
                self.get_user_by_access_token,
 | 
			
		||||
                [(token,) for token, _, _ in tokens_and_devices],
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            txn.execute("DELETE FROM access_tokens WHERE %s" % where_clause, values)
 | 
			
		||||
 | 
			
		||||
            txn.execute(
 | 
			
		||||
                "DELETE FROM refresh_tokens WHERE %s" % refresh_where_clause,
 | 
			
		||||
                refresh_values,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            return tokens_and_devices
 | 
			
		||||
 | 
			
		||||
        return await self.db_pool.runInteraction("user_delete_access_tokens", f)
 | 
			
		||||
 | 
			
		||||
    async def user_delete_access_tokens_for_devices(
 | 
			
		||||
        self,
 | 
			
		||||
        user_id: str,
 | 
			
		||||
        device_ids: StrCollection,
 | 
			
		||||
    ) -> List[Tuple[str, int, Optional[str]]]:
 | 
			
		||||
        """
 | 
			
		||||
        Invalidate access and refresh tokens belonging to a user
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            user_id: ID of user the tokens belong to
 | 
			
		||||
            device_ids: The devices to delete tokens for.
 | 
			
		||||
        Returns:
 | 
			
		||||
            A tuple of (token, token id, device id) for each of the deleted tokens
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        def user_delete_access_tokens_for_devices_txn(
 | 
			
		||||
            txn: LoggingTransaction, batch_device_ids: StrCollection
 | 
			
		||||
        ) -> List[Tuple[str, int, Optional[str]]]:
 | 
			
		||||
            self.db_pool.simple_delete_many_txn(
 | 
			
		||||
                txn,
 | 
			
		||||
                table="refresh_tokens",
 | 
			
		||||
                keyvalues={"user_id": user_id},
 | 
			
		||||
                column="device_id",
 | 
			
		||||
                values=batch_device_ids,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            clause, args = make_in_list_sql_clause(
 | 
			
		||||
                txn.database_engine, "device_id", batch_device_ids
 | 
			
		||||
            )
 | 
			
		||||
            args.append(user_id)
 | 
			
		||||
 | 
			
		||||
            if self.database_engine.supports_returning:
 | 
			
		||||
                sql = f"""
 | 
			
		||||
                    DELETE FROM access_tokens
 | 
			
		||||
                    WHERE {clause} AND user_id = ?
 | 
			
		||||
                    RETURNING token, id, device_id
 | 
			
		||||
                """
 | 
			
		||||
                txn.execute(sql, args)
 | 
			
		||||
                tokens_and_devices = txn.fetchall()
 | 
			
		||||
            else:
 | 
			
		||||
                tokens_and_devices = self.db_pool.simple_select_many_txn(
 | 
			
		||||
                    txn,
 | 
			
		||||
                    table="access_tokens",
 | 
			
		||||
                    column="device_id",
 | 
			
		||||
                    iterable=batch_device_ids,
 | 
			
		||||
                    keyvalues={"user_id": user_id},
 | 
			
		||||
                    retcols=("token", "id", "device_id"),
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                self.db_pool.simple_delete_many_txn(
 | 
			
		||||
                    txn,
 | 
			
		||||
                    table="access_tokens",
 | 
			
		||||
                    keyvalues={"user_id": user_id},
 | 
			
		||||
                    column="device_id",
 | 
			
		||||
                    values=batch_device_ids,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            self._invalidate_cache_and_stream_bulk(
 | 
			
		||||
                txn,
 | 
			
		||||
                self.get_user_by_access_token,
 | 
			
		||||
                [(t[0],) for t in tokens_and_devices],
 | 
			
		||||
            )
 | 
			
		||||
            return tokens_and_devices
 | 
			
		||||
 | 
			
		||||
        results = []
 | 
			
		||||
        for batch_device_ids in batch_iter(device_ids, 1000):
 | 
			
		||||
            tokens_and_devices = await self.db_pool.runInteraction(
 | 
			
		||||
                "user_delete_access_tokens_for_devices",
 | 
			
		||||
                user_delete_access_tokens_for_devices_txn,
 | 
			
		||||
                batch_device_ids,
 | 
			
		||||
            )
 | 
			
		||||
            results.extend(tokens_and_devices)
 | 
			
		||||
 | 
			
		||||
        return results
 | 
			
		||||
 | 
			
		||||
    async def delete_access_token(self, access_token: str) -> None:
 | 
			
		||||
        def f(txn: LoggingTransaction) -> None:
 | 
			
		||||
            self.db_pool.simple_delete_one_txn(
 | 
			
		||||
                txn, table="access_tokens", keyvalues={"token": access_token}
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            self._invalidate_cache_and_stream(
 | 
			
		||||
                txn, self.get_user_by_access_token, (access_token,)
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        await self.db_pool.runInteraction("delete_access_token", f)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
 | 
			
		||||
    def __init__(
 | 
			
		||||
@ -2620,58 +2820,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
 | 
			
		||||
 | 
			
		||||
        return next_id
 | 
			
		||||
 | 
			
		||||
    async def set_device_for_refresh_token(
 | 
			
		||||
        self, user_id: str, old_device_id: str, device_id: str
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        """Moves refresh tokens from old device to current device
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            user_id: The user of the devices.
 | 
			
		||||
            old_device_id: The old device.
 | 
			
		||||
            device_id: The new device ID.
 | 
			
		||||
        Returns:
 | 
			
		||||
            None
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        await self.db_pool.simple_update(
 | 
			
		||||
            "refresh_tokens",
 | 
			
		||||
            keyvalues={"user_id": user_id, "device_id": old_device_id},
 | 
			
		||||
            updatevalues={"device_id": device_id},
 | 
			
		||||
            desc="set_device_for_refresh_token",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _set_device_for_access_token_txn(
 | 
			
		||||
        self, txn: LoggingTransaction, token: str, device_id: str
 | 
			
		||||
    ) -> str:
 | 
			
		||||
        old_device_id = self.db_pool.simple_select_one_onecol_txn(
 | 
			
		||||
            txn, "access_tokens", {"token": token}, "device_id"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.db_pool.simple_update_txn(
 | 
			
		||||
            txn, "access_tokens", {"token": token}, {"device_id": device_id}
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self._invalidate_cache_and_stream(txn, self.get_user_by_access_token, (token,))
 | 
			
		||||
 | 
			
		||||
        return old_device_id
 | 
			
		||||
 | 
			
		||||
    async def set_device_for_access_token(self, token: str, device_id: str) -> str:
 | 
			
		||||
        """Sets the device ID associated with an access token.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            token: The access token to modify.
 | 
			
		||||
            device_id: The new device ID.
 | 
			
		||||
        Returns:
 | 
			
		||||
            The old device ID associated with the access token.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        return await self.db_pool.runInteraction(
 | 
			
		||||
            "set_device_for_access_token",
 | 
			
		||||
            self._set_device_for_access_token_txn,
 | 
			
		||||
            token,
 | 
			
		||||
            device_id,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    async def user_set_password_hash(
 | 
			
		||||
        self, user_id: str, password_hash: Optional[str]
 | 
			
		||||
    ) -> None:
 | 
			
		||||
@ -2743,162 +2891,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
 | 
			
		||||
 | 
			
		||||
        await self.db_pool.runInteraction("user_set_consent_server_notice_sent", f)
 | 
			
		||||
 | 
			
		||||
    async def user_delete_access_tokens(
 | 
			
		||||
        self,
 | 
			
		||||
        user_id: str,
 | 
			
		||||
        except_token_id: Optional[int] = None,
 | 
			
		||||
        device_id: Optional[str] = None,
 | 
			
		||||
    ) -> List[Tuple[str, int, Optional[str]]]:
 | 
			
		||||
        """
 | 
			
		||||
        Invalidate access and refresh tokens belonging to a user
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            user_id: ID of user the tokens belong to
 | 
			
		||||
            except_token_id: access_tokens ID which should *not* be deleted
 | 
			
		||||
            device_id: ID of device the tokens are associated with.
 | 
			
		||||
                If None, tokens associated with any device (or no device) will
 | 
			
		||||
                be deleted
 | 
			
		||||
        Returns:
 | 
			
		||||
            A tuple of (token, token id, device id) for each of the deleted tokens
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        def f(txn: LoggingTransaction) -> List[Tuple[str, int, Optional[str]]]:
 | 
			
		||||
            keyvalues = {"user_id": user_id}
 | 
			
		||||
            if device_id is not None:
 | 
			
		||||
                keyvalues["device_id"] = device_id
 | 
			
		||||
 | 
			
		||||
            items = keyvalues.items()
 | 
			
		||||
            where_clause = " AND ".join(k + " = ?" for k, _ in items)
 | 
			
		||||
            values: List[Union[str, int]] = [v for _, v in items]
 | 
			
		||||
            # Conveniently, refresh_tokens and access_tokens both use the user_id and device_id fields. Only caveat
 | 
			
		||||
            # is the `except_token_id` param that is tricky to get right, so for now we're just using the same where
 | 
			
		||||
            # clause and values before we handle that. This seems to be only used in the "set password" handler.
 | 
			
		||||
            refresh_where_clause = where_clause
 | 
			
		||||
            refresh_values = values.copy()
 | 
			
		||||
            if except_token_id:
 | 
			
		||||
                # TODO: support that for refresh tokens
 | 
			
		||||
                where_clause += " AND id != ?"
 | 
			
		||||
                values.append(except_token_id)
 | 
			
		||||
 | 
			
		||||
            txn.execute(
 | 
			
		||||
                "SELECT token, id, device_id FROM access_tokens WHERE %s"
 | 
			
		||||
                % where_clause,
 | 
			
		||||
                values,
 | 
			
		||||
            )
 | 
			
		||||
            tokens_and_devices = [(r[0], r[1], r[2]) for r in txn]
 | 
			
		||||
 | 
			
		||||
            self._invalidate_cache_and_stream_bulk(
 | 
			
		||||
                txn,
 | 
			
		||||
                self.get_user_by_access_token,
 | 
			
		||||
                [(token,) for token, _, _ in tokens_and_devices],
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            txn.execute("DELETE FROM access_tokens WHERE %s" % where_clause, values)
 | 
			
		||||
 | 
			
		||||
            txn.execute(
 | 
			
		||||
                "DELETE FROM refresh_tokens WHERE %s" % refresh_where_clause,
 | 
			
		||||
                refresh_values,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            return tokens_and_devices
 | 
			
		||||
 | 
			
		||||
        return await self.db_pool.runInteraction("user_delete_access_tokens", f)
 | 
			
		||||
 | 
			
		||||
    async def user_delete_access_tokens_for_devices(
 | 
			
		||||
        self,
 | 
			
		||||
        user_id: str,
 | 
			
		||||
        device_ids: StrCollection,
 | 
			
		||||
    ) -> List[Tuple[str, int, Optional[str]]]:
 | 
			
		||||
        """
 | 
			
		||||
        Invalidate access and refresh tokens belonging to a user
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            user_id: ID of user the tokens belong to
 | 
			
		||||
            device_ids: The devices to delete tokens for.
 | 
			
		||||
        Returns:
 | 
			
		||||
            A tuple of (token, token id, device id) for each of the deleted tokens
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        def user_delete_access_tokens_for_devices_txn(
 | 
			
		||||
            txn: LoggingTransaction, batch_device_ids: StrCollection
 | 
			
		||||
        ) -> List[Tuple[str, int, Optional[str]]]:
 | 
			
		||||
            self.db_pool.simple_delete_many_txn(
 | 
			
		||||
                txn,
 | 
			
		||||
                table="refresh_tokens",
 | 
			
		||||
                keyvalues={"user_id": user_id},
 | 
			
		||||
                column="device_id",
 | 
			
		||||
                values=batch_device_ids,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            clause, args = make_in_list_sql_clause(
 | 
			
		||||
                txn.database_engine, "device_id", batch_device_ids
 | 
			
		||||
            )
 | 
			
		||||
            args.append(user_id)
 | 
			
		||||
 | 
			
		||||
            if self.database_engine.supports_returning:
 | 
			
		||||
                sql = f"""
 | 
			
		||||
                    DELETE FROM access_tokens
 | 
			
		||||
                    WHERE {clause} AND user_id = ?
 | 
			
		||||
                    RETURNING token, id, device_id
 | 
			
		||||
                """
 | 
			
		||||
                txn.execute(sql, args)
 | 
			
		||||
                tokens_and_devices = txn.fetchall()
 | 
			
		||||
            else:
 | 
			
		||||
                tokens_and_devices = self.db_pool.simple_select_many_txn(
 | 
			
		||||
                    txn,
 | 
			
		||||
                    table="access_tokens",
 | 
			
		||||
                    column="device_id",
 | 
			
		||||
                    iterable=batch_device_ids,
 | 
			
		||||
                    keyvalues={"user_id": user_id},
 | 
			
		||||
                    retcols=("token", "id", "device_id"),
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                self.db_pool.simple_delete_many_txn(
 | 
			
		||||
                    txn,
 | 
			
		||||
                    table="access_tokens",
 | 
			
		||||
                    keyvalues={"user_id": user_id},
 | 
			
		||||
                    column="device_id",
 | 
			
		||||
                    values=batch_device_ids,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            self._invalidate_cache_and_stream_bulk(
 | 
			
		||||
                txn,
 | 
			
		||||
                self.get_user_by_access_token,
 | 
			
		||||
                [(t[0],) for t in tokens_and_devices],
 | 
			
		||||
            )
 | 
			
		||||
            return tokens_and_devices
 | 
			
		||||
 | 
			
		||||
        results = []
 | 
			
		||||
        for batch_device_ids in batch_iter(device_ids, 1000):
 | 
			
		||||
            tokens_and_devices = await self.db_pool.runInteraction(
 | 
			
		||||
                "user_delete_access_tokens_for_devices",
 | 
			
		||||
                user_delete_access_tokens_for_devices_txn,
 | 
			
		||||
                batch_device_ids,
 | 
			
		||||
            )
 | 
			
		||||
            results.extend(tokens_and_devices)
 | 
			
		||||
 | 
			
		||||
        return results
 | 
			
		||||
 | 
			
		||||
    async def delete_access_token(self, access_token: str) -> None:
 | 
			
		||||
        def f(txn: LoggingTransaction) -> None:
 | 
			
		||||
            self.db_pool.simple_delete_one_txn(
 | 
			
		||||
                txn, table="access_tokens", keyvalues={"token": access_token}
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            self._invalidate_cache_and_stream(
 | 
			
		||||
                txn, self.get_user_by_access_token, (access_token,)
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        await self.db_pool.runInteraction("delete_access_token", f)
 | 
			
		||||
 | 
			
		||||
    async def delete_refresh_token(self, refresh_token: str) -> None:
 | 
			
		||||
        def f(txn: LoggingTransaction) -> None:
 | 
			
		||||
            self.db_pool.simple_delete_one_txn(
 | 
			
		||||
                txn, table="refresh_tokens", keyvalues={"token": refresh_token}
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        await self.db_pool.runInteraction("delete_refresh_token", f)
 | 
			
		||||
 | 
			
		||||
    async def add_user_pending_deactivation(self, user_id: str) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Adds a user to the table of users who need to be parted from all the rooms they're
 | 
			
		||||
 | 
			
		||||
@ -324,7 +324,7 @@ class RelationsWorkerStore(SQLBaseStore):
 | 
			
		||||
                        account_data_key=0,
 | 
			
		||||
                        push_rules_key=0,
 | 
			
		||||
                        to_device_key=0,
 | 
			
		||||
                        device_list_key=0,
 | 
			
		||||
                        device_list_key=MultiWriterStreamToken(stream=0),
 | 
			
		||||
                        groups_key=0,
 | 
			
		||||
                        un_partial_stated_rooms_key=0,
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
@ -1574,6 +1574,11 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
 | 
			
		||||
        """Get the event ID of the initial join that started the partial
 | 
			
		||||
        join, and the device list stream ID at the point we started the partial
 | 
			
		||||
        join.
 | 
			
		||||
 | 
			
		||||
        This only returns the minimum device list stream ID at the time of
 | 
			
		||||
        joining, not the full device list stream token. The only impact of this
 | 
			
		||||
        is that we may be sending again device list updates that we've already
 | 
			
		||||
        sent to some destinations, which is harmless.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        return cast(
 | 
			
		||||
 | 
			
		||||
@ -61,7 +61,6 @@ from typing import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
import attr
 | 
			
		||||
from immutabledict import immutabledict
 | 
			
		||||
from typing_extensions import assert_never
 | 
			
		||||
 | 
			
		||||
from twisted.internet import defer
 | 
			
		||||
@ -657,23 +656,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 | 
			
		||||
        component.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        min_pos = self._stream_id_gen.get_current_token()
 | 
			
		||||
 | 
			
		||||
        positions = {}
 | 
			
		||||
        if isinstance(self._stream_id_gen, MultiWriterIdGenerator):
 | 
			
		||||
            # The `min_pos` is the minimum position that we know all instances
 | 
			
		||||
            # have finished persisting to, so we only care about instances whose
 | 
			
		||||
            # positions are ahead of that. (Instance positions can be behind the
 | 
			
		||||
            # min position as there are times we can work out that the minimum
 | 
			
		||||
            # position is ahead of the naive minimum across all current
 | 
			
		||||
            # positions. See MultiWriterIdGenerator for details)
 | 
			
		||||
            positions = {
 | 
			
		||||
                i: p
 | 
			
		||||
                for i, p in self._stream_id_gen.get_positions().items()
 | 
			
		||||
                if p > min_pos
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
        return RoomStreamToken(stream=min_pos, instance_map=immutabledict(positions))
 | 
			
		||||
        return RoomStreamToken.from_generator(self._stream_id_gen)
 | 
			
		||||
 | 
			
		||||
    def get_events_stream_id_generator(self) -> MultiWriterIdGenerator:
 | 
			
		||||
        return self._stream_id_gen
 | 
			
		||||
 | 
			
		||||
@ -203,7 +203,7 @@ class EventSources:
 | 
			
		||||
            account_data_key=0,
 | 
			
		||||
            push_rules_key=0,
 | 
			
		||||
            to_device_key=0,
 | 
			
		||||
            device_list_key=0,
 | 
			
		||||
            device_list_key=MultiWriterStreamToken(stream=0),
 | 
			
		||||
            groups_key=0,
 | 
			
		||||
            un_partial_stated_rooms_key=0,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
@ -75,6 +75,7 @@ if TYPE_CHECKING:
 | 
			
		||||
    from synapse.appservice.api import ApplicationService
 | 
			
		||||
    from synapse.storage.databases.main import DataStore, PurgeEventsStore
 | 
			
		||||
    from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
 | 
			
		||||
    from synapse.storage.util.id_generators import MultiWriterIdGenerator
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
@ -570,6 +571,25 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_generator(cls, generator: "MultiWriterIdGenerator") -> Self:
 | 
			
		||||
        """Get the current token out of a MultiWriterIdGenerator"""
 | 
			
		||||
 | 
			
		||||
        # The `min_pos` is the minimum position that we know all instances
 | 
			
		||||
        # have finished persisting to, so we only care about instances whose
 | 
			
		||||
        # positions are ahead of that. (Instance positions can be behind the
 | 
			
		||||
        # min position as there are times we can work out that the minimum
 | 
			
		||||
        # position is ahead of the naive minimum across all current
 | 
			
		||||
        # positions. See MultiWriterIdGenerator for details)
 | 
			
		||||
        min_pos = generator.get_current_token()
 | 
			
		||||
        positions = {
 | 
			
		||||
            instance: position
 | 
			
		||||
            for instance, position in generator.get_positions().items()
 | 
			
		||||
            if position > min_pos
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        return cls(stream=min_pos, instance_map=immutabledict(positions))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@attr.s(frozen=True, slots=True, order=False)
 | 
			
		||||
class RoomStreamToken(AbstractMultiWriterStreamToken):
 | 
			
		||||
@ -980,7 +1000,9 @@ class StreamToken:
 | 
			
		||||
    account_data_key: int
 | 
			
		||||
    push_rules_key: int
 | 
			
		||||
    to_device_key: int
 | 
			
		||||
    device_list_key: int
 | 
			
		||||
    device_list_key: MultiWriterStreamToken = attr.ib(
 | 
			
		||||
        validator=attr.validators.instance_of(MultiWriterStreamToken)
 | 
			
		||||
    )
 | 
			
		||||
    # Note that the groups key is no longer used and may have bogus values.
 | 
			
		||||
    groups_key: int
 | 
			
		||||
    un_partial_stated_rooms_key: int
 | 
			
		||||
@ -1021,7 +1043,9 @@ class StreamToken:
 | 
			
		||||
                account_data_key=int(account_data_key),
 | 
			
		||||
                push_rules_key=int(push_rules_key),
 | 
			
		||||
                to_device_key=int(to_device_key),
 | 
			
		||||
                device_list_key=int(device_list_key),
 | 
			
		||||
                device_list_key=await MultiWriterStreamToken.parse(
 | 
			
		||||
                    store, device_list_key
 | 
			
		||||
                ),
 | 
			
		||||
                groups_key=int(groups_key),
 | 
			
		||||
                un_partial_stated_rooms_key=int(un_partial_stated_rooms_key),
 | 
			
		||||
            )
 | 
			
		||||
@ -1040,7 +1064,7 @@ class StreamToken:
 | 
			
		||||
                str(self.account_data_key),
 | 
			
		||||
                str(self.push_rules_key),
 | 
			
		||||
                str(self.to_device_key),
 | 
			
		||||
                str(self.device_list_key),
 | 
			
		||||
                await self.device_list_key.to_string(store),
 | 
			
		||||
                # Note that the groups key is no longer used, but it is still
 | 
			
		||||
                # serialized so that there will not be confusion in the future
 | 
			
		||||
                # if additional tokens are added.
 | 
			
		||||
@ -1069,6 +1093,12 @@ class StreamToken:
 | 
			
		||||
                StreamKeyType.RECEIPT, self.receipt_key.copy_and_advance(new_value)
 | 
			
		||||
            )
 | 
			
		||||
            return new_token
 | 
			
		||||
        elif key == StreamKeyType.DEVICE_LIST:
 | 
			
		||||
            new_token = self.copy_and_replace(
 | 
			
		||||
                StreamKeyType.DEVICE_LIST,
 | 
			
		||||
                self.device_list_key.copy_and_advance(new_value),
 | 
			
		||||
            )
 | 
			
		||||
            return new_token
 | 
			
		||||
 | 
			
		||||
        new_token = self.copy_and_replace(key, new_value)
 | 
			
		||||
        new_id = new_token.get_field(key)
 | 
			
		||||
@ -1087,7 +1117,11 @@ class StreamToken:
 | 
			
		||||
 | 
			
		||||
    @overload
 | 
			
		||||
    def get_field(
 | 
			
		||||
        self, key: Literal[StreamKeyType.RECEIPT]
 | 
			
		||||
        self,
 | 
			
		||||
        key: Literal[
 | 
			
		||||
            StreamKeyType.RECEIPT,
 | 
			
		||||
            StreamKeyType.DEVICE_LIST,
 | 
			
		||||
        ],
 | 
			
		||||
    ) -> MultiWriterStreamToken: ...
 | 
			
		||||
 | 
			
		||||
    @overload
 | 
			
		||||
@ -1095,7 +1129,6 @@ class StreamToken:
 | 
			
		||||
        self,
 | 
			
		||||
        key: Literal[
 | 
			
		||||
            StreamKeyType.ACCOUNT_DATA,
 | 
			
		||||
            StreamKeyType.DEVICE_LIST,
 | 
			
		||||
            StreamKeyType.PRESENCE,
 | 
			
		||||
            StreamKeyType.PUSH_RULES,
 | 
			
		||||
            StreamKeyType.TO_DEVICE,
 | 
			
		||||
@ -1161,7 +1194,16 @@ class StreamToken:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
StreamToken.START = StreamToken(
 | 
			
		||||
    RoomStreamToken(stream=0), 0, 0, MultiWriterStreamToken(stream=0), 0, 0, 0, 0, 0, 0
 | 
			
		||||
    room_key=RoomStreamToken(stream=0),
 | 
			
		||||
    presence_key=0,
 | 
			
		||||
    typing_key=0,
 | 
			
		||||
    receipt_key=MultiWriterStreamToken(stream=0),
 | 
			
		||||
    account_data_key=0,
 | 
			
		||||
    push_rules_key=0,
 | 
			
		||||
    to_device_key=0,
 | 
			
		||||
    device_list_key=MultiWriterStreamToken(stream=0),
 | 
			
		||||
    groups_key=0,
 | 
			
		||||
    un_partial_stated_rooms_key=0,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -30,7 +30,7 @@ from synapse.api.constants import EduTypes, RoomEncryptionAlgorithms
 | 
			
		||||
from synapse.api.presence import UserPresenceState
 | 
			
		||||
from synapse.federation.sender.per_destination_queue import MAX_PRESENCE_STATES_PER_EDU
 | 
			
		||||
from synapse.federation.units import Transaction
 | 
			
		||||
from synapse.handlers.device import DeviceHandler
 | 
			
		||||
from synapse.handlers.device import DeviceListUpdater, DeviceWriterHandler
 | 
			
		||||
from synapse.rest import admin
 | 
			
		||||
from synapse.rest.client import login
 | 
			
		||||
from synapse.server import HomeServer
 | 
			
		||||
@ -500,7 +500,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
 | 
			
		||||
        hs.get_datastores().main.get_current_hosts_in_room = get_current_hosts_in_room  # type: ignore[assignment]
 | 
			
		||||
 | 
			
		||||
        device_handler = hs.get_device_handler()
 | 
			
		||||
        assert isinstance(device_handler, DeviceHandler)
 | 
			
		||||
        assert isinstance(device_handler, DeviceWriterHandler)
 | 
			
		||||
        self.device_handler = device_handler
 | 
			
		||||
 | 
			
		||||
        # whenever send_transaction is called, record the edu data
 | 
			
		||||
@ -554,6 +554,8 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
 | 
			
		||||
            "devices": [{"device_id": "D1"}],
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        assert isinstance(self.device_handler.device_list_updater, DeviceListUpdater)
 | 
			
		||||
 | 
			
		||||
        self.get_success(
 | 
			
		||||
            self.device_handler.device_list_updater.incoming_device_list_update(
 | 
			
		||||
                "host2",
 | 
			
		||||
 | 
			
		||||
@ -29,7 +29,7 @@ from twisted.test.proto_helpers import MemoryReactor
 | 
			
		||||
from synapse.api.constants import RoomEncryptionAlgorithms
 | 
			
		||||
from synapse.api.errors import NotFoundError, SynapseError
 | 
			
		||||
from synapse.appservice import ApplicationService
 | 
			
		||||
from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler
 | 
			
		||||
from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceWriterHandler
 | 
			
		||||
from synapse.rest import admin
 | 
			
		||||
from synapse.rest.client import devices, login, register
 | 
			
		||||
from synapse.server import HomeServer
 | 
			
		||||
@ -53,7 +53,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
 | 
			
		||||
            application_service_api=self.appservice_api,
 | 
			
		||||
        )
 | 
			
		||||
        handler = hs.get_device_handler()
 | 
			
		||||
        assert isinstance(handler, DeviceHandler)
 | 
			
		||||
        assert isinstance(handler, DeviceWriterHandler)
 | 
			
		||||
        self.handler = handler
 | 
			
		||||
        self.store = hs.get_datastores().main
 | 
			
		||||
        self.device_message_handler = hs.get_device_message_handler()
 | 
			
		||||
@ -229,7 +229,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
 | 
			
		||||
 | 
			
		||||
        # queue a bunch of messages in the inbox
 | 
			
		||||
        requester = create_requester(sender, device_id=DEVICE_ID)
 | 
			
		||||
        for i in range(DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT + 10):
 | 
			
		||||
        for i in range(DeviceWriterHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT + 10):
 | 
			
		||||
            self.get_success(
 | 
			
		||||
                self.device_message_handler.send_device_message(
 | 
			
		||||
                    requester, "message_type", {receiver: {"*": {"val": i}}}
 | 
			
		||||
@ -462,7 +462,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
 | 
			
		||||
    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
 | 
			
		||||
        hs = self.setup_test_homeserver("server")
 | 
			
		||||
        handler = hs.get_device_handler()
 | 
			
		||||
        assert isinstance(handler, DeviceHandler)
 | 
			
		||||
        assert isinstance(handler, DeviceWriterHandler)
 | 
			
		||||
        self.handler = handler
 | 
			
		||||
        self.message_handler = hs.get_device_message_handler()
 | 
			
		||||
        self.registration = hs.get_registration_handler()
 | 
			
		||||
 | 
			
		||||
@ -31,7 +31,7 @@ from twisted.test.proto_helpers import MemoryReactor
 | 
			
		||||
from synapse.api.constants import RoomEncryptionAlgorithms
 | 
			
		||||
from synapse.api.errors import Codes, SynapseError
 | 
			
		||||
from synapse.appservice import ApplicationService
 | 
			
		||||
from synapse.handlers.device import DeviceHandler
 | 
			
		||||
from synapse.handlers.device import DeviceWriterHandler
 | 
			
		||||
from synapse.server import HomeServer
 | 
			
		||||
from synapse.storage.databases.main.appservice import _make_exclusive_regex
 | 
			
		||||
from synapse.types import JsonDict, UserID
 | 
			
		||||
@ -856,7 +856,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
 | 
			
		||||
        self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))
 | 
			
		||||
 | 
			
		||||
        device_handler = self.hs.get_device_handler()
 | 
			
		||||
        assert isinstance(device_handler, DeviceHandler)
 | 
			
		||||
        assert isinstance(device_handler, DeviceWriterHandler)
 | 
			
		||||
        e = self.get_failure(
 | 
			
		||||
            device_handler.check_device_registered(
 | 
			
		||||
                user_id=local_user,
 | 
			
		||||
 | 
			
		||||
@ -28,7 +28,7 @@ from synapse.api.constants import EduTypes, EventTypes
 | 
			
		||||
from synapse.api.errors import NotFoundError
 | 
			
		||||
from synapse.events import EventBase
 | 
			
		||||
from synapse.federation.units import Transaction
 | 
			
		||||
from synapse.handlers.device import DeviceHandler
 | 
			
		||||
from synapse.handlers.device import DeviceWriterHandler
 | 
			
		||||
from synapse.handlers.presence import UserPresenceState
 | 
			
		||||
from synapse.handlers.push_rules import InvalidRuleException
 | 
			
		||||
from synapse.module_api import ModuleApi
 | 
			
		||||
@ -819,7 +819,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
 | 
			
		||||
 | 
			
		||||
        # Delete the device.
 | 
			
		||||
        device_handler = self.hs.get_device_handler()
 | 
			
		||||
        assert isinstance(device_handler, DeviceHandler)
 | 
			
		||||
        assert isinstance(device_handler, DeviceWriterHandler)
 | 
			
		||||
        self.get_success(device_handler.delete_devices(user_id, [device_id]))
 | 
			
		||||
 | 
			
		||||
        # Check that the callback was called and the pushers still existed.
 | 
			
		||||
 | 
			
		||||
@ -26,7 +26,7 @@ from twisted.test.proto_helpers import MemoryReactor
 | 
			
		||||
 | 
			
		||||
import synapse.rest.admin
 | 
			
		||||
from synapse.api.errors import Codes
 | 
			
		||||
from synapse.handlers.device import DeviceHandler
 | 
			
		||||
from synapse.handlers.device import DeviceWriterHandler
 | 
			
		||||
from synapse.rest.client import devices, login
 | 
			
		||||
from synapse.server import HomeServer
 | 
			
		||||
from synapse.util import Clock
 | 
			
		||||
@ -42,7 +42,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
 | 
			
		||||
 | 
			
		||||
    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
 | 
			
		||||
        handler = hs.get_device_handler()
 | 
			
		||||
        assert isinstance(handler, DeviceHandler)
 | 
			
		||||
        assert isinstance(handler, DeviceWriterHandler)
 | 
			
		||||
        self.handler = handler
 | 
			
		||||
 | 
			
		||||
        self.admin_user = self.register_user("admin", "pass", admin=True)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user