From 379356c0ea2f5d0a0d65d900efdf157417fa6947 Mon Sep 17 00:00:00 2001 From: Hugh Nimmo-Smith Date: Wed, 4 Jun 2025 11:33:10 +0100 Subject: [PATCH] Add media repository callbacks to module API to control media upload size (#18457) Adds new callbacks for media related functionality: - `get_media_config_for_user` - `is_user_allowed_to_upload_media_of_size` --- changelog.d/18457.feature | 1 + docs/SUMMARY.md | 1 + docs/modules/media_repository_callbacks.md | 56 ++++++++++++++ synapse/module_api/__init__.py | 20 +++++ synapse/module_api/callbacks/__init__.py | 4 + .../callbacks/media_repository_callbacks.py | 76 +++++++++++++++++++ synapse/rest/client/media.py | 11 ++- synapse/rest/media/config_resource.py | 11 ++- synapse/rest/media/upload_resource.py | 24 ++++-- tests/media/test_media_storage.py | 39 ++++++++++ tests/rest/client/test_media.py | 57 ++++++++++++++ 11 files changed, 291 insertions(+), 9 deletions(-) create mode 100644 changelog.d/18457.feature create mode 100644 docs/modules/media_repository_callbacks.md create mode 100644 synapse/module_api/callbacks/media_repository_callbacks.py diff --git a/changelog.d/18457.feature b/changelog.d/18457.feature new file mode 100644 index 0000000000..76374dc3cd --- /dev/null +++ b/changelog.d/18457.feature @@ -0,0 +1 @@ +Add new module API callbacks that allows overriding of media repository maximum upload size. \ No newline at end of file diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index fd91d9fa11..abb1d5603c 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -49,6 +49,7 @@ - [Background update controller callbacks](modules/background_update_controller_callbacks.md) - [Account data callbacks](modules/account_data_callbacks.md) - [Add extra fields to client events unsigned section callbacks](modules/add_extra_fields_to_client_events_unsigned.md) + - [Media repository](modules/media_repository_callbacks.md) - [Porting a legacy module to the new interface](modules/porting_legacy_module.md) - [Workers](workers.md) - [Using `synctl` with Workers](synctl_workers.md) diff --git a/docs/modules/media_repository_callbacks.md b/docs/modules/media_repository_callbacks.md new file mode 100644 index 0000000000..0cee3384bb --- /dev/null +++ b/docs/modules/media_repository_callbacks.md @@ -0,0 +1,56 @@ +# Media repository callbacks + +Media repository callbacks allow module developers to customise the behaviour of the +media repository on a per user basis. Media repository callbacks can be registered +using the module API's `register_media_repository_callbacks` method. + +The available media repository callbacks are: + +### `get_media_config_for_user` + +_First introduced in Synapse v1.132.0_ + +```python +async def get_media_config_for_user(user_id: str) -> Optional[JsonDict] +``` + +Called when processing a request from a client for the +[media config endpoint](https://spec.matrix.org/latest/client-server-api/#get_matrixclientv1mediaconfig). + +The arguments passed to this callback are: + +* `user_id`: The Matrix user ID of the user (e.g. `@alice:example.com`) making the request. + +If the callback returns a dictionary then it will be used as the body of the response to the +client. + +If multiple modules implement this callback, they will be considered in order. If a +callback returns `None`, Synapse falls through to the next one. The value of the first +callback that does not return `None` will be used. If this happens, Synapse will not call +any of the subsequent implementations of this callback. + +If no module returns a non-`None` value then the default media config will be returned. + +### `is_user_allowed_to_upload_media_of_size` + +_First introduced in Synapse v1.132.0_ + +```python +async def is_user_allowed_to_upload_media_of_size(user_id: str, size: int) -> bool +``` + +Called before media is accepted for upload from a user, in case the module needs to +enforce a different limit for the particular user. + +The arguments passed to this callback are: + +* `user_id`: The Matrix user ID of the user (e.g. `@alice:example.com`) making the request. +* `size`: The size in bytes of media that is being requested to upload. + +If the module returns `False`, the current request will be denied with the error code +`M_TOO_LARGE` and the HTTP status code 413. + +If multiple modules implement this callback, they will be considered in order. If a callback +returns `True`, Synapse falls through to the next one. The value of the first callback that +returns `False` will be used. If this happens, Synapse will not call any of the subsequent +implementations of this callback. diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 7834da759c..4ecdf0f3bb 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -90,6 +90,10 @@ from synapse.module_api.callbacks.account_validity_callbacks import ( ON_USER_LOGIN_CALLBACK, ON_USER_REGISTRATION_CALLBACK, ) +from synapse.module_api.callbacks.media_repository_callbacks import ( + GET_MEDIA_CONFIG_FOR_USER_CALLBACK, + IS_USER_ALLOWED_TO_UPLOAD_MEDIA_OF_SIZE_CALLBACK, +) from synapse.module_api.callbacks.spamchecker_callbacks import ( CHECK_EVENT_FOR_SPAM_CALLBACK, CHECK_LOGIN_FOR_SPAM_CALLBACK, @@ -360,6 +364,22 @@ class ModuleApi: on_legacy_admin_request=on_legacy_admin_request, ) + def register_media_repository_callbacks( + self, + *, + get_media_config_for_user: Optional[GET_MEDIA_CONFIG_FOR_USER_CALLBACK] = None, + is_user_allowed_to_upload_media_of_size: Optional[ + IS_USER_ALLOWED_TO_UPLOAD_MEDIA_OF_SIZE_CALLBACK + ] = None, + ) -> None: + """Registers callbacks for media repository capabilities. + Added in Synapse v1.x.x. + """ + return self._callbacks.media_repository.register_callbacks( + get_media_config_for_user=get_media_config_for_user, + is_user_allowed_to_upload_media_of_size=is_user_allowed_to_upload_media_of_size, + ) + def register_third_party_rules_callbacks( self, *, diff --git a/synapse/module_api/callbacks/__init__.py b/synapse/module_api/callbacks/__init__.py index c20d9543fb..a36c0fc7c6 100644 --- a/synapse/module_api/callbacks/__init__.py +++ b/synapse/module_api/callbacks/__init__.py @@ -27,6 +27,9 @@ if TYPE_CHECKING: from synapse.module_api.callbacks.account_validity_callbacks import ( AccountValidityModuleApiCallbacks, ) +from synapse.module_api.callbacks.media_repository_callbacks import ( + MediaRepositoryModuleApiCallbacks, +) from synapse.module_api.callbacks.spamchecker_callbacks import ( SpamCheckerModuleApiCallbacks, ) @@ -38,5 +41,6 @@ from synapse.module_api.callbacks.third_party_event_rules_callbacks import ( class ModuleApiCallbacks: def __init__(self, hs: "HomeServer") -> None: self.account_validity = AccountValidityModuleApiCallbacks() + self.media_repository = MediaRepositoryModuleApiCallbacks(hs) self.spam_checker = SpamCheckerModuleApiCallbacks(hs) self.third_party_event_rules = ThirdPartyEventRulesModuleApiCallbacks(hs) diff --git a/synapse/module_api/callbacks/media_repository_callbacks.py b/synapse/module_api/callbacks/media_repository_callbacks.py new file mode 100644 index 0000000000..6fa80a8eab --- /dev/null +++ b/synapse/module_api/callbacks/media_repository_callbacks.py @@ -0,0 +1,76 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2025 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . +# + +import logging +from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional + +from synapse.types import JsonDict +from synapse.util.async_helpers import delay_cancellation +from synapse.util.metrics import Measure + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + +GET_MEDIA_CONFIG_FOR_USER_CALLBACK = Callable[[str], Awaitable[Optional[JsonDict]]] + +IS_USER_ALLOWED_TO_UPLOAD_MEDIA_OF_SIZE_CALLBACK = Callable[[str, int], Awaitable[bool]] + + +class MediaRepositoryModuleApiCallbacks: + def __init__(self, hs: "HomeServer") -> None: + self.clock = hs.get_clock() + self._get_media_config_for_user_callbacks: List[ + GET_MEDIA_CONFIG_FOR_USER_CALLBACK + ] = [] + self._is_user_allowed_to_upload_media_of_size_callbacks: List[ + IS_USER_ALLOWED_TO_UPLOAD_MEDIA_OF_SIZE_CALLBACK + ] = [] + + def register_callbacks( + self, + get_media_config_for_user: Optional[GET_MEDIA_CONFIG_FOR_USER_CALLBACK] = None, + is_user_allowed_to_upload_media_of_size: Optional[ + IS_USER_ALLOWED_TO_UPLOAD_MEDIA_OF_SIZE_CALLBACK + ] = None, + ) -> None: + """Register callbacks from module for each hook.""" + if get_media_config_for_user is not None: + self._get_media_config_for_user_callbacks.append(get_media_config_for_user) + + if is_user_allowed_to_upload_media_of_size is not None: + self._is_user_allowed_to_upload_media_of_size_callbacks.append( + is_user_allowed_to_upload_media_of_size + ) + + async def get_media_config_for_user(self, user_id: str) -> Optional[JsonDict]: + for callback in self._get_media_config_for_user_callbacks: + with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"): + res: Optional[JsonDict] = await delay_cancellation(callback(user_id)) + if res: + return res + + return None + + async def is_user_allowed_to_upload_media_of_size( + self, user_id: str, size: int + ) -> bool: + for callback in self._is_user_allowed_to_upload_media_of_size_callbacks: + with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"): + res: bool = await delay_cancellation(callback(user_id, size)) + if not res: + return res + + return True diff --git a/synapse/rest/client/media.py b/synapse/rest/client/media.py index 25b302370f..4c044ae900 100644 --- a/synapse/rest/client/media.py +++ b/synapse/rest/client/media.py @@ -102,10 +102,17 @@ class MediaConfigResource(RestServlet): self.clock = hs.get_clock() self.auth = hs.get_auth() self.limits_dict = {"m.upload.size": config.media.max_upload_size} + self.media_repository_callbacks = hs.get_module_api_callbacks().media_repository async def on_GET(self, request: SynapseRequest) -> None: - await self.auth.get_user_by_req(request) - respond_with_json(request, 200, self.limits_dict, send_cors=True) + requester = await self.auth.get_user_by_req(request) + user_specific_config = ( + await self.media_repository_callbacks.get_media_config_for_user( + requester.user.to_string(), + ) + ) + response = user_specific_config if user_specific_config else self.limits_dict + respond_with_json(request, 200, response, send_cors=True) class ThumbnailResource(RestServlet): diff --git a/synapse/rest/media/config_resource.py b/synapse/rest/media/config_resource.py index 80462d65d3..b014e91bdb 100644 --- a/synapse/rest/media/config_resource.py +++ b/synapse/rest/media/config_resource.py @@ -40,7 +40,14 @@ class MediaConfigResource(RestServlet): self.clock = hs.get_clock() self.auth = hs.get_auth() self.limits_dict = {"m.upload.size": config.media.max_upload_size} + self.media_repository_callbacks = hs.get_module_api_callbacks().media_repository async def on_GET(self, request: SynapseRequest) -> None: - await self.auth.get_user_by_req(request) - respond_with_json(request, 200, self.limits_dict, send_cors=True) + requester = await self.auth.get_user_by_req(request) + user_specific_config = ( + await self.media_repository_callbacks.get_media_config_for_user( + requester.user.to_string() + ) + ) + response = user_specific_config if user_specific_config else self.limits_dict + respond_with_json(request, 200, response, send_cors=True) diff --git a/synapse/rest/media/upload_resource.py b/synapse/rest/media/upload_resource.py index 359d006f04..572f7897fd 100644 --- a/synapse/rest/media/upload_resource.py +++ b/synapse/rest/media/upload_resource.py @@ -50,9 +50,12 @@ class BaseUploadServlet(RestServlet): self.server_name = hs.hostname self.auth = hs.get_auth() self.max_upload_size = hs.config.media.max_upload_size + self._media_repository_callbacks = ( + hs.get_module_api_callbacks().media_repository + ) - def _get_file_metadata( - self, request: SynapseRequest + async def _get_file_metadata( + self, request: SynapseRequest, user_id: str ) -> Tuple[int, Optional[str], str]: raw_content_length = request.getHeader("Content-Length") if raw_content_length is None: @@ -67,7 +70,14 @@ class BaseUploadServlet(RestServlet): code=413, errcode=Codes.TOO_LARGE, ) - + if not await self._media_repository_callbacks.is_user_allowed_to_upload_media_of_size( + user_id, content_length + ): + raise SynapseError( + msg="Upload request body is too large", + code=413, + errcode=Codes.TOO_LARGE, + ) args: Dict[bytes, List[bytes]] = request.args # type: ignore upload_name_bytes = parse_bytes_from_args(args, "filename") if upload_name_bytes: @@ -104,7 +114,9 @@ class UploadServlet(BaseUploadServlet): async def on_POST(self, request: SynapseRequest) -> None: requester = await self.auth.get_user_by_req(request) - content_length, upload_name, media_type = self._get_file_metadata(request) + content_length, upload_name, media_type = await self._get_file_metadata( + request, requester.user.to_string() + ) try: content: IO = request.content # type: ignore @@ -152,7 +164,9 @@ class AsyncUploadServlet(BaseUploadServlet): async with lock: await self.media_repo.verify_can_upload(media_id, requester.user) - content_length, upload_name, media_type = self._get_file_metadata(request) + content_length, upload_name, media_type = await self._get_file_metadata( + request, requester.user.to_string() + ) try: content: IO = request.content # type: ignore diff --git a/tests/media/test_media_storage.py b/tests/media/test_media_storage.py index 31dc32d67e..2f7cf4569b 100644 --- a/tests/media/test_media_storage.py +++ b/tests/media/test_media_storage.py @@ -1360,3 +1360,42 @@ class MediaHashesTestCase(unittest.HomeserverTestCase): store_media.sha256, SMALL_PNG_SHA256, ) + + +class MediaRepoSizeModuleCallbackTestCase(unittest.HomeserverTestCase): + servlets = [ + login.register_servlets, + admin.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.user = self.register_user("user", "pass") + self.tok = self.login("user", "pass") + self.mock_result = True # Allow all uploads by default + + hs.get_module_api().register_media_repository_callbacks( + is_user_allowed_to_upload_media_of_size=self.is_user_allowed_to_upload_media_of_size, + ) + + def create_resource_dict(self) -> Dict[str, Resource]: + resources = super().create_resource_dict() + resources["/_matrix/media"] = self.hs.get_media_repository_resource() + return resources + + async def is_user_allowed_to_upload_media_of_size( + self, user_id: str, size: int + ) -> bool: + self.last_user_id = user_id + self.last_size = size + return self.mock_result + + def test_upload_allowed(self) -> None: + self.helper.upload_media(SMALL_PNG, tok=self.tok, expect_code=200) + assert self.last_user_id == self.user + assert self.last_size == len(SMALL_PNG) + + def test_upload_not_allowed(self) -> None: + self.mock_result = False + self.helper.upload_media(SMALL_PNG, tok=self.tok, expect_code=413) + assert self.last_user_id == self.user + assert self.last_size == len(SMALL_PNG) diff --git a/tests/rest/client/test_media.py b/tests/rest/client/test_media.py index 9ad8ecf1cd..6ee761e44b 100644 --- a/tests/rest/client/test_media.py +++ b/tests/rest/client/test_media.py @@ -1618,6 +1618,63 @@ class MediaConfigTest(unittest.HomeserverTestCase): ) +class MediaConfigModuleCallbackTestCase(unittest.HomeserverTestCase): + servlets = [ + media.register_servlets, + admin.register_servlets, + login.register_servlets, + ] + + def make_homeserver( + self, reactor: ThreadedMemoryReactorClock, clock: Clock + ) -> HomeServer: + config = self.default_config() + + self.storage_path = self.mktemp() + self.media_store_path = self.mktemp() + os.mkdir(self.storage_path) + os.mkdir(self.media_store_path) + config["media_store_path"] = self.media_store_path + + provider_config = { + "module": "synapse.media.storage_provider.FileStorageProviderBackend", + "store_local": True, + "store_synchronous": False, + "store_remote": True, + "config": {"directory": self.storage_path}, + } + + config["media_storage_providers"] = [provider_config] + + return self.setup_test_homeserver(config=config) + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.user = self.register_user("user", "password") + self.tok = self.login("user", "password") + + hs.get_module_api().register_media_repository_callbacks( + get_media_config_for_user=self.get_media_config_for_user, + ) + + async def get_media_config_for_user( + self, + user_id: str, + ) -> Optional[JsonDict]: + # We echo back the user_id and set a custom upload size. + return {"m.upload.size": 1024, "user_id": user_id} + + def test_media_config(self) -> None: + channel = self.make_request( + "GET", + "/_matrix/client/v1/media/config", + shorthand=False, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["m.upload.size"], 1024) + self.assertEqual(channel.json_body["user_id"], self.user) + + class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase): servlets = [ media.register_servlets,