Compare commits

...

5 Commits

Author SHA1 Message Date
Kegan Dougal
aac3c846a8 Use a tri-state for soft failed to communicate when we need to cache invalidate 2025-10-02 16:47:45 +01:00
Kegan Dougal
888ab79b3b Add NewServerJoined replication command
Currently emits on joins and logs on receive, WIP.
2025-10-02 15:25:47 +01:00
Kegan Dougal
aa45bf7c3a Add msc4354 to /versions response 2025-10-02 10:46:44 +01:00
Kegan Dougal
15453d4e6e JSON false not str 2025-10-02 09:30:21 +01:00
Kegan Dougal
78c40973f4 SQLite specific soft-failure update code 2025-10-02 09:26:33 +01:00
14 changed files with 185 additions and 58 deletions

View File

@ -372,3 +372,10 @@ class StickyEvent:
QUERY_PARAM_NAME: Final = "org.matrix.msc4354.sticky_duration_ms" QUERY_PARAM_NAME: Final = "org.matrix.msc4354.sticky_duration_ms"
FIELD_NAME: Final = "msc4354_sticky" FIELD_NAME: Final = "msc4354_sticky"
MAX_DURATION_MS: Final = 3600000 # 1 hour MAX_DURATION_MS: Final = 3600000 # 1 hour
# for the database
class StickyEventSoftFailed(enum.IntEnum):
FALSE = 0
TRUE = 1
FORMER_TRUE = 2

View File

