Make room upgrades faster for rooms with many bans (#18574)

We do this by a) not pulling out all membership events, and b) batch
inserting bans.

One blocking concern is that this bypasses the `update_membership`
function, which otherwise all other membership events go via. In this
case it's fine (having audited what it is doing), but I'm hesitant to
set the precedent of bypassing it, given it has a lot of logic in there.

---------

Co-authored-by: Eric Eastwood <erice@element.io>
This commit is contained in:
Erik Johnston 2025-08-04 10:42:52 +01:00 committed by GitHub
parent e16fbdcdcc
commit 72cd5cccf7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 155 additions and 25 deletions

1
changelog.d/18574.misc Normal file
View File

@ -0,0 +1 @@
Speed up upgrading a room with large numbers of banned users.

View File

@ -22,7 +22,7 @@
import logging
import random
from http import HTTPStatus
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Sequence, Tuple
from canonicaljson import encode_canonical_json
@ -55,7 +55,11 @@ from synapse.api.urls import ConsentURIBuilder
from synapse.event_auth import validate_event_for_room_version
from synapse.events import EventBase, relation_from_event
from synapse.events.builder import EventBuilder
from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
from synapse.events.snapshot import (
EventContext,
UnpersistedEventContext,
UnpersistedEventContextBase,
)
from synapse.events.utils import SerializeEventConfig, maybe_upsert_event_field
from synapse.events.validator import EventValidator
from synapse.handlers.directory import DirectoryHandler
@ -66,6 +70,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.send_events import ReplicationSendEventsRestServlet
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.types import (
JsonDict,
PersistedEventPosition,
Requester,
RoomAlias,
@ -1520,6 +1525,92 @@ class EventCreationHandler:
return result
async def create_and_send_new_client_events(
self,
requester: Requester,
room_id: str,
prev_event_id: str,
event_dicts: Sequence[JsonDict],
ratelimit: bool = True,
ignore_shadow_ban: bool = False,
) -> None:
"""Helper to create and send a batch of new client events.
This supports sending membership events in very limited circumstances
(namely that the event is valid as is and doesn't need federation
requests or anything). Callers should prefer to use `update_membership`,
which correctly handles membership events in all cases. We allow
sending membership events here as its useful when copying e.g. bans
between rooms.
All other events and state events are supported.
Args:
requester: The requester sending the events.
room_id: The room ID to send the events in.
prev_event_id: The event ID to use as the previous event for the first
of the events, must have already been persisted.
event_dicts: A sequence of event dictionaries to create and send.
ratelimit: Whether to rate limit this send.
ignore_shadow_ban: True if shadow-banned users should be allowed to
send these events.
"""
if not event_dicts:
# Nothing to do.
return
state_groups = await self._storage_controllers.state.get_state_group_for_events(
[prev_event_id]
)
if prev_event_id not in state_groups:
# This should only happen if we got passed a prev event ID that
# hasn't been persisted yet.
raise Exception("Previous event ID not found ")
current_state_group = state_groups[prev_event_id]
state_map = await self._storage_controllers.state.get_state_ids_for_group(
current_state_group
)
events_and_contexts_to_send = []
state_map = dict(state_map)
depth = None
for event_dict in event_dicts:
event, context = await self.create_event(
requester=requester,
event_dict=event_dict,
prev_event_ids=[prev_event_id],
depth=depth,
# Take a copy to ensure each event gets a unique copy of
# state_map since it is modified below.
state_map=dict(state_map),
for_batch=True,
)
events_and_contexts_to_send.append((event, context))
prev_event_id = event.event_id
depth = event.depth + 1
if event.is_state():
# If this is a state event, we need to update the state map
# so that it can be used for the next event.
state_map[(event.type, event.state_key)] = event.event_id
datastore = self.hs.get_datastores().state
events_and_context = (
await UnpersistedEventContext.batch_persist_unpersisted_contexts(
events_and_contexts_to_send, room_id, current_state_group, datastore
)
)
await self.handle_new_client_event(
requester,
events_and_context,
ignore_shadow_ban=ignore_shadow_ban,
ratelimit=ratelimit,
)
async def _persist_events(
self,
requester: Requester,

View File

@ -94,6 +94,7 @@ from synapse.types.handlers import ShutdownRoomParams, ShutdownRoomResponse
from synapse.types.state import StateFilter
from synapse.util import stringutils
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.iterutils import batch_iter
from synapse.util.stringutils import parse_and_validate_server_name
from synapse.visibility import filter_events_for_client
@ -607,7 +608,7 @@ class RoomCreationHandler:
additional_fields=spam_check[1],
)
await self._send_events_for_new_room(
_, last_event_id, _ = await self._send_events_for_new_room(
requester,
new_room_id,
new_room_version,
@ -620,29 +621,32 @@ class RoomCreationHandler:
)
# Transfer membership events
old_room_member_state_ids = (
await self._storage_controllers.state.get_current_state_ids(
old_room_id, StateFilter.from_types([(EventTypes.Member, None)])
)
)
ban_event_ids = await self.store.get_ban_event_ids_in_room(old_room_id)
if ban_event_ids:
ban_events = await self.store.get_events_as_list(ban_event_ids)
# map from event_id to BaseEvent
old_room_member_state_events = await self.store.get_events(
old_room_member_state_ids.values()
)
for old_event in old_room_member_state_events.values():
# Only transfer ban events
if (
"membership" in old_event.content
and old_event.content["membership"] == "ban"
):
await self.room_member_handler.update_membership(
requester,
UserID.from_string(old_event.state_key),
new_room_id,
"ban",
# Add any banned users to the new room.
#
# Note generally we should send membership events via
# `update_membership`, however in this case its fine to bypass as
# these bans don't need any special treatment, i.e. the sender is in
# the room and they don't need any extra signatures, etc.
for batched_events in batch_iter(ban_events, 1000):
await self.event_creation_handler.create_and_send_new_client_events(
requester=requester,
room_id=new_room_id,
prev_event_id=last_event_id,
event_dicts=[
{
"type": EventTypes.Member,
"state_key": ban_event.state_key,
"room_id": new_room_id,
"sender": requester.user.to_string(),
"content": ban_event.content,
}
for ban_event in batched_events
],
ratelimit=False,
content=old_event.content,
)
# XXX invites/joins

View File

@ -1849,6 +1849,19 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
"_get_room_participation_txn", _get_room_participation_txn, user_id, room_id
)
async def get_ban_event_ids_in_room(self, room_id: str) -> StrCollection:
"""Get all event IDs for ban events in the given room."""
return await self.db_pool.simple_select_onecol(
table="current_state_events",
keyvalues={
"room_id": room_id,
"type": EventTypes.Member,
"membership": Membership.BAN,
},
retcol="event_id",
desc="get_ban_event_ids_in_room",
)
class RoomMemberBackgroundUpdateStore(SQLBaseStore):
def __init__(

View File

@ -23,7 +23,7 @@ from unittest.mock import patch
from twisted.internet.testing import MemoryReactor
from synapse.api.constants import EventContentFields, EventTypes, RoomTypes
from synapse.api.constants import EventContentFields, EventTypes, Membership, RoomTypes
from synapse.config.server import DEFAULT_ROOM_VERSION
from synapse.rest import admin
from synapse.rest.client import login, room, room_upgrade_rest_servlet
@ -411,3 +411,24 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
channel = self._upgrade_room(expire_cache=False)
self.assertEqual(200, channel.code, channel.result)
def test_bans(self) -> None:
"""
Test that bans get copied over when upgrading a room.
"""
users_to_ban = ["@user2:test", "@user3:test", "@user4:test"]
for user in users_to_ban:
self.helper.ban(self.room_id, self.creator, user, tok=self.creator_token)
channel = self._upgrade_room(self.creator_token)
self.assertEqual(200, channel.code, channel.result)
for user in users_to_ban:
content = self.helper.get_state(
self.room_id,
event_type=EventTypes.Member,
state_key=user,
tok=self.creator_token,
)
self.assertEqual(content[EventContentFields.MEMBERSHIP], Membership.BAN)