MSC4140: Remove auth from delayed event management endpoints (#19152)

As per recent proposals in MSC4140, remove authentication for
restarting/cancelling/sending a delayed event, and give each of those
actions its own endpoint. (The original consolidated endpoint is still
supported for backwards compatibility.)

### Pull Request Checklist

<!-- Please read
https://element-hq.github.io/synapse/latest/development/contributing_guide.html
before submitting your pull request -->

* [x] Pull request is based on the develop branch
* [x] Pull request includes a [changelog
file](https://element-hq.github.io/synapse/latest/development/contributing_guide.html#changelog).
The entry should:
- Be a short description of your change which makes sense to users.
"Fixed a bug that prevented receiving messages from other servers."
instead of "Moved X method from `EventStore` to `EventWorkerStore`.".
  - Use markdown where necessary, mostly for `code blocks`.
  - End with either a period (.) or an exclamation mark (!).
  - Start with a capital letter.
- Feel free to credit yourself, by adding a sentence "Contributed by
@github_username." or "Contributed by [Your Name]." to the end of the
entry.
* [x] [Code
style](https://element-hq.github.io/synapse/latest/code_style.html) is
correct (run the
[linters](https://element-hq.github.io/synapse/latest/development/contributing_guide.html#run-the-linters))

---------

Co-authored-by: Half-Shot <will@half-shot.uk>
This commit is contained in:
Andrew Ferrazzutti 2025-11-13 13:56:17 -05:00 committed by GitHub
parent 4494cc0694
commit 9e23cded8f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 198 additions and 212 deletions

View File

@ -0,0 +1 @@
Remove authentication from `POST /_matrix/client/v1/delayed_events`, and allow calling this endpoint with the update action to take (`send`/`cancel`/`restart`) in the request path instead of the body.

View File

@ -58,6 +58,7 @@ from synapse.storage.database import DatabasePool, LoggingTransaction, make_conn
from synapse.storage.databases.main import FilteringWorkerStore
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateStore
from synapse.storage.databases.main.delayed_events import DelayedEventsStore
from synapse.storage.databases.main.deviceinbox import DeviceInboxBackgroundUpdateStore
from synapse.storage.databases.main.devices import DeviceBackgroundUpdateStore
from synapse.storage.databases.main.e2e_room_keys import EndToEndRoomKeyBackgroundStore
@ -273,6 +274,7 @@ class Store(
RelationsWorkerStore,
EventFederationWorkerStore,
SlidingSyncStore,
DelayedEventsStore,
):
def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]:
return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs)

View File

@ -21,6 +21,7 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import ShadowBanError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
from synapse.config.workers import MAIN_PROCESS_INSTANCE_NAME
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
from synapse.logging.opentracing import set_tag
from synapse.metrics import SERVER_NAME_LABEL, event_processing_positions
@ -29,11 +30,9 @@ from synapse.replication.http.delayed_events import (
)
from synapse.storage.databases.main.delayed_events import (
DelayedEventDetails,
DelayID,
EventType,
StateKey,
Timestamp,
UserLocalpart,
)
from synapse.storage.databases.main.state_deltas import StateDelta
from synapse.types import (
@ -399,96 +398,63 @@ class DelayedEventsHandler:
if self._next_send_ts_changed(next_send_ts):
self._schedule_next_at(next_send_ts)
async def cancel(self, requester: Requester, delay_id: str) -> None:
async def cancel(self, request: SynapseRequest, delay_id: str) -> None:
"""
Cancels the scheduled delivery of the matching delayed event.
Args:
requester: The owner of the delayed event to act on.
delay_id: The ID of the delayed event to act on.
Raises:
NotFoundError: if no matching delayed event could be found.
"""
assert self._is_master
await self._delayed_event_mgmt_ratelimiter.ratelimit(
requester,
(requester.user.to_string(), requester.device_id),
None, request.getClientAddress().host
)
await make_deferred_yieldable(self._initialized_from_db)
next_send_ts = await self._store.cancel_delayed_event(
delay_id=delay_id,
user_localpart=requester.user.localpart,
)
next_send_ts = await self._store.cancel_delayed_event(delay_id)
if self._next_send_ts_changed(next_send_ts):
self._schedule_next_at_or_none(next_send_ts)
async def restart(self, requester: Requester, delay_id: str) -> None:
async def restart(self, request: SynapseRequest, delay_id: str) -> None:
"""
Restarts the scheduled delivery of the matching delayed event.
Args:
requester: The owner of the delayed event to act on.
delay_id: The ID of the delayed event to act on.
Raises:
NotFoundError: if no matching delayed event could be found.
"""
assert self._is_master
await self._delayed_event_mgmt_ratelimiter.ratelimit(
requester,
(requester.user.to_string(), requester.device_id),
None, request.getClientAddress().host
)
await make_deferred_yieldable(self._initialized_from_db)
next_send_ts = await self._store.restart_delayed_event(
delay_id=delay_id,
user_localpart=requester.user.localpart,
current_ts=self._get_current_ts(),
delay_id, self._get_current_ts()
)
if self._next_send_ts_changed(next_send_ts):
self._schedule_next_at(next_send_ts)
async def send(self, requester: Requester, delay_id: str) -> None:
async def send(self, request: SynapseRequest, delay_id: str) -> None:
"""
Immediately sends the matching delayed event, instead of waiting for its scheduled delivery.
Args:
requester: The owner of the delayed event to act on.
delay_id: The ID of the delayed event to act on.
Raises:
NotFoundError: if no matching delayed event could be found.
"""
assert self._is_master
# Use standard request limiter for sending delayed events on-demand,
# as an on-demand send is similar to sending a regular event.
await self._request_ratelimiter.ratelimit(requester)
await self._delayed_event_mgmt_ratelimiter.ratelimit(
None, request.getClientAddress().host
)
await make_deferred_yieldable(self._initialized_from_db)
event, next_send_ts = await self._store.process_target_delayed_event(
delay_id=delay_id,
user_localpart=requester.user.localpart,
)
event, next_send_ts = await self._store.process_target_delayed_event(delay_id)
if self._next_send_ts_changed(next_send_ts):
self._schedule_next_at_or_none(next_send_ts)
await self._send_event(
DelayedEventDetails(
delay_id=DelayID(delay_id),
user_localpart=UserLocalpart(requester.user.localpart),
room_id=event.room_id,
type=event.type,
state_key=event.state_key,
origin_server_ts=event.origin_server_ts,
content=event.content,
device_id=event.device_id,
)
)
await self._send_event(event)
async def _send_on_timeout(self) -> None:
self._next_delayed_event_call = None
@ -611,9 +577,7 @@ class DelayedEventsHandler:
finally:
# TODO: If this is a temporary error, retry. Otherwise, consider notifying clients of the failure
try:
await self._store.delete_processed_delayed_event(
event.delay_id, event.user_localpart
)
await self._store.delete_processed_delayed_event(event.delay_id)
except Exception:
logger.exception("Failed to delete processed delayed event")

View File

@ -47,14 +47,11 @@ class UpdateDelayedEventServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.delayed_events_handler = hs.get_delayed_events_handler()
async def on_POST(
self, request: SynapseRequest, delay_id: str
) -> tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
body = parse_json_object_from_request(request)
try:
action = str(body["action"])
@ -75,11 +72,65 @@ class UpdateDelayedEventServlet(RestServlet):
)
if enum_action == _UpdateDelayedEventAction.CANCEL:
await self.delayed_events_handler.cancel(requester, delay_id)
await self.delayed_events_handler.cancel(request, delay_id)
elif enum_action == _UpdateDelayedEventAction.RESTART:
await self.delayed_events_handler.restart(requester, delay_id)
await self.delayed_events_handler.restart(request, delay_id)
elif enum_action == _UpdateDelayedEventAction.SEND:
await self.delayed_events_handler.send(requester, delay_id)
await self.delayed_events_handler.send(request, delay_id)
return 200, {}
class CancelDelayedEventServlet(RestServlet):
PATTERNS = client_patterns(
r"/org\.matrix\.msc4140/delayed_events/(?P<delay_id>[^/]+)/cancel$",
releases=(),
)
CATEGORY = "Delayed event management requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
self.delayed_events_handler = hs.get_delayed_events_handler()
async def on_POST(
self, request: SynapseRequest, delay_id: str
) -> tuple[int, JsonDict]:
await self.delayed_events_handler.cancel(request, delay_id)
return 200, {}
class RestartDelayedEventServlet(RestServlet):
PATTERNS = client_patterns(
r"/org\.matrix\.msc4140/delayed_events/(?P<delay_id>[^/]+)/restart$",
releases=(),
)
CATEGORY = "Delayed event management requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
self.delayed_events_handler = hs.get_delayed_events_handler()
async def on_POST(
self, request: SynapseRequest, delay_id: str
) -> tuple[int, JsonDict]:
await self.delayed_events_handler.restart(request, delay_id)
return 200, {}
class SendDelayedEventServlet(RestServlet):
PATTERNS = client_patterns(
r"/org\.matrix\.msc4140/delayed_events/(?P<delay_id>[^/]+)/send$",
releases=(),
)
CATEGORY = "Delayed event management requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
self.delayed_events_handler = hs.get_delayed_events_handler()
async def on_POST(
self, request: SynapseRequest, delay_id: str
) -> tuple[int, JsonDict]:
await self.delayed_events_handler.send(request, delay_id)
return 200, {}
@ -108,4 +159,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
# The following can't currently be instantiated on workers.
if hs.config.worker.worker_app is None:
UpdateDelayedEventServlet(hs).register(http_server)
CancelDelayedEventServlet(hs).register(http_server)
RestartDelayedEventServlet(hs).register(http_server)
SendDelayedEventServlet(hs).register(http_server)
DelayedEventsServlet(hs).register(http_server)

View File

@ -13,18 +13,26 @@
#
import logging
from typing import NewType
from typing import TYPE_CHECKING, NewType
import attr
from synapse.api.errors import NotFoundError
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import LoggingTransaction, StoreError
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
StoreError,
)
from synapse.storage.engines import PostgresEngine
from synapse.types import JsonDict, RoomID
from synapse.util import stringutils
from synapse.util.json import json_encoder
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@ -55,6 +63,27 @@ class DelayedEventDetails(EventDetails):
class DelayedEventsStore(SQLBaseStore):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
# Set delayed events to be uniquely identifiable by their delay_id.
# In practice, delay_ids are already unique because they are generated
# from cryptographically strong random strings.
# Therefore, adding this constraint is not expected to ever fail,
# despite the current pkey technically allowing non-unique delay_ids.
self.db_pool.updates.register_background_index_update(
update_name="delayed_events_idx",
index_name="delayed_events_idx",
table="delayed_events",
columns=("delay_id",),
unique=True,
)
async def get_delayed_events_stream_pos(self) -> int:
"""
Gets the stream position of the background process to watch for state events
@ -134,9 +163,7 @@ class DelayedEventsStore(SQLBaseStore):
async def restart_delayed_event(
self,
*,
delay_id: str,
user_localpart: str,
current_ts: Timestamp,
) -> Timestamp:
"""
@ -145,7 +172,6 @@ class DelayedEventsStore(SQLBaseStore):
Args:
delay_id: The ID of the delayed event to restart.
user_localpart: The localpart of the delayed event's owner.
current_ts: The current time, which will be used to calculate the new send time.
Returns: The send time of the next delayed event to be sent,
@ -163,13 +189,11 @@ class DelayedEventsStore(SQLBaseStore):
"""
UPDATE delayed_events
SET send_ts = ? + delay
WHERE delay_id = ? AND user_localpart = ?
AND NOT is_processed
WHERE delay_id = ? AND NOT is_processed
""",
(
current_ts,
delay_id,
user_localpart,
),
)
if txn.rowcount == 0:
@ -319,21 +343,15 @@ class DelayedEventsStore(SQLBaseStore):
async def process_target_delayed_event(
self,
*,
delay_id: str,
user_localpart: str,
) -> tuple[
EventDetails,
DelayedEventDetails,
Timestamp | None,
]:
"""
Marks for processing the matching delayed event, regardless of its timeout time,
as long as it has not already been marked as such.
Args:
delay_id: The ID of the delayed event to restart.
user_localpart: The localpart of the delayed event's owner.
Returns: The details of the matching delayed event,
and the send time of the next delayed event to be sent, if any.
@ -344,39 +362,38 @@ class DelayedEventsStore(SQLBaseStore):
def process_target_delayed_event_txn(
txn: LoggingTransaction,
) -> tuple[
EventDetails,
DelayedEventDetails,
Timestamp | None,
]:
txn.execute(
"""
UPDATE delayed_events
SET is_processed = TRUE
WHERE delay_id = ? AND user_localpart = ?
AND NOT is_processed
WHERE delay_id = ? AND NOT is_processed
RETURNING
room_id,
event_type,
state_key,
origin_server_ts,
content,
device_id
device_id,
user_localpart
""",
(
delay_id,
user_localpart,
),
(delay_id,),
)
row = txn.fetchone()
if row is None:
raise NotFoundError("Delayed event not found")
event = EventDetails(
event = DelayedEventDetails(
RoomID.from_string(row[0]),
EventType(row[1]),
StateKey(row[2]) if row[2] is not None else None,
Timestamp(row[3]) if row[3] is not None else None,
db_to_json(row[4]),
DeviceID(row[5]) if row[5] is not None else None,
DelayID(delay_id),
UserLocalpart(row[6]),
)
return event, self._get_next_delayed_event_send_ts_txn(txn)
@ -385,19 +402,10 @@ class DelayedEventsStore(SQLBaseStore):
"process_target_delayed_event", process_target_delayed_event_txn
)
async def cancel_delayed_event(
self,
*,
delay_id: str,
user_localpart: str,
) -> Timestamp | None:
async def cancel_delayed_event(self, delay_id: str) -> Timestamp | None:
"""
Cancels the matching delayed event, i.e. remove it as long as it hasn't been processed.
Args:
delay_id: The ID of the delayed event to restart.
user_localpart: The localpart of the delayed event's owner.
Returns: The send time of the next delayed event to be sent, if any.
Raises:
@ -413,7 +421,6 @@ class DelayedEventsStore(SQLBaseStore):
table="delayed_events",
keyvalues={
"delay_id": delay_id,
"user_localpart": user_localpart,
"is_processed": False,
},
)
@ -473,11 +480,7 @@ class DelayedEventsStore(SQLBaseStore):
"cancel_delayed_state_events", cancel_delayed_state_events_txn
)
async def delete_processed_delayed_event(
self,
delay_id: DelayID,
user_localpart: UserLocalpart,
) -> None:
async def delete_processed_delayed_event(self, delay_id: DelayID) -> None:
"""
Delete the matching delayed event, as long as it has been marked as processed.
@ -488,7 +491,6 @@ class DelayedEventsStore(SQLBaseStore):
table="delayed_events",
keyvalues={
"delay_id": delay_id,
"user_localpart": user_localpart,
"is_processed": True,
},
desc="delete_processed_delayed_event",
@ -554,7 +556,7 @@ def _generate_delay_id() -> DelayID:
# We use the following format for delay IDs:
# syd_<random string>
# They are scoped to user localparts, so it is possible for
# the same ID to exist for multiple users.
# They are not scoped to user localparts, but the random string
# is expected to be sufficiently random to be globally unique.
return DelayID(f"syd_{stringutils.random_string(20)}")

View File

@ -19,7 +19,7 @@
#
#
SCHEMA_VERSION = 92 # remember to update the list below when updating
SCHEMA_VERSION = 93 # remember to update the list below when updating
"""Represents the expectations made by the codebase about the database schema
This should be incremented whenever the codebase changes its requirements on the
@ -168,11 +168,15 @@ Changes in SCHEMA_VERSION = 91
Changes in SCHEMA_VERSION = 92
- Cleaned up a trigger that was added in #18260 and then reverted.
Changes in SCHEMA_VERSION = 93
- MSC4140: Set delayed events to be uniquely identifiable by their delay ID.
"""
SCHEMA_COMPAT_VERSION = (
# Transitive links are no longer written to `event_auth_chain_links`
# TODO: On the next compat bump, update the primary key of `delayed_events`
84
)
"""Limit on how far the synapse codebase can be rolled back without breaking db compat