@ -213,6 +213,11 @@ class FederationRemoteSendQueue(AbstractFederationSender):
# This should never get called. # This should never get called.
raise NotImplementedError() raise NotImplementedError()
def notify_new_server_joined(self, server: str, room_id: str) -> None:
"""As per FederationSender"""
# This should never get called.
raise NotImplementedError()
def build_and_send_edu( def build_and_send_edu(
self, self,
destination: str, destination: str,

View File

@ -239,6 +239,13 @@ class AbstractFederationSender(metaclass=abc.ABCMeta):
""" """
raise NotImplementedError() raise NotImplementedError()
@abc.abstractmethod
def notify_new_server_joined(self, server: str, room_id: str) -> None:
"""This gets called when we a new server has joined a room. We might
want to send out some events to this server.
"""
raise NotImplementedError()
@abc.abstractmethod @abc.abstractmethod
async def send_read_receipt(self, receipt: ReadReceipt) -> None: async def send_read_receipt(self, receipt: ReadReceipt) -> None:
"""Send a RR to any other servers in the room """Send a RR to any other servers in the room
@ -488,6 +495,9 @@ class FederationSender(AbstractFederationSender):
self._per_destination_queues[destination] = queue self._per_destination_queues[destination] = queue
return queue return queue
def notify_new_server_joined(self, server: str, room_id: str) -> None:
print(f"FEDSENDER: new server joined: server={server} room={room_id}")
def notify_new_events(self, max_token: RoomStreamToken) -> None: def notify_new_events(self, max_token: RoomStreamToken) -> None:
"""This gets called when we have some new events we might want to """This gets called when we have some new events we might want to
send out to other servers. send out to other servers.

View File

@ -933,6 +933,11 @@ class Notifier:
# that any in flight requests can be immediately retried. # that any in flight requests can be immediately retried.
self._federation_client.wake_destination(server) self._federation_client.wake_destination(server)
def notify_new_server_joined(self, server: str, room_id: str) -> None:
# Inform the federation_sender that it may need to send events to the new server.
if self.federation_sender:
self.federation_sender.notify_new_server_joined(server, room_id)
def add_lock_released_callback( def add_lock_released_callback(
self, callback: Callable[[str, str, str], None] self, callback: Callable[[str, str, str], None]
) -> None: ) -> None:

View File

@ -266,7 +266,6 @@ class ReplicationDataHandler:
users=[row.user_id for row in rows], users=[row.user_id for row in rows],
) )
elif stream_name == StickyEventsStream.NAME: elif stream_name == StickyEventsStream.NAME:
print(f"STICKY_EVENTS on_rdata {token} => {rows}")
self.notifier.on_new_event( self.notifier.on_new_event(
StreamKeyType.STICKY_EVENTS, StreamKeyType.STICKY_EVENTS,
token, token,

View File

@ -462,6 +462,32 @@ class RemoteServerUpCommand(_SimpleCommand):
NAME = "REMOTE_SERVER_UP" NAME = "REMOTE_SERVER_UP"
class NewServerJoinedCommand(Command):
"""Sent when a worker has detected that a new remote server has joined a room.
Format::
NEW_SERVER_JOINED <server> <room_id>
"""
NAME = "NEW_SERVER_JOINED"
__slots__ = ["server", "room_id"]
def __init__(self, server: str, room_id: str):
self.server = server
self.room_id = room_id
@classmethod
def from_line(
cls: Type["NewServerJoinedCommand"], line: str
) -> "NewServerJoinedCommand":
server, room_id = line.split(" ")
return cls(server, room_id)
def to_line(self) -> str:
return "%s %s" % (self.server, self.room_id)
class LockReleasedCommand(Command): class LockReleasedCommand(Command):
"""Sent to inform other instances that a given lock has been dropped. """Sent to inform other instances that a given lock has been dropped.
@ -517,6 +543,7 @@ _COMMANDS: Tuple[Type[Command], ...] = (
FederationAckCommand, FederationAckCommand,
UserIpCommand, UserIpCommand,
RemoteServerUpCommand, RemoteServerUpCommand,
NewServerJoinedCommand,
ClearUserSyncsCommand, ClearUserSyncsCommand,
LockReleasedCommand, LockReleasedCommand,
NewActiveTaskCommand, NewActiveTaskCommand,
@ -533,6 +560,7 @@ VALID_SERVER_COMMANDS = (
ErrorCommand.NAME, ErrorCommand.NAME,
PingCommand.NAME, PingCommand.NAME,
RemoteServerUpCommand.NAME, RemoteServerUpCommand.NAME,
NewServerJoinedCommand.NAME,
LockReleasedCommand.NAME, LockReleasedCommand.NAME,
) )
@ -547,6 +575,7 @@ VALID_CLIENT_COMMANDS = (
UserIpCommand.NAME, UserIpCommand.NAME,
ErrorCommand.NAME, ErrorCommand.NAME,
RemoteServerUpCommand.NAME, RemoteServerUpCommand.NAME,
NewServerJoinedCommand.NAME,
LockReleasedCommand.NAME, LockReleasedCommand.NAME,
) )

View File

@ -48,6 +48,7 @@ from synapse.replication.tcp.commands import (
FederationAckCommand, FederationAckCommand,
LockReleasedCommand, LockReleasedCommand,
NewActiveTaskCommand, NewActiveTaskCommand,
NewServerJoinedCommand,
PositionCommand, PositionCommand,
RdataCommand, RdataCommand,
RemoteServerUpCommand, RemoteServerUpCommand,
@ -764,6 +765,12 @@ class ReplicationCommandHandler:
"""Called when get a new REMOTE_SERVER_UP command.""" """Called when get a new REMOTE_SERVER_UP command."""
self._notifier.notify_remote_server_up(cmd.data) self._notifier.notify_remote_server_up(cmd.data)
def on_NEW_SERVER_JOINED(
self, conn: IReplicationConnection, cmd: NewServerJoinedCommand
) -> None:
"""Called when get a new NEW_SERVER_JOINED command."""
self._notifier.notify_new_server_joined(cmd.server, cmd.room_id)
def on_LOCK_RELEASED( def on_LOCK_RELEASED(
self, conn: IReplicationConnection, cmd: LockReleasedCommand self, conn: IReplicationConnection, cmd: LockReleasedCommand
) -> None: ) -> None:
@ -886,6 +893,9 @@ class ReplicationCommandHandler:
def send_remote_server_up(self, server: str) -> None: def send_remote_server_up(self, server: str) -> None:
self.send_command(RemoteServerUpCommand(server)) self.send_command(RemoteServerUpCommand(server))
def send_new_server_joined(self, server: str, room_id: str) -> None:
self.send_command(NewServerJoinedCommand(server, room_id))
def stream_update(self, stream_name: str, token: Optional[int], data: Any) -> None: def stream_update(self, stream_name: str, token: Optional[int], data: Any) -> None:
"""Called when a new update is available to stream to Redis subscribers. """Called when a new update is available to stream to Redis subscribers.

View File

@ -34,7 +34,7 @@ from typing import (
import attr import attr
from synapse.api.constants import AccountDataTypes from synapse.api.constants import AccountDataTypes, StickyEventSoftFailed
from synapse.replication.http.streams import ReplicationGetStreamUpdates from synapse.replication.http.streams import ReplicationGetStreamUpdates
if TYPE_CHECKING: if TYPE_CHECKING:
@ -768,16 +768,18 @@ class ThreadSubscriptionsStream(_StreamFromIdGen):
return rows, rows[-1][0], len(updates) == limit return rows, rows[-1][0], len(updates) == limit
@attr.s(slots=True, auto_attribs=True)
class StickyEventsStreamRow:
"""Stream to inform workers about changes to sticky events."""
room_id: str
event_id: str # The sticky event ID
soft_failed_status: StickyEventSoftFailed
class StickyEventsStream(_StreamFromIdGen): class StickyEventsStream(_StreamFromIdGen):
"""A sticky event was changed.""" """A sticky event was changed."""
@attr.s(slots=True, auto_attribs=True)
class StickyEventsStreamRow:
"""Stream to inform workers about changes to sticky events."""
room_id: str
event_id: str # The sticky event ID
NAME = "sticky_events" NAME = "sticky_events"
ROW_TYPE = StickyEventsStreamRow ROW_TYPE = StickyEventsStreamRow
@ -799,9 +801,9 @@ class StickyEventsStream(_StreamFromIdGen):
( (
stream_id, stream_id,
# These are the args to `StickyEventsStreamRow` # These are the args to `StickyEventsStreamRow`
(room_id, event_id), (room_id, event_id, soft_failed),
) )
for stream_id, room_id, event_id in updates for stream_id, room_id, event_id, soft_failed in updates
] ]
if not rows: if not rows:

View File

@ -182,6 +182,8 @@ class VersionsRestServlet(RestServlet):
"org.matrix.msc4306": self.config.experimental.msc4306_enabled, "org.matrix.msc4306": self.config.experimental.msc4306_enabled,
# MSC4169: Backwards-compatible redaction sending using `/send` # MSC4169: Backwards-compatible redaction sending using `/send`
"com.beeper.msc4169": self.config.experimental.msc4169_enabled, "com.beeper.msc4169": self.config.experimental.msc4169_enabled,
# MSC4354: Sticky events
"org.matrix.msc4354": self.config.experimental.msc4354_enabled,
}, },
}, },
) )

