MSC4140: Remove auth from delayed event management endpoints (#19152)

As per recent proposals in MSC4140, remove authentication for
restarting/cancelling/sending a delayed event, and give each of those
actions its own endpoint. (The original consolidated endpoint is still
supported for backwards compatibility.)

### Pull Request Checklist

<!-- Please read
https://element-hq.github.io/synapse/latest/development/contributing_guide.html
before submitting your pull request -->

* [x] Pull request is based on the develop branch
* [x] Pull request includes a [changelog
file](https://element-hq.github.io/synapse/latest/development/contributing_guide.html#changelog).
The entry should:
- Be a short description of your change which makes sense to users.
"Fixed a bug that prevented receiving messages from other servers."
instead of "Moved X method from `EventStore` to `EventWorkerStore`.".
  - Use markdown where necessary, mostly for `code blocks`.
  - End with either a period (.) or an exclamation mark (!).
  - Start with a capital letter.
- Feel free to credit yourself, by adding a sentence "Contributed by
@github_username." or "Contributed by [Your Name]." to the end of the
entry.
* [x] [Code
style](https://element-hq.github.io/synapse/latest/code_style.html) is
correct (run the
[linters](https://element-hq.github.io/synapse/latest/development/contributing_guide.html#run-the-linters))

---------

Co-authored-by: Half-Shot <will@half-shot.uk>
This commit is contained in:
Andrew Ferrazzutti 2025-11-13 13:56:17 -05:00 committed by GitHub
parent 4494cc0694
commit 9e23cded8f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 198 additions and 212 deletions

View File

@ -0,0 +1 @@
Remove authentication from `POST /_matrix/client/v1/delayed_events`, and allow calling this endpoint with the update action to take (`send`/`cancel`/`restart`) in the request path instead of the body.

View File

@ -58,6 +58,7 @@ from synapse.storage.database import DatabasePool, LoggingTransaction, make_conn
from synapse.storage.databases.main import FilteringWorkerStore from synapse.storage.databases.main import FilteringWorkerStore
from synapse.storage.databases.main.account_data import AccountDataWorkerStore from synapse.storage.databases.main.account_data import AccountDataWorkerStore
from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateStore from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateStore
from synapse.storage.databases.main.delayed_events import DelayedEventsStore
from synapse.storage.databases.main.deviceinbox import DeviceInboxBackgroundUpdateStore from synapse.storage.databases.main.deviceinbox import DeviceInboxBackgroundUpdateStore
from synapse.storage.databases.main.devices import DeviceBackgroundUpdateStore from synapse.storage.databases.main.devices import DeviceBackgroundUpdateStore
from synapse.storage.databases.main.e2e_room_keys import EndToEndRoomKeyBackgroundStore from synapse.storage.databases.main.e2e_room_keys import EndToEndRoomKeyBackgroundStore
@ -273,6 +274,7 @@ class Store(
RelationsWorkerStore, RelationsWorkerStore,
EventFederationWorkerStore, EventFederationWorkerStore,
SlidingSyncStore, SlidingSyncStore,
DelayedEventsStore,
): ):
def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]: def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]:
return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs) return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs)

View File

