Compare commits

...

2 Commits

Author SHA1 Message Date
Devon Hudson
24b38733df
Don't return empty fields in response 2025-10-02 17:23:30 -06:00
Devon Hudson
4602b56643
Stub in early db queries to get tests going 2025-10-02 17:11:14 -06:00
6 changed files with 225 additions and 15 deletions

View File

@ -109,8 +109,6 @@ class RelationsHandler:
) -> JsonDict:
"""Get related events of a event, ordered by topological ordering.
TODO Accept a PaginationConfig instead of individual pagination parameters.
Args:
requester: The user requesting the relations.
event_id: Fetch events that relate to this event ID.

View File

@ -61,6 +61,7 @@ _ThreadSubscription: TypeAlias = (
_ThreadUnsubscription: TypeAlias = (
SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadUnsubscription
)
_ThreadUpdate: TypeAlias = SlidingSyncResult.Extensions.ThreadsExtension.ThreadUpdate
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -1005,13 +1006,32 @@ class SlidingSyncExtensionHandler:
if not threads_request.enabled:
return None
limit = threads_request.limit
# TODO: is the `room_key` the right thing to use here?
# ie. does it translate into /relations
# TODO: use new function to get thread updates
updates, prev_batch = await self.store.get_thread_updates_for_user(
user_id=sync_config.user.to_string(),
from_token=from_token.stream_token if from_token else None,
to_token=to_token,
limit=limit,
)
if len(updates) == 0:
return None
# TODO: implement
_limit = threads_request.limit
prev_batch = None
thread_updates: Dict[str, Dict[str, _ThreadUpdate]] = {}
for thread_root_id, room_id in updates:
thread_updates.setdefault(room_id, {})[thread_root_id] = _ThreadUpdate(
thread_root=None,
prev_batch=None,
)
return SlidingSyncResult.Extensions.ThreadsExtension(
updates=None,
updates=thread_updates,
prev_batch=prev_batch,
)

View File

@ -1134,7 +1134,6 @@ def _serialise_thread_subscriptions(
return out
# TODO: is this necessary for serialization?
def _serialise_threads(
threads: SlidingSyncResult.Extensions.ThreadsExtension,
) -> JsonDict:
@ -1143,17 +1142,16 @@ def _serialise_threads(
if threads.updates:
out["updates"] = {
room_id: {
thread_root_id: {
"thread_root": update.thread_root,
"prev_batch": update.prev_batch,
}
thread_root_id: attr.asdict(
update, filter=lambda _attr, v: v is not None
)
for thread_root_id, update in thread_updates.items()
}
for room_id, thread_updates in threads.updates.items()
}
if threads.prev_batch:
out["prev_batch"] = threads.prev_batch.to_string()
out["prev_batch"] = str(threads.prev_batch)
return out

View File

@ -1118,6 +1118,97 @@ class RelationsWorkerStore(SQLBaseStore):
"get_related_thread_id", _get_related_thread_id
)
async def get_thread_updates_for_user(
self,
*,
user_id: str,
from_token: Optional[StreamToken] = None,
to_token: Optional[StreamToken] = None,
limit: int = 5,
) -> Tuple[Sequence[Tuple[str, str]], Optional[int]]:
# TODO: comment
"""Get a list of updates threads, ordered by stream ordering of their
latest reply.
Args:
user_id: Only fetch threads for rooms that the `user` is `join`ed to.
from_token: Fetch rows from a previous next_batch, or from the start if None.
to_token: Fetch rows from a previous prev_batch, or from the stream end if None.
limit: Only fetch the most recent `limit` threads.
Returns:
A tuple of:
A list of thread root event IDs.
The next_batch, if one exists.
"""
# Ensure bad limits aren't being passed in.
assert limit >= 0
# Generate the pagination clause, if necessary.
#
# Find any threads where the latest reply is between the stream ordering bounds.
pagination_clause = ""
pagination_args: List[str] = []
if from_token:
from_bound = from_token.room_key.stream
pagination_clause += " AND stream_ordering > ?"
pagination_args.append(str(from_bound))
if to_token:
to_bound = to_token.room_key.stream
pagination_clause += " AND stream_ordering <= ?"
pagination_args.append(str(to_bound))
# TODO: get room_ids somehow...
# seems inefficient as we have to basically query for every single joined room
# id don't we?
# How would a specific thread_updates table be any better?
# There must be something somewhere that already does a query which has a
# "filter by all rooms that a user is joined to" clause.
sql = f"""
SELECT thread_id, room_id, latest_event_id, stream_ordering
FROM threads
WHERE
room_id LIKE ?
{pagination_clause}
ORDER BY stream_ordering DESC
LIMIT ?
"""
# sql = """
# SELECT event_id, relation_type, sender, topological_ordering, stream_ordering
# FROM event_relations
# INNER JOIN events USING (event_id)
# WHERE relates_to_id = ? AND %s
# ORDER BY topological_ordering %s, stream_ordering %s
# LIMIT ?
# """ % (
# " AND ".join(where_clause),
# order,
# order,
# )
def _get_thread_updates_for_user_txn(
txn: LoggingTransaction,
) -> Tuple[List[Tuple[str, str]], Optional[int]]:
txn.execute(sql, ("%", *pagination_args, limit + 1))
rows = cast(List[Tuple[str, str, str, int]], txn.fetchall())
thread_ids = [(r[0], r[1]) for r in rows]
# If there are more events, generate the next pagination key from the
# last thread which will be returned.
next_token = None
if len(thread_ids) > limit:
# TODO: why -2?
next_token = rows[-2][3]
return thread_ids[:limit], next_token
return await self.db_pool.runInteraction(
"get_thread_updates_for_user", _get_thread_updates_for_user_txn
)
class RelationsStore(RelationsWorkerStore):
pass