View File

@ -0,0 +1,15 @@
--
-- This file is licensed under the Affero General Public License (AGPL) version 3.
--
-- Copyright (C) 2025 Element Creations, Ltd
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU Affero General Public License as
-- published by the Free Software Foundation, either version 3 of the
-- License, or (at your option) any later version.
--
-- See the GNU Affero General Public License for more details:
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(9301, 'delayed_events_idx', '{}');

View File

@ -28,6 +28,7 @@ from synapse.types import JsonDict
from synapse.util.clock import Clock
from tests import unittest
from tests.server import FakeChannel
from tests.unittest import HomeserverTestCase
PATH_PREFIX = "/_matrix/client/unstable/org.matrix.msc4140/delayed_events"
@ -127,6 +128,10 @@ class DelayedEventsTestCase(HomeserverTestCase):
)
self.assertEqual(setter_expected, content.get(setter_key), content)
def test_get_delayed_events_auth(self) -> None:
channel = self.make_request("GET", PATH_PREFIX)
self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, channel.result)
@unittest.override_config(
{"rc_delayed_event_mgmt": {"per_second": 0.5, "burst_count": 1}}
)
@ -154,7 +159,6 @@ class DelayedEventsTestCase(HomeserverTestCase):
channel = self.make_request(
"POST",
f"{PATH_PREFIX}/",
access_token=self.user1_access_token,
)
self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, channel.result)
@ -162,7 +166,6 @@ class DelayedEventsTestCase(HomeserverTestCase):
channel = self.make_request(
"POST",
f"{PATH_PREFIX}/abc",
access_token=self.user1_access_token,
)
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
self.assertEqual(
@ -175,7 +178,6 @@ class DelayedEventsTestCase(HomeserverTestCase):
"POST",
f"{PATH_PREFIX}/abc",
{},
self.user1_access_token,
)
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
self.assertEqual(
@ -188,7 +190,6 @@ class DelayedEventsTestCase(HomeserverTestCase):
"POST",
f"{PATH_PREFIX}/abc",
{"action": "oops"},
self.user1_access_token,
)
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
self.assertEqual(
@ -196,17 +197,21 @@ class DelayedEventsTestCase(HomeserverTestCase):
channel.json_body["errcode"],
)
@parameterized.expand(["cancel", "restart", "send"])
def test_update_delayed_event_without_match(self, action: str) -> None:
channel = self.make_request(
"POST",
f"{PATH_PREFIX}/abc",
{"action": action},
self.user1_access_token,
@parameterized.expand(
(
(action, action_in_path)
for action in ("cancel", "restart", "send")
for action_in_path in (True, False)
)
)
def test_update_delayed_event_without_match(
self, action: str, action_in_path: bool
) -> None:
channel = self._update_delayed_event("abc", action, action_in_path)
self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, channel.result)
def test_cancel_delayed_state_event(self) -> None:
@parameterized.expand((True, False))
def test_cancel_delayed_state_event(self, action_in_path: bool) -> None:
state_key = "to_never_send"
setter_key = "setter"
@ -221,7 +226,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
delay_id = channel.json_body.get("delay_id")
self.assertIsNotNone(delay_id)
assert delay_id is not None
self.reactor.advance(1)
events = self._get_delayed_events()
@ -236,12 +241,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
expect_code=HTTPStatus.NOT_FOUND,
)
channel = self.make_request(
"POST",
f"{PATH_PREFIX}/{delay_id}",
{"action": "cancel"},
self.user1_access_token,
)
channel = self._update_delayed_event(delay_id, "cancel", action_in_path)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertListEqual([], self._get_delayed_events())
@ -254,10 +254,11 @@ class DelayedEventsTestCase(HomeserverTestCase):
expect_code=HTTPStatus.NOT_FOUND,
)
@parameterized.expand((True, False))
@unittest.override_config(
{"rc_delayed_event_mgmt": {"per_second": 0.5, "burst_count": 1}}
)
def test_cancel_delayed_event_ratelimit(self) -> None:
def test_cancel_delayed_event_ratelimit(self, action_in_path: bool) -> None:
delay_ids = []
for _ in range(2):
channel = self.make_request(
@ -268,38 +269,17 @@ class DelayedEventsTestCase(HomeserverTestCase):
)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
delay_id = channel.json_body.get("delay_id")
self.assertIsNotNone(delay_id)
assert delay_id is not None
delay_ids.append(delay_id)
channel = self.make_request(
"POST",
f"{PATH_PREFIX}/{delay_ids.pop(0)}",
{"action": "cancel"},
self.user1_access_token,
)
channel = self._update_delayed_event(delay_ids.pop(0), "cancel", action_in_path)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
args = (
"POST",
f"{PATH_PREFIX}/{delay_ids.pop(0)}",
{"action": "cancel"},
self.user1_access_token,
)
channel = self.make_request(*args)
channel = self._update_delayed_event(delay_ids.pop(0), "cancel", action_in_path)
self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result)
# Add the current user to the ratelimit overrides, allowing them no ratelimiting.
self.get_success(
self.hs.get_datastores().main.set_ratelimit_for_user(
self.user1_user_id, 0, 0
)
)
# Test that the request isn't ratelimited anymore.
channel = self.make_request(*args)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
def test_send_delayed_state_event(self) -> None:
@parameterized.expand((True, False))
def test_send_delayed_state_event(self, action_in_path: bool) -> None:
state_key = "to_send_on_request"
setter_key = "setter"
@ -314,7 +294,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
delay_id = channel.json_body.get("delay_id")
self.assertIsNotNone(delay_id)
assert delay_id is not None
self.reactor.advance(1)
events = self._get_delayed_events()
@ -329,12 +309,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
expect_code=HTTPStatus.NOT_FOUND,
)
channel = self.make_request(
"POST",
f"{PATH_PREFIX}/{delay_id}",
{"action": "send"},
self.user1_access_token,
)
channel = self._update_delayed_event(delay_id, "send", action_in_path)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertListEqual([], self._get_delayed_events())
content = self.helper.get_state(
@ -345,8 +320,9 @@ class DelayedEventsTestCase(HomeserverTestCase):
)
self.assertEqual(setter_expected, content.get(setter_key), content)
@unittest.override_config({"rc_message": {"per_second": 3.5, "burst_count": 4}})
def test_send_delayed_event_ratelimit(self) -> None:
@parameterized.expand((True, False))
@unittest.override_config({"rc_message": {"per_second": 2.5, "burst_count": 3}})
def test_send_delayed_event_ratelimit(self, action_in_path: bool) -> None:
delay_ids = []
for _ in range(2):
channel = self.make_request(
@ -357,38 +333,17 @@ class DelayedEventsTestCase(HomeserverTestCase):
)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
delay_id = channel.json_body.get("delay_id")
self.assertIsNotNone(delay_id)
assert delay_id is not None
delay_ids.append(delay_id)
channel = self.make_request(
"POST",
f"{PATH_PREFIX}/{delay_ids.pop(0)}",
{"action": "send"},
self.user1_access_token,
)
channel = self._update_delayed_event(delay_ids.pop(0), "send", action_in_path)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
args = (
"POST",
f"{PATH_PREFIX}/{delay_ids.pop(0)}",
{"action": "send"},
self.user1_access_token,
)
channel = self.make_request(*args)
channel = self._update_delayed_event(delay_ids.pop(0), "send", action_in_path)
self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result)
# Add the current user to the ratelimit overrides, allowing them no ratelimiting.
self.get_success(
self.hs.get_datastores().main.set_ratelimit_for_user(
self.user1_user_id, 0, 0
)
)
# Test that the request isn't ratelimited anymore.
channel = self.make_request(*args)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
def test_restart_delayed_state_event(self) -> None:
@parameterized.expand((True, False))
def test_restart_delayed_state_event(self, action_in_path: bool) -> None:
state_key = "to_send_on_restarted_timeout"
setter_key = "setter"
@ -403,7 +358,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
delay_id = channel.json_body.get("delay_id")
self.assertIsNotNone(delay_id)
assert delay_id is not None
self.reactor.advance(1)
events = self._get_delayed_events()
@ -418,12 +373,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
expect_code=HTTPStatus.NOT_FOUND,
)
channel = self.make_request(
"POST",
f"{PATH_PREFIX}/{delay_id}",
{"action": "restart"},
self.user1_access_token,
)
channel = self._update_delayed_event(delay_id, "restart", action_in_path)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.reactor.advance(1)
@ -449,10 +399,11 @@ class DelayedEventsTestCase(HomeserverTestCase):
)
self.assertEqual(setter_expected, content.get(setter_key), content)
@parameterized.expand((True, False))
@unittest.override_config(
{"rc_delayed_event_mgmt": {"per_second": 0.5, "burst_count": 1}}
)
def test_restart_delayed_event_ratelimit(self) -> None:
def test_restart_delayed_event_ratelimit(self, action_in_path: bool) -> None:
delay_ids = []
for _ in range(2):
channel = self.make_request(
@ -463,37 +414,19 @@ class DelayedEventsTestCase(HomeserverTestCase):
)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
delay_id = channel.json_body.get("delay_id")
self.assertIsNotNone(delay_id)
assert delay_id is not None
delay_ids.append(delay_id)
channel = self.make_request(
"POST",
f"{PATH_PREFIX}/{delay_ids.pop(0)}",
{"action": "restart"},
self.user1_access_token,
channel = self._update_delayed_event(
delay_ids.pop(0), "restart", action_in_path
)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
args = (
"POST",
f"{PATH_PREFIX}/{delay_ids.pop(0)}",
{"action": "restart"},
self.user1_access_token,
channel = self._update_delayed_event(
delay_ids.pop(0), "restart", action_in_path
)
channel = self.make_request(*args)
self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result)
# Add the current user to the ratelimit overrides, allowing them no ratelimiting.
self.get_success(
self.hs.get_datastores().main.set_ratelimit_for_user(
self.user1_user_id, 0, 0
)
)
# Test that the request isn't ratelimited anymore.
channel = self.make_request(*args)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
def test_delayed_state_is_not_cancelled_by_new_state_from_same_user(
self,
) -> None:
@ -598,6 +531,17 @@ class DelayedEventsTestCase(HomeserverTestCase):
return content
def _update_delayed_event(
self, delay_id: str, action: str, action_in_path: bool
) -> FakeChannel:
path = f"{PATH_PREFIX}/{delay_id}"
body = {}
if action_in_path:
path += f"/{action}"
else:
body["action"] = action
return self.make_request("POST", path, body)
def _get_path_for_delayed_state(
room_id: str, event_type: str, state_key: str, delay_ms: int