@ -21,6 +21,7 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import ShadowBanError, SynapseError from synapse.api.errors import ShadowBanError, SynapseError
from synapse.api.ratelimiting import Ratelimiter from synapse.api.ratelimiting import Ratelimiter
from synapse.config.workers import MAIN_PROCESS_INSTANCE_NAME from synapse.config.workers import MAIN_PROCESS_INSTANCE_NAME
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable
from synapse.logging.opentracing import set_tag from synapse.logging.opentracing import set_tag
from synapse.metrics import SERVER_NAME_LABEL, event_processing_positions from synapse.metrics import SERVER_NAME_LABEL, event_processing_positions
@ -29,11 +30,9 @@ from synapse.replication.http.delayed_events import (
) )
from synapse.storage.databases.main.delayed_events import ( from synapse.storage.databases.main.delayed_events import (
DelayedEventDetails, DelayedEventDetails,
DelayID,
EventType, EventType,
StateKey, StateKey,
Timestamp, Timestamp,
UserLocalpart,
) )
from synapse.storage.databases.main.state_deltas import StateDelta from synapse.storage.databases.main.state_deltas import StateDelta
from synapse.types import ( from synapse.types import (
@ -399,96 +398,63 @@ class DelayedEventsHandler:
if self._next_send_ts_changed(next_send_ts): if self._next_send_ts_changed(next_send_ts):
self._schedule_next_at(next_send_ts) self._schedule_next_at(next_send_ts)
async def cancel(self, requester: Requester, delay_id: str) -> None: async def cancel(self, request: SynapseRequest, delay_id: str) -> None:
""" """
Cancels the scheduled delivery of the matching delayed event. Cancels the scheduled delivery of the matching delayed event.
Args:
requester: The owner of the delayed event to act on.
delay_id: The ID of the delayed event to act on.
Raises: Raises:
NotFoundError: if no matching delayed event could be found. NotFoundError: if no matching delayed event could be found.
""" """
assert self._is_master assert self._is_master
await self._delayed_event_mgmt_ratelimiter.ratelimit( await self._delayed_event_mgmt_ratelimiter.ratelimit(
requester, None, request.getClientAddress().host
(requester.user.to_string(), requester.device_id),
) )
await make_deferred_yieldable(self._initialized_from_db) await make_deferred_yieldable(self._initialized_from_db)
next_send_ts = await self._store.cancel_delayed_event( next_send_ts = await self._store.cancel_delayed_event(delay_id)
delay_id=delay_id,
user_localpart=requester.user.localpart,
)
if self._next_send_ts_changed(next_send_ts): if self._next_send_ts_changed(next_send_ts):
self._schedule_next_at_or_none(next_send_ts) self._schedule_next_at_or_none(next_send_ts)
async def restart(self, requester: Requester, delay_id: str) -> None: async def restart(self, request: SynapseRequest, delay_id: str) -> None:
""" """
Restarts the scheduled delivery of the matching delayed event. Restarts the scheduled delivery of the matching delayed event.
Args:
requester: The owner of the delayed event to act on.
delay_id: The ID of the delayed event to act on.
Raises: Raises:
NotFoundError: if no matching delayed event could be found. NotFoundError: if no matching delayed event could be found.
""" """
assert self._is_master assert self._is_master
await self._delayed_event_mgmt_ratelimiter.ratelimit( await self._delayed_event_mgmt_ratelimiter.ratelimit(
requester, None, request.getClientAddress().host
(requester.user.to_string(), requester.device_id),
) )
await make_deferred_yieldable(self._initialized_from_db) await make_deferred_yieldable(self._initialized_from_db)
next_send_ts = await self._store.restart_delayed_event( next_send_ts = await self._store.restart_delayed_event(
delay_id=delay_id, delay_id, self._get_current_ts()
user_localpart=requester.user.localpart,
current_ts=self._get_current_ts(),
) )
if self._next_send_ts_changed(next_send_ts): if self._next_send_ts_changed(next_send_ts):
self._schedule_next_at(next_send_ts) self._schedule_next_at(next_send_ts)
async def send(self, requester: Requester, delay_id: str) -> None: async def send(self, request: SynapseRequest, delay_id: str) -> None:
""" """
Immediately sends the matching delayed event, instead of waiting for its scheduled delivery. Immediately sends the matching delayed event, instead of waiting for its scheduled delivery.
Args:
requester: The owner of the delayed event to act on.
delay_id: The ID of the delayed event to act on.
Raises: Raises:
NotFoundError: if no matching delayed event could be found. NotFoundError: if no matching delayed event could be found.
""" """
assert self._is_master assert self._is_master
# Use standard request limiter for sending delayed events on-demand, await self._delayed_event_mgmt_ratelimiter.ratelimit(
# as an on-demand send is similar to sending a regular event. None, request.getClientAddress().host
await self._request_ratelimiter.ratelimit(requester) )
await make_deferred_yieldable(self._initialized_from_db) await make_deferred_yieldable(self._initialized_from_db)
event, next_send_ts = await self._store.process_target_delayed_event( event, next_send_ts = await self._store.process_target_delayed_event(delay_id)
delay_id=delay_id,
user_localpart=requester.user.localpart,
)
if self._next_send_ts_changed(next_send_ts): if self._next_send_ts_changed(next_send_ts):
self._schedule_next_at_or_none(next_send_ts) self._schedule_next_at_or_none(next_send_ts)
await self._send_event( await self._send_event(event)
DelayedEventDetails(
delay_id=DelayID(delay_id),
user_localpart=UserLocalpart(requester.user.localpart),
room_id=event.room_id,
type=event.type,
state_key=event.state_key,
origin_server_ts=event.origin_server_ts,
content=event.content,
device_id=event.device_id,
)
)
async def _send_on_timeout(self) -> None: async def _send_on_timeout(self) -> None:
self._next_delayed_event_call = None self._next_delayed_event_call = None
@ -611,9 +577,7 @@ class DelayedEventsHandler:
finally: finally:
# TODO: If this is a temporary error, retry. Otherwise, consider notifying clients of the failure # TODO: If this is a temporary error, retry. Otherwise, consider notifying clients of the failure
try: try:
await self._store.delete_processed_delayed_event( await self._store.delete_processed_delayed_event(event.delay_id)
event.delay_id, event.user_localpart
)
except Exception: except Exception:
logger.exception("Failed to delete processed delayed event") logger.exception("Failed to delete processed delayed event")

View File

@ -47,14 +47,11 @@ class UpdateDelayedEventServlet(RestServlet):
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth()
self.delayed_events_handler = hs.get_delayed_events_handler() self.delayed_events_handler = hs.get_delayed_events_handler()
async def on_POST( async def on_POST(
self, request: SynapseRequest, delay_id: str self, request: SynapseRequest, delay_id: str
) -> tuple[int, JsonDict]: ) -> tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
try: try:
action = str(body["action"]) action = str(body["action"])
@ -75,11 +72,65 @@ class UpdateDelayedEventServlet(RestServlet):
) )
if enum_action == _UpdateDelayedEventAction.CANCEL: if enum_action == _UpdateDelayedEventAction.CANCEL:
await self.delayed_events_handler.cancel(requester, delay_id) await self.delayed_events_handler.cancel(request, delay_id)
elif enum_action == _UpdateDelayedEventAction.RESTART: elif enum_action == _UpdateDelayedEventAction.RESTART:
await self.delayed_events_handler.restart(requester, delay_id) await self.delayed_events_handler.restart(request, delay_id)
elif enum_action == _UpdateDelayedEventAction.SEND: elif enum_action == _UpdateDelayedEventAction.SEND:
await self.delayed_events_handler.send(requester, delay_id) await self.delayed_events_handler.send(request, delay_id)
return 200, {}
class CancelDelayedEventServlet(RestServlet):
PATTERNS = client_patterns(
r"/org\.matrix\.msc4140/delayed_events/(?P<delay_id>[^/]+)/cancel$",
releases=(),
)
CATEGORY = "Delayed event management requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
self.delayed_events_handler = hs.get_delayed_events_handler()
async def on_POST(
self, request: SynapseRequest, delay_id: str
) -> tuple[int, JsonDict]:
await self.delayed_events_handler.cancel(request, delay_id)
return 200, {}
class RestartDelayedEventServlet(RestServlet):
PATTERNS = client_patterns(
r"/org\.matrix\.msc4140/delayed_events/(?P<delay_id>[^/]+)/restart$",
releases=(),
)
CATEGORY = "Delayed event management requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
self.delayed_events_handler = hs.get_delayed_events_handler()
async def on_POST(
self, request: SynapseRequest, delay_id: str
) -> tuple[int, JsonDict]:
await self.delayed_events_handler.restart(request, delay_id)
return 200, {}
class SendDelayedEventServlet(RestServlet):
PATTERNS = client_patterns(
r"/org\.matrix\.msc4140/delayed_events/(?P<delay_id>[^/]+)/send$",
releases=(),
)
CATEGORY = "Delayed event management requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
self.delayed_events_handler = hs.get_delayed_events_handler()
async def on_POST(
self, request: SynapseRequest, delay_id: str
) -> tuple[int, JsonDict]:
await self.delayed_events_handler.send(request, delay_id)
return 200, {} return 200, {}
@ -108,4 +159,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
# The following can't currently be instantiated on workers. # The following can't currently be instantiated on workers.
if hs.config.worker.worker_app is None: if hs.config.worker.worker_app is None:
UpdateDelayedEventServlet(hs).register(http_server) UpdateDelayedEventServlet(hs).register(http_server)
CancelDelayedEventServlet(hs).register(http_server)
RestartDelayedEventServlet(hs).register(http_server)
SendDelayedEventServlet(hs).register(http_server)
DelayedEventsServlet(hs).register(http_server) DelayedEventsServlet(hs).register(http_server)

