mirror of
https://github.com/element-hq/synapse.git
synced 2025-12-10 00:02:09 -05:00
MSC4140: Remove auth from delayed event management endpoints (#19152)
As per recent proposals in MSC4140, remove authentication for restarting/cancelling/sending a delayed event, and give each of those actions its own endpoint. (The original consolidated endpoint is still supported for backwards compatibility.) ### Pull Request Checklist <!-- Please read https://element-hq.github.io/synapse/latest/development/contributing_guide.html before submitting your pull request --> * [x] Pull request is based on the develop branch * [x] Pull request includes a [changelog file](https://element-hq.github.io/synapse/latest/development/contributing_guide.html#changelog). The entry should: - Be a short description of your change which makes sense to users. "Fixed a bug that prevented receiving messages from other servers." instead of "Moved X method from `EventStore` to `EventWorkerStore`.". - Use markdown where necessary, mostly for `code blocks`. - End with either a period (.) or an exclamation mark (!). - Start with a capital letter. - Feel free to credit yourself, by adding a sentence "Contributed by @github_username." or "Contributed by [Your Name]." to the end of the entry. * [x] [Code style](https://element-hq.github.io/synapse/latest/code_style.html) is correct (run the [linters](https://element-hq.github.io/synapse/latest/development/contributing_guide.html#run-the-linters)) --------- Co-authored-by: Half-Shot <will@half-shot.uk>
This commit is contained in:
parent
4494cc0694
commit
9e23cded8f
1
changelog.d/19152.feature
Normal file
1
changelog.d/19152.feature
Normal file
@ -0,0 +1 @@
|
||||
Remove authentication from `POST /_matrix/client/v1/delayed_events`, and allow calling this endpoint with the update action to take (`send`/`cancel`/`restart`) in the request path instead of the body.
|
||||
@ -58,6 +58,7 @@ from synapse.storage.database import DatabasePool, LoggingTransaction, make_conn
|
||||
from synapse.storage.databases.main import FilteringWorkerStore
|
||||
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
|
||||
from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateStore
|
||||
from synapse.storage.databases.main.delayed_events import DelayedEventsStore
|
||||
from synapse.storage.databases.main.deviceinbox import DeviceInboxBackgroundUpdateStore
|
||||
from synapse.storage.databases.main.devices import DeviceBackgroundUpdateStore
|
||||
from synapse.storage.databases.main.e2e_room_keys import EndToEndRoomKeyBackgroundStore
|
||||
@ -273,6 +274,7 @@ class Store(
|
||||
RelationsWorkerStore,
|
||||
EventFederationWorkerStore,
|
||||
SlidingSyncStore,
|
||||
DelayedEventsStore,
|
||||
):
|
||||
def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]:
|
||||
return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs)
|
||||
|
||||
@ -21,6 +21,7 @@ from synapse.api.constants import EventTypes
|
||||
from synapse.api.errors import ShadowBanError, SynapseError
|
||||
from synapse.api.ratelimiting import Ratelimiter
|
||||
from synapse.config.workers import MAIN_PROCESS_INSTANCE_NAME
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.logging.opentracing import set_tag
|
||||
from synapse.metrics import SERVER_NAME_LABEL, event_processing_positions
|
||||
@ -29,11 +30,9 @@ from synapse.replication.http.delayed_events import (
|
||||
)
|
||||
from synapse.storage.databases.main.delayed_events import (
|
||||
DelayedEventDetails,
|
||||
DelayID,
|
||||
EventType,
|
||||
StateKey,
|
||||
Timestamp,
|
||||
UserLocalpart,
|
||||
)
|
||||
from synapse.storage.databases.main.state_deltas import StateDelta
|
||||
from synapse.types import (
|
||||
@ -399,96 +398,63 @@ class DelayedEventsHandler:
|
||||
if self._next_send_ts_changed(next_send_ts):
|
||||
self._schedule_next_at(next_send_ts)
|
||||
|
||||
async def cancel(self, requester: Requester, delay_id: str) -> None:
|
||||
async def cancel(self, request: SynapseRequest, delay_id: str) -> None:
|
||||
"""
|
||||
Cancels the scheduled delivery of the matching delayed event.
|
||||
|
||||
Args:
|
||||
requester: The owner of the delayed event to act on.
|
||||
delay_id: The ID of the delayed event to act on.
|
||||
|
||||
Raises:
|
||||
NotFoundError: if no matching delayed event could be found.
|
||||
"""
|
||||
assert self._is_master
|
||||
await self._delayed_event_mgmt_ratelimiter.ratelimit(
|
||||
requester,
|
||||
(requester.user.to_string(), requester.device_id),
|
||||
None, request.getClientAddress().host
|
||||
)
|
||||
await make_deferred_yieldable(self._initialized_from_db)
|
||||
|
||||
next_send_ts = await self._store.cancel_delayed_event(
|
||||
delay_id=delay_id,
|
||||
user_localpart=requester.user.localpart,
|
||||
)
|
||||
next_send_ts = await self._store.cancel_delayed_event(delay_id)
|
||||
|
||||
if self._next_send_ts_changed(next_send_ts):
|
||||
self._schedule_next_at_or_none(next_send_ts)
|
||||
|
||||
async def restart(self, requester: Requester, delay_id: str) -> None:
|
||||
async def restart(self, request: SynapseRequest, delay_id: str) -> None:
|
||||
"""
|
||||
Restarts the scheduled delivery of the matching delayed event.
|
||||
|
||||
Args:
|
||||
requester: The owner of the delayed event to act on.
|
||||
delay_id: The ID of the delayed event to act on.
|
||||
|
||||
Raises:
|
||||
NotFoundError: if no matching delayed event could be found.
|
||||
"""
|
||||
assert self._is_master
|
||||
await self._delayed_event_mgmt_ratelimiter.ratelimit(
|
||||
requester,
|
||||
(requester.user.to_string(), requester.device_id),
|
||||
None, request.getClientAddress().host
|
||||
)
|
||||
await make_deferred_yieldable(self._initialized_from_db)
|
||||
|
||||
next_send_ts = await self._store.restart_delayed_event(
|
||||
delay_id=delay_id,
|
||||
user_localpart=requester.user.localpart,
|
||||
current_ts=self._get_current_ts(),
|
||||
delay_id, self._get_current_ts()
|
||||
)
|
||||
|
||||
if self._next_send_ts_changed(next_send_ts):
|
||||
self._schedule_next_at(next_send_ts)
|
||||
|
||||
async def send(self, requester: Requester, delay_id: str) -> None:
|
||||
async def send(self, request: SynapseRequest, delay_id: str) -> None:
|
||||
"""
|
||||
Immediately sends the matching delayed event, instead of waiting for its scheduled delivery.
|
||||
|
||||
Args:
|
||||
requester: The owner of the delayed event to act on.
|
||||
delay_id: The ID of the delayed event to act on.
|
||||
|
||||
Raises:
|
||||
NotFoundError: if no matching delayed event could be found.
|
||||
"""
|
||||
assert self._is_master
|
||||
# Use standard request limiter for sending delayed events on-demand,
|
||||
# as an on-demand send is similar to sending a regular event.
|
||||
await self._request_ratelimiter.ratelimit(requester)
|
||||
await self._delayed_event_mgmt_ratelimiter.ratelimit(
|
||||
None, request.getClientAddress().host
|
||||
)
|
||||
await make_deferred_yieldable(self._initialized_from_db)
|
||||
|
||||
event, next_send_ts = await self._store.process_target_delayed_event(
|
||||
delay_id=delay_id,
|
||||
user_localpart=requester.user.localpart,
|
||||
)
|
||||
event, next_send_ts = await self._store.process_target_delayed_event(delay_id)
|
||||
|
||||
if self._next_send_ts_changed(next_send_ts):
|
||||
self._schedule_next_at_or_none(next_send_ts)
|
||||
|
||||
await self._send_event(
|
||||
DelayedEventDetails(
|
||||
delay_id=DelayID(delay_id),
|
||||
user_localpart=UserLocalpart(requester.user.localpart),
|
||||
room_id=event.room_id,
|
||||
type=event.type,
|
||||
state_key=event.state_key,
|
||||
origin_server_ts=event.origin_server_ts,
|
||||
content=event.content,
|
||||
device_id=event.device_id,
|
||||
)
|
||||
)
|
||||
await self._send_event(event)
|
||||
|
||||
async def _send_on_timeout(self) -> None:
|
||||
self._next_delayed_event_call = None
|
||||
@ -611,9 +577,7 @@ class DelayedEventsHandler:
|
||||
finally:
|
||||
# TODO: If this is a temporary error, retry. Otherwise, consider notifying clients of the failure
|
||||
try:
|
||||
await self._store.delete_processed_delayed_event(
|
||||
event.delay_id, event.user_localpart
|
||||
)
|
||||
await self._store.delete_processed_delayed_event(event.delay_id)
|
||||
except Exception:
|
||||
logger.exception("Failed to delete processed delayed event")
|
||||
|
||||
|
||||
@ -47,14 +47,11 @@ class UpdateDelayedEventServlet(RestServlet):
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.delayed_events_handler = hs.get_delayed_events_handler()
|
||||
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, delay_id: str
|
||||
) -> tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
|
||||
body = parse_json_object_from_request(request)
|
||||
try:
|
||||
action = str(body["action"])
|
||||
@ -75,11 +72,65 @@ class UpdateDelayedEventServlet(RestServlet):
|
||||
)
|
||||
|
||||
if enum_action == _UpdateDelayedEventAction.CANCEL:
|
||||
await self.delayed_events_handler.cancel(requester, delay_id)
|
||||
await self.delayed_events_handler.cancel(request, delay_id)
|
||||
elif enum_action == _UpdateDelayedEventAction.RESTART:
|
||||
await self.delayed_events_handler.restart(requester, delay_id)
|
||||
await self.delayed_events_handler.restart(request, delay_id)
|
||||
elif enum_action == _UpdateDelayedEventAction.SEND:
|
||||
await self.delayed_events_handler.send(requester, delay_id)
|
||||
await self.delayed_events_handler.send(request, delay_id)
|
||||
return 200, {}
|
||||
|
||||
|
||||
class CancelDelayedEventServlet(RestServlet):
|
||||
PATTERNS = client_patterns(
|
||||
r"/org\.matrix\.msc4140/delayed_events/(?P<delay_id>[^/]+)/cancel$",
|
||||
releases=(),
|
||||
)
|
||||
CATEGORY = "Delayed event management requests"
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.delayed_events_handler = hs.get_delayed_events_handler()
|
||||
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, delay_id: str
|
||||
) -> tuple[int, JsonDict]:
|
||||
await self.delayed_events_handler.cancel(request, delay_id)
|
||||
return 200, {}
|
||||
|
||||
|
||||
class RestartDelayedEventServlet(RestServlet):
|
||||
PATTERNS = client_patterns(
|
||||
r"/org\.matrix\.msc4140/delayed_events/(?P<delay_id>[^/]+)/restart$",
|
||||
releases=(),
|
||||
)
|
||||
CATEGORY = "Delayed event management requests"
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.delayed_events_handler = hs.get_delayed_events_handler()
|
||||
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, delay_id: str
|
||||
) -> tuple[int, JsonDict]:
|
||||
await self.delayed_events_handler.restart(request, delay_id)
|
||||
return 200, {}
|
||||
|
||||
|
||||
class SendDelayedEventServlet(RestServlet):
|
||||
PATTERNS = client_patterns(
|
||||
r"/org\.matrix\.msc4140/delayed_events/(?P<delay_id>[^/]+)/send$",
|
||||
releases=(),
|
||||
)
|
||||
CATEGORY = "Delayed event management requests"
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.delayed_events_handler = hs.get_delayed_events_handler()
|
||||
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, delay_id: str
|
||||
) -> tuple[int, JsonDict]:
|
||||
await self.delayed_events_handler.send(request, delay_id)
|
||||
return 200, {}
|
||||
|
||||
|
||||
@ -108,4 +159,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
# The following can't currently be instantiated on workers.
|
||||
if hs.config.worker.worker_app is None:
|
||||
UpdateDelayedEventServlet(hs).register(http_server)
|
||||
CancelDelayedEventServlet(hs).register(http_server)
|
||||
RestartDelayedEventServlet(hs).register(http_server)
|
||||
SendDelayedEventServlet(hs).register(http_server)
|
||||
DelayedEventsServlet(hs).register(http_server)
|
||||
|
||||
@ -13,18 +13,26 @@
|
||||
#
|
||||
|
||||
import logging
|
||||
from typing import NewType
|
||||
from typing import TYPE_CHECKING, NewType
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.api.errors import NotFoundError
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
from synapse.storage.database import LoggingTransaction, StoreError
|
||||
from synapse.storage.database import (
|
||||
DatabasePool,
|
||||
LoggingDatabaseConnection,
|
||||
LoggingTransaction,
|
||||
StoreError,
|
||||
)
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.types import JsonDict, RoomID
|
||||
from synapse.util import stringutils
|
||||
from synapse.util.json import json_encoder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -55,6 +63,27 @@ class DelayedEventDetails(EventDetails):
|
||||
|
||||
|
||||
class DelayedEventsStore(SQLBaseStore):
|
||||
def __init__(
|
||||
self,
|
||||
database: DatabasePool,
|
||||
db_conn: LoggingDatabaseConnection,
|
||||
hs: "HomeServer",
|
||||
):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
# Set delayed events to be uniquely identifiable by their delay_id.
|
||||
# In practice, delay_ids are already unique because they are generated
|
||||
# from cryptographically strong random strings.
|
||||
# Therefore, adding this constraint is not expected to ever fail,
|
||||
# despite the current pkey technically allowing non-unique delay_ids.
|
||||
self.db_pool.updates.register_background_index_update(
|
||||
update_name="delayed_events_idx",
|
||||
index_name="delayed_events_idx",
|
||||
table="delayed_events",
|
||||
columns=("delay_id",),
|
||||
unique=True,
|
||||
)
|
||||
|
||||
async def get_delayed_events_stream_pos(self) -> int:
|
||||
"""
|
||||
Gets the stream position of the background process to watch for state events
|
||||
@ -134,9 +163,7 @@ class DelayedEventsStore(SQLBaseStore):
|
||||
|
||||
async def restart_delayed_event(
|
||||
self,
|
||||
*,
|
||||
delay_id: str,
|
||||
user_localpart: str,
|
||||
current_ts: Timestamp,
|
||||
) -> Timestamp:
|
||||
"""
|
||||
@ -145,7 +172,6 @@ class DelayedEventsStore(SQLBaseStore):
|
||||
|
||||
Args:
|
||||
delay_id: The ID of the delayed event to restart.
|
||||
user_localpart: The localpart of the delayed event's owner.
|
||||
current_ts: The current time, which will be used to calculate the new send time.
|
||||
|
||||
Returns: The send time of the next delayed event to be sent,
|
||||
@ -163,13 +189,11 @@ class DelayedEventsStore(SQLBaseStore):
|
||||
"""
|
||||
UPDATE delayed_events
|
||||
SET send_ts = ? + delay
|
||||
WHERE delay_id = ? AND user_localpart = ?
|
||||
AND NOT is_processed
|
||||
WHERE delay_id = ? AND NOT is_processed
|
||||
""",
|
||||
(
|
||||
current_ts,
|
||||
delay_id,
|
||||
user_localpart,
|
||||
),
|
||||
)
|
||||
if txn.rowcount == 0:
|
||||
@ -319,21 +343,15 @@ class DelayedEventsStore(SQLBaseStore):
|
||||
|
||||
async def process_target_delayed_event(
|
||||
self,
|
||||
*,
|
||||
delay_id: str,
|
||||
user_localpart: str,
|
||||
) -> tuple[
|
||||
EventDetails,
|
||||
DelayedEventDetails,
|
||||
Timestamp | None,
|
||||
]:
|
||||
"""
|
||||
Marks for processing the matching delayed event, regardless of its timeout time,
|
||||
as long as it has not already been marked as such.
|
||||
|
||||
Args:
|
||||
delay_id: The ID of the delayed event to restart.
|
||||
user_localpart: The localpart of the delayed event's owner.
|
||||
|
||||
Returns: The details of the matching delayed event,
|
||||
and the send time of the next delayed event to be sent, if any.
|
||||
|
||||
@ -344,39 +362,38 @@ class DelayedEventsStore(SQLBaseStore):
|
||||
def process_target_delayed_event_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> tuple[
|
||||
EventDetails,
|
||||
DelayedEventDetails,
|
||||
Timestamp | None,
|
||||
]:
|
||||
txn.execute(
|
||||
"""
|
||||
UPDATE delayed_events
|
||||
SET is_processed = TRUE
|
||||
WHERE delay_id = ? AND user_localpart = ?
|
||||
AND NOT is_processed
|
||||
WHERE delay_id = ? AND NOT is_processed
|
||||
RETURNING
|
||||
room_id,
|
||||
event_type,
|
||||
state_key,
|
||||
origin_server_ts,
|
||||
content,
|
||||
device_id
|
||||
device_id,
|
||||
user_localpart
|
||||
""",
|
||||
(
|
||||
delay_id,
|
||||
user_localpart,
|
||||
),
|
||||
(delay_id,),
|
||||
)
|
||||
row = txn.fetchone()
|
||||
if row is None:
|
||||
raise NotFoundError("Delayed event not found")
|
||||
|
||||
event = EventDetails(
|
||||
event = DelayedEventDetails(
|
||||
RoomID.from_string(row[0]),
|
||||
EventType(row[1]),
|
||||
StateKey(row[2]) if row[2] is not None else None,
|
||||
Timestamp(row[3]) if row[3] is not None else None,
|
||||
db_to_json(row[4]),
|
||||
DeviceID(row[5]) if row[5] is not None else None,
|
||||
DelayID(delay_id),
|
||||
UserLocalpart(row[6]),
|
||||
)
|
||||
|
||||
return event, self._get_next_delayed_event_send_ts_txn(txn)
|
||||
@ -385,19 +402,10 @@ class DelayedEventsStore(SQLBaseStore):
|
||||
"process_target_delayed_event", process_target_delayed_event_txn
|
||||
)
|
||||
|
||||
async def cancel_delayed_event(
|
||||
self,
|
||||
*,
|
||||
delay_id: str,
|
||||
user_localpart: str,
|
||||
) -> Timestamp | None:
|
||||
async def cancel_delayed_event(self, delay_id: str) -> Timestamp | None:
|
||||
"""
|
||||
Cancels the matching delayed event, i.e. remove it as long as it hasn't been processed.
|
||||
|
||||
Args:
|
||||
delay_id: The ID of the delayed event to restart.
|
||||
user_localpart: The localpart of the delayed event's owner.
|
||||
|
||||
Returns: The send time of the next delayed event to be sent, if any.
|
||||
|
||||
Raises:
|
||||
@ -413,7 +421,6 @@ class DelayedEventsStore(SQLBaseStore):
|
||||
table="delayed_events",
|
||||
keyvalues={
|
||||
"delay_id": delay_id,
|
||||
"user_localpart": user_localpart,
|
||||
"is_processed": False,
|
||||
},
|
||||
)
|
||||
@ -473,11 +480,7 @@ class DelayedEventsStore(SQLBaseStore):
|
||||
"cancel_delayed_state_events", cancel_delayed_state_events_txn
|
||||
)
|
||||
|
||||
async def delete_processed_delayed_event(
|
||||
self,
|
||||
delay_id: DelayID,
|
||||
user_localpart: UserLocalpart,
|
||||
) -> None:
|
||||
async def delete_processed_delayed_event(self, delay_id: DelayID) -> None:
|
||||
"""
|
||||
Delete the matching delayed event, as long as it has been marked as processed.
|
||||
|
||||
@ -488,7 +491,6 @@ class DelayedEventsStore(SQLBaseStore):
|
||||
table="delayed_events",
|
||||
keyvalues={
|
||||
"delay_id": delay_id,
|
||||
"user_localpart": user_localpart,
|
||||
"is_processed": True,
|
||||
},
|
||||
desc="delete_processed_delayed_event",
|
||||
@ -554,7 +556,7 @@ def _generate_delay_id() -> DelayID:
|
||||
|
||||
# We use the following format for delay IDs:
|
||||
# syd_<random string>
|
||||
# They are scoped to user localparts, so it is possible for
|
||||
# the same ID to exist for multiple users.
|
||||
# They are not scoped to user localparts, but the random string
|
||||
# is expected to be sufficiently random to be globally unique.
|
||||
|
||||
return DelayID(f"syd_{stringutils.random_string(20)}")
|
||||
|
||||
@ -19,7 +19,7 @@
|
||||
#
|
||||
#
|
||||
|
||||
SCHEMA_VERSION = 92 # remember to update the list below when updating
|
||||
SCHEMA_VERSION = 93 # remember to update the list below when updating
|
||||
"""Represents the expectations made by the codebase about the database schema
|
||||
|
||||
This should be incremented whenever the codebase changes its requirements on the
|
||||
@ -168,11 +168,15 @@ Changes in SCHEMA_VERSION = 91
|
||||
|
||||
Changes in SCHEMA_VERSION = 92
|
||||
- Cleaned up a trigger that was added in #18260 and then reverted.
|
||||
|
||||
Changes in SCHEMA_VERSION = 93
|
||||
- MSC4140: Set delayed events to be uniquely identifiable by their delay ID.
|
||||
"""
|
||||
|
||||
|
||||
SCHEMA_COMPAT_VERSION = (
|
||||
# Transitive links are no longer written to `event_auth_chain_links`
|
||||
# TODO: On the next compat bump, update the primary key of `delayed_events`
|
||||
84
|
||||
)
|
||||
"""Limit on how far the synapse codebase can be rolled back without breaking db compat
|
||||
|
||||
@ -0,0 +1,15 @@
|
||||
--
|
||||
-- This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
--
|
||||
-- Copyright (C) 2025 Element Creations, Ltd
|
||||
--
|
||||
-- This program is free software: you can redistribute it and/or modify
|
||||
-- it under the terms of the GNU Affero General Public License as
|
||||
-- published by the Free Software Foundation, either version 3 of the
|
||||
-- License, or (at your option) any later version.
|
||||
--
|
||||
-- See the GNU Affero General Public License for more details:
|
||||
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
|
||||
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
|
||||
(9301, 'delayed_events_idx', '{}');
|
||||
@ -28,6 +28,7 @@ from synapse.types import JsonDict
|
||||
from synapse.util.clock import Clock
|
||||
|
||||
from tests import unittest
|
||||
from tests.server import FakeChannel
|
||||
from tests.unittest import HomeserverTestCase
|
||||
|
||||
PATH_PREFIX = "/_matrix/client/unstable/org.matrix.msc4140/delayed_events"
|
||||
@ -127,6 +128,10 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(setter_expected, content.get(setter_key), content)
|
||||
|
||||
def test_get_delayed_events_auth(self) -> None:
|
||||
channel = self.make_request("GET", PATH_PREFIX)
|
||||
self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, channel.result)
|
||||
|
||||
@unittest.override_config(
|
||||
{"rc_delayed_event_mgmt": {"per_second": 0.5, "burst_count": 1}}
|
||||
)
|
||||
@ -154,7 +159,6 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
f"{PATH_PREFIX}/",
|
||||
access_token=self.user1_access_token,
|
||||
)
|
||||
self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, channel.result)
|
||||
|
||||
@ -162,7 +166,6 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
f"{PATH_PREFIX}/abc",
|
||||
access_token=self.user1_access_token,
|
||||
)
|
||||
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
|
||||
self.assertEqual(
|
||||
@ -175,7 +178,6 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
||||
"POST",
|
||||
f"{PATH_PREFIX}/abc",
|
||||
{},
|
||||
self.user1_access_token,
|
||||
)
|
||||
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
|
||||
self.assertEqual(
|
||||
@ -188,7 +190,6 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
||||
"POST",
|
||||
f"{PATH_PREFIX}/abc",
|
||||
{"action": "oops"},
|
||||
self.user1_access_token,
|
||||
)
|
||||
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
|
||||
self.assertEqual(
|
||||
@ -196,17 +197,21 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
||||
channel.json_body["errcode"],
|
||||
)
|
||||
|
||||
@parameterized.expand(["cancel", "restart", "send"])
|
||||
def test_update_delayed_event_without_match(self, action: str) -> None:
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
f"{PATH_PREFIX}/abc",
|
||||
{"action": action},
|
||||
self.user1_access_token,
|
||||
@parameterized.expand(
|
||||
(
|
||||
(action, action_in_path)
|
||||
for action in ("cancel", "restart", "send")
|
||||
for action_in_path in (True, False)
|
||||
)
|
||||
)
|
||||
def test_update_delayed_event_without_match(
|
||||
self, action: str, action_in_path: bool
|
||||
) -> None:
|
||||
channel = self._update_delayed_event("abc", action, action_in_path)
|
||||
self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, channel.result)
|
||||
|
||||
def test_cancel_delayed_state_event(self) -> None:
|
||||
@parameterized.expand((True, False))
|
||||
def test_cancel_delayed_state_event(self, action_in_path: bool) -> None:
|
||||
state_key = "to_never_send"
|
||||
|
||||
setter_key = "setter"
|
||||
@ -221,7 +226,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||
delay_id = channel.json_body.get("delay_id")
|
||||
self.assertIsNotNone(delay_id)
|
||||
assert delay_id is not None
|
||||
|
||||
self.reactor.advance(1)
|
||||
events = self._get_delayed_events()
|
||||
@ -236,12 +241,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
||||
expect_code=HTTPStatus.NOT_FOUND,
|
||||
)
|
||||
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
f"{PATH_PREFIX}/{delay_id}",
|
||||
{"action": "cancel"},
|
||||
self.user1_access_token,
|
||||
)
|
||||
channel = self._update_delayed_event(delay_id, "cancel", action_in_path)
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||
self.assertListEqual([], self._get_delayed_events())
|
||||
|
||||
@ -254,10 +254,11 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
||||
expect_code=HTTPStatus.NOT_FOUND,
|
||||
)
|
||||
|
||||
@parameterized.expand((True, False))
|
||||
@unittest.override_config(
|
||||
{"rc_delayed_event_mgmt": {"per_second": 0.5, "burst_count": 1}}
|
||||
)
|
||||
def test_cancel_delayed_event_ratelimit(self) -> None:
|
||||
def test_cancel_delayed_event_ratelimit(self, action_in_path: bool) -> None:
|
||||
delay_ids = []
|
||||
for _ in range(2):
|
||||
channel = self.make_request(
|
||||
@ -268,38 +269,17 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||
delay_id = channel.json_body.get("delay_id")
|
||||
self.assertIsNotNone(delay_id)
|
||||
assert delay_id is not None
|
||||
delay_ids.append(delay_id)
|
||||
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
f"{PATH_PREFIX}/{delay_ids.pop(0)}",
|
||||
{"action": "cancel"},
|
||||
self.user1_access_token,
|
||||
)
|
||||
channel = self._update_delayed_event(delay_ids.pop(0), "cancel", action_in_path)
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||
|
||||
args = (
|
||||
"POST",
|
||||
f"{PATH_PREFIX}/{delay_ids.pop(0)}",
|
||||
{"action": "cancel"},
|
||||
self.user1_access_token,
|
||||
)
|
||||
channel = self.make_request(*args)
|
||||
channel = self._update_delayed_event(delay_ids.pop(0), "cancel", action_in_path)
|
||||
self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result)
|
||||
|
||||
# Add the current user to the ratelimit overrides, allowing them no ratelimiting.
|
||||
self.get_success(
|
||||
self.hs.get_datastores().main.set_ratelimit_for_user(
|
||||
self.user1_user_id, 0, 0
|
||||
)
|
||||
)
|
||||
|
||||
# Test that the request isn't ratelimited anymore.
|
||||
channel = self.make_request(*args)
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||
|
||||
def test_send_delayed_state_event(self) -> None:
|
||||
@parameterized.expand((True, False))
|
||||
def test_send_delayed_state_event(self, action_in_path: bool) -> None:
|
||||
state_key = "to_send_on_request"
|
||||
|
||||
setter_key = "setter"
|
||||
@ -314,7 +294,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||
delay_id = channel.json_body.get("delay_id")
|
||||
self.assertIsNotNone(delay_id)
|
||||
assert delay_id is not None
|
||||
|
||||
self.reactor.advance(1)
|
||||
events = self._get_delayed_events()
|
||||
@ -329,12 +309,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
||||
expect_code=HTTPStatus.NOT_FOUND,
|
||||
)
|
||||
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
f"{PATH_PREFIX}/{delay_id}",
|
||||
{"action": "send"},
|
||||
self.user1_access_token,
|
||||
)
|
||||
channel = self._update_delayed_event(delay_id, "send", action_in_path)
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||
self.assertListEqual([], self._get_delayed_events())
|
||||
content = self.helper.get_state(
|
||||
@ -345,8 +320,9 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(setter_expected, content.get(setter_key), content)
|
||||
|
||||
@unittest.override_config({"rc_message": {"per_second": 3.5, "burst_count": 4}})
|
||||
def test_send_delayed_event_ratelimit(self) -> None:
|
||||
@parameterized.expand((True, False))
|
||||
@unittest.override_config({"rc_message": {"per_second": 2.5, "burst_count": 3}})
|
||||
def test_send_delayed_event_ratelimit(self, action_in_path: bool) -> None:
|
||||
delay_ids = []
|
||||
for _ in range(2):
|
||||
channel = self.make_request(
|
||||
@ -357,38 +333,17 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||
delay_id = channel.json_body.get("delay_id")
|
||||
self.assertIsNotNone(delay_id)
|
||||
assert delay_id is not None
|
||||
delay_ids.append(delay_id)
|
||||
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
f"{PATH_PREFIX}/{delay_ids.pop(0)}",
|
||||
{"action": "send"},
|
||||
self.user1_access_token,
|
||||
)
|
||||
channel = self._update_delayed_event(delay_ids.pop(0), "send", action_in_path)
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||
|
||||
args = (
|
||||
"POST",
|
||||
f"{PATH_PREFIX}/{delay_ids.pop(0)}",
|
||||
{"action": "send"},
|
||||
self.user1_access_token,
|
||||
)
|
||||
channel = self.make_request(*args)
|
||||
channel = self._update_delayed_event(delay_ids.pop(0), "send", action_in_path)
|
||||
self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result)
|
||||
|
||||
# Add the current user to the ratelimit overrides, allowing them no ratelimiting.
|
||||
self.get_success(
|
||||
self.hs.get_datastores().main.set_ratelimit_for_user(
|
||||
self.user1_user_id, 0, 0
|
||||
)
|
||||
)
|
||||
|
||||
# Test that the request isn't ratelimited anymore.
|
||||
channel = self.make_request(*args)
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||
|
||||
def test_restart_delayed_state_event(self) -> None:
|
||||
@parameterized.expand((True, False))
|
||||
def test_restart_delayed_state_event(self, action_in_path: bool) -> None:
|
||||
state_key = "to_send_on_restarted_timeout"
|
||||
|
||||
setter_key = "setter"
|
||||
@ -403,7 +358,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||
delay_id = channel.json_body.get("delay_id")
|
||||
self.assertIsNotNone(delay_id)
|
||||
assert delay_id is not None
|
||||
|
||||
self.reactor.advance(1)
|
||||
events = self._get_delayed_events()
|
||||
@ -418,12 +373,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
||||
expect_code=HTTPStatus.NOT_FOUND,
|
||||
)
|
||||
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
f"{PATH_PREFIX}/{delay_id}",
|
||||
{"action": "restart"},
|
||||
self.user1_access_token,
|
||||
)
|
||||
channel = self._update_delayed_event(delay_id, "restart", action_in_path)
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||
|
||||
self.reactor.advance(1)
|
||||
@ -449,10 +399,11 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(setter_expected, content.get(setter_key), content)
|
||||
|
||||
@parameterized.expand((True, False))
|
||||
@unittest.override_config(
|
||||
{"rc_delayed_event_mgmt": {"per_second": 0.5, "burst_count": 1}}
|
||||
)
|
||||
def test_restart_delayed_event_ratelimit(self) -> None:
|
||||
def test_restart_delayed_event_ratelimit(self, action_in_path: bool) -> None:
|
||||
delay_ids = []
|
||||
for _ in range(2):
|
||||
channel = self.make_request(
|
||||
@ -463,37 +414,19 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||
delay_id = channel.json_body.get("delay_id")
|
||||
self.assertIsNotNone(delay_id)
|
||||
assert delay_id is not None
|
||||
delay_ids.append(delay_id)
|
||||
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
f"{PATH_PREFIX}/{delay_ids.pop(0)}",
|
||||
{"action": "restart"},
|
||||
self.user1_access_token,
|
||||
channel = self._update_delayed_event(
|
||||
delay_ids.pop(0), "restart", action_in_path
|
||||
)
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||
|
||||
args = (
|
||||
"POST",
|
||||
f"{PATH_PREFIX}/{delay_ids.pop(0)}",
|
||||
{"action": "restart"},
|
||||
self.user1_access_token,
|
||||
channel = self._update_delayed_event(
|
||||
delay_ids.pop(0), "restart", action_in_path
|
||||
)
|
||||
channel = self.make_request(*args)
|
||||
self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result)
|
||||
|
||||
# Add the current user to the ratelimit overrides, allowing them no ratelimiting.
|
||||
self.get_success(
|
||||
self.hs.get_datastores().main.set_ratelimit_for_user(
|
||||
self.user1_user_id, 0, 0
|
||||
)
|
||||
)
|
||||
|
||||
# Test that the request isn't ratelimited anymore.
|
||||
channel = self.make_request(*args)
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||
|
||||
def test_delayed_state_is_not_cancelled_by_new_state_from_same_user(
|
||||
self,
|
||||
) -> None:
|
||||
@ -598,6 +531,17 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
||||
|
||||
return content
|
||||
|
||||
def _update_delayed_event(
|
||||
self, delay_id: str, action: str, action_in_path: bool
|
||||
) -> FakeChannel:
|
||||
path = f"{PATH_PREFIX}/{delay_id}"
|
||||
body = {}
|
||||
if action_in_path:
|
||||
path += f"/{action}"
|
||||
else:
|
||||
body["action"] = action
|
||||
return self.make_request("POST", path, body)
|
||||
|
||||
|
||||
def _get_path_for_delayed_state(
|
||||
room_id: str, event_type: str, state_key: str, delay_ms: int
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user