mirror of
https://github.com/element-hq/synapse.git
synced 2025-11-09 00:02:41 -05:00
Dedicated MAS API (#18520)
This introduces a dedicated API for MAS to consume. Companion PR on the MAS side: element-hq/matrix-authentication-service#4801 This has a few advantages over the previous admin API: - it works on workers (this will be documented once we stabilise MSC3861 as a whole) - it is more efficient because more focused - it propagates trace contexts from MAS - it is only accessible to MAS (through the shared secret) and will let us remove the weird hack that made this token 'admin' with a ghost '@__oidc_admin:' user The next MAS version should support it, but will be opt-in. The version after that should use this new API by default --------- Co-authored-by: Eric Eastwood <erice@element.io>
This commit is contained in:
parent
875269eb53
commit
8a4e2e826d
1
changelog.d/18520.misc
Normal file
1
changelog.d/18520.misc
Normal file
@ -0,0 +1 @@
|
||||
Dedicated internal API for Matrix Authentication Service to Synapse communication.
|
||||
@ -48,6 +48,7 @@ if TYPE_CHECKING or HAS_PYDANTIC_V2:
|
||||
conint,
|
||||
constr,
|
||||
parse_obj_as,
|
||||
root_validator,
|
||||
validator,
|
||||
)
|
||||
from pydantic.v1.error_wrappers import ErrorWrapper
|
||||
@ -68,6 +69,7 @@ else:
|
||||
conint,
|
||||
constr,
|
||||
parse_obj_as,
|
||||
root_validator,
|
||||
validator,
|
||||
)
|
||||
from pydantic.error_wrappers import ErrorWrapper
|
||||
@ -92,4 +94,5 @@ __all__ = (
|
||||
"StrictStr",
|
||||
"ValidationError",
|
||||
"validator",
|
||||
"root_validator",
|
||||
)
|
||||
|
||||
@ -369,6 +369,12 @@ class MSC3861DelegatedAuth(BaseAuth):
|
||||
async def is_server_admin(self, requester: Requester) -> bool:
|
||||
return "urn:synapse:admin:*" in requester.scope
|
||||
|
||||
def _is_access_token_the_admin_token(self, token: str) -> bool:
|
||||
admin_token = self._admin_token()
|
||||
if admin_token is None:
|
||||
return False
|
||||
return token == admin_token
|
||||
|
||||
async def get_user_by_req(
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
@ -434,7 +440,7 @@ class MSC3861DelegatedAuth(BaseAuth):
|
||||
requester = await self.get_user_by_access_token(access_token, allow_expired)
|
||||
|
||||
# Do not record requests from MAS using the virtual `__oidc_admin` user.
|
||||
if access_token != self._admin_token():
|
||||
if not self._is_access_token_the_admin_token(access_token):
|
||||
await self._record_request(request, requester)
|
||||
|
||||
if not allow_guest and requester.is_guest:
|
||||
@ -470,13 +476,25 @@ class MSC3861DelegatedAuth(BaseAuth):
|
||||
|
||||
raise UnrecognizedRequestError(code=404)
|
||||
|
||||
def is_request_using_the_admin_token(self, request: SynapseRequest) -> bool:
|
||||
"""
|
||||
Check if the request is using the admin token.
|
||||
|
||||
Args:
|
||||
request: The request to check.
|
||||
|
||||
Returns:
|
||||
True if the request is using the admin token, False otherwise.
|
||||
"""
|
||||
access_token = self.get_access_token_from_request(request)
|
||||
return self._is_access_token_the_admin_token(access_token)
|
||||
|
||||
async def get_user_by_access_token(
|
||||
self,
|
||||
token: str,
|
||||
allow_expired: bool = False,
|
||||
) -> Requester:
|
||||
admin_token = self._admin_token()
|
||||
if admin_token is not None and token == admin_token:
|
||||
if self._is_access_token_the_admin_token(token):
|
||||
# XXX: This is a temporary solution so that the admin API can be called by
|
||||
# the OIDC provider. This will be removed once we have OIDC client
|
||||
# credentials grant support in matrix-authentication-service.
|
||||
|
||||
@ -30,6 +30,7 @@ from synapse.rest.synapse.client.pick_username import pick_username_resource
|
||||
from synapse.rest.synapse.client.rendezvous import MSC4108RendezvousSessionResource
|
||||
from synapse.rest.synapse.client.sso_register import SsoRegisterResource
|
||||
from synapse.rest.synapse.client.unsubscribe import UnsubscribeResource
|
||||
from synapse.rest.synapse.mas import MasResource
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
@ -60,6 +61,7 @@ def build_synapse_client_resource_tree(hs: "HomeServer") -> Mapping[str, Resourc
|
||||
from synapse.rest.synapse.client.jwks import JwksResource
|
||||
|
||||
resources["/_synapse/jwks"] = JwksResource(hs)
|
||||
resources["/_synapse/mas"] = MasResource(hs)
|
||||
|
||||
# provider-specific SSO bits. Only load these if they are enabled, since they
|
||||
# rely on optional dependencies.
|
||||
|
||||
71
synapse/rest/synapse/mas/__init__.py
Normal file
71
synapse/rest/synapse/mas/__init__.py
Normal file
@ -0,0 +1,71 @@
|
||||
#
|
||||
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
#
|
||||
# Copyright (C) 2025 New Vector, Ltd
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
# published by the Free Software Foundation, either version 3 of the
|
||||
# License, or (at your option) any later version.
|
||||
#
|
||||
# See the GNU Affero General Public License for more details:
|
||||
# <https://www.gnu.org/licenses/agpl_3.0.html>.
|
||||
#
|
||||
#
|
||||
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from twisted.web.resource import Resource
|
||||
|
||||
from synapse.rest.synapse.mas.devices import (
|
||||
MasDeleteDeviceResource,
|
||||
MasSyncDevicesResource,
|
||||
MasUpdateDeviceDisplayNameResource,
|
||||
MasUpsertDeviceResource,
|
||||
)
|
||||
from synapse.rest.synapse.mas.users import (
|
||||
MasAllowCrossSigningResetResource,
|
||||
MasDeleteUserResource,
|
||||
MasIsLocalpartAvailableResource,
|
||||
MasProvisionUserResource,
|
||||
MasQueryUserResource,
|
||||
MasReactivateUserResource,
|
||||
MasSetDisplayNameResource,
|
||||
MasUnsetDisplayNameResource,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MasResource(Resource):
|
||||
"""
|
||||
Provides endpoints for MAS to manage user accounts and devices.
|
||||
|
||||
All endpoints are mounted under the path `/_synapse/mas/` and only work
|
||||
using the MAS admin token.
|
||||
"""
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
Resource.__init__(self)
|
||||
self.putChild(b"query_user", MasQueryUserResource(hs))
|
||||
self.putChild(b"provision_user", MasProvisionUserResource(hs))
|
||||
self.putChild(b"is_localpart_available", MasIsLocalpartAvailableResource(hs))
|
||||
self.putChild(b"delete_user", MasDeleteUserResource(hs))
|
||||
self.putChild(b"upsert_device", MasUpsertDeviceResource(hs))
|
||||
self.putChild(b"delete_device", MasDeleteDeviceResource(hs))
|
||||
self.putChild(
|
||||
b"update_device_display_name", MasUpdateDeviceDisplayNameResource(hs)
|
||||
)
|
||||
self.putChild(b"sync_devices", MasSyncDevicesResource(hs))
|
||||
self.putChild(b"reactivate_user", MasReactivateUserResource(hs))
|
||||
self.putChild(b"set_displayname", MasSetDisplayNameResource(hs))
|
||||
self.putChild(b"unset_displayname", MasUnsetDisplayNameResource(hs))
|
||||
self.putChild(
|
||||
b"allow_cross_signing_reset", MasAllowCrossSigningResetResource(hs)
|
||||
)
|
||||
47
synapse/rest/synapse/mas/_base.py
Normal file
47
synapse/rest/synapse/mas/_base.py
Normal file
@ -0,0 +1,47 @@
|
||||
#
|
||||
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
#
|
||||
# Copyright (C) 2025 New Vector, Ltd
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
# published by the Free Software Foundation, either version 3 of the
|
||||
# License, or (at your option) any later version.
|
||||
#
|
||||
# See the GNU Affero General Public License for more details:
|
||||
# <https://www.gnu.org/licenses/agpl_3.0.html>.
|
||||
#
|
||||
#
|
||||
|
||||
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.http.server import DirectServeJsonResource
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.generic_worker import GenericWorkerStore
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.server import HomeServer
|
||||
|
||||
|
||||
class MasBaseResource(DirectServeJsonResource):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
# Importing this module requires authlib, which is an optional
|
||||
# dependency but required if msc3861 is enabled
|
||||
from synapse.api.auth.msc3861_delegated import MSC3861DelegatedAuth
|
||||
|
||||
DirectServeJsonResource.__init__(self, extract_context=True)
|
||||
auth = hs.get_auth()
|
||||
assert isinstance(auth, MSC3861DelegatedAuth)
|
||||
self.msc3861_auth = auth
|
||||
self.store = cast("GenericWorkerStore", hs.get_datastores().main)
|
||||
self.hostname = hs.hostname
|
||||
|
||||
def assert_request_is_from_mas(self, request: "SynapseRequest") -> None:
|
||||
"""Assert that the request is coming from MAS itself, not a regular user.
|
||||
|
||||
Throws a 403 if the request is not coming from MAS.
|
||||
"""
|
||||
if not self.msc3861_auth.is_request_using_the_admin_token(request):
|
||||
raise SynapseError(403, "This endpoint must only be called by MAS")
|
||||
238
synapse/rest/synapse/mas/devices.py
Normal file
238
synapse/rest/synapse/mas/devices.py
Normal file
@ -0,0 +1,238 @@
|
||||
#
|
||||
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
#
|
||||
# Copyright (C) 2025 New Vector, Ltd
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
# published by the Free Software Foundation, either version 3 of the
|
||||
# License, or (at your option) any later version.
|
||||
#
|
||||
# See the GNU Affero General Public License for more details:
|
||||
# <https://www.gnu.org/licenses/agpl_3.0.html>.
|
||||
#
|
||||
#
|
||||
|
||||
import logging
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
from synapse._pydantic_compat import StrictStr
|
||||
from synapse.api.errors import NotFoundError
|
||||
from synapse.http.servlet import parse_and_validate_json_object_from_request
|
||||
from synapse.types import JsonDict, UserID
|
||||
from synapse.types.rest import RequestBodyModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.server import HomeServer
|
||||
|
||||
|
||||
from ._base import MasBaseResource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MasUpsertDeviceResource(MasBaseResource):
|
||||
"""
|
||||
Endpoint for MAS to create or update user devices.
|
||||
|
||||
Takes a localpart, device ID, and optional display name to create new devices
|
||||
or update existing ones.
|
||||
|
||||
POST /_synapse/mas/upsert_device
|
||||
{"localpart": "alice", "device_id": "DEVICE123", "display_name": "Alice's Phone"}
|
||||
"""
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
MasBaseResource.__init__(self, hs)
|
||||
|
||||
self.device_handler = hs.get_device_handler()
|
||||
|
||||
class PostBody(RequestBodyModel):
|
||||
localpart: StrictStr
|
||||
device_id: StrictStr
|
||||
display_name: Optional[StrictStr]
|
||||
|
||||
async def _async_render_POST(
|
||||
self, request: "SynapseRequest"
|
||||
) -> Tuple[int, JsonDict]:
|
||||
self.assert_request_is_from_mas(request)
|
||||
|
||||
body = parse_and_validate_json_object_from_request(request, self.PostBody)
|
||||
user_id = UserID(body.localpart, self.hostname)
|
||||
|
||||
# Check the user exists
|
||||
user = await self.store.get_user_by_id(user_id=str(user_id))
|
||||
if user is None:
|
||||
raise NotFoundError("User not found")
|
||||
|
||||
inserted = await self.device_handler.upsert_device(
|
||||
user_id=str(user_id),
|
||||
device_id=body.device_id,
|
||||
display_name=body.display_name,
|
||||
)
|
||||
|
||||
return HTTPStatus.CREATED if inserted else HTTPStatus.OK, {}
|
||||
|
||||
|
||||
class MasDeleteDeviceResource(MasBaseResource):
|
||||
"""
|
||||
Endpoint for MAS to delete user devices.
|
||||
|
||||
Takes a localpart and device ID to remove the specified device from the user's account.
|
||||
|
||||
POST /_synapse/mas/delete_device
|
||||
{"localpart": "alice", "device_id": "DEVICE123"}
|
||||
"""
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
MasBaseResource.__init__(self, hs)
|
||||
|
||||
self.device_handler = hs.get_device_handler()
|
||||
|
||||
class PostBody(RequestBodyModel):
|
||||
localpart: StrictStr
|
||||
device_id: StrictStr
|
||||
|
||||
async def _async_render_POST(
|
||||
self, request: "SynapseRequest"
|
||||
) -> Tuple[int, JsonDict]:
|
||||
self.assert_request_is_from_mas(request)
|
||||
|
||||
body = parse_and_validate_json_object_from_request(request, self.PostBody)
|
||||
user_id = UserID(body.localpart, self.hostname)
|
||||
|
||||
# Check the user exists
|
||||
user = await self.store.get_user_by_id(user_id=str(user_id))
|
||||
if user is None:
|
||||
raise NotFoundError("User not found")
|
||||
|
||||
await self.device_handler.delete_devices(
|
||||
user_id=str(user_id),
|
||||
device_ids=[body.device_id],
|
||||
)
|
||||
|
||||
return HTTPStatus.NO_CONTENT, {}
|
||||
|
||||
|
||||
class MasUpdateDeviceDisplayNameResource(MasBaseResource):
|
||||
"""
|
||||
Endpoint for MAS to update a device's display name.
|
||||
|
||||
Takes a localpart, device ID, and new display name to update the device's name.
|
||||
|
||||
POST /_synapse/mas/update_device_display_name
|
||||
{"localpart": "alice", "device_id": "DEVICE123", "display_name": "Alice's New Phone"}
|
||||
"""
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
MasBaseResource.__init__(self, hs)
|
||||
|
||||
self.device_handler = hs.get_device_handler()
|
||||
|
||||
class PostBody(RequestBodyModel):
|
||||
localpart: StrictStr
|
||||
device_id: StrictStr
|
||||
display_name: StrictStr
|
||||
|
||||
async def _async_render_POST(
|
||||
self, request: "SynapseRequest"
|
||||
) -> Tuple[int, JsonDict]:
|
||||
self.assert_request_is_from_mas(request)
|
||||
|
||||
body = parse_and_validate_json_object_from_request(request, self.PostBody)
|
||||
user_id = UserID(body.localpart, self.hostname)
|
||||
|
||||
# Check the user exists
|
||||
user = await self.store.get_user_by_id(user_id=str(user_id))
|
||||
if user is None:
|
||||
raise NotFoundError("User not found")
|
||||
|
||||
await self.device_handler.update_device(
|
||||
user_id=str(user_id),
|
||||
device_id=body.device_id,
|
||||
content={"display_name": body.display_name},
|
||||
)
|
||||
|
||||
return HTTPStatus.OK, {}
|
||||
|
||||
|
||||
class MasSyncDevicesResource(MasBaseResource):
|
||||
"""
|
||||
Endpoint for MAS to synchronize a user's complete device list.
|
||||
|
||||
Takes a localpart and a set of device IDs to ensure the user's device list
|
||||
matches the provided set by adding missing devices and removing extra ones.
|
||||
|
||||
POST /_synapse/mas/sync_devices
|
||||
{"localpart": "alice", "devices": ["DEVICE123", "DEVICE456"]}
|
||||
"""
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
MasBaseResource.__init__(self, hs)
|
||||
|
||||
self.device_handler = hs.get_device_handler()
|
||||
|
||||
class PostBody(RequestBodyModel):
|
||||
localpart: StrictStr
|
||||
devices: set[StrictStr]
|
||||
|
||||
async def _async_render_POST(
|
||||
self, request: "SynapseRequest"
|
||||
) -> Tuple[int, JsonDict]:
|
||||
self.assert_request_is_from_mas(request)
|
||||
|
||||
body = parse_and_validate_json_object_from_request(request, self.PostBody)
|
||||
user_id = UserID(body.localpart, self.hostname)
|
||||
|
||||
# Check the user exists
|
||||
user = await self.store.get_user_by_id(user_id=str(user_id))
|
||||
if user is None:
|
||||
raise NotFoundError("User not found")
|
||||
|
||||
current_devices = await self.store.get_devices_by_user(user_id=str(user_id))
|
||||
current_devices_list = set(current_devices.keys())
|
||||
target_device_list = set(body.devices)
|
||||
|
||||
to_add = target_device_list - current_devices_list
|
||||
to_delete = current_devices_list - target_device_list
|
||||
|
||||
# Log what we're about to do to make it easier to debug if it stops
|
||||
# mid-way, as this can be a long operation if there are a lot of devices
|
||||
# to delete or to add.
|
||||
if to_add and to_delete:
|
||||
logger.info(
|
||||
"Syncing %d devices for user %s will add %d devices and delete %d devices",
|
||||
len(target_device_list),
|
||||
user_id,
|
||||
len(to_add),
|
||||
len(to_delete),
|
||||
)
|
||||
elif to_add:
|
||||
logger.info(
|
||||
"Syncing %d devices for user %s will add %d devices",
|
||||
len(target_device_list),
|
||||
user_id,
|
||||
len(to_add),
|
||||
)
|
||||
elif to_delete:
|
||||
logger.info(
|
||||
"Syncing %d devices for user %s will delete %d devices",
|
||||
len(target_device_list),
|
||||
user_id,
|
||||
len(to_delete),
|
||||
)
|
||||
|
||||
if to_delete:
|
||||
await self.device_handler.delete_devices(
|
||||
user_id=str(user_id), device_ids=to_delete
|
||||
)
|
||||
|
||||
for device_id in to_add:
|
||||
await self.device_handler.upsert_device(
|
||||
user_id=str(user_id),
|
||||
device_id=device_id,
|
||||
)
|
||||
|
||||
return 200, {}
|
||||
467
synapse/rest/synapse/mas/users.py
Normal file
467
synapse/rest/synapse/mas/users.py
Normal file
@ -0,0 +1,467 @@
|
||||
#
|
||||
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
#
|
||||
# Copyright (C) 2025 New Vector, Ltd
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
# published by the Free Software Foundation, either version 3 of the
|
||||
# License, or (at your option) any later version.
|
||||
#
|
||||
# See the GNU Affero General Public License for more details:
|
||||
# <https://www.gnu.org/licenses/agpl_3.0.html>.
|
||||
#
|
||||
#
|
||||
|
||||
import logging
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, Any, Optional, Tuple, TypedDict
|
||||
|
||||
from synapse._pydantic_compat import StrictBool, StrictStr, root_validator
|
||||
from synapse.api.errors import NotFoundError, SynapseError
|
||||
from synapse.http.servlet import (
|
||||
parse_and_validate_json_object_from_request,
|
||||
parse_string,
|
||||
)
|
||||
from synapse.types import JsonDict, UserID, UserInfo, create_requester
|
||||
from synapse.types.rest import RequestBodyModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.server import HomeServer
|
||||
|
||||
|
||||
from ._base import MasBaseResource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MasQueryUserResource(MasBaseResource):
|
||||
"""
|
||||
Endpoint for MAS to query user information by localpart.
|
||||
|
||||
Takes a localpart parameter and returns user profile data including display name,
|
||||
avatar URL, and account status (suspended/deactivated).
|
||||
|
||||
GET /_synapse/mas/query_user?localpart=alice
|
||||
"""
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
MasBaseResource.__init__(self, hs)
|
||||
|
||||
class Response(TypedDict):
|
||||
user_id: str
|
||||
display_name: Optional[str]
|
||||
avatar_url: Optional[str]
|
||||
is_suspended: bool
|
||||
is_deactivated: bool
|
||||
|
||||
async def _async_render_GET(
|
||||
self, request: "SynapseRequest"
|
||||
) -> Tuple[int, Response]:
|
||||
self.assert_request_is_from_mas(request)
|
||||
|
||||
localpart = parse_string(request, "localpart", required=True)
|
||||
user_id = UserID(localpart, self.hostname)
|
||||
|
||||
user: Optional[UserInfo] = await self.store.get_user_by_id(user_id=str(user_id))
|
||||
if user is None:
|
||||
raise NotFoundError("User not found")
|
||||
|
||||
profile = await self.store.get_profileinfo(user_id=user_id)
|
||||
|
||||
return HTTPStatus.OK, self.Response(
|
||||
user_id=user_id.to_string(),
|
||||
display_name=profile.display_name,
|
||||
avatar_url=profile.avatar_url,
|
||||
is_suspended=user.suspended,
|
||||
is_deactivated=user.is_deactivated,
|
||||
)
|
||||
|
||||
|
||||
class MasProvisionUserResource(MasBaseResource):
|
||||
"""
|
||||
Endpoint for MAS to create or update user accounts and their profile data.
|
||||
|
||||
Takes a localpart and optional profile fields (display name, avatar URL, email addresses).
|
||||
Can create new users or update existing ones by setting or unsetting profile fields.
|
||||
|
||||
POST /_synapse/mas/provision_user
|
||||
{"localpart": "alice", "set_displayname": "Alice", "set_emails": ["alice@example.com"]}
|
||||
"""
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
MasBaseResource.__init__(self, hs)
|
||||
self.registration_handler = hs.get_registration_handler()
|
||||
self.identity_handler = hs.get_identity_handler()
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
self.profile_handler = hs.get_profile_handler()
|
||||
self.clock = hs.get_clock()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
class PostBody(RequestBodyModel):
|
||||
localpart: StrictStr
|
||||
|
||||
unset_displayname: StrictBool = False
|
||||
set_displayname: Optional[StrictStr] = None
|
||||
|
||||
unset_avatar_url: StrictBool = False
|
||||
set_avatar_url: Optional[StrictStr] = None
|
||||
|
||||
unset_emails: StrictBool = False
|
||||
set_emails: Optional[list[StrictStr]] = None
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_exclusive(cls, values: Any) -> Any:
|
||||
if "unset_displayname" in values and "set_displayname" in values:
|
||||
raise ValueError(
|
||||
"Cannot specify both unset_displayname and set_displayname"
|
||||
)
|
||||
if "unset_avatar_url" in values and "set_avatar_url" in values:
|
||||
raise ValueError(
|
||||
"Cannot specify both unset_avatar_url and set_avatar_url"
|
||||
)
|
||||
if "unset_emails" in values and "set_emails" in values:
|
||||
raise ValueError("Cannot specify both unset_emails and set_emails")
|
||||
|
||||
return values
|
||||
|
||||
async def _async_render_POST(
|
||||
self, request: "SynapseRequest"
|
||||
) -> Tuple[int, JsonDict]:
|
||||
self.assert_request_is_from_mas(request)
|
||||
|
||||
body = parse_and_validate_json_object_from_request(request, self.PostBody)
|
||||
|
||||
localpart = body.localpart
|
||||
user_id = UserID(localpart, self.hostname)
|
||||
|
||||
requester = create_requester(user_id=user_id)
|
||||
existing_user = await self.store.get_user_by_id(user_id=str(user_id))
|
||||
if existing_user is None:
|
||||
created = True
|
||||
await self.registration_handler.register_user(
|
||||
localpart=localpart,
|
||||
default_display_name=body.set_displayname,
|
||||
bind_emails=body.set_emails,
|
||||
by_admin=True,
|
||||
)
|
||||
else:
|
||||
created = False
|
||||
if body.unset_displayname:
|
||||
await self.profile_handler.set_displayname(
|
||||
target_user=user_id,
|
||||
requester=requester,
|
||||
new_displayname="",
|
||||
by_admin=True,
|
||||
)
|
||||
elif body.set_displayname is not None:
|
||||
await self.profile_handler.set_displayname(
|
||||
target_user=user_id,
|
||||
requester=requester,
|
||||
new_displayname=body.set_displayname,
|
||||
by_admin=True,
|
||||
)
|
||||
|
||||
new_email_list: Optional[set[str]] = None
|
||||
if body.unset_emails:
|
||||
new_email_list = set()
|
||||
elif body.set_emails is not None:
|
||||
new_email_list = set(body.set_emails)
|
||||
|
||||
if new_email_list is not None:
|
||||
medium = "email"
|
||||
current_threepid_list = await self.store.user_get_threepids(
|
||||
user_id=user_id.to_string()
|
||||
)
|
||||
current_email_list = {
|
||||
t.address for t in current_threepid_list if t.medium == medium
|
||||
}
|
||||
|
||||
to_delete = current_email_list - new_email_list
|
||||
to_add = new_email_list - current_email_list
|
||||
|
||||
for address in to_delete:
|
||||
await self.identity_handler.try_unbind_threepid(
|
||||
mxid=user_id.to_string(),
|
||||
medium=medium,
|
||||
address=address,
|
||||
id_server=None,
|
||||
)
|
||||
|
||||
await self.auth_handler.delete_local_threepid(
|
||||
user_id=user_id.to_string(),
|
||||
medium=medium,
|
||||
address=address,
|
||||
)
|
||||
|
||||
current_time = self.clock.time_msec()
|
||||
for address in to_add:
|
||||
await self.auth_handler.add_threepid(
|
||||
user_id=user_id.to_string(),
|
||||
medium=medium,
|
||||
address=address,
|
||||
validated_at=current_time,
|
||||
)
|
||||
|
||||
if body.unset_avatar_url:
|
||||
await self.profile_handler.set_avatar_url(
|
||||
target_user=user_id,
|
||||
requester=requester,
|
||||
new_avatar_url="",
|
||||
by_admin=True,
|
||||
)
|
||||
elif body.set_avatar_url is not None:
|
||||
await self.profile_handler.set_avatar_url(
|
||||
target_user=user_id,
|
||||
requester=requester,
|
||||
new_avatar_url=body.set_avatar_url,
|
||||
by_admin=True,
|
||||
)
|
||||
|
||||
return HTTPStatus.CREATED if created else HTTPStatus.OK, {}
|
||||
|
||||
|
||||
class MasIsLocalpartAvailableResource(MasBaseResource):
|
||||
"""
|
||||
Endpoint for MAS to check if a localpart is available for user registration.
|
||||
|
||||
Takes a localpart parameter and validates its format and availability,
|
||||
checking for conflicts with existing users or application service namespaces.
|
||||
|
||||
GET /_synapse/mas/is_localpart_available?localpart=alice
|
||||
"""
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
|
||||
self.registration_handler = hs.get_registration_handler()
|
||||
|
||||
async def _async_render_GET(
|
||||
self, request: "SynapseRequest"
|
||||
) -> Tuple[int, JsonDict]:
|
||||
self.assert_request_is_from_mas(request)
|
||||
localpart = parse_string(request, "localpart")
|
||||
if localpart is None:
|
||||
raise SynapseError(400, "Missing localpart")
|
||||
|
||||
await self.registration_handler.check_username(localpart)
|
||||
|
||||
return HTTPStatus.OK, {}
|
||||
|
||||
|
||||
class MasDeleteUserResource(MasBaseResource):
|
||||
"""
|
||||
Endpoint for MAS to delete/deactivate user accounts.
|
||||
|
||||
Takes a localpart and an erase flag to determine whether to deactivate
|
||||
the account and optionally erase user data for compliance purposes.
|
||||
|
||||
POST /_synapse/mas/delete_user
|
||||
{"localpart": "alice", "erase": true}
|
||||
"""
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
|
||||
self.deactivate_account_handler = hs.get_deactivate_account_handler()
|
||||
|
||||
class PostBody(RequestBodyModel):
|
||||
localpart: StrictStr
|
||||
erase: StrictBool
|
||||
|
||||
async def _async_render_POST(
|
||||
self, request: "SynapseRequest"
|
||||
) -> Tuple[int, JsonDict]:
|
||||
self.assert_request_is_from_mas(request)
|
||||
|
||||
body = parse_and_validate_json_object_from_request(request, self.PostBody)
|
||||
user_id = UserID(body.localpart, self.hostname)
|
||||
|
||||
# Check the user exists
|
||||
user = await self.store.get_user_by_id(user_id=str(user_id))
|
||||
if user is None:
|
||||
raise NotFoundError("User not found")
|
||||
|
||||
await self.deactivate_account_handler.deactivate_account(
|
||||
user_id=user_id.to_string(),
|
||||
erase_data=body.erase,
|
||||
requester=create_requester(user_id=user_id),
|
||||
)
|
||||
|
||||
return HTTPStatus.OK, {}
|
||||
|
||||
|
||||
class MasReactivateUserResource(MasBaseResource):
|
||||
"""
|
||||
Endpoint for MAS to reactivate previously deactivated user accounts.
|
||||
|
||||
Takes a localpart parameter to restore access to deactivated accounts.
|
||||
|
||||
POST /_synapse/mas/reactivate_user
|
||||
{"localpart": "alice"}
|
||||
"""
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
MasBaseResource.__init__(self, hs)
|
||||
|
||||
self.deactivate_account_handler = hs.get_deactivate_account_handler()
|
||||
|
||||
class PostBody(RequestBodyModel):
|
||||
localpart: StrictStr
|
||||
|
||||
async def _async_render_POST(
|
||||
self, request: "SynapseRequest"
|
||||
) -> Tuple[int, JsonDict]:
|
||||
self.assert_request_is_from_mas(request)
|
||||
|
||||
body = parse_and_validate_json_object_from_request(request, self.PostBody)
|
||||
user_id = UserID(body.localpart, self.hostname)
|
||||
|
||||
# Check the user exists
|
||||
user = await self.store.get_user_by_id(user_id=str(user_id))
|
||||
if user is None:
|
||||
raise NotFoundError("User not found")
|
||||
|
||||
await self.deactivate_account_handler.activate_account(user_id=str(user_id))
|
||||
|
||||
return HTTPStatus.OK, {}
|
||||
|
||||
|
||||
class MasSetDisplayNameResource(MasBaseResource):
|
||||
"""
|
||||
Endpoint for MAS to set a user's display name.
|
||||
|
||||
Takes a localpart and display name to update the user's profile.
|
||||
|
||||
POST /_synapse/mas/set_displayname
|
||||
{"localpart": "alice", "displayname": "Alice"}
|
||||
"""
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
MasBaseResource.__init__(self, hs)
|
||||
|
||||
self.profile_handler = hs.get_profile_handler()
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
|
||||
class PostBody(RequestBodyModel):
|
||||
localpart: StrictStr
|
||||
displayname: StrictStr
|
||||
|
||||
async def _async_render_POST(
|
||||
self, request: "SynapseRequest"
|
||||
) -> Tuple[int, JsonDict]:
|
||||
self.assert_request_is_from_mas(request)
|
||||
|
||||
body = parse_and_validate_json_object_from_request(request, self.PostBody)
|
||||
user_id = UserID(body.localpart, self.hostname)
|
||||
|
||||
# Check the user exists
|
||||
user = await self.store.get_user_by_id(user_id=str(user_id))
|
||||
if user is None:
|
||||
raise NotFoundError("User not found")
|
||||
|
||||
requester = create_requester(user_id=user_id)
|
||||
|
||||
await self.profile_handler.set_displayname(
|
||||
target_user=requester.user,
|
||||
requester=requester,
|
||||
new_displayname=body.displayname,
|
||||
by_admin=True,
|
||||
)
|
||||
|
||||
return HTTPStatus.OK, {}
|
||||
|
||||
|
||||
class MasUnsetDisplayNameResource(MasBaseResource):
|
||||
"""
|
||||
Endpoint for MAS to clear a user's display name.
|
||||
|
||||
Takes a localpart parameter to remove the display name for the specified user.
|
||||
|
||||
POST /_synapse/mas/unset_displayname
|
||||
{"localpart": "alice"}
|
||||
"""
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
MasBaseResource.__init__(self, hs)
|
||||
|
||||
self.profile_handler = hs.get_profile_handler()
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
|
||||
class PostBody(RequestBodyModel):
|
||||
localpart: StrictStr
|
||||
|
||||
async def _async_render_POST(
|
||||
self, request: "SynapseRequest"
|
||||
) -> Tuple[int, JsonDict]:
|
||||
self.assert_request_is_from_mas(request)
|
||||
|
||||
body = parse_and_validate_json_object_from_request(request, self.PostBody)
|
||||
user_id = UserID(body.localpart, self.hostname)
|
||||
|
||||
# Check the user exists
|
||||
user = await self.store.get_user_by_id(user_id=str(user_id))
|
||||
if user is None:
|
||||
raise NotFoundError("User not found")
|
||||
|
||||
requester = create_requester(user_id=user_id)
|
||||
|
||||
await self.profile_handler.set_displayname(
|
||||
target_user=requester.user,
|
||||
requester=requester,
|
||||
new_displayname="",
|
||||
by_admin=True,
|
||||
)
|
||||
|
||||
return HTTPStatus.OK, {}
|
||||
|
||||
|
||||
class MasAllowCrossSigningResetResource(MasBaseResource):
|
||||
"""
|
||||
Endpoint for MAS to allow cross-signing key reset without user interaction.
|
||||
|
||||
Takes a localpart parameter to temporarily allow cross-signing key replacement
|
||||
without requiring User-Interactive Authentication (UIA).
|
||||
|
||||
POST /_synapse/mas/allow_cross_signing_reset
|
||||
{"localpart": "alice"}
|
||||
"""
|
||||
|
||||
REPLACEMENT_PERIOD_MS = 10 * 60 * 1000 # 10 minutes
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
MasBaseResource.__init__(self, hs)
|
||||
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
|
||||
class PostBody(RequestBodyModel):
|
||||
localpart: StrictStr
|
||||
|
||||
async def _async_render_POST(
|
||||
self, request: "SynapseRequest"
|
||||
) -> Tuple[int, JsonDict]:
|
||||
self.assert_request_is_from_mas(request)
|
||||
|
||||
body = parse_and_validate_json_object_from_request(request, self.PostBody)
|
||||
user_id = UserID(body.localpart, self.hostname)
|
||||
|
||||
# Check the user exists
|
||||
user = await self.store.get_user_by_id(user_id=str(user_id))
|
||||
if user is None:
|
||||
raise NotFoundError("User not found")
|
||||
|
||||
timestamp = (
|
||||
await self.store.allow_master_cross_signing_key_replacement_without_uia(
|
||||
user_id=str(user_id),
|
||||
duration_ms=self.REPLACEMENT_PERIOD_MS,
|
||||
)
|
||||
)
|
||||
|
||||
if timestamp is None:
|
||||
# If there are no cross-signing keys, this is a no-op, but we should log
|
||||
logger.warning(
|
||||
"User %s has no master cross-signing key", user_id.to_string()
|
||||
)
|
||||
|
||||
return HTTPStatus.OK, {}
|
||||
12
tests/rest/synapse/mas/__init__.py
Normal file
12
tests/rest/synapse/mas/__init__.py
Normal file
@ -0,0 +1,12 @@
|
||||
#
|
||||
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
#
|
||||
# Copyright (C) 2025 New Vector, Ltd
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
# published by the Free Software Foundation, either version 3 of the
|
||||
# License, or (at your option) any later version.
|
||||
#
|
||||
# See the GNU Affero General Public License for more details:
|
||||
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
43
tests/rest/synapse/mas/_base.py
Normal file
43
tests/rest/synapse/mas/_base.py
Normal file
@ -0,0 +1,43 @@
|
||||
#
|
||||
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
#
|
||||
# Copyright (C) 2025 New Vector, Ltd
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
# published by the Free Software Foundation, either version 3 of the
|
||||
# License, or (at your option) any later version.
|
||||
#
|
||||
# See the GNU Affero General Public License for more details:
|
||||
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
|
||||
from twisted.web.resource import Resource
|
||||
|
||||
from synapse.rest.synapse.client import build_synapse_client_resource_tree
|
||||
from synapse.types import JsonDict
|
||||
|
||||
from tests import unittest
|
||||
|
||||
|
||||
class BaseTestCase(unittest.HomeserverTestCase):
|
||||
SHARED_SECRET = "shared_secret"
|
||||
|
||||
def default_config(self) -> JsonDict:
|
||||
config = super().default_config()
|
||||
config["enable_registration"] = False
|
||||
config["experimental_features"] = {
|
||||
"msc3861": {
|
||||
"enabled": True,
|
||||
"issuer": "https://example.com",
|
||||
"client_id": "dummy",
|
||||
"client_auth_method": "client_secret_basic",
|
||||
"client_secret": "dummy",
|
||||
"admin_token": self.SHARED_SECRET,
|
||||
}
|
||||
}
|
||||
return config
|
||||
|
||||
def create_resource_dict(self) -> dict[str, Resource]:
|
||||
base = super().create_resource_dict()
|
||||
base.update(build_synapse_client_resource_tree(self.hs))
|
||||
return base
|
||||
693
tests/rest/synapse/mas/test_devices.py
Normal file
693
tests/rest/synapse/mas/test_devices.py
Normal file
@ -0,0 +1,693 @@
|
||||
#
|
||||
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
#
|
||||
# Copyright (C) 2025 New Vector, Ltd
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
# published by the Free Software Foundation, either version 3 of the
|
||||
# License, or (at your option) any later version.
|
||||
#
|
||||
# See the GNU Affero General Public License for more details:
|
||||
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import UserID
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests.unittest import skip_unless
|
||||
from tests.utils import HAS_AUTHLIB
|
||||
|
||||
from ._base import BaseTestCase
|
||||
|
||||
|
||||
@skip_unless(HAS_AUTHLIB, "requires authlib")
|
||||
class MasUpsertDeviceResource(BaseTestCase):
|
||||
def prepare(
|
||||
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
|
||||
) -> None:
|
||||
# Create a user for testing
|
||||
self.alice_user_id = UserID("alice", "test")
|
||||
self.get_success(
|
||||
homeserver.get_registration_handler().register_user(
|
||||
localpart=self.alice_user_id.localpart,
|
||||
)
|
||||
)
|
||||
|
||||
def test_other_token(self) -> None:
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_synapse/mas/upsert_device",
|
||||
shorthand=False,
|
||||
access_token="other_token",
|
||||
content={
|
||||
"localpart": "alice",
|
||||
"device_id": "DEVICE1",
|
||||
},
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 403, channel.json_body)
|
||||
self.assertEqual(
|
||||
channel.json_body["error"], "This endpoint must only be called by MAS"
|
||||
)
|
||||
|
||||
def test_upsert_device(self) -> None:
|
||||
store = self.hs.get_datastores().main
|
||||
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_synapse/mas/upsert_device",
|
||||
shorthand=False,
|
||||
access_token=self.SHARED_SECRET,
|
||||
content={
|
||||
"localpart": "alice",
|
||||
"device_id": "DEVICE1",
|
||||
},
|
||||
)
|
||||
|
||||
# This created a new device, hence the 201 status code
|
||||
self.assertEqual(channel.code, 201, channel.json_body)
|
||||
self.assertEqual(channel.json_body, {})
|
||||
|
||||
# Verify the device exists
|
||||
device = self.get_success(store.get_device(str(self.alice_user_id), "DEVICE1"))
|
||||
assert device is not None
|
||||
self.assertEqual(device["device_id"], "DEVICE1")
|
||||
self.assertIsNone(device["display_name"])
|
||||
|
||||
def test_update_existing_device(self) -> None:
|
||||
store = self.hs.get_datastores().main
|
||||
device_handler = self.hs.get_device_handler()
|
||||
|
||||
# Create an initial device
|
||||
self.get_success(
|
||||
device_handler.upsert_device(
|
||||
user_id=str(self.alice_user_id),
|
||||
device_id="DEVICE1",
|
||||
display_name="Old Name",
|
||||
)
|
||||
)
|
||||
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_synapse/mas/upsert_device",
|
||||
shorthand=False,
|
||||
access_token=self.SHARED_SECRET,
|
||||
content={
|
||||
"localpart": "alice",
|
||||
"device_id": "DEVICE1",
|
||||
"display_name": "New Name",
|
||||
},
|
||||
)
|
||||
|
||||
# This updated an existing device, hence the 200 status code
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
self.assertEqual(channel.json_body, {})
|
||||
|
||||
# Verify the device was updated
|
||||
device = self.get_success(store.get_device(str(self.alice_user_id), "DEVICE1"))
|
||||
assert device is not None
|
||||
self.assertEqual(device["display_name"], "New Name")
|
||||
|
||||
def test_upsert_device_with_display_name(self) -> None:
|
||||
store = self.hs.get_datastores().main
|
||||
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_synapse/mas/upsert_device",
|
||||
shorthand=False,
|
||||
access_token=self.SHARED_SECRET,
|
||||
content={
|
||||
"localpart": "alice",
|
||||
"device_id": "DEVICE1",
|
||||
"display_name": "Alice's Phone",
|
||||
},
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 201, channel.json_body)
|
||||
self.assertEqual(channel.json_body, {})
|
||||
|
||||
# Verify the device exists with correct display name
|
||||
device = self.get_success(store.get_device(str(self.alice_user_id), "DEVICE1"))
|
||||
assert device is not None
|
||||
self.assertEqual(device["display_name"], "Alice's Phone")
|
||||
|
||||
def test_upsert_device_missing_localpart(self) -> None:
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_synapse/mas/upsert_device",
|
||||
shorthand=False,
|
||||
access_token=self.SHARED_SECRET,
|
||||
content={
|
||||
"device_id": "DEVICE1",
|
||||
},
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 400, channel.json_body)
|
||||
|
||||
def test_upsert_device_missing_device_id(self) -> None:
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_synapse/mas/upsert_device",
|
||||
shorthand=False,
|
||||
access_token=self.SHARED_SECRET,
|
||||
content={
|
||||
"localpart": "alice",
|
||||
},
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 400, channel.json_body)
|
||||
|
||||
def test_upsert_device_nonexistent_user(self) -> None:
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_synapse/mas/upsert_device",
|
||||
shorthand=False,
|
||||
access_token=self.SHARED_SECRET,
|
||||
content={
|
||||
"localpart": "nonexistent",
|
||||
"device_id": "DEVICE1",
|
||||
},
|
||||
)
|
||||
|
||||
# We get a 404 here as the user doesn't exist
|
||||
self.assertEqual(channel.code, 404, channel.json_body)
|
||||
|
||||
|
||||
@skip_unless(HAS_AUTHLIB, "requires authlib")
|
||||
class MasDeleteDeviceResource(BaseTestCase):
|
||||
def prepare(
|
||||
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
|
||||
) -> None:
|
||||
# Create a user and device for testing
|
||||
self.alice_user_id = UserID("alice", "test")
|
||||
self.get_success(
|
||||
homeserver.get_registration_handler().register_user(
|
||||
localpart=self.alice_user_id.localpart,
|
||||
)
|
||||
)
|
||||
|
||||
# Create a device
|
||||
device_handler = homeserver.get_device_handler()
|
||||
self.get_success(
|
||||
device_handler.upsert_device(
|
||||
user_id=str(self.alice_user_id),
|
||||
device_id="DEVICE1",
|
||||
display_name="Test Device",
|
||||
)
|
||||
)
|
||||
|
||||
def test_other_token(self) -> None:
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_synapse/mas/delete_device",
|
||||
shorthand=False,
|
||||
access_token="other_token",
|
||||
content={
|
||||
"localpart": "alice",
|
||||
"device_id": "DEVICE1",
|
||||
},
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 403, channel.json_body)
|
||||
self.assertEqual(
|
||||
channel.json_body["error"], "This endpoint must only be called by MAS"
|
||||
)
|
||||
|
||||
def test_delete_device(self) -> None:
|
||||
store = self.hs.get_datastores().main
|
||||
|
||||
# Verify device exists before deletion
|
||||
device = self.get_success(store.get_device(str(self.alice_user_id), "DEVICE1"))
|
||||
assert device is not None
|
||||
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_synapse/mas/delete_device",
|
||||
shorthand=False,
|
||||
access_token=self.SHARED_SECRET,
|
||||
content={
|
||||
"localpart": "alice",
|
||||
"device_id": "DEVICE1",
|
||||
},
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 204)
|
||||
|
||||
# Verify the device no longer exists
|
||||
device = self.get_success(store.get_device(str(self.alice_user_id), "DEVICE1"))
|
||||
self.assertIsNone(device)
|
||||
|
||||
def test_delete_nonexistent_device(self) -> None:
|
||||
# Deleting a non-existent device should be idempotent
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_synapse/mas/delete_device",
|
||||
shorthand=False,
|
||||
access_token=self.SHARED_SECRET,
|
||||
content={
|
||||
"localpart": "alice",
|
||||
"device_id": "NONEXISTENT",
|
||||
},
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 204)
|
||||
|
||||
def test_delete_device_missing_localpart(self) -> None:
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_synapse/mas/delete_device",
|
||||
shorthand=False,
|
||||
access_token=self.SHARED_SECRET,
|
||||
content={
|
||||
"device_id": "DEVICE1",
|
||||
},
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 400, channel.json_body)
|
||||
|
||||
def test_delete_device_missing_device_id(self) -> None:
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_synapse/mas/delete_device",
|
||||
shorthand=False,
|
||||
access_token=self.SHARED_SECRET,
|
||||
content={
|
||||
"localpart": "alice",
|
||||
},
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 400, channel.json_body)
|
||||
|
||||
def test_delete_device_nonexistent_user(self) -> None:
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_synapse/mas/delete_device",
|
||||
shorthand=False,
|
||||
access_token=self.SHARED_SECRET,
|
||||
content={
|
||||
"localpart": "nonexistent",
|
||||
"device_id": "DEVICE1",
|
||||
},
|
||||
)
|
||||
|
||||
# Should fail on a non-existent user
|
||||
self.assertEqual(channel.code, 404, channel.json_body)
|
||||
|
||||
|
||||
@skip_unless(HAS_AUTHLIB, "requires authlib")
|
||||
class MasUpdateDeviceDisplayNameResource(BaseTestCase):
|
||||
def prepare(
|
||||
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
|
||||
) -> None:
|
||||
# Create a user and device for testing
|
||||
self.alice_user_id = UserID("alice", "test")
|
||||
self.get_success(
|
||||
homeserver.get_registration_handler().register_user(
|
||||
localpart=self.alice_user_id.localpart,
|
||||
)
|
||||
)
|
||||
|
||||
# Create a device
|
||||
device_handler = homeserver.get_device_handler()
|
||||
self.get_success(
|
||||
device_handler.upsert_device(
|
||||
user_id=str(self.alice_user_id),
|
||||
device_id="DEVICE1",
|
||||
display_name="Old Name",
|
||||
)
|
||||
)
|
||||
|
||||
def test_other_token(self) -> None:
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_synapse/mas/update_device_display_name",
|
||||
shorthand=False,
|
||||
access_token="other_token",
|
||||
content={
|
||||
"localpart": "alice",
|
||||
"device_id": "DEVICE1",
|
||||
"display_name": "New Name",
|
||||
},
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 403, channel.json_body)
|
||||
self.assertEqual(
|
||||
channel.json_body["error"], "This endpoint must only be called by MAS"
|
||||
)
|
||||
|
||||
def test_update_device_display_name(self) -> None:
|
||||
store = self.hs.get_datastores().main
|
||||
|
||||
# Verify initial display name
|
||||
device = self.get_success(store.get_device(str(self.alice_user_id), "DEVICE1"))
|
||||
assert device is not None
|
||||
self.assertEqual(device["display_name"], "Old Name")
|
||||
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_synapse/mas/update_device_display_name",
|
||||
shorthand=False,
|
||||
access_token=self.SHARED_SECRET,
|
||||
content={
|
||||
"localpart": "alice",
|
||||
"device_id": "DEVICE1",
|
||||
"display_name": "Updated Name",
|
||||
},
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
self.assertEqual(channel.json_body, {})
|
||||
|
||||
# Verify the display name was updated
|
||||
device = self.get_success(store.get_device(str(self.alice_user_id), "DEVICE1"))
|
||||
assert device is not None
|
||||
self.assertEqual(device["display_name"], "Updated Name")
|
||||
|
||||
def test_update_nonexistent_device(self) -> None:
|
||||
# Updating a non-existent device should fail
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_synapse/mas/update_device_display_name",
|
||||
shorthand=False,
|
||||
access_token=self.SHARED_SECRET,
|
||||
content={
|
||||
"localpart": "alice",
|
||||
"device_id": "NONEXISTENT",
|
||||
"display_name": "New Name",
|
||||
},
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 404, channel.json_body)
|
||||
|
||||
def test_update_device_display_name_missing_localpart(self) -> None:
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_synapse/mas/update_device_display_name",
|
||||
shorthand=False,
|
||||
access_token=self.SHARED_SECRET,
|
||||
content={
|
||||
"device_id": "DEVICE1",
|
||||
"display_name": "New Name",
|
||||
},
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 400, channel.json_body)
|
||||
|
||||
def test_update_device_display_name_missing_device_id(self) -> None:
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_synapse/mas/update_device_display_name",
|
||||
shorthand=False,
|
||||
access_token=self.SHARED_SECRET,
|
||||
content={
|
||||
"localpart": "alice",
|
||||
"display_name": "New Name",
|
||||
},
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 400, channel.json_body)
|
||||
|
||||
def test_update_device_display_name_missing_display_name(self) -> None:
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_synapse/mas/update_device_display_name",
|
||||
shorthand=False,
|
||||
access_token=self.SHARED_SECRET,
|
||||
content={
|
||||
"localpart": "alice",
|
||||
"device_id": "DEVICE1",
|
||||
},
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 400, channel.json_body)
|
||||
|
||||
def test_update_device_display_name_nonexistent_user(self) -> None:
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_synapse/mas/update_device_display_name",
|
||||
shorthand=False,
|
||||
access_token=self.SHARED_SECRET,
|
||||
content={
|
||||
"localpart": "nonexistent",
|
||||
"device_id": "DEVICE1",
|
||||
"display_name": "New Name",
|
||||
},
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 404, channel.json_body)
|
||||
|
||||
|
||||
@skip_unless(HAS_AUTHLIB, "requires authlib")
|
||||
class MasSyncDevicesResource(BaseTestCase):
|
||||
def prepare(
|
||||
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
|
||||
) -> None:
|
||||
# Create a user for testing
|
||||
self.alice_user_id = UserID("alice", "test")
|
||||
self.get_success(
|
||||
homeserver.get_registration_handler().register_user(
|
||||
localpart=self.alice_user_id.localpart,
|
||||
)
|
||||
)
|
||||
|
||||
# Create some initial devices
|
||||
device_handler = homeserver.get_device_handler()
|
||||
for device_id in ["DEVICE1", "DEVICE2", "DEVICE3"]:
|
||||
self.get_success(
|
||||
device_handler.upsert_device(
|
||||
user_id=str(self.alice_user_id),
|
||||
device_id=device_id,
|
||||
display_name=f"Device {device_id}",
|
||||
)
|
||||
)
|
||||
|
||||
def test_other_token(self) -> None:
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_synapse/mas/sync_devices",
|
||||
shorthand=False,
|
||||
access_token="other_token",
|
||||
content={
|
||||
"localpart": "alice",
|
||||
"devices": ["DEVICE1", "DEVICE2"],
|
||||
},
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 403, channel.json_body)
|
||||
self.assertEqual(
|
||||
channel.json_body["error"], "This endpoint must only be called by MAS"
|
||||
)
|
||||
|
||||
def test_sync_devices_no_changes(self) -> None:
|
||||
# Sync with the same devices that already exist
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_synapse/mas/sync_devices",
|
||||
shorthand=False,
|
||||
access_token=self.SHARED_SECRET,
|
||||
content={
|
||||
"localpart": "alice",
|
||||
"devices": ["DEVICE1", "DEVICE2", "DEVICE3"],
|
||||
},
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
self.assertEqual(channel.json_body, {})
|
||||
|
||||
# Verify all devices still exist
|
||||
store = self.hs.get_datastores().main
|
||||
devices = self.get_success(store.get_devices_by_user(str(self.alice_user_id)))
|
||||
self.assertEqual(set(devices.keys()), {"DEVICE1", "DEVICE2", "DEVICE3"})
|
||||
|
||||
def test_sync_devices_add_only(self) -> None:
|
||||
# Sync with additional devices
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_synapse/mas/sync_devices",
|
||||
shorthand=False,
|
||||
access_token=self.SHARED_SECRET,
|
||||
content={
|
||||
"localpart": "alice",
|
||||
"devices": ["DEVICE1", "DEVICE2", "DEVICE3", "DEVICE4", "DEVICE5"],
|
||||
},
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
self.assertEqual(channel.json_body, {})
|
||||
|
||||
# Verify new devices were added
|
||||
store = self.hs.get_datastores().main
|
||||
devices = self.get_success(store.get_devices_by_user(str(self.alice_user_id)))
|
||||
self.assertEqual(
|
||||
set(devices.keys()), {"DEVICE1", "DEVICE2", "DEVICE3", "DEVICE4", "DEVICE5"}
|
||||
)
|
||||
|
||||
def test_sync_devices_delete_only(self) -> None:
|
||||
# Sync with fewer devices
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_synapse/mas/sync_devices",
|
||||
shorthand=False,
|
||||
access_token=self.SHARED_SECRET,
|
||||
content={
|
||||
"localpart": "alice",
|
||||
"devices": ["DEVICE1"],
|
||||
},
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
self.assertEqual(channel.json_body, {})
|
||||
|
||||
# Verify devices were deleted
|
||||
store = self.hs.get_datastores().main
|
||||
devices = self.get_success(store.get_devices_by_user(str(self.alice_user_id)))
|
||||
self.assertEqual(set(devices.keys()), {"DEVICE1"})
|
||||
|
||||
def test_sync_devices_add_and_delete(self) -> None:
|
||||
# Sync with a mix of additions and deletions
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_synapse/mas/sync_devices",
|
||||
shorthand=False,
|
||||
access_token=self.SHARED_SECRET,
|
||||
content={
|
||||
"localpart": "alice",
|
||||
"devices": ["DEVICE1", "DEVICE4", "DEVICE5"],
|
||||
},
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
self.assertEqual(channel.json_body, {})
|
||||
|
||||
# Verify the correct devices exist
|
||||
store = self.hs.get_datastores().main
|
||||
devices = self.get_success(store.get_devices_by_user(str(self.alice_user_id)))
|
||||
self.assertEqual(set(devices.keys()), {"DEVICE1", "DEVICE4", "DEVICE5"})
|
||||
|
||||
def test_sync_devices_empty_list(self) -> None:
|
||||
# Sync with empty device list (delete all devices)
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_synapse/mas/sync_devices",
|
||||
shorthand=False,
|
||||
access_token=self.SHARED_SECRET,
|
||||
content={
|
||||
"localpart": "alice",
|
||||
"devices": [],
|
||||
},
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
self.assertEqual(channel.json_body, {})
|
||||
|
||||
# Verify all devices were deleted
|
||||
store = self.hs.get_datastores().main
|
||||
devices = self.get_success(store.get_devices_by_user(str(self.alice_user_id)))
|
||||
self.assertEqual(devices, {})
|
||||
|
||||
def test_sync_devices_for_new_user(self) -> None:
|
||||
# Test syncing devices for a user that doesn't have any devices yet
|
||||
bob_user_id = UserID("bob", "test")
|
||||
self.get_success(
|
||||
self.hs.get_registration_handler().register_user(
|
||||
localpart=bob_user_id.localpart,
|
||||
)
|
||||
)
|
||||
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_synapse/mas/sync_devices",
|
||||
shorthand=False,
|
||||
access_token=self.SHARED_SECRET,
|
||||
content={
|
||||
"localpart": "bob",
|
||||
"devices": ["DEVICE1", "DEVICE2"],
|
||||
},
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
self.assertEqual(channel.json_body, {})
|
||||
|
||||
# Verify devices were created
|
||||
store = self.hs.get_datastores().main
|
||||
devices = self.get_success(store.get_devices_by_user(str(bob_user_id)))
|
||||
self.assertEqual(set(devices.keys()), {"DEVICE1", "DEVICE2"})
|
||||
|
||||
def test_sync_devices_missing_localpart(self) -> None:
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_synapse/mas/sync_devices",
|
||||
shorthand=False,
|
||||
access_token=self.SHARED_SECRET,
|
||||
content={
|
||||
"devices": ["DEVICE1", "DEVICE2"],
|
||||
},
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 400, channel.json_body)
|
||||
|
||||
def test_sync_devices_missing_devices(self) -> None:
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_synapse/mas/sync_devices",
|
||||
shorthand=False,
|
||||
access_token=self.SHARED_SECRET,
|
||||
content={
|
||||
"localpart": "alice",
|
||||
},
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 400, channel.json_body)
|
||||
|
||||
def test_sync_devices_invalid_devices_type(self) -> None:
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_synapse/mas/sync_devices",
|
||||
shorthand=False,
|
||||
access_token=self.SHARED_SECRET,
|
||||
content={
|
||||
"localpart": "alice",
|
||||
"devices": "not_a_list",
|
||||
},
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 400, channel.json_body)
|
||||
|
||||
def test_sync_devices_nonexistent_user(self) -> None:
|
||||
# Test syncing devices for a user that doesn't exist
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_synapse/mas/sync_devices",
|
||||
shorthand=False,
|
||||
access_token=self.SHARED_SECRET,
|
||||
content={
|
||||
"localpart": "nonexistent",
|
||||
"devices": ["DEVICE1", "DEVICE2"],
|
||||
},
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 404, channel.json_body)
|
||||
|
||||
def test_sync_devices_duplicate_device_ids(self) -> None:
|
||||
# Test syncing with duplicate device IDs (sets should handle this)
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_synapse/mas/sync_devices",
|
||||
shorthand=False,
|
||||
access_token=self.SHARED_SECRET,
|
||||
content={
|
||||
"localpart": "alice",
|
||||
"devices": ["DEVICE1", "DEVICE1", "DEVICE2"],
|
||||
},
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
self.assertEqual(channel.json_body, {})
|
||||
|
||||
# Verify the correct devices exist (duplicates should be handled)
|
||||
store = self.hs.get_datastores().main
|
||||
devices = self.get_success(store.get_devices_by_user(str(self.alice_user_id)))
|
||||
self.assertEqual(sorted(devices.keys()), ["DEVICE1", "DEVICE2"])
|
||||
1399
tests/rest/synapse/mas/test_users.py
Normal file
1399
tests/rest/synapse/mas/test_users.py
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user