View File

@ -13,18 +13,26 @@
# #
import logging import logging
from typing import NewType from typing import TYPE_CHECKING, NewType
import attr import attr
from synapse.api.errors import NotFoundError from synapse.api.errors import NotFoundError
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import LoggingTransaction, StoreError from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
StoreError,
)
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.types import JsonDict, RoomID from synapse.types import JsonDict, RoomID
from synapse.util import stringutils from synapse.util import stringutils
from synapse.util.json import json_encoder from synapse.util.json import json_encoder
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -55,6 +63,27 @@ class DelayedEventDetails(EventDetails):
class DelayedEventsStore(SQLBaseStore): class DelayedEventsStore(SQLBaseStore):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
# Set delayed events to be uniquely identifiable by their delay_id.
# In practice, delay_ids are already unique because they are generated
# from cryptographically strong random strings.
# Therefore, adding this constraint is not expected to ever fail,
# despite the current pkey technically allowing non-unique delay_ids.
self.db_pool.updates.register_background_index_update(
update_name="delayed_events_idx",
index_name="delayed_events_idx",
table="delayed_events",
columns=("delay_id",),
unique=True,
)
async def get_delayed_events_stream_pos(self) -> int: async def get_delayed_events_stream_pos(self) -> int:
""" """
Gets the stream position of the background process to watch for state events Gets the stream position of the background process to watch for state events
@ -134,9 +163,7 @@ class DelayedEventsStore(SQLBaseStore):
async def restart_delayed_event( async def restart_delayed_event(
self, self,
*,
delay_id: str, delay_id: str,
user_localpart: str,
current_ts: Timestamp, current_ts: Timestamp,
) -> Timestamp: ) -> Timestamp:
""" """
@ -145,7 +172,6 @@ class DelayedEventsStore(SQLBaseStore):
Args: Args:
delay_id: The ID of the delayed event to restart. delay_id: The ID of the delayed event to restart.
user_localpart: The localpart of the delayed event's owner.
current_ts: The current time, which will be used to calculate the new send time. current_ts: The current time, which will be used to calculate the new send time.
Returns: The send time of the next delayed event to be sent, Returns: The send time of the next delayed event to be sent,
@ -163,13 +189,11 @@ class DelayedEventsStore(SQLBaseStore):
""" """
UPDATE delayed_events UPDATE delayed_events
SET send_ts = ? + delay SET send_ts = ? + delay
WHERE delay_id = ? AND user_localpart = ? WHERE delay_id = ? AND NOT is_processed
AND NOT is_processed
""", """,
( (
current_ts, current_ts,
delay_id, delay_id,
user_localpart,
), ),
) )
if txn.rowcount == 0: if txn.rowcount == 0:
@ -319,21 +343,15 @@ class DelayedEventsStore(SQLBaseStore):
async def process_target_delayed_event( async def process_target_delayed_event(
self, self,
*,
delay_id: str, delay_id: str,
user_localpart: str,
) -> tuple[ ) -> tuple[
EventDetails, DelayedEventDetails,
Timestamp | None, Timestamp | None,
]: ]:
""" """
Marks for processing the matching delayed event, regardless of its timeout time, Marks for processing the matching delayed event, regardless of its timeout time,
as long as it has not already been marked as such. as long as it has not already been marked as such.
Args:
delay_id: The ID of the delayed event to restart.
user_localpart: The localpart of the delayed event's owner.
Returns: The details of the matching delayed event, Returns: The details of the matching delayed event,
and the send time of the next delayed event to be sent, if any. and the send time of the next delayed event to be sent, if any.
@ -344,39 +362,38 @@ class DelayedEventsStore(SQLBaseStore):
def process_target_delayed_event_txn( def process_target_delayed_event_txn(
txn: LoggingTransaction, txn: LoggingTransaction,
) -> tuple[ ) -> tuple[
EventDetails, DelayedEventDetails,
Timestamp | None, Timestamp | None,
]: ]:
txn.execute( txn.execute(
""" """
UPDATE delayed_events UPDATE delayed_events
SET is_processed = TRUE SET is_processed = TRUE
WHERE delay_id = ? AND user_localpart = ? WHERE delay_id = ? AND NOT is_processed
AND NOT is_processed
RETURNING RETURNING
room_id, room_id,
event_type, event_type,
state_key, state_key,
origin_server_ts, origin_server_ts,
content, content,
device_id device_id,
user_localpart
""", """,
( (delay_id,),
delay_id,
user_localpart,
),
) )
row = txn.fetchone() row = txn.fetchone()
if row is None: if row is None:
raise NotFoundError("Delayed event not found") raise NotFoundError("Delayed event not found")
event = EventDetails( event = DelayedEventDetails(
RoomID.from_string(row[0]), RoomID.from_string(row[0]),
EventType(row[1]), EventType(row[1]),
StateKey(row[2]) if row[2] is not None else None, StateKey(row[2]) if row[2] is not None else None,
Timestamp(row[3]) if row[3] is not None else None, Timestamp(row[3]) if row[3] is not None else None,
db_to_json(row[4]), db_to_json(row[4]),
DeviceID(row[5]) if row[5] is not None else None, DeviceID(row[5]) if row[5] is not None else None,
DelayID(delay_id),
UserLocalpart(row[6]),
) )
return event, self._get_next_delayed_event_send_ts_txn(txn) return event, self._get_next_delayed_event_send_ts_txn(txn)
@ -385,19 +402,10 @@ class DelayedEventsStore(SQLBaseStore):
"process_target_delayed_event", process_target_delayed_event_txn "process_target_delayed_event", process_target_delayed_event_txn
) )
async def cancel_delayed_event( async def cancel_delayed_event(self, delay_id: str) -> Timestamp | None:
self,
*,
delay_id: str,
user_localpart: str,
) -> Timestamp | None:
""" """
Cancels the matching delayed event, i.e. remove it as long as it hasn't been processed. Cancels the matching delayed event, i.e. remove it as long as it hasn't been processed.
Args:
delay_id: The ID of the delayed event to restart.
user_localpart: The localpart of the delayed event's owner.
Returns: The send time of the next delayed event to be sent, if any. Returns: The send time of the next delayed event to be sent, if any.
Raises: Raises:
@ -413,7 +421,6 @@ class DelayedEventsStore(SQLBaseStore):
table="delayed_events", table="delayed_events",
keyvalues={ keyvalues={
"delay_id": delay_id, "delay_id": delay_id,
"user_localpart": user_localpart,
"is_processed": False, "is_processed": False,
}, },
) )
@ -473,11 +480,7 @@ class DelayedEventsStore(SQLBaseStore):
"cancel_delayed_state_events", cancel_delayed_state_events_txn "cancel_delayed_state_events", cancel_delayed_state_events_txn
) )
async def delete_processed_delayed_event( async def delete_processed_delayed_event(self, delay_id: DelayID) -> None:
self,
delay_id: DelayID,
user_localpart: UserLocalpart,
) -> None:
""" """
Delete the matching delayed event, as long as it has been marked as processed. Delete the matching delayed event, as long as it has been marked as processed.
@ -488,7 +491,6 @@ class DelayedEventsStore(SQLBaseStore):
table="delayed_events", table="delayed_events",
keyvalues={ keyvalues={
"delay_id": delay_id, "delay_id": delay_id,
"user_localpart": user_localpart,
"is_processed": True, "is_processed": True,
}, },
desc="delete_processed_delayed_event", desc="delete_processed_delayed_event",
@ -554,7 +556,7 @@ def _generate_delay_id() -> DelayID:
# We use the following format for delay IDs: # We use the following format for delay IDs:
# syd_<random string> # syd_<random string>
# They are scoped to user localparts, so it is possible for # They are not scoped to user localparts, but the random string
# the same ID to exist for multiple users. # is expected to be sufficiently random to be globally unique.
return DelayID(f"syd_{stringutils.random_string(20)}") return DelayID(f"syd_{stringutils.random_string(20)}")