View File

@ -405,15 +405,18 @@ class SlidingSyncResult:
"""
@attr.s(slots=True, frozen=True, auto_attribs=True)
class ThreadUpdates:
class ThreadUpdate:
# TODO: comment
thread_root: Optional[EventBase]
# TODO: comment
prev_batch: Optional[StreamToken]
updates: Optional[Mapping[str, Mapping[str, ThreadUpdates]]]
prev_batch: Optional[ThreadSubscriptionsToken]
def __bool__(self) -> bool:
return bool(self.thread_root) or bool(self.prev_batch)
updates: Optional[Mapping[str, Mapping[str, ThreadUpdate]]]
prev_batch: Optional[int]
def __bool__(self) -> bool:
return bool(self.updates) or bool(self.prev_batch)

View File

@ -16,6 +16,7 @@ import logging
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.api.constants import RelationTypes
from synapse.rest.client import login, room, sync
from synapse.server import HomeServer
from synapse.types import JsonDict
@ -107,3 +108,102 @@ class SlidingSyncThreadsExtensionTestCase(SlidingSyncBase):
response_body["extensions"],
response_body,
)
def test_threads_initial_sync(self) -> None:
"""
Test threads appear in initial sync response.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
thread_root_resp = self.helper.send(room_id, body="Thread root", tok=user1_tok)
thread_root_id = thread_root_resp["event_id"]
_latest_event_id = self.helper.send_event(
room_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": user1_id,
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread_root_id,
},
},
tok=user1_tok,
)["event_id"]
# # get the baseline stream_id of the thread_subscriptions stream
# # before we write any data.
# # Required because the initial value differs between SQLite and Postgres.
# base = self.store.get_max_thread_subscriptions_stream_id()
sync_body = {
"lists": {},
"extensions": {
EXT_NAME: {
"enabled": True,
}
},
}
# Sync
response_body, _ = self.do_sync(sync_body, tok=user1_tok)
# Assert
self.assertEqual(
response_body["extensions"][EXT_NAME],
{"updates": {room_id: {thread_root_id: {}}}},
)
def test_threads_incremental_sync(self) -> None:
"""
Test new thread updates appear in incremental sync response.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
sync_body = {
"lists": {},
"extensions": {
EXT_NAME: {
"enabled": True,
}
},
}
thread_root_resp = self.helper.send(room_id, body="Thread root", tok=user1_tok)
thread_root_id = thread_root_resp["event_id"]
# get the baseline stream_id of the room events stream
# before we write any data.
# Required because the initial value differs between SQLite and Postgres.
# base = self.store.get_room_max_stream_ordering()
# Initial sync
_, sync_pos = self.do_sync(sync_body, tok=user1_tok)
logger.info("Synced to: %r, now subscribing to thread", sync_pos)
# Do thing
_latest_event_id = self.helper.send_event(
room_id,
type="m.room.message",
content={
"msgtype": "m.text",
"body": user1_id,
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread_root_id,
},
},
tok=user1_tok,
)["event_id"]
# Incremental sync
response_body, sync_pos = self.do_sync(sync_body, tok=user1_tok, since=sync_pos)
logger.info("Synced to: %r", sync_pos)
# Assert
self.assertEqual(
response_body["extensions"][EXT_NAME],
{"updates": {room_id: {thread_root_id: {}}}},
)