Compare commits

...

5 Commits

Author SHA1 Message Date
Kegan Dougal
aac3c846a8 Use a tri-state for soft failed to communicate when we need to cache invalidate 2025-10-02 16:47:45 +01:00
Kegan Dougal
888ab79b3b Add NewServerJoined replication command
Currently emits on joins and logs on receive, WIP.
2025-10-02 15:25:47 +01:00
Kegan Dougal
aa45bf7c3a Add msc4354 to /versions response 2025-10-02 10:46:44 +01:00
Kegan Dougal
15453d4e6e JSON false not str 2025-10-02 09:30:21 +01:00
Kegan Dougal
78c40973f4 SQLite specific soft-failure update code 2025-10-02 09:26:33 +01:00
14 changed files with 185 additions and 58 deletions

View File

@ -372,3 +372,10 @@ class StickyEvent:
QUERY_PARAM_NAME: Final = "org.matrix.msc4354.sticky_duration_ms"
FIELD_NAME: Final = "msc4354_sticky"
MAX_DURATION_MS: Final = 3600000 # 1 hour
# for the database
class StickyEventSoftFailed(enum.IntEnum):
FALSE = 0
TRUE = 1
FORMER_TRUE = 2

View File

@ -213,6 +213,11 @@ class FederationRemoteSendQueue(AbstractFederationSender):
# This should never get called.
raise NotImplementedError()
def notify_new_server_joined(self, server: str, room_id: str) -> None:
"""As per FederationSender"""
# This should never get called.
raise NotImplementedError()
def build_and_send_edu(
self,
destination: str,

View File

@ -239,6 +239,13 @@ class AbstractFederationSender(metaclass=abc.ABCMeta):
"""
raise NotImplementedError()
@abc.abstractmethod
def notify_new_server_joined(self, server: str, room_id: str) -> None:
"""This gets called when we a new server has joined a room. We might
want to send out some events to this server.
"""
raise NotImplementedError()
@abc.abstractmethod
async def send_read_receipt(self, receipt: ReadReceipt) -> None:
"""Send a RR to any other servers in the room
@ -488,6 +495,9 @@ class FederationSender(AbstractFederationSender):
self._per_destination_queues[destination] = queue
return queue
def notify_new_server_joined(self, server: str, room_id: str) -> None:
print(f"FEDSENDER: new server joined: server={server} room={room_id}")
def notify_new_events(self, max_token: RoomStreamToken) -> None:
"""This gets called when we have some new events we might want to
send out to other servers.

View File

@ -933,6 +933,11 @@ class Notifier:
# that any in flight requests can be immediately retried.
self._federation_client.wake_destination(server)
def notify_new_server_joined(self, server: str, room_id: str) -> None:
# Inform the federation_sender that it may need to send events to the new server.
if self.federation_sender:
self.federation_sender.notify_new_server_joined(server, room_id)
def add_lock_released_callback(
self, callback: Callable[[str, str, str], None]
) -> None:

View File

@ -266,7 +266,6 @@ class ReplicationDataHandler:
users=[row.user_id for row in rows],
)
elif stream_name == StickyEventsStream.NAME:
print(f"STICKY_EVENTS on_rdata {token} => {rows}")
self.notifier.on_new_event(
StreamKeyType.STICKY_EVENTS,
token,

View File

@ -462,6 +462,32 @@ class RemoteServerUpCommand(_SimpleCommand):
NAME = "REMOTE_SERVER_UP"
class NewServerJoinedCommand(Command):
"""Sent when a worker has detected that a new remote server has joined a room.
Format::
NEW_SERVER_JOINED <server> <room_id>
"""
NAME = "NEW_SERVER_JOINED"
__slots__ = ["server", "room_id"]
def __init__(self, server: str, room_id: str):
self.server = server
self.room_id = room_id
@classmethod
def from_line(
cls: Type["NewServerJoinedCommand"], line: str
) -> "NewServerJoinedCommand":
server, room_id = line.split(" ")
return cls(server, room_id)
def to_line(self) -> str:
return "%s %s" % (self.server, self.room_id)
class LockReleasedCommand(Command):
"""Sent to inform other instances that a given lock has been dropped.
@ -517,6 +543,7 @@ _COMMANDS: Tuple[Type[Command], ...] = (
FederationAckCommand,
UserIpCommand,
RemoteServerUpCommand,
NewServerJoinedCommand,
ClearUserSyncsCommand,
LockReleasedCommand,
NewActiveTaskCommand,
@ -533,6 +560,7 @@ VALID_SERVER_COMMANDS = (
ErrorCommand.NAME,
PingCommand.NAME,
RemoteServerUpCommand.NAME,
NewServerJoinedCommand.NAME,
LockReleasedCommand.NAME,
)
@ -547,6 +575,7 @@ VALID_CLIENT_COMMANDS = (
UserIpCommand.NAME,
ErrorCommand.NAME,
RemoteServerUpCommand.NAME,
NewServerJoinedCommand.NAME,
LockReleasedCommand.NAME,
)

View File

@ -48,6 +48,7 @@ from synapse.replication.tcp.commands import (
FederationAckCommand,
LockReleasedCommand,
NewActiveTaskCommand,
NewServerJoinedCommand,
PositionCommand,
RdataCommand,
RemoteServerUpCommand,
@ -764,6 +765,12 @@ class ReplicationCommandHandler:
"""Called when get a new REMOTE_SERVER_UP command."""
self._notifier.notify_remote_server_up(cmd.data)
def on_NEW_SERVER_JOINED(
self, conn: IReplicationConnection, cmd: NewServerJoinedCommand
) -> None:
"""Called when get a new NEW_SERVER_JOINED command."""
self._notifier.notify_new_server_joined(cmd.server, cmd.room_id)
def on_LOCK_RELEASED(
self, conn: IReplicationConnection, cmd: LockReleasedCommand
) -> None:
@ -886,6 +893,9 @@ class ReplicationCommandHandler:
def send_remote_server_up(self, server: str) -> None:
self.send_command(RemoteServerUpCommand(server))
def send_new_server_joined(self, server: str, room_id: str) -> None:
self.send_command(NewServerJoinedCommand(server, room_id))
def stream_update(self, stream_name: str, token: Optional[int], data: Any) -> None:
"""Called when a new update is available to stream to Redis subscribers.

View File

@ -34,7 +34,7 @@ from typing import (
import attr
from synapse.api.constants import AccountDataTypes
from synapse.api.constants import AccountDataTypes, StickyEventSoftFailed
from synapse.replication.http.streams import ReplicationGetStreamUpdates
if TYPE_CHECKING:
@ -768,16 +768,18 @@ class ThreadSubscriptionsStream(_StreamFromIdGen):
return rows, rows[-1][0], len(updates) == limit
@attr.s(slots=True, auto_attribs=True)
class StickyEventsStreamRow:
"""Stream to inform workers about changes to sticky events."""
room_id: str
event_id: str # The sticky event ID
soft_failed_status: StickyEventSoftFailed
class StickyEventsStream(_StreamFromIdGen):
"""A sticky event was changed."""
@attr.s(slots=True, auto_attribs=True)
class StickyEventsStreamRow:
"""Stream to inform workers about changes to sticky events."""
room_id: str
event_id: str # The sticky event ID
NAME = "sticky_events"
ROW_TYPE = StickyEventsStreamRow
@ -799,9 +801,9 @@ class StickyEventsStream(_StreamFromIdGen):
(
stream_id,
# These are the args to `StickyEventsStreamRow`
(room_id, event_id),
(room_id, event_id, soft_failed),
)
for stream_id, room_id, event_id in updates
for stream_id, room_id, event_id, soft_failed in updates
]
if not rows:

View File

@ -182,6 +182,8 @@ class VersionsRestServlet(RestServlet):
"org.matrix.msc4306": self.config.experimental.msc4306_enabled,
# MSC4169: Backwards-compatible redaction sending using `/send`
"com.beeper.msc4169": self.config.experimental.msc4169_enabled,
# MSC4354: Sticky events
"org.matrix.msc4354": self.config.experimental.msc4354_enabled,
},
},
)

View File

@ -1188,6 +1188,14 @@ class PersistEventsStore:
if self.msc4354_sticky_events:
self.store.insert_sticky_events_txn(txn, events_and_contexts)
for ev, _ in events_and_contexts:
if ev.type == "m.room.member" and ev.membership == "join":
print(f"GOT JOIN FOR {ev.state_key}")
domain = get_domain_from_id(ev.state_key)
self.hs.get_notifier().notify_new_server_joined(domain, ev.room_id)
self.hs.get_replication_command_handler().send_new_server_joined(
domain, ev.room_id
)
# We only update the sliding sync tables for non-backfilled events.
self._update_sliding_sync_tables_with_new_persisted_events_txn(

View File

@ -45,7 +45,7 @@ from prometheus_client import Gauge
from twisted.internet import defer
from synapse.api.constants import Direction, EventTypes
from synapse.api.constants import Direction, EventTypes, StickyEventSoftFailed
from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.room_versions import (
KNOWN_ROOM_VERSIONS,
@ -74,6 +74,10 @@ from synapse.metrics.background_process_metrics import (
wrap_as_background_process,
)
from synapse.replication.tcp.streams import BackfillStream, UnPartialStatedEventStream
from synapse.replication.tcp.streams._base import (
StickyEventsStream,
StickyEventsStreamRow,
)
from synapse.replication.tcp.streams.events import EventsStream
from synapse.replication.tcp.streams.partial_state import UnPartialStatedEventStreamRow
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
@ -463,6 +467,12 @@ class EventsWorkerStore(SQLBaseStore):
# If the partial-stated event became rejected or unrejected
# when it wasn't before, we need to invalidate this cache.
self._invalidate_local_get_event_cache(row.event_id)
elif stream_name == StickyEventsStream.NAME:
for row in rows:
assert isinstance(row, StickyEventsStreamRow)
if row.soft_failed_status == StickyEventSoftFailed.FORMER_TRUE:
# was soft-failed, now not, so invalidate caches
self._invalidate_local_get_event_cache(row.event_id)
super().process_replication_rows(stream_name, instance_name, token, rows)

View File

@ -28,7 +28,7 @@ from typing import (
from twisted.internet.defer import Deferred
from synapse import event_auth
from synapse.api.constants import EventTypes, StickyEvent
from synapse.api.constants import EventTypes, StickyEvent, StickyEventSoftFailed
from synapse.api.errors import AuthError
from synapse.events import EventBase
from synapse.events.snapshot import EventPersistencePair
@ -43,6 +43,7 @@ from synapse.storage.database import (
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.databases.main.events import DeltaState
from synapse.storage.databases.main.state import StateGroupWorkerStore
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types.state import StateFilter
from synapse.util.stringutils import shortstr
@ -170,15 +171,15 @@ class StickyEventsWorkerStore(StateGroupWorkerStore, CacheInvalidationWorkerStor
txn.execute(
f"""
SELECT stream_id, room_id, event_id FROM sticky_events
WHERE soft_failed=FALSE AND expires_at > ? AND stream_id > ? AND stream_id <= ? AND {clause}
WHERE soft_failed != ? AND expires_at > ? AND stream_id > ? AND stream_id <= ? AND {clause}
""",
(now, from_id, to_id, *room_id_values),
(StickyEventSoftFailed.TRUE, now, from_id, to_id, *room_id_values),
)
return cast(List[Tuple[int, str, str]], txn.fetchall())
async def get_updated_sticky_events(
self, from_id: int, to_id: int, limit: int
) -> List[Tuple[int, str, str]]:
) -> List[Tuple[int, str, str, StickyEventSoftFailed]]:
"""Get updates to sticky events between two stream IDs.
Args:
@ -199,14 +200,14 @@ class StickyEventsWorkerStore(StateGroupWorkerStore, CacheInvalidationWorkerStor
def _get_updated_sticky_events_txn(
self, txn: LoggingTransaction, from_id: int, to_id: int, limit: int
) -> List[Tuple[int, str, str]]:
) -> List[Tuple[int, str, str, StickyEventSoftFailed]]:
txn.execute(
"""
SELECT stream_id, room_id, event_id FROM sticky_events WHERE stream_id > ? AND stream_id <= ? LIMIT ?
SELECT stream_id, room_id, event_id, soft_failed FROM sticky_events WHERE stream_id > ? AND stream_id <= ? LIMIT ?
""",
(from_id, to_id, limit),
)
return cast(List[Tuple[int, str, str]], txn.fetchall())
return cast(List[Tuple[int, str, str, StickyEventSoftFailed]], txn.fetchall())
async def get_sticky_event_ids_sent_by_self(
self, room_id: str, from_stream_pos: int
@ -236,9 +237,9 @@ class StickyEventsWorkerStore(StateGroupWorkerStore, CacheInvalidationWorkerStor
"""
SELECT sticky_events.event_id, sticky_events.sender, events.stream_ordering FROM sticky_events
INNER JOIN events ON events.event_id = sticky_events.event_id
WHERE soft_failed=FALSE AND expires_at > ? AND sticky_events.room_id = ?
WHERE soft_failed=? AND expires_at > ? AND sticky_events.room_id = ?
""",
(now_ms, room_id),
(StickyEventSoftFailed.FALSE, now_ms, room_id),
)
rows = cast(List[Tuple[str, str, int]], txn.fetchall())
return [
@ -340,7 +341,9 @@ class StickyEventsWorkerStore(StateGroupWorkerStore, CacheInvalidationWorkerStor
ev.event_id,
ev.sender,
expires_at,
ev.internal_metadata.is_soft_failed(),
StickyEventSoftFailed.TRUE
if ev.internal_metadata.is_soft_failed()
else StickyEventSoftFailed.FALSE,
)
for (ev, expires_at, stream_id) in sticky_events
],
@ -425,7 +428,7 @@ class StickyEventsWorkerStore(StateGroupWorkerStore, CacheInvalidationWorkerStor
iterable=new_membership_changes,
keyvalues={
"room_id": room_id,
"soft_failed": True,
"soft_failed": StickyEventSoftFailed.TRUE,
},
retcols=("event_id",),
desc="_get_soft_failed_sticky_events_to_recheck_members",
@ -456,7 +459,7 @@ class StickyEventsWorkerStore(StateGroupWorkerStore, CacheInvalidationWorkerStor
table="sticky_events",
keyvalues={
"room_id": room_id,
"soft_failed": True,
"soft_failed": StickyEventSoftFailed.TRUE,
},
retcols=("event_id",),
desc="_get_soft_failed_sticky_events_to_recheck",
@ -531,39 +534,72 @@ class StickyEventsWorkerStore(StateGroupWorkerStore, CacheInvalidationWorkerStor
(event_id, self._sticky_events_id_gen.get_next_txn(txn))
for event_id in passing_event_ids
]
values_placeholders = ", ".join(["(?, ?)"] * len(new_stream_ids))
# [event_id, stream_pos, event_id, stream_pos, ...]
params = [p for pair in new_stream_ids for p in pair]
txn.execute(
f"""
UPDATE sticky_events AS se
SET
soft_failed = FALSE,
stream_id = v.stream_id
FROM (VALUES
{values_placeholders}
) AS v(event_id, stream_id)
WHERE se.event_id = v.event_id;
""",
params,
)
# Also update the internal metadata on the event itself, so when we filter_events_for_client
# we don't filter them out. It's a bit sad internal_metadata is TEXT and not JSONB...
clause, args = make_in_list_sql_clause(
txn.database_engine,
"event_id",
passing_event_ids,
)
txn.execute(
"""
UPDATE event_json
SET internal_metadata = (
jsonb_set(internal_metadata::jsonb, '{soft_failed}', 'false'::jsonb)
)::text
WHERE %s
"""
% clause,
args,
)
if isinstance(txn.database_engine, PostgresEngine):
values_placeholders = ", ".join(["(?, ?)"] * len(new_stream_ids))
txn.execute(
f"""
UPDATE sticky_events AS se
SET
soft_failed = ?,
stream_id = v.stream_id
FROM (VALUES
{values_placeholders}
) AS v(event_id, stream_id)
WHERE se.event_id = v.event_id;
""",
[StickyEventSoftFailed.FORMER_TRUE] + params,
)
# Also update the internal metadata on the event itself, so when we filter_events_for_client
# we don't filter them out. It's a bit sad internal_metadata is TEXT and not JSONB...
clause, args = make_in_list_sql_clause(
txn.database_engine,
"event_id",
passing_event_ids,
)
txn.execute(
"""
UPDATE event_json
SET internal_metadata = (
jsonb_set(internal_metadata::jsonb, '{soft_failed}', 'false'::jsonb)
)::text
WHERE %s
"""
% clause,
args,
)
else:
# Use a CASE expression to update in bulk for sqlite
case_expr = " ".join(["WHEN ? THEN ? " for _ in new_stream_ids])
txn.execute(
f"""
UPDATE sticky_events
SET
soft_failed = ?,
stream_id = CASE event_id
{case_expr}
ELSE stream_id
END
WHERE event_id IN ({",".join("?" * len(new_stream_ids))});
""",
[StickyEventSoftFailed.FORMER_TRUE]
+ params
+ [eid for eid, _ in new_stream_ids],
)
clause, args = make_in_list_sql_clause(
txn.database_engine,
"event_id",
passing_event_ids,
)
txn.execute(
f"""
UPDATE event_json
SET internal_metadata = json_set(internal_metadata, '$.soft_failed', json('false'))
WHERE {clause}
""",
args,
)
# finally, invalidate caches
for event_id in passing_event_ids:
self.invalidate_get_event_cache_after_txn(txn, event_id)

View File

@ -18,7 +18,11 @@ CREATE TABLE IF NOT EXISTS sticky_events(
event_id TEXT NOT NULL,
sender TEXT NOT NULL,
expires_at BIGINT NOT NULL,
soft_failed BOOLEAN NOT NULL
-- 0=False, 1=True, 2=False-but-was-True
-- We need '2' to handle cache invalidation downstream.
-- Receiving a sticky event replication row with '2' will cause get_event
-- caches to be invalidated, so the soft-failure status can change.
soft_failed SMALLINT NOT NULL
);
-- for pulling out soft failed events by room

View File

@ -15,4 +15,4 @@ CREATE SEQUENCE sticky_events_sequence;
-- Synapse streams start at 2, because the default position is 1
-- so any item inserted at position 1 is ignored.
-- We have to use nextval not START WITH 2, see https://github.com/element-hq/synapse/issues/18712
SELECT nextval('thread_subscriptions_sequence');
SELECT nextval('sticky_events_sequence');