View File

@ -19,7 +19,7 @@
# #
# #
SCHEMA_VERSION = 92 # remember to update the list below when updating SCHEMA_VERSION = 93 # remember to update the list below when updating
"""Represents the expectations made by the codebase about the database schema """Represents the expectations made by the codebase about the database schema
This should be incremented whenever the codebase changes its requirements on the This should be incremented whenever the codebase changes its requirements on the
@ -168,11 +168,15 @@ Changes in SCHEMA_VERSION = 91
Changes in SCHEMA_VERSION = 92 Changes in SCHEMA_VERSION = 92
- Cleaned up a trigger that was added in #18260 and then reverted. - Cleaned up a trigger that was added in #18260 and then reverted.
Changes in SCHEMA_VERSION = 93
- MSC4140: Set delayed events to be uniquely identifiable by their delay ID.
""" """
SCHEMA_COMPAT_VERSION = ( SCHEMA_COMPAT_VERSION = (
# Transitive links are no longer written to `event_auth_chain_links` # Transitive links are no longer written to `event_auth_chain_links`
# TODO: On the next compat bump, update the primary key of `delayed_events`
84 84
) )
"""Limit on how far the synapse codebase can be rolled back without breaking db compat """Limit on how far the synapse codebase can be rolled back without breaking db compat

