diff --git a/changelog.d/18574.misc b/changelog.d/18574.misc new file mode 100644 index 0000000000..5b223f5a93 --- /dev/null +++ b/changelog.d/18574.misc @@ -0,0 +1 @@ +Speed up upgrading a room with large numbers of banned users. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 8398832515..22301f9e63 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -22,7 +22,7 @@ import logging import random from http import HTTPStatus -from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Sequence, Tuple from canonicaljson import encode_canonical_json @@ -55,7 +55,11 @@ from synapse.api.urls import ConsentURIBuilder from synapse.event_auth import validate_event_for_room_version from synapse.events import EventBase, relation_from_event from synapse.events.builder import EventBuilder -from synapse.events.snapshot import EventContext, UnpersistedEventContextBase +from synapse.events.snapshot import ( + EventContext, + UnpersistedEventContext, + UnpersistedEventContextBase, +) from synapse.events.utils import SerializeEventConfig, maybe_upsert_event_field from synapse.events.validator import EventValidator from synapse.handlers.directory import DirectoryHandler @@ -66,6 +70,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.send_events import ReplicationSendEventsRestServlet from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.types import ( + JsonDict, PersistedEventPosition, Requester, RoomAlias, @@ -1520,6 +1525,92 @@ class EventCreationHandler: return result + async def create_and_send_new_client_events( + self, + requester: Requester, + room_id: str, + prev_event_id: str, + event_dicts: Sequence[JsonDict], + ratelimit: bool = True, + ignore_shadow_ban: bool = False, + ) -> None: + """Helper to create and send a batch of new client events. + + This supports sending membership events in very limited circumstances + (namely that the event is valid as is and doesn't need federation + requests or anything). Callers should prefer to use `update_membership`, + which correctly handles membership events in all cases. We allow + sending membership events here as its useful when copying e.g. bans + between rooms. + + All other events and state events are supported. + + Args: + requester: The requester sending the events. + room_id: The room ID to send the events in. + prev_event_id: The event ID to use as the previous event for the first + of the events, must have already been persisted. + event_dicts: A sequence of event dictionaries to create and send. + ratelimit: Whether to rate limit this send. + ignore_shadow_ban: True if shadow-banned users should be allowed to + send these events. + """ + + if not event_dicts: + # Nothing to do. + return + + state_groups = await self._storage_controllers.state.get_state_group_for_events( + [prev_event_id] + ) + if prev_event_id not in state_groups: + # This should only happen if we got passed a prev event ID that + # hasn't been persisted yet. + raise Exception("Previous event ID not found ") + + current_state_group = state_groups[prev_event_id] + state_map = await self._storage_controllers.state.get_state_ids_for_group( + current_state_group + ) + + events_and_contexts_to_send = [] + state_map = dict(state_map) + depth = None + + for event_dict in event_dicts: + event, context = await self.create_event( + requester=requester, + event_dict=event_dict, + prev_event_ids=[prev_event_id], + depth=depth, + # Take a copy to ensure each event gets a unique copy of + # state_map since it is modified below. + state_map=dict(state_map), + for_batch=True, + ) + events_and_contexts_to_send.append((event, context)) + + prev_event_id = event.event_id + depth = event.depth + 1 + if event.is_state(): + # If this is a state event, we need to update the state map + # so that it can be used for the next event. + state_map[(event.type, event.state_key)] = event.event_id + + datastore = self.hs.get_datastores().state + events_and_context = ( + await UnpersistedEventContext.batch_persist_unpersisted_contexts( + events_and_contexts_to_send, room_id, current_state_group, datastore + ) + ) + + await self.handle_new_client_event( + requester, + events_and_context, + ignore_shadow_ban=ignore_shadow_ban, + ratelimit=ratelimit, + ) + async def _persist_events( self, requester: Requester, diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 820140b28f..a1731752cf 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -94,6 +94,7 @@ from synapse.types.handlers import ShutdownRoomParams, ShutdownRoomResponse from synapse.types.state import StateFilter from synapse.util import stringutils from synapse.util.caches.response_cache import ResponseCache +from synapse.util.iterutils import batch_iter from synapse.util.stringutils import parse_and_validate_server_name from synapse.visibility import filter_events_for_client @@ -607,7 +608,7 @@ class RoomCreationHandler: additional_fields=spam_check[1], ) - await self._send_events_for_new_room( + _, last_event_id, _ = await self._send_events_for_new_room( requester, new_room_id, new_room_version, @@ -620,29 +621,32 @@ class RoomCreationHandler: ) # Transfer membership events - old_room_member_state_ids = ( - await self._storage_controllers.state.get_current_state_ids( - old_room_id, StateFilter.from_types([(EventTypes.Member, None)]) - ) - ) + ban_event_ids = await self.store.get_ban_event_ids_in_room(old_room_id) + if ban_event_ids: + ban_events = await self.store.get_events_as_list(ban_event_ids) - # map from event_id to BaseEvent - old_room_member_state_events = await self.store.get_events( - old_room_member_state_ids.values() - ) - for old_event in old_room_member_state_events.values(): - # Only transfer ban events - if ( - "membership" in old_event.content - and old_event.content["membership"] == "ban" - ): - await self.room_member_handler.update_membership( - requester, - UserID.from_string(old_event.state_key), - new_room_id, - "ban", + # Add any banned users to the new room. + # + # Note generally we should send membership events via + # `update_membership`, however in this case its fine to bypass as + # these bans don't need any special treatment, i.e. the sender is in + # the room and they don't need any extra signatures, etc. + for batched_events in batch_iter(ban_events, 1000): + await self.event_creation_handler.create_and_send_new_client_events( + requester=requester, + room_id=new_room_id, + prev_event_id=last_event_id, + event_dicts=[ + { + "type": EventTypes.Member, + "state_key": ban_event.state_key, + "room_id": new_room_id, + "sender": requester.user.to_string(), + "content": ban_event.content, + } + for ban_event in batched_events + ], ratelimit=False, - content=old_event.content, ) # XXX invites/joins diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 9c35a7837d..654250fadc 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -1849,6 +1849,19 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): "_get_room_participation_txn", _get_room_participation_txn, user_id, room_id ) + async def get_ban_event_ids_in_room(self, room_id: str) -> StrCollection: + """Get all event IDs for ban events in the given room.""" + return await self.db_pool.simple_select_onecol( + table="current_state_events", + keyvalues={ + "room_id": room_id, + "type": EventTypes.Member, + "membership": Membership.BAN, + }, + retcol="event_id", + desc="get_ban_event_ids_in_room", + ) + class RoomMemberBackgroundUpdateStore(SQLBaseStore): def __init__( diff --git a/tests/rest/client/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py index f3ed49f99b..66fddc5475 100644 --- a/tests/rest/client/test_upgrade_room.py +++ b/tests/rest/client/test_upgrade_room.py @@ -23,7 +23,7 @@ from unittest.mock import patch from twisted.internet.testing import MemoryReactor -from synapse.api.constants import EventContentFields, EventTypes, RoomTypes +from synapse.api.constants import EventContentFields, EventTypes, Membership, RoomTypes from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.rest import admin from synapse.rest.client import login, room, room_upgrade_rest_servlet @@ -411,3 +411,24 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): channel = self._upgrade_room(expire_cache=False) self.assertEqual(200, channel.code, channel.result) + + def test_bans(self) -> None: + """ + Test that bans get copied over when upgrading a room. + """ + + users_to_ban = ["@user2:test", "@user3:test", "@user4:test"] + for user in users_to_ban: + self.helper.ban(self.room_id, self.creator, user, tok=self.creator_token) + + channel = self._upgrade_room(self.creator_token) + self.assertEqual(200, channel.code, channel.result) + + for user in users_to_ban: + content = self.helper.get_state( + self.room_id, + event_type=EventTypes.Member, + state_key=user, + tok=self.creator_token, + ) + self.assertEqual(content[EventContentFields.MEMBERSHIP], Membership.BAN)