View File

@ -1188,6 +1188,14 @@ class PersistEventsStore:
if self.msc4354_sticky_events: if self.msc4354_sticky_events:
self.store.insert_sticky_events_txn(txn, events_and_contexts) self.store.insert_sticky_events_txn(txn, events_and_contexts)
for ev, _ in events_and_contexts:
if ev.type == "m.room.member" and ev.membership == "join":
print(f"GOT JOIN FOR {ev.state_key}")
domain = get_domain_from_id(ev.state_key)
self.hs.get_notifier().notify_new_server_joined(domain, ev.room_id)
self.hs.get_replication_command_handler().send_new_server_joined(
domain, ev.room_id
)
# We only update the sliding sync tables for non-backfilled events. # We only update the sliding sync tables for non-backfilled events.
self._update_sliding_sync_tables_with_new_persisted_events_txn( self._update_sliding_sync_tables_with_new_persisted_events_txn(

View File

@ -45,7 +45,7 @@ from prometheus_client import Gauge
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import Direction, EventTypes from synapse.api.constants import Direction, EventTypes, StickyEventSoftFailed
from synapse.api.errors import NotFoundError, SynapseError from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.room_versions import ( from synapse.api.room_versions import (
KNOWN_ROOM_VERSIONS, KNOWN_ROOM_VERSIONS,
@ -74,6 +74,10 @@ from synapse.metrics.background_process_metrics import (
wrap_as_background_process, wrap_as_background_process,
) )
from synapse.replication.tcp.streams import BackfillStream, UnPartialStatedEventStream from synapse.replication.tcp.streams import BackfillStream, UnPartialStatedEventStream
from synapse.replication.tcp.streams._base import (
StickyEventsStream,
StickyEventsStreamRow,
)
from synapse.replication.tcp.streams.events import EventsStream from synapse.replication.tcp.streams.events import EventsStream
from synapse.replication.tcp.streams.partial_state import UnPartialStatedEventStreamRow from synapse.replication.tcp.streams.partial_state import UnPartialStatedEventStreamRow
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
@ -463,6 +467,12 @@ class EventsWorkerStore(SQLBaseStore):
# If the partial-stated event became rejected or unrejected # If the partial-stated event became rejected or unrejected
# when it wasn't before, we need to invalidate this cache. # when it wasn't before, we need to invalidate this cache.
self._invalidate_local_get_event_cache(row.event_id) self._invalidate_local_get_event_cache(row.event_id)
elif stream_name == StickyEventsStream.NAME:
for row in rows:
assert isinstance(row, StickyEventsStreamRow)
if row.soft_failed_status == StickyEventSoftFailed.FORMER_TRUE:
# was soft-failed, now not, so invalidate caches
self._invalidate_local_get_event_cache(row.event_id)
super().process_replication_rows(stream_name, instance_name, token, rows) super().process_replication_rows(stream_name, instance_name, token, rows)

View File

@ -28,7 +28,7 @@ from typing import (
from twisted.internet.defer import Deferred from twisted.internet.defer import Deferred
from synapse import event_auth from synapse import event_auth
from synapse.api.constants import EventTypes, StickyEvent from synapse.api.constants import EventTypes, StickyEvent, StickyEventSoftFailed
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.snapshot import EventPersistencePair from synapse.events.snapshot import EventPersistencePair
@ -43,6 +43,7 @@ from synapse.storage.database import (
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.databases.main.events import DeltaState from synapse.storage.databases.main.events import DeltaState
from synapse.storage.databases.main.state import StateGroupWorkerStore from synapse.storage.databases.main.state import StateGroupWorkerStore
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types.state import StateFilter from synapse.types.state import StateFilter
from synapse.util.stringutils import shortstr from synapse.util.stringutils import shortstr
@ -170,15 +171,15 @@ class StickyEventsWorkerStore(StateGroupWorkerStore, CacheInvalidationWorkerStor
txn.execute( txn.execute(
f""" f"""
SELECT stream_id, room_id, event_id FROM sticky_events SELECT stream_id, room_id, event_id FROM sticky_events
WHERE soft_failed=FALSE AND expires_at > ? AND stream_id > ? AND stream_id <= ? AND {clause} WHERE soft_failed != ? AND expires_at > ? AND stream_id > ? AND stream_id <= ? AND {clause}
""", """,
(now, from_id, to_id, *room_id_values), (StickyEventSoftFailed.TRUE, now, from_id, to_id, *room_id_values),
) )
return cast(List[Tuple[int, str, str]], txn.fetchall()) return cast(List[Tuple[int, str, str]], txn.fetchall())
async def get_updated_sticky_events( async def get_updated_sticky_events(
self, from_id: int, to_id: int, limit: int self, from_id: int, to_id: int, limit: int
) -> List[Tuple[int, str, str]]: ) -> List[Tuple[int, str, str, StickyEventSoftFailed]]:
"""Get updates to sticky events between two stream IDs. """Get updates to sticky events between two stream IDs.
Args: Args:
@ -199,14 +200,14 @@ class StickyEventsWorkerStore(StateGroupWorkerStore, CacheInvalidationWorkerStor
def _get_updated_sticky_events_txn( def _get_updated_sticky_events_txn(
self, txn: LoggingTransaction, from_id: int, to_id: int, limit: int self, txn: LoggingTransaction, from_id: int, to_id: int, limit: int
) -> List[Tuple[int, str, str]]: ) -> List[Tuple[int, str, str, StickyEventSoftFailed]]:
txn.execute( txn.execute(
""" """
SELECT stream_id, room_id, event_id FROM sticky_events WHERE stream_id > ? AND stream_id <= ? LIMIT ? SELECT stream_id, room_id, event_id, soft_failed FROM sticky_events WHERE stream_id > ? AND stream_id <= ? LIMIT ?
""", """,
(from_id, to_id, limit), (from_id, to_id, limit),
) )
return cast(List[Tuple[int, str, str]], txn.fetchall()) return cast(List[Tuple[int, str, str, StickyEventSoftFailed]], txn.fetchall())
async def get_sticky_event_ids_sent_by_self( async def get_sticky_event_ids_sent_by_self(
self, room_id: str, from_stream_pos: int self, room_id: str, from_stream_pos: int
@ -236,9 +237,9 @@ class StickyEventsWorkerStore(StateGroupWorkerStore, CacheInvalidationWorkerStor
""" """
SELECT sticky_events.event_id, sticky_events.sender, events.stream_ordering FROM sticky_events SELECT sticky_events.event_id, sticky_events.sender, events.stream_ordering FROM sticky_events
INNER JOIN events ON events.event_id = sticky_events.event_id INNER JOIN events ON events.event_id = sticky_events.event_id
WHERE soft_failed=FALSE AND expires_at > ? AND sticky_events.room_id = ? WHERE soft_failed=? AND expires_at > ? AND sticky_events.room_id = ?
""", """,
(now_ms, room_id), (StickyEventSoftFailed.FALSE, now_ms, room_id),
) )
rows = cast(List[Tuple[str, str, int]], txn.fetchall()) rows = cast(List[Tuple[str, str, int]], txn.fetchall())
return [ return [
@ -340,7 +341,9 @@ class StickyEventsWorkerStore(StateGroupWorkerStore, CacheInvalidationWorkerStor
ev.event_id, ev.event_id,
ev.sender, ev.sender,
expires_at, expires_at,
ev.internal_metadata.is_soft_failed(), StickyEventSoftFailed.TRUE
if ev.internal_metadata.is_soft_failed()
else StickyEventSoftFailed.FALSE,
) )
for (ev, expires_at, stream_id) in sticky_events for (ev, expires_at, stream_id) in sticky_events
], ],
@ -425,7 +428,7 @@ class StickyEventsWorkerStore(StateGroupWorkerStore, CacheInvalidationWorkerStor
iterable=new_membership_changes, iterable=new_membership_changes,
keyvalues={ keyvalues={
"room_id": room_id, "room_id": room_id,
"soft_failed": True, "soft_failed": StickyEventSoftFailed.TRUE,
}, },
retcols=("event_id",), retcols=("event_id",),
desc="_get_soft_failed_sticky_events_to_recheck_members", desc="_get_soft_failed_sticky_events_to_recheck_members",
@ -456,7 +459,7 @@ class StickyEventsWorkerStore(StateGroupWorkerStore, CacheInvalidationWorkerStor
table="sticky_events", table="sticky_events",
keyvalues={ keyvalues={
"room_id": room_id, "room_id": room_id,
"soft_failed": True, "soft_failed": StickyEventSoftFailed.TRUE,
}, },
retcols=("event_id",), retcols=("event_id",),
desc="_get_soft_failed_sticky_events_to_recheck", desc="_get_soft_failed_sticky_events_to_recheck",
@ -531,39 +534,72 @@ class StickyEventsWorkerStore(StateGroupWorkerStore, CacheInvalidationWorkerStor
(event_id, self._sticky_events_id_gen.get_next_txn(txn)) (event_id, self._sticky_events_id_gen.get_next_txn(txn))
for event_id in passing_event_ids for event_id in passing_event_ids
] ]
values_placeholders = ", ".join(["(?, ?)"] * len(new_stream_ids)) # [event_id, stream_pos, event_id, stream_pos, ...]
params = [p for pair in new_stream_ids for p in pair] params = [p for pair in new_stream_ids for p in pair]
txn.execute( if isinstance(txn.database_engine, PostgresEngine):
f""" values_placeholders = ", ".join(["(?, ?)"] * len(new_stream_ids))
UPDATE sticky_events AS se txn.execute(
SET f"""
soft_failed = FALSE, UPDATE sticky_events AS se
stream_id = v.stream_id SET
FROM (VALUES soft_failed = ?,
{values_placeholders} stream_id = v.stream_id
) AS v(event_id, stream_id) FROM (VALUES
WHERE se.event_id = v.event_id; {values_placeholders}
""", ) AS v(event_id, stream_id)
params, WHERE se.event_id = v.event_id;
) """,
# Also update the internal metadata on the event itself, so when we filter_events_for_client [StickyEventSoftFailed.FORMER_TRUE] + params,
# we don't filter them out. It's a bit sad internal_metadata is TEXT and not JSONB... )
clause, args = make_in_list_sql_clause( # Also update the internal metadata on the event itself, so when we filter_events_for_client
txn.database_engine, # we don't filter them out. It's a bit sad internal_metadata is TEXT and not JSONB...
"event_id", clause, args = make_in_list_sql_clause(
passing_event_ids, txn.database_engine,
) "event_id",
txn.execute( passing_event_ids,
""" )
UPDATE event_json txn.execute(
SET internal_metadata = ( """
jsonb_set(internal_metadata::jsonb, '{soft_failed}', 'false'::jsonb) UPDATE event_json
)::text SET internal_metadata = (
WHERE %s jsonb_set(internal_metadata::jsonb, '{soft_failed}', 'false'::jsonb)
""" )::text
% clause, WHERE %s
args, """
) % clause,
args,
)
else:
# Use a CASE expression to update in bulk for sqlite
case_expr = " ".join(["WHEN ? THEN ? " for _ in new_stream_ids])
txn.execute(
f"""
UPDATE sticky_events
SET
soft_failed = ?,
stream_id = CASE event_id
{case_expr}
ELSE stream_id
END
WHERE event_id IN ({",".join("?" * len(new_stream_ids))});
""",
[StickyEventSoftFailed.FORMER_TRUE]
+ params
+ [eid for eid, _ in new_stream_ids],
)
clause, args = make_in_list_sql_clause(
txn.database_engine,
"event_id",
passing_event_ids,
)
txn.execute(
f"""
UPDATE event_json
SET internal_metadata = json_set(internal_metadata, '$.soft_failed', json('false'))
WHERE {clause}
""",
args,
)
# finally, invalidate caches # finally, invalidate caches
for event_id in passing_event_ids: for event_id in passing_event_ids:
self.invalidate_get_event_cache_after_txn(txn, event_id) self.invalidate_get_event_cache_after_txn(txn, event_id)

View File

@ -18,7 +18,11 @@ CREATE TABLE IF NOT EXISTS sticky_events(
event_id TEXT NOT NULL, event_id TEXT NOT NULL,
sender TEXT NOT NULL, sender TEXT NOT NULL,
expires_at BIGINT NOT NULL, expires_at BIGINT NOT NULL,
soft_failed BOOLEAN NOT NULL -- 0=False, 1=True, 2=False-but-was-True
-- We need '2' to handle cache invalidation downstream.
-- Receiving a sticky event replication row with '2' will cause get_event
-- caches to be invalidated, so the soft-failure status can change.
soft_failed SMALLINT NOT NULL
); );
-- for pulling out soft failed events by room -- for pulling out soft failed events by room

View File

@ -15,4 +15,4 @@ CREATE SEQUENCE sticky_events_sequence;
-- Synapse streams start at 2, because the default position is 1 -- Synapse streams start at 2, because the default position is 1
-- so any item inserted at position 1 is ignored. -- so any item inserted at position 1 is ignored.
-- We have to use nextval not START WITH 2, see https://github.com/element-hq/synapse/issues/18712 -- We have to use nextval not START WITH 2, see https://github.com/element-hq/synapse/issues/18712
SELECT nextval('thread_subscriptions_sequence'); SELECT nextval('sticky_events_sequence');