View File

@ -0,0 +1,15 @@
--
-- This file is licensed under the Affero General Public License (AGPL) version 3.
--
-- Copyright (C) 2025 Element Creations, Ltd
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU Affero General Public License as
-- published by the Free Software Foundation, either version 3 of the
-- License, or (at your option) any later version.
--
-- See the GNU Affero General Public License for more details:
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(9301, 'delayed_events_idx', '{}');

View File

@ -28,6 +28,7 @@ from synapse.types import JsonDict
from synapse.util.clock import Clock from synapse.util.clock import Clock
from tests import unittest from tests import unittest
from tests.server import FakeChannel
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
PATH_PREFIX = "/_matrix/client/unstable/org.matrix.msc4140/delayed_events" PATH_PREFIX = "/_matrix/client/unstable/org.matrix.msc4140/delayed_events"
@ -127,6 +128,10 @@ class DelayedEventsTestCase(HomeserverTestCase):
) )
self.assertEqual(setter_expected, content.get(setter_key), content) self.assertEqual(setter_expected, content.get(setter_key), content)
def test_get_delayed_events_auth(self) -> None:
channel = self.make_request("GET", PATH_PREFIX)
self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, channel.result)
@unittest.override_config( @unittest.override_config(
{"rc_delayed_event_mgmt": {"per_second": 0.5, "burst_count": 1}} {"rc_delayed_event_mgmt": {"per_second": 0.5, "burst_count": 1}}
) )
@ -154,7 +159,6 @@ class DelayedEventsTestCase(HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"POST", "POST",
f"{PATH_PREFIX}/", f"{PATH_PREFIX}/",
access_token=self.user1_access_token,
) )
self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, channel.result) self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, channel.result)
@ -162,7 +166,6 @@ class DelayedEventsTestCase(HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"POST", "POST",
f"{PATH_PREFIX}/abc", f"{PATH_PREFIX}/abc",
access_token=self.user1_access_token,
) )
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result) self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
self.assertEqual( self.assertEqual(
@ -175,7 +178,6 @@ class DelayedEventsTestCase(HomeserverTestCase):
"POST", "POST",
f"{PATH_PREFIX}/abc", f"{PATH_PREFIX}/abc",
{}, {},
self.user1_access_token,
) )
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result) self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
self.assertEqual( self.assertEqual(
@ -188,7 +190,6 @@ class DelayedEventsTestCase(HomeserverTestCase):
"POST", "POST",
f"{PATH_PREFIX}/abc", f"{PATH_PREFIX}/abc",
{"action": "oops"}, {"action": "oops"},
self.user1_access_token,
) )
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result) self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
self.assertEqual( self.assertEqual(
@ -196,17 +197,21 @@ class DelayedEventsTestCase(HomeserverTestCase):
channel.json_body["errcode"], channel.json_body["errcode"],
) )
@parameterized.expand(["cancel", "restart", "send"]) @parameterized.expand(
def test_update_delayed_event_without_match(self, action: str) -> None: (
channel = self.make_request( (action, action_in_path)
"POST", for action in ("cancel", "restart", "send")
f"{PATH_PREFIX}/abc", for action_in_path in (True, False)
{"action": action},
self.user1_access_token,
) )
)
def test_update_delayed_event_without_match(
self, action: str, action_in_path: bool
) -> None:
channel = self._update_delayed_event("abc", action, action_in_path)
self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, channel.result) self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, channel.result)
def test_cancel_delayed_state_event(self) -> None: @parameterized.expand((True, False))
def test_cancel_delayed_state_event(self, action_in_path: bool) -> None:
state_key = "to_never_send" state_key = "to_never_send"
setter_key = "setter" setter_key = "setter"
@ -221,7 +226,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
) )
self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
delay_id = channel.json_body.get("delay_id") delay_id = channel.json_body.get("delay_id")
self.assertIsNotNone(delay_id) assert delay_id is not None
self.reactor.advance(1) self.reactor.advance(1)
events = self._get_delayed_events() events = self._get_delayed_events()
@ -236,12 +241,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
expect_code=HTTPStatus.NOT_FOUND, expect_code=HTTPStatus.NOT_FOUND,
) )
channel = self.make_request( channel = self._update_delayed_event(delay_id, "cancel", action_in_path)
"POST",
f"{PATH_PREFIX}/{delay_id}",
{"action": "cancel"},
self.user1_access_token,
)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertListEqual([], self._get_delayed_events()) self.assertListEqual([], self._get_delayed_events())
@ -254,10 +254,11 @@ class DelayedEventsTestCase(HomeserverTestCase):
expect_code=HTTPStatus.NOT_FOUND, expect_code=HTTPStatus.NOT_FOUND,
) )
@parameterized.expand((True, False))
@unittest.override_config( @unittest.override_config(
{"rc_delayed_event_mgmt": {"per_second": 0.5, "burst_count": 1}} {"rc_delayed_event_mgmt": {"per_second": 0.5, "burst_count": 1}}
) )
def test_cancel_delayed_event_ratelimit(self) -> None: def test_cancel_delayed_event_ratelimit(self, action_in_path: bool) -> None:
delay_ids = [] delay_ids = []
for _ in range(2): for _ in range(2):
channel = self.make_request( channel = self.make_request(
@ -268,38 +269,17 @@ class DelayedEventsTestCase(HomeserverTestCase):
) )
self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
delay_id = channel.json_body.get("delay_id") delay_id = channel.json_body.get("delay_id")
self.assertIsNotNone(delay_id) assert delay_id is not None
delay_ids.append(delay_id) delay_ids.append(delay_id)
channel = self.make_request( channel = self._update_delayed_event(delay_ids.pop(0), "cancel", action_in_path)
"POST",
f"{PATH_PREFIX}/{delay_ids.pop(0)}",
{"action": "cancel"},
self.user1_access_token,
)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
args = ( channel = self._update_delayed_event(delay_ids.pop(0), "cancel", action_in_path)
"POST",
f"{PATH_PREFIX}/{delay_ids.pop(0)}",
{"action": "cancel"},
self.user1_access_token,
)
channel = self.make_request(*args)
self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result) self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result)
# Add the current user to the ratelimit overrides, allowing them no ratelimiting. @parameterized.expand((True, False))
self.get_success( def test_send_delayed_state_event(self, action_in_path: bool) -> None:
self.hs.get_datastores().main.set_ratelimit_for_user(
self.user1_user_id, 0, 0
)
)
# Test that the request isn't ratelimited anymore.
channel = self.make_request(*args)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
def test_send_delayed_state_event(self) -> None:
state_key = "to_send_on_request" state_key = "to_send_on_request"
setter_key = "setter" setter_key = "setter"
@ -314,7 +294,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
) )
self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
delay_id = channel.json_body.get("delay_id") delay_id = channel.json_body.get("delay_id")
self.assertIsNotNone(delay_id) assert delay_id is not None
self.reactor.advance(1) self.reactor.advance(1)
events = self._get_delayed_events() events = self._get_delayed_events()
@ -329,12 +309,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
expect_code=HTTPStatus.NOT_FOUND, expect_code=HTTPStatus.NOT_FOUND,
) )
channel = self.make_request( channel = self._update_delayed_event(delay_id, "send", action_in_path)
"POST",
f"{PATH_PREFIX}/{delay_id}",
{"action": "send"},
self.user1_access_token,
)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertListEqual([], self._get_delayed_events()) self.assertListEqual([], self._get_delayed_events())
content = self.helper.get_state( content = self.helper.get_state(
@ -345,8 +320,9 @@ class DelayedEventsTestCase(HomeserverTestCase):
) )
self.assertEqual(setter_expected, content.get(setter_key), content) self.assertEqual(setter_expected, content.get(setter_key), content)
@unittest.override_config({"rc_message": {"per_second": 3.5, "burst_count": 4}}) @parameterized.expand((True, False))
def test_send_delayed_event_ratelimit(self) -> None: @unittest.override_config({"rc_message": {"per_second": 2.5, "burst_count": 3}})
def test_send_delayed_event_ratelimit(self, action_in_path: bool) -> None:
delay_ids = [] delay_ids = []
for _ in range(2): for _ in range(2):
channel = self.make_request( channel = self.make_request(
@ -357,38 +333,17 @@ class DelayedEventsTestCase(HomeserverTestCase):
) )
self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
delay_id = channel.json_body.get("delay_id") delay_id = channel.json_body.get("delay_id")
self.assertIsNotNone(delay_id) assert delay_id is not None
delay_ids.append(delay_id) delay_ids.append(delay_id)
channel = self.make_request( channel = self._update_delayed_event(delay_ids.pop(0), "send", action_in_path)
"POST",
f"{PATH_PREFIX}/{delay_ids.pop(0)}",
{"action": "send"},
self.user1_access_token,
)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
args = ( channel = self._update_delayed_event(delay_ids.pop(0), "send", action_in_path)
"POST",
f"{PATH_PREFIX}/{delay_ids.pop(0)}",
{"action": "send"},
self.user1_access_token,
)
channel = self.make_request(*args)
self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result) self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result)
# Add the current user to the ratelimit overrides, allowing them no ratelimiting. @parameterized.expand((True, False))
self.get_success( def test_restart_delayed_state_event(self, action_in_path: bool) -> None:
self.hs.get_datastores().main.set_ratelimit_for_user(
self.user1_user_id, 0, 0
)
)
# Test that the request isn't ratelimited anymore.
channel = self.make_request(*args)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
def test_restart_delayed_state_event(self) -> None:
state_key = "to_send_on_restarted_timeout" state_key = "to_send_on_restarted_timeout"
setter_key = "setter" setter_key = "setter"
@ -403,7 +358,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
) )
self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
delay_id = channel.json_body.get("delay_id") delay_id = channel.json_body.get("delay_id")
self.assertIsNotNone(delay_id) assert delay_id is not None
self.reactor.advance(1) self.reactor.advance(1)
events = self._get_delayed_events() events = self._get_delayed_events()
@ -418,12 +373,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
expect_code=HTTPStatus.NOT_FOUND, expect_code=HTTPStatus.NOT_FOUND,
) )
channel = self.make_request( channel = self._update_delayed_event(delay_id, "restart", action_in_path)
"POST",
f"{PATH_PREFIX}/{delay_id}",
{"action": "restart"},
self.user1_access_token,
)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.reactor.advance(1) self.reactor.advance(1)
@ -449,10 +399,11 @@ class DelayedEventsTestCase(HomeserverTestCase):
) )
self.assertEqual(setter_expected, content.get(setter_key), content) self.assertEqual(setter_expected, content.get(setter_key), content)
@parameterized.expand((True, False))
@unittest.override_config( @unittest.override_config(
{"rc_delayed_event_mgmt": {"per_second": 0.5, "burst_count": 1}} {"rc_delayed_event_mgmt": {"per_second": 0.5, "burst_count": 1}}
) )
def test_restart_delayed_event_ratelimit(self) -> None: def test_restart_delayed_event_ratelimit(self, action_in_path: bool) -> None:
delay_ids = [] delay_ids = []
for _ in range(2): for _ in range(2):
channel = self.make_request( channel = self.make_request(
@ -463,37 +414,19 @@ class DelayedEventsTestCase(HomeserverTestCase):
) )
self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
delay_id = channel.json_body.get("delay_id") delay_id = channel.json_body.get("delay_id")
self.assertIsNotNone(delay_id) assert delay_id is not None
delay_ids.append(delay_id) delay_ids.append(delay_id)
channel = self.make_request( channel = self._update_delayed_event(
"POST", delay_ids.pop(0), "restart", action_in_path
f"{PATH_PREFIX}/{delay_ids.pop(0)}",
{"action": "restart"},
self.user1_access_token,
) )
self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
args = ( channel = self._update_delayed_event(
"POST", delay_ids.pop(0), "restart", action_in_path
f"{PATH_PREFIX}/{delay_ids.pop(0)}",
{"action": "restart"},
self.user1_access_token,
) )
channel = self.make_request(*args)
self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result) self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result)
# Add the current user to the ratelimit overrides, allowing them no ratelimiting.
self.get_success(
self.hs.get_datastores().main.set_ratelimit_for_user(
self.user1_user_id, 0, 0
)
)
# Test that the request isn't ratelimited anymore.
channel = self.make_request(*args)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
def test_delayed_state_is_not_cancelled_by_new_state_from_same_user( def test_delayed_state_is_not_cancelled_by_new_state_from_same_user(
self, self,
) -> None: ) -> None:
@ -598,6 +531,17 @@ class DelayedEventsTestCase(HomeserverTestCase):
return content return content
def _update_delayed_event(
self, delay_id: str, action: str, action_in_path: bool
) -> FakeChannel:
path = f"{PATH_PREFIX}/{delay_id}"
body = {}
if action_in_path:
path += f"/{action}"
else:
body["action"] = action
return self.make_request("POST", path, body)
def _get_path_for_delayed_state( def _get_path_for_delayed_state(
room_id: str, event_type: str, state_key: str, delay_ms: int room_id: str, event_type: str, state_key: str, delay_ms: int