mirror of
https://github.com/element-hq/synapse.git
synced 2025-12-10 00:02:09 -05:00
MSC4140: Remove auth from delayed event management endpoints (#19152)
As per recent proposals in MSC4140, remove authentication for restarting/cancelling/sending a delayed event, and give each of those actions its own endpoint. (The original consolidated endpoint is still supported for backwards compatibility.) ### Pull Request Checklist <!-- Please read https://element-hq.github.io/synapse/latest/development/contributing_guide.html before submitting your pull request --> * [x] Pull request is based on the develop branch * [x] Pull request includes a [changelog file](https://element-hq.github.io/synapse/latest/development/contributing_guide.html#changelog). The entry should: - Be a short description of your change which makes sense to users. "Fixed a bug that prevented receiving messages from other servers." instead of "Moved X method from `EventStore` to `EventWorkerStore`.". - Use markdown where necessary, mostly for `code blocks`. - End with either a period (.) or an exclamation mark (!). - Start with a capital letter. - Feel free to credit yourself, by adding a sentence "Contributed by @github_username." or "Contributed by [Your Name]." to the end of the entry. * [x] [Code style](https://element-hq.github.io/synapse/latest/code_style.html) is correct (run the [linters](https://element-hq.github.io/synapse/latest/development/contributing_guide.html#run-the-linters)) --------- Co-authored-by: Half-Shot <will@half-shot.uk>
This commit is contained in:
parent
4494cc0694
commit
9e23cded8f
1
changelog.d/19152.feature
Normal file
1
changelog.d/19152.feature
Normal file
@ -0,0 +1 @@
|
|||||||
|
Remove authentication from `POST /_matrix/client/v1/delayed_events`, and allow calling this endpoint with the update action to take (`send`/`cancel`/`restart`) in the request path instead of the body.
|
||||||
@ -58,6 +58,7 @@ from synapse.storage.database import DatabasePool, LoggingTransaction, make_conn
|
|||||||
from synapse.storage.databases.main import FilteringWorkerStore
|
from synapse.storage.databases.main import FilteringWorkerStore
|
||||||
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
|
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
|
||||||
from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateStore
|
from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateStore
|
||||||
|
from synapse.storage.databases.main.delayed_events import DelayedEventsStore
|
||||||
from synapse.storage.databases.main.deviceinbox import DeviceInboxBackgroundUpdateStore
|
from synapse.storage.databases.main.deviceinbox import DeviceInboxBackgroundUpdateStore
|
||||||
from synapse.storage.databases.main.devices import DeviceBackgroundUpdateStore
|
from synapse.storage.databases.main.devices import DeviceBackgroundUpdateStore
|
||||||
from synapse.storage.databases.main.e2e_room_keys import EndToEndRoomKeyBackgroundStore
|
from synapse.storage.databases.main.e2e_room_keys import EndToEndRoomKeyBackgroundStore
|
||||||
@ -273,6 +274,7 @@ class Store(
|
|||||||
RelationsWorkerStore,
|
RelationsWorkerStore,
|
||||||
EventFederationWorkerStore,
|
EventFederationWorkerStore,
|
||||||
SlidingSyncStore,
|
SlidingSyncStore,
|
||||||
|
DelayedEventsStore,
|
||||||
):
|
):
|
||||||
def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]:
|
def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]:
|
||||||
return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs)
|
return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs)
|
||||||
|
|||||||
@ -21,6 +21,7 @@ from synapse.api.constants import EventTypes
|
|||||||
from synapse.api.errors import ShadowBanError, SynapseError
|
from synapse.api.errors import ShadowBanError, SynapseError
|
||||||
from synapse.api.ratelimiting import Ratelimiter
|
from synapse.api.ratelimiting import Ratelimiter
|
||||||
from synapse.config.workers import MAIN_PROCESS_INSTANCE_NAME
|
from synapse.config.workers import MAIN_PROCESS_INSTANCE_NAME
|
||||||
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.logging.context import make_deferred_yieldable
|
from synapse.logging.context import make_deferred_yieldable
|
||||||
from synapse.logging.opentracing import set_tag
|
from synapse.logging.opentracing import set_tag
|
||||||
from synapse.metrics import SERVER_NAME_LABEL, event_processing_positions
|
from synapse.metrics import SERVER_NAME_LABEL, event_processing_positions
|
||||||
@ -29,11 +30,9 @@ from synapse.replication.http.delayed_events import (
|
|||||||
)
|
)
|
||||||
from synapse.storage.databases.main.delayed_events import (
|
from synapse.storage.databases.main.delayed_events import (
|
||||||
DelayedEventDetails,
|
DelayedEventDetails,
|
||||||
DelayID,
|
|
||||||
EventType,
|
EventType,
|
||||||
StateKey,
|
StateKey,
|
||||||
Timestamp,
|
Timestamp,
|
||||||
UserLocalpart,
|
|
||||||
)
|
)
|
||||||
from synapse.storage.databases.main.state_deltas import StateDelta
|
from synapse.storage.databases.main.state_deltas import StateDelta
|
||||||
from synapse.types import (
|
from synapse.types import (
|
||||||
@ -399,96 +398,63 @@ class DelayedEventsHandler:
|
|||||||
if self._next_send_ts_changed(next_send_ts):
|
if self._next_send_ts_changed(next_send_ts):
|
||||||
self._schedule_next_at(next_send_ts)
|
self._schedule_next_at(next_send_ts)
|
||||||
|
|
||||||
async def cancel(self, requester: Requester, delay_id: str) -> None:
|
async def cancel(self, request: SynapseRequest, delay_id: str) -> None:
|
||||||
"""
|
"""
|
||||||
Cancels the scheduled delivery of the matching delayed event.
|
Cancels the scheduled delivery of the matching delayed event.
|
||||||
|
|
||||||
Args:
|
|
||||||
requester: The owner of the delayed event to act on.
|
|
||||||
delay_id: The ID of the delayed event to act on.
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
NotFoundError: if no matching delayed event could be found.
|
NotFoundError: if no matching delayed event could be found.
|
||||||
"""
|
"""
|
||||||
assert self._is_master
|
assert self._is_master
|
||||||
await self._delayed_event_mgmt_ratelimiter.ratelimit(
|
await self._delayed_event_mgmt_ratelimiter.ratelimit(
|
||||||
requester,
|
None, request.getClientAddress().host
|
||||||
(requester.user.to_string(), requester.device_id),
|
|
||||||
)
|
)
|
||||||
await make_deferred_yieldable(self._initialized_from_db)
|
await make_deferred_yieldable(self._initialized_from_db)
|
||||||
|
|
||||||
next_send_ts = await self._store.cancel_delayed_event(
|
next_send_ts = await self._store.cancel_delayed_event(delay_id)
|
||||||
delay_id=delay_id,
|
|
||||||
user_localpart=requester.user.localpart,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self._next_send_ts_changed(next_send_ts):
|
if self._next_send_ts_changed(next_send_ts):
|
||||||
self._schedule_next_at_or_none(next_send_ts)
|
self._schedule_next_at_or_none(next_send_ts)
|
||||||
|
|
||||||
async def restart(self, requester: Requester, delay_id: str) -> None:
|
async def restart(self, request: SynapseRequest, delay_id: str) -> None:
|
||||||
"""
|
"""
|
||||||
Restarts the scheduled delivery of the matching delayed event.
|
Restarts the scheduled delivery of the matching delayed event.
|
||||||
|
|
||||||
Args:
|
|
||||||
requester: The owner of the delayed event to act on.
|
|
||||||
delay_id: The ID of the delayed event to act on.
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
NotFoundError: if no matching delayed event could be found.
|
NotFoundError: if no matching delayed event could be found.
|
||||||
"""
|
"""
|
||||||
assert self._is_master
|
assert self._is_master
|
||||||
await self._delayed_event_mgmt_ratelimiter.ratelimit(
|
await self._delayed_event_mgmt_ratelimiter.ratelimit(
|
||||||
requester,
|
None, request.getClientAddress().host
|
||||||
(requester.user.to_string(), requester.device_id),
|
|
||||||
)
|
)
|
||||||
await make_deferred_yieldable(self._initialized_from_db)
|
await make_deferred_yieldable(self._initialized_from_db)
|
||||||
|
|
||||||
next_send_ts = await self._store.restart_delayed_event(
|
next_send_ts = await self._store.restart_delayed_event(
|
||||||
delay_id=delay_id,
|
delay_id, self._get_current_ts()
|
||||||
user_localpart=requester.user.localpart,
|
|
||||||
current_ts=self._get_current_ts(),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._next_send_ts_changed(next_send_ts):
|
if self._next_send_ts_changed(next_send_ts):
|
||||||
self._schedule_next_at(next_send_ts)
|
self._schedule_next_at(next_send_ts)
|
||||||
|
|
||||||
async def send(self, requester: Requester, delay_id: str) -> None:
|
async def send(self, request: SynapseRequest, delay_id: str) -> None:
|
||||||
"""
|
"""
|
||||||
Immediately sends the matching delayed event, instead of waiting for its scheduled delivery.
|
Immediately sends the matching delayed event, instead of waiting for its scheduled delivery.
|
||||||
|
|
||||||
Args:
|
|
||||||
requester: The owner of the delayed event to act on.
|
|
||||||
delay_id: The ID of the delayed event to act on.
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
NotFoundError: if no matching delayed event could be found.
|
NotFoundError: if no matching delayed event could be found.
|
||||||
"""
|
"""
|
||||||
assert self._is_master
|
assert self._is_master
|
||||||
# Use standard request limiter for sending delayed events on-demand,
|
await self._delayed_event_mgmt_ratelimiter.ratelimit(
|
||||||
# as an on-demand send is similar to sending a regular event.
|
None, request.getClientAddress().host
|
||||||
await self._request_ratelimiter.ratelimit(requester)
|
)
|
||||||
await make_deferred_yieldable(self._initialized_from_db)
|
await make_deferred_yieldable(self._initialized_from_db)
|
||||||
|
|
||||||
event, next_send_ts = await self._store.process_target_delayed_event(
|
event, next_send_ts = await self._store.process_target_delayed_event(delay_id)
|
||||||
delay_id=delay_id,
|
|
||||||
user_localpart=requester.user.localpart,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self._next_send_ts_changed(next_send_ts):
|
if self._next_send_ts_changed(next_send_ts):
|
||||||
self._schedule_next_at_or_none(next_send_ts)
|
self._schedule_next_at_or_none(next_send_ts)
|
||||||
|
|
||||||
await self._send_event(
|
await self._send_event(event)
|
||||||
DelayedEventDetails(
|
|
||||||
delay_id=DelayID(delay_id),
|
|
||||||
user_localpart=UserLocalpart(requester.user.localpart),
|
|
||||||
room_id=event.room_id,
|
|
||||||
type=event.type,
|
|
||||||
state_key=event.state_key,
|
|
||||||
origin_server_ts=event.origin_server_ts,
|
|
||||||
content=event.content,
|
|
||||||
device_id=event.device_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _send_on_timeout(self) -> None:
|
async def _send_on_timeout(self) -> None:
|
||||||
self._next_delayed_event_call = None
|
self._next_delayed_event_call = None
|
||||||
@ -611,9 +577,7 @@ class DelayedEventsHandler:
|
|||||||
finally:
|
finally:
|
||||||
# TODO: If this is a temporary error, retry. Otherwise, consider notifying clients of the failure
|
# TODO: If this is a temporary error, retry. Otherwise, consider notifying clients of the failure
|
||||||
try:
|
try:
|
||||||
await self._store.delete_processed_delayed_event(
|
await self._store.delete_processed_delayed_event(event.delay_id)
|
||||||
event.delay_id, event.user_localpart
|
|
||||||
)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to delete processed delayed event")
|
logger.exception("Failed to delete processed delayed event")
|
||||||
|
|
||||||
|
|||||||
@ -47,14 +47,11 @@ class UpdateDelayedEventServlet(RestServlet):
|
|||||||
|
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.auth = hs.get_auth()
|
|
||||||
self.delayed_events_handler = hs.get_delayed_events_handler()
|
self.delayed_events_handler = hs.get_delayed_events_handler()
|
||||||
|
|
||||||
async def on_POST(
|
async def on_POST(
|
||||||
self, request: SynapseRequest, delay_id: str
|
self, request: SynapseRequest, delay_id: str
|
||||||
) -> tuple[int, JsonDict]:
|
) -> tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
|
||||||
|
|
||||||
body = parse_json_object_from_request(request)
|
body = parse_json_object_from_request(request)
|
||||||
try:
|
try:
|
||||||
action = str(body["action"])
|
action = str(body["action"])
|
||||||
@ -75,11 +72,65 @@ class UpdateDelayedEventServlet(RestServlet):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if enum_action == _UpdateDelayedEventAction.CANCEL:
|
if enum_action == _UpdateDelayedEventAction.CANCEL:
|
||||||
await self.delayed_events_handler.cancel(requester, delay_id)
|
await self.delayed_events_handler.cancel(request, delay_id)
|
||||||
elif enum_action == _UpdateDelayedEventAction.RESTART:
|
elif enum_action == _UpdateDelayedEventAction.RESTART:
|
||||||
await self.delayed_events_handler.restart(requester, delay_id)
|
await self.delayed_events_handler.restart(request, delay_id)
|
||||||
elif enum_action == _UpdateDelayedEventAction.SEND:
|
elif enum_action == _UpdateDelayedEventAction.SEND:
|
||||||
await self.delayed_events_handler.send(requester, delay_id)
|
await self.delayed_events_handler.send(request, delay_id)
|
||||||
|
return 200, {}
|
||||||
|
|
||||||
|
|
||||||
|
class CancelDelayedEventServlet(RestServlet):
|
||||||
|
PATTERNS = client_patterns(
|
||||||
|
r"/org\.matrix\.msc4140/delayed_events/(?P<delay_id>[^/]+)/cancel$",
|
||||||
|
releases=(),
|
||||||
|
)
|
||||||
|
CATEGORY = "Delayed event management requests"
|
||||||
|
|
||||||
|
def __init__(self, hs: "HomeServer"):
|
||||||
|
super().__init__()
|
||||||
|
self.delayed_events_handler = hs.get_delayed_events_handler()
|
||||||
|
|
||||||
|
async def on_POST(
|
||||||
|
self, request: SynapseRequest, delay_id: str
|
||||||
|
) -> tuple[int, JsonDict]:
|
||||||
|
await self.delayed_events_handler.cancel(request, delay_id)
|
||||||
|
return 200, {}
|
||||||
|
|
||||||
|
|
||||||
|
class RestartDelayedEventServlet(RestServlet):
|
||||||
|
PATTERNS = client_patterns(
|
||||||
|
r"/org\.matrix\.msc4140/delayed_events/(?P<delay_id>[^/]+)/restart$",
|
||||||
|
releases=(),
|
||||||
|
)
|
||||||
|
CATEGORY = "Delayed event management requests"
|
||||||
|
|
||||||
|
def __init__(self, hs: "HomeServer"):
|
||||||
|
super().__init__()
|
||||||
|
self.delayed_events_handler = hs.get_delayed_events_handler()
|
||||||
|
|
||||||
|
async def on_POST(
|
||||||
|
self, request: SynapseRequest, delay_id: str
|
||||||
|
) -> tuple[int, JsonDict]:
|
||||||
|
await self.delayed_events_handler.restart(request, delay_id)
|
||||||
|
return 200, {}
|
||||||
|
|
||||||
|
|
||||||
|
class SendDelayedEventServlet(RestServlet):
|
||||||
|
PATTERNS = client_patterns(
|
||||||
|
r"/org\.matrix\.msc4140/delayed_events/(?P<delay_id>[^/]+)/send$",
|
||||||
|
releases=(),
|
||||||
|
)
|
||||||
|
CATEGORY = "Delayed event management requests"
|
||||||
|
|
||||||
|
def __init__(self, hs: "HomeServer"):
|
||||||
|
super().__init__()
|
||||||
|
self.delayed_events_handler = hs.get_delayed_events_handler()
|
||||||
|
|
||||||
|
async def on_POST(
|
||||||
|
self, request: SynapseRequest, delay_id: str
|
||||||
|
) -> tuple[int, JsonDict]:
|
||||||
|
await self.delayed_events_handler.send(request, delay_id)
|
||||||
return 200, {}
|
return 200, {}
|
||||||
|
|
||||||
|
|
||||||
@ -108,4 +159,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
|||||||
# The following can't currently be instantiated on workers.
|
# The following can't currently be instantiated on workers.
|
||||||
if hs.config.worker.worker_app is None:
|
if hs.config.worker.worker_app is None:
|
||||||
UpdateDelayedEventServlet(hs).register(http_server)
|
UpdateDelayedEventServlet(hs).register(http_server)
|
||||||
|
CancelDelayedEventServlet(hs).register(http_server)
|
||||||
|
RestartDelayedEventServlet(hs).register(http_server)
|
||||||
|
SendDelayedEventServlet(hs).register(http_server)
|
||||||
DelayedEventsServlet(hs).register(http_server)
|
DelayedEventsServlet(hs).register(http_server)
|
||||||
|
|||||||
@ -13,18 +13,26 @@
|
|||||||
#
|
#
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import NewType
|
from typing import TYPE_CHECKING, NewType
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
from synapse.api.errors import NotFoundError
|
from synapse.api.errors import NotFoundError
|
||||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||||
from synapse.storage.database import LoggingTransaction, StoreError
|
from synapse.storage.database import (
|
||||||
|
DatabasePool,
|
||||||
|
LoggingDatabaseConnection,
|
||||||
|
LoggingTransaction,
|
||||||
|
StoreError,
|
||||||
|
)
|
||||||
from synapse.storage.engines import PostgresEngine
|
from synapse.storage.engines import PostgresEngine
|
||||||
from synapse.types import JsonDict, RoomID
|
from synapse.types import JsonDict, RoomID
|
||||||
from synapse.util import stringutils
|
from synapse.util import stringutils
|
||||||
from synapse.util.json import json_encoder
|
from synapse.util.json import json_encoder
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -55,6 +63,27 @@ class DelayedEventDetails(EventDetails):
|
|||||||
|
|
||||||
|
|
||||||
class DelayedEventsStore(SQLBaseStore):
|
class DelayedEventsStore(SQLBaseStore):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
database: DatabasePool,
|
||||||
|
db_conn: LoggingDatabaseConnection,
|
||||||
|
hs: "HomeServer",
|
||||||
|
):
|
||||||
|
super().__init__(database, db_conn, hs)
|
||||||
|
|
||||||
|
# Set delayed events to be uniquely identifiable by their delay_id.
|
||||||
|
# In practice, delay_ids are already unique because they are generated
|
||||||
|
# from cryptographically strong random strings.
|
||||||
|
# Therefore, adding this constraint is not expected to ever fail,
|
||||||
|
# despite the current pkey technically allowing non-unique delay_ids.
|
||||||
|
self.db_pool.updates.register_background_index_update(
|
||||||
|
update_name="delayed_events_idx",
|
||||||
|
index_name="delayed_events_idx",
|
||||||
|
table="delayed_events",
|
||||||
|
columns=("delay_id",),
|
||||||
|
unique=True,
|
||||||
|
)
|
||||||
|
|
||||||
async def get_delayed_events_stream_pos(self) -> int:
|
async def get_delayed_events_stream_pos(self) -> int:
|
||||||
"""
|
"""
|
||||||
Gets the stream position of the background process to watch for state events
|
Gets the stream position of the background process to watch for state events
|
||||||
@ -134,9 +163,7 @@ class DelayedEventsStore(SQLBaseStore):
|
|||||||
|
|
||||||
async def restart_delayed_event(
|
async def restart_delayed_event(
|
||||||
self,
|
self,
|
||||||
*,
|
|
||||||
delay_id: str,
|
delay_id: str,
|
||||||
user_localpart: str,
|
|
||||||
current_ts: Timestamp,
|
current_ts: Timestamp,
|
||||||
) -> Timestamp:
|
) -> Timestamp:
|
||||||
"""
|
"""
|
||||||
@ -145,7 +172,6 @@ class DelayedEventsStore(SQLBaseStore):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
delay_id: The ID of the delayed event to restart.
|
delay_id: The ID of the delayed event to restart.
|
||||||
user_localpart: The localpart of the delayed event's owner.
|
|
||||||
current_ts: The current time, which will be used to calculate the new send time.
|
current_ts: The current time, which will be used to calculate the new send time.
|
||||||
|
|
||||||
Returns: The send time of the next delayed event to be sent,
|
Returns: The send time of the next delayed event to be sent,
|
||||||
@ -163,13 +189,11 @@ class DelayedEventsStore(SQLBaseStore):
|
|||||||
"""
|
"""
|
||||||
UPDATE delayed_events
|
UPDATE delayed_events
|
||||||
SET send_ts = ? + delay
|
SET send_ts = ? + delay
|
||||||
WHERE delay_id = ? AND user_localpart = ?
|
WHERE delay_id = ? AND NOT is_processed
|
||||||
AND NOT is_processed
|
|
||||||
""",
|
""",
|
||||||
(
|
(
|
||||||
current_ts,
|
current_ts,
|
||||||
delay_id,
|
delay_id,
|
||||||
user_localpart,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
if txn.rowcount == 0:
|
if txn.rowcount == 0:
|
||||||
@ -319,21 +343,15 @@ class DelayedEventsStore(SQLBaseStore):
|
|||||||
|
|
||||||
async def process_target_delayed_event(
|
async def process_target_delayed_event(
|
||||||
self,
|
self,
|
||||||
*,
|
|
||||||
delay_id: str,
|
delay_id: str,
|
||||||
user_localpart: str,
|
|
||||||
) -> tuple[
|
) -> tuple[
|
||||||
EventDetails,
|
DelayedEventDetails,
|
||||||
Timestamp | None,
|
Timestamp | None,
|
||||||
]:
|
]:
|
||||||
"""
|
"""
|
||||||
Marks for processing the matching delayed event, regardless of its timeout time,
|
Marks for processing the matching delayed event, regardless of its timeout time,
|
||||||
as long as it has not already been marked as such.
|
as long as it has not already been marked as such.
|
||||||
|
|
||||||
Args:
|
|
||||||
delay_id: The ID of the delayed event to restart.
|
|
||||||
user_localpart: The localpart of the delayed event's owner.
|
|
||||||
|
|
||||||
Returns: The details of the matching delayed event,
|
Returns: The details of the matching delayed event,
|
||||||
and the send time of the next delayed event to be sent, if any.
|
and the send time of the next delayed event to be sent, if any.
|
||||||
|
|
||||||
@ -344,39 +362,38 @@ class DelayedEventsStore(SQLBaseStore):
|
|||||||
def process_target_delayed_event_txn(
|
def process_target_delayed_event_txn(
|
||||||
txn: LoggingTransaction,
|
txn: LoggingTransaction,
|
||||||
) -> tuple[
|
) -> tuple[
|
||||||
EventDetails,
|
DelayedEventDetails,
|
||||||
Timestamp | None,
|
Timestamp | None,
|
||||||
]:
|
]:
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"""
|
"""
|
||||||
UPDATE delayed_events
|
UPDATE delayed_events
|
||||||
SET is_processed = TRUE
|
SET is_processed = TRUE
|
||||||
WHERE delay_id = ? AND user_localpart = ?
|
WHERE delay_id = ? AND NOT is_processed
|
||||||
AND NOT is_processed
|
|
||||||
RETURNING
|
RETURNING
|
||||||
room_id,
|
room_id,
|
||||||
event_type,
|
event_type,
|
||||||
state_key,
|
state_key,
|
||||||
origin_server_ts,
|
origin_server_ts,
|
||||||
content,
|
content,
|
||||||
device_id
|
device_id,
|
||||||
|
user_localpart
|
||||||
""",
|
""",
|
||||||
(
|
(delay_id,),
|
||||||
delay_id,
|
|
||||||
user_localpart,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
row = txn.fetchone()
|
row = txn.fetchone()
|
||||||
if row is None:
|
if row is None:
|
||||||
raise NotFoundError("Delayed event not found")
|
raise NotFoundError("Delayed event not found")
|
||||||
|
|
||||||
event = EventDetails(
|
event = DelayedEventDetails(
|
||||||
RoomID.from_string(row[0]),
|
RoomID.from_string(row[0]),
|
||||||
EventType(row[1]),
|
EventType(row[1]),
|
||||||
StateKey(row[2]) if row[2] is not None else None,
|
StateKey(row[2]) if row[2] is not None else None,
|
||||||
Timestamp(row[3]) if row[3] is not None else None,
|
Timestamp(row[3]) if row[3] is not None else None,
|
||||||
db_to_json(row[4]),
|
db_to_json(row[4]),
|
||||||
DeviceID(row[5]) if row[5] is not None else None,
|
DeviceID(row[5]) if row[5] is not None else None,
|
||||||
|
DelayID(delay_id),
|
||||||
|
UserLocalpart(row[6]),
|
||||||
)
|
)
|
||||||
|
|
||||||
return event, self._get_next_delayed_event_send_ts_txn(txn)
|
return event, self._get_next_delayed_event_send_ts_txn(txn)
|
||||||
@ -385,19 +402,10 @@ class DelayedEventsStore(SQLBaseStore):
|
|||||||
"process_target_delayed_event", process_target_delayed_event_txn
|
"process_target_delayed_event", process_target_delayed_event_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
async def cancel_delayed_event(
|
async def cancel_delayed_event(self, delay_id: str) -> Timestamp | None:
|
||||||
self,
|
|
||||||
*,
|
|
||||||
delay_id: str,
|
|
||||||
user_localpart: str,
|
|
||||||
) -> Timestamp | None:
|
|
||||||
"""
|
"""
|
||||||
Cancels the matching delayed event, i.e. remove it as long as it hasn't been processed.
|
Cancels the matching delayed event, i.e. remove it as long as it hasn't been processed.
|
||||||
|
|
||||||
Args:
|
|
||||||
delay_id: The ID of the delayed event to restart.
|
|
||||||
user_localpart: The localpart of the delayed event's owner.
|
|
||||||
|
|
||||||
Returns: The send time of the next delayed event to be sent, if any.
|
Returns: The send time of the next delayed event to be sent, if any.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
@ -413,7 +421,6 @@ class DelayedEventsStore(SQLBaseStore):
|
|||||||
table="delayed_events",
|
table="delayed_events",
|
||||||
keyvalues={
|
keyvalues={
|
||||||
"delay_id": delay_id,
|
"delay_id": delay_id,
|
||||||
"user_localpart": user_localpart,
|
|
||||||
"is_processed": False,
|
"is_processed": False,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@ -473,11 +480,7 @@ class DelayedEventsStore(SQLBaseStore):
|
|||||||
"cancel_delayed_state_events", cancel_delayed_state_events_txn
|
"cancel_delayed_state_events", cancel_delayed_state_events_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
async def delete_processed_delayed_event(
|
async def delete_processed_delayed_event(self, delay_id: DelayID) -> None:
|
||||||
self,
|
|
||||||
delay_id: DelayID,
|
|
||||||
user_localpart: UserLocalpart,
|
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
Delete the matching delayed event, as long as it has been marked as processed.
|
Delete the matching delayed event, as long as it has been marked as processed.
|
||||||
|
|
||||||
@ -488,7 +491,6 @@ class DelayedEventsStore(SQLBaseStore):
|
|||||||
table="delayed_events",
|
table="delayed_events",
|
||||||
keyvalues={
|
keyvalues={
|
||||||
"delay_id": delay_id,
|
"delay_id": delay_id,
|
||||||
"user_localpart": user_localpart,
|
|
||||||
"is_processed": True,
|
"is_processed": True,
|
||||||
},
|
},
|
||||||
desc="delete_processed_delayed_event",
|
desc="delete_processed_delayed_event",
|
||||||
@ -554,7 +556,7 @@ def _generate_delay_id() -> DelayID:
|
|||||||
|
|
||||||
# We use the following format for delay IDs:
|
# We use the following format for delay IDs:
|
||||||
# syd_<random string>
|
# syd_<random string>
|
||||||
# They are scoped to user localparts, so it is possible for
|
# They are not scoped to user localparts, but the random string
|
||||||
# the same ID to exist for multiple users.
|
# is expected to be sufficiently random to be globally unique.
|
||||||
|
|
||||||
return DelayID(f"syd_{stringutils.random_string(20)}")
|
return DelayID(f"syd_{stringutils.random_string(20)}")
|
||||||
|
|||||||
@ -19,7 +19,7 @@
|
|||||||
#
|
#
|
||||||
#
|
#
|
||||||
|
|
||||||
SCHEMA_VERSION = 92 # remember to update the list below when updating
|
SCHEMA_VERSION = 93 # remember to update the list below when updating
|
||||||
"""Represents the expectations made by the codebase about the database schema
|
"""Represents the expectations made by the codebase about the database schema
|
||||||
|
|
||||||
This should be incremented whenever the codebase changes its requirements on the
|
This should be incremented whenever the codebase changes its requirements on the
|
||||||
@ -168,11 +168,15 @@ Changes in SCHEMA_VERSION = 91
|
|||||||
|
|
||||||
Changes in SCHEMA_VERSION = 92
|
Changes in SCHEMA_VERSION = 92
|
||||||
- Cleaned up a trigger that was added in #18260 and then reverted.
|
- Cleaned up a trigger that was added in #18260 and then reverted.
|
||||||
|
|
||||||
|
Changes in SCHEMA_VERSION = 93
|
||||||
|
- MSC4140: Set delayed events to be uniquely identifiable by their delay ID.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
SCHEMA_COMPAT_VERSION = (
|
SCHEMA_COMPAT_VERSION = (
|
||||||
# Transitive links are no longer written to `event_auth_chain_links`
|
# Transitive links are no longer written to `event_auth_chain_links`
|
||||||
|
# TODO: On the next compat bump, update the primary key of `delayed_events`
|
||||||
84
|
84
|
||||||
)
|
)
|
||||||
"""Limit on how far the synapse codebase can be rolled back without breaking db compat
|
"""Limit on how far the synapse codebase can be rolled back without breaking db compat
|
||||||
|
|||||||
@ -0,0 +1,15 @@
|
|||||||
|
--
|
||||||
|
-- This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||||
|
--
|
||||||
|
-- Copyright (C) 2025 Element Creations, Ltd
|
||||||
|
--
|
||||||
|
-- This program is free software: you can redistribute it and/or modify
|
||||||
|
-- it under the terms of the GNU Affero General Public License as
|
||||||
|
-- published by the Free Software Foundation, either version 3 of the
|
||||||
|
-- License, or (at your option) any later version.
|
||||||
|
--
|
||||||
|
-- See the GNU Affero General Public License for more details:
|
||||||
|
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||||
|
|
||||||
|
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
|
||||||
|
(9301, 'delayed_events_idx', '{}');
|
||||||
@ -28,6 +28,7 @@ from synapse.types import JsonDict
|
|||||||
from synapse.util.clock import Clock
|
from synapse.util.clock import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
from tests.server import FakeChannel
|
||||||
from tests.unittest import HomeserverTestCase
|
from tests.unittest import HomeserverTestCase
|
||||||
|
|
||||||
PATH_PREFIX = "/_matrix/client/unstable/org.matrix.msc4140/delayed_events"
|
PATH_PREFIX = "/_matrix/client/unstable/org.matrix.msc4140/delayed_events"
|
||||||
@ -127,6 +128,10 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(setter_expected, content.get(setter_key), content)
|
self.assertEqual(setter_expected, content.get(setter_key), content)
|
||||||
|
|
||||||
|
def test_get_delayed_events_auth(self) -> None:
|
||||||
|
channel = self.make_request("GET", PATH_PREFIX)
|
||||||
|
self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, channel.result)
|
||||||
|
|
||||||
@unittest.override_config(
|
@unittest.override_config(
|
||||||
{"rc_delayed_event_mgmt": {"per_second": 0.5, "burst_count": 1}}
|
{"rc_delayed_event_mgmt": {"per_second": 0.5, "burst_count": 1}}
|
||||||
)
|
)
|
||||||
@ -154,7 +159,6 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
|||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"POST",
|
"POST",
|
||||||
f"{PATH_PREFIX}/",
|
f"{PATH_PREFIX}/",
|
||||||
access_token=self.user1_access_token,
|
|
||||||
)
|
)
|
||||||
self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, channel.result)
|
self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, channel.result)
|
||||||
|
|
||||||
@ -162,7 +166,6 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
|||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"POST",
|
"POST",
|
||||||
f"{PATH_PREFIX}/abc",
|
f"{PATH_PREFIX}/abc",
|
||||||
access_token=self.user1_access_token,
|
|
||||||
)
|
)
|
||||||
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
|
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
@ -175,7 +178,6 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
|||||||
"POST",
|
"POST",
|
||||||
f"{PATH_PREFIX}/abc",
|
f"{PATH_PREFIX}/abc",
|
||||||
{},
|
{},
|
||||||
self.user1_access_token,
|
|
||||||
)
|
)
|
||||||
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
|
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
@ -188,7 +190,6 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
|||||||
"POST",
|
"POST",
|
||||||
f"{PATH_PREFIX}/abc",
|
f"{PATH_PREFIX}/abc",
|
||||||
{"action": "oops"},
|
{"action": "oops"},
|
||||||
self.user1_access_token,
|
|
||||||
)
|
)
|
||||||
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
|
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
@ -196,17 +197,21 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
|||||||
channel.json_body["errcode"],
|
channel.json_body["errcode"],
|
||||||
)
|
)
|
||||||
|
|
||||||
@parameterized.expand(["cancel", "restart", "send"])
|
@parameterized.expand(
|
||||||
def test_update_delayed_event_without_match(self, action: str) -> None:
|
(
|
||||||
channel = self.make_request(
|
(action, action_in_path)
|
||||||
"POST",
|
for action in ("cancel", "restart", "send")
|
||||||
f"{PATH_PREFIX}/abc",
|
for action_in_path in (True, False)
|
||||||
{"action": action},
|
|
||||||
self.user1_access_token,
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
def test_update_delayed_event_without_match(
|
||||||
|
self, action: str, action_in_path: bool
|
||||||
|
) -> None:
|
||||||
|
channel = self._update_delayed_event("abc", action, action_in_path)
|
||||||
self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, channel.result)
|
self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, channel.result)
|
||||||
|
|
||||||
def test_cancel_delayed_state_event(self) -> None:
|
@parameterized.expand((True, False))
|
||||||
|
def test_cancel_delayed_state_event(self, action_in_path: bool) -> None:
|
||||||
state_key = "to_never_send"
|
state_key = "to_never_send"
|
||||||
|
|
||||||
setter_key = "setter"
|
setter_key = "setter"
|
||||||
@ -221,7 +226,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||||
delay_id = channel.json_body.get("delay_id")
|
delay_id = channel.json_body.get("delay_id")
|
||||||
self.assertIsNotNone(delay_id)
|
assert delay_id is not None
|
||||||
|
|
||||||
self.reactor.advance(1)
|
self.reactor.advance(1)
|
||||||
events = self._get_delayed_events()
|
events = self._get_delayed_events()
|
||||||
@ -236,12 +241,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
|||||||
expect_code=HTTPStatus.NOT_FOUND,
|
expect_code=HTTPStatus.NOT_FOUND,
|
||||||
)
|
)
|
||||||
|
|
||||||
channel = self.make_request(
|
channel = self._update_delayed_event(delay_id, "cancel", action_in_path)
|
||||||
"POST",
|
|
||||||
f"{PATH_PREFIX}/{delay_id}",
|
|
||||||
{"action": "cancel"},
|
|
||||||
self.user1_access_token,
|
|
||||||
)
|
|
||||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||||
self.assertListEqual([], self._get_delayed_events())
|
self.assertListEqual([], self._get_delayed_events())
|
||||||
|
|
||||||
@ -254,10 +254,11 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
|||||||
expect_code=HTTPStatus.NOT_FOUND,
|
expect_code=HTTPStatus.NOT_FOUND,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@parameterized.expand((True, False))
|
||||||
@unittest.override_config(
|
@unittest.override_config(
|
||||||
{"rc_delayed_event_mgmt": {"per_second": 0.5, "burst_count": 1}}
|
{"rc_delayed_event_mgmt": {"per_second": 0.5, "burst_count": 1}}
|
||||||
)
|
)
|
||||||
def test_cancel_delayed_event_ratelimit(self) -> None:
|
def test_cancel_delayed_event_ratelimit(self, action_in_path: bool) -> None:
|
||||||
delay_ids = []
|
delay_ids = []
|
||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
@ -268,38 +269,17 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||||
delay_id = channel.json_body.get("delay_id")
|
delay_id = channel.json_body.get("delay_id")
|
||||||
self.assertIsNotNone(delay_id)
|
assert delay_id is not None
|
||||||
delay_ids.append(delay_id)
|
delay_ids.append(delay_id)
|
||||||
|
|
||||||
channel = self.make_request(
|
channel = self._update_delayed_event(delay_ids.pop(0), "cancel", action_in_path)
|
||||||
"POST",
|
|
||||||
f"{PATH_PREFIX}/{delay_ids.pop(0)}",
|
|
||||||
{"action": "cancel"},
|
|
||||||
self.user1_access_token,
|
|
||||||
)
|
|
||||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||||
|
|
||||||
args = (
|
channel = self._update_delayed_event(delay_ids.pop(0), "cancel", action_in_path)
|
||||||
"POST",
|
|
||||||
f"{PATH_PREFIX}/{delay_ids.pop(0)}",
|
|
||||||
{"action": "cancel"},
|
|
||||||
self.user1_access_token,
|
|
||||||
)
|
|
||||||
channel = self.make_request(*args)
|
|
||||||
self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result)
|
self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result)
|
||||||
|
|
||||||
# Add the current user to the ratelimit overrides, allowing them no ratelimiting.
|
@parameterized.expand((True, False))
|
||||||
self.get_success(
|
def test_send_delayed_state_event(self, action_in_path: bool) -> None:
|
||||||
self.hs.get_datastores().main.set_ratelimit_for_user(
|
|
||||||
self.user1_user_id, 0, 0
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test that the request isn't ratelimited anymore.
|
|
||||||
channel = self.make_request(*args)
|
|
||||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
|
||||||
|
|
||||||
def test_send_delayed_state_event(self) -> None:
|
|
||||||
state_key = "to_send_on_request"
|
state_key = "to_send_on_request"
|
||||||
|
|
||||||
setter_key = "setter"
|
setter_key = "setter"
|
||||||
@ -314,7 +294,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||||
delay_id = channel.json_body.get("delay_id")
|
delay_id = channel.json_body.get("delay_id")
|
||||||
self.assertIsNotNone(delay_id)
|
assert delay_id is not None
|
||||||
|
|
||||||
self.reactor.advance(1)
|
self.reactor.advance(1)
|
||||||
events = self._get_delayed_events()
|
events = self._get_delayed_events()
|
||||||
@ -329,12 +309,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
|||||||
expect_code=HTTPStatus.NOT_FOUND,
|
expect_code=HTTPStatus.NOT_FOUND,
|
||||||
)
|
)
|
||||||
|
|
||||||
channel = self.make_request(
|
channel = self._update_delayed_event(delay_id, "send", action_in_path)
|
||||||
"POST",
|
|
||||||
f"{PATH_PREFIX}/{delay_id}",
|
|
||||||
{"action": "send"},
|
|
||||||
self.user1_access_token,
|
|
||||||
)
|
|
||||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||||
self.assertListEqual([], self._get_delayed_events())
|
self.assertListEqual([], self._get_delayed_events())
|
||||||
content = self.helper.get_state(
|
content = self.helper.get_state(
|
||||||
@ -345,8 +320,9 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(setter_expected, content.get(setter_key), content)
|
self.assertEqual(setter_expected, content.get(setter_key), content)
|
||||||
|
|
||||||
@unittest.override_config({"rc_message": {"per_second": 3.5, "burst_count": 4}})
|
@parameterized.expand((True, False))
|
||||||
def test_send_delayed_event_ratelimit(self) -> None:
|
@unittest.override_config({"rc_message": {"per_second": 2.5, "burst_count": 3}})
|
||||||
|
def test_send_delayed_event_ratelimit(self, action_in_path: bool) -> None:
|
||||||
delay_ids = []
|
delay_ids = []
|
||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
@ -357,38 +333,17 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||||
delay_id = channel.json_body.get("delay_id")
|
delay_id = channel.json_body.get("delay_id")
|
||||||
self.assertIsNotNone(delay_id)
|
assert delay_id is not None
|
||||||
delay_ids.append(delay_id)
|
delay_ids.append(delay_id)
|
||||||
|
|
||||||
channel = self.make_request(
|
channel = self._update_delayed_event(delay_ids.pop(0), "send", action_in_path)
|
||||||
"POST",
|
|
||||||
f"{PATH_PREFIX}/{delay_ids.pop(0)}",
|
|
||||||
{"action": "send"},
|
|
||||||
self.user1_access_token,
|
|
||||||
)
|
|
||||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||||
|
|
||||||
args = (
|
channel = self._update_delayed_event(delay_ids.pop(0), "send", action_in_path)
|
||||||
"POST",
|
|
||||||
f"{PATH_PREFIX}/{delay_ids.pop(0)}",
|
|
||||||
{"action": "send"},
|
|
||||||
self.user1_access_token,
|
|
||||||
)
|
|
||||||
channel = self.make_request(*args)
|
|
||||||
self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result)
|
self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result)
|
||||||
|
|
||||||
# Add the current user to the ratelimit overrides, allowing them no ratelimiting.
|
@parameterized.expand((True, False))
|
||||||
self.get_success(
|
def test_restart_delayed_state_event(self, action_in_path: bool) -> None:
|
||||||
self.hs.get_datastores().main.set_ratelimit_for_user(
|
|
||||||
self.user1_user_id, 0, 0
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test that the request isn't ratelimited anymore.
|
|
||||||
channel = self.make_request(*args)
|
|
||||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
|
||||||
|
|
||||||
def test_restart_delayed_state_event(self) -> None:
|
|
||||||
state_key = "to_send_on_restarted_timeout"
|
state_key = "to_send_on_restarted_timeout"
|
||||||
|
|
||||||
setter_key = "setter"
|
setter_key = "setter"
|
||||||
@ -403,7 +358,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||||
delay_id = channel.json_body.get("delay_id")
|
delay_id = channel.json_body.get("delay_id")
|
||||||
self.assertIsNotNone(delay_id)
|
assert delay_id is not None
|
||||||
|
|
||||||
self.reactor.advance(1)
|
self.reactor.advance(1)
|
||||||
events = self._get_delayed_events()
|
events = self._get_delayed_events()
|
||||||
@ -418,12 +373,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
|||||||
expect_code=HTTPStatus.NOT_FOUND,
|
expect_code=HTTPStatus.NOT_FOUND,
|
||||||
)
|
)
|
||||||
|
|
||||||
channel = self.make_request(
|
channel = self._update_delayed_event(delay_id, "restart", action_in_path)
|
||||||
"POST",
|
|
||||||
f"{PATH_PREFIX}/{delay_id}",
|
|
||||||
{"action": "restart"},
|
|
||||||
self.user1_access_token,
|
|
||||||
)
|
|
||||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||||
|
|
||||||
self.reactor.advance(1)
|
self.reactor.advance(1)
|
||||||
@ -449,10 +399,11 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(setter_expected, content.get(setter_key), content)
|
self.assertEqual(setter_expected, content.get(setter_key), content)
|
||||||
|
|
||||||
|
@parameterized.expand((True, False))
|
||||||
@unittest.override_config(
|
@unittest.override_config(
|
||||||
{"rc_delayed_event_mgmt": {"per_second": 0.5, "burst_count": 1}}
|
{"rc_delayed_event_mgmt": {"per_second": 0.5, "burst_count": 1}}
|
||||||
)
|
)
|
||||||
def test_restart_delayed_event_ratelimit(self) -> None:
|
def test_restart_delayed_event_ratelimit(self, action_in_path: bool) -> None:
|
||||||
delay_ids = []
|
delay_ids = []
|
||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
@ -463,37 +414,19 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||||
delay_id = channel.json_body.get("delay_id")
|
delay_id = channel.json_body.get("delay_id")
|
||||||
self.assertIsNotNone(delay_id)
|
assert delay_id is not None
|
||||||
delay_ids.append(delay_id)
|
delay_ids.append(delay_id)
|
||||||
|
|
||||||
channel = self.make_request(
|
channel = self._update_delayed_event(
|
||||||
"POST",
|
delay_ids.pop(0), "restart", action_in_path
|
||||||
f"{PATH_PREFIX}/{delay_ids.pop(0)}",
|
|
||||||
{"action": "restart"},
|
|
||||||
self.user1_access_token,
|
|
||||||
)
|
)
|
||||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||||
|
|
||||||
args = (
|
channel = self._update_delayed_event(
|
||||||
"POST",
|
delay_ids.pop(0), "restart", action_in_path
|
||||||
f"{PATH_PREFIX}/{delay_ids.pop(0)}",
|
|
||||||
{"action": "restart"},
|
|
||||||
self.user1_access_token,
|
|
||||||
)
|
)
|
||||||
channel = self.make_request(*args)
|
|
||||||
self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result)
|
self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result)
|
||||||
|
|
||||||
# Add the current user to the ratelimit overrides, allowing them no ratelimiting.
|
|
||||||
self.get_success(
|
|
||||||
self.hs.get_datastores().main.set_ratelimit_for_user(
|
|
||||||
self.user1_user_id, 0, 0
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test that the request isn't ratelimited anymore.
|
|
||||||
channel = self.make_request(*args)
|
|
||||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
|
||||||
|
|
||||||
def test_delayed_state_is_not_cancelled_by_new_state_from_same_user(
|
def test_delayed_state_is_not_cancelled_by_new_state_from_same_user(
|
||||||
self,
|
self,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -598,6 +531,17 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
|||||||
|
|
||||||
return content
|
return content
|
||||||
|
|
||||||
|
def _update_delayed_event(
|
||||||
|
self, delay_id: str, action: str, action_in_path: bool
|
||||||
|
) -> FakeChannel:
|
||||||
|
path = f"{PATH_PREFIX}/{delay_id}"
|
||||||
|
body = {}
|
||||||
|
if action_in_path:
|
||||||
|
path += f"/{action}"
|
||||||
|
else:
|
||||||
|
body["action"] = action
|
||||||
|
return self.make_request("POST", path, body)
|
||||||
|
|
||||||
|
|
||||||
def _get_path_for_delayed_state(
|
def _get_path_for_delayed_state(
|
||||||
room_id: str, event_type: str, state_key: str, delay_ms: int
|
room_id: str, event_type: str, state_key: str, delay_ms: int
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user