Fix bad deferred logcontext handling (#19180)

These aren't really something personally experienced but I just went
around the codebase looking for all of the Deferred `.callback`,
`.errback`, and `.cancel` and wrapped them with
`PreserveLoggingContext()`

Spawning from wanting to solve
https://github.com/element-hq/synapse/issues/19165 but unconfirmed
whether this has any effect.

To explain the fix, see the [*Deferred
callbacks*](3b59ac3b69/docs/log_contexts.md (deferred-callbacks))
section of our logcontext docs for more info (specifically using
solution 2).
This commit is contained in:
Eric Eastwood 2025-11-14 11:21:15 -06:00 committed by GitHub
parent 8da8d4b4f5
commit edc0de9fa0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 81 additions and 49 deletions

1
changelog.d/19180.misc Normal file
View File

@ -0,0 +1 @@
Fix bad deferred logcontext handling across the codebase.

View File

@ -184,9 +184,7 @@ class WorkerLocksHandler:
locks: Collection[WaitingLock | WaitingMultiLock], locks: Collection[WaitingLock | WaitingMultiLock],
) -> None: ) -> None:
for lock in locks: for lock in locks:
deferred = lock.deferred lock.release_lock()
if not deferred.called:
deferred.callback(None)
self._clock.call_later( self._clock.call_later(
0, 0,
@ -215,6 +213,12 @@ class WaitingLock:
lambda: start_active_span("WaitingLock.lock") lambda: start_active_span("WaitingLock.lock")
) )
def release_lock(self) -> None:
"""Release the lock (by resolving the deferred)"""
if not self.deferred.called:
with PreserveLoggingContext():
self.deferred.callback(None)
async def __aenter__(self) -> None: async def __aenter__(self) -> None:
self._lock_span.__enter__() self._lock_span.__enter__()
@ -298,6 +302,12 @@ class WaitingMultiLock:
lambda: start_active_span("WaitingLock.lock") lambda: start_active_span("WaitingLock.lock")
) )
def release_lock(self) -> None:
"""Release the lock (by resolving the deferred)"""
if not self.deferred.called:
with PreserveLoggingContext():
self.deferred.callback(None)
async def __aenter__(self) -> None: async def __aenter__(self) -> None:
self._lock_span.__enter__() self._lock_span.__enter__()

View File

@ -77,7 +77,11 @@ from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_u
from synapse.http.proxyagent import ProxyAgent from synapse.http.proxyagent import ProxyAgent
from synapse.http.replicationagent import ReplicationAgent from synapse.http.replicationagent import ReplicationAgent
from synapse.http.types import QueryParams from synapse.http.types import QueryParams
from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.context import (
PreserveLoggingContext,
make_deferred_yieldable,
run_in_background,
)
from synapse.logging.opentracing import set_tag, start_active_span, tags from synapse.logging.opentracing import set_tag, start_active_span, tags
from synapse.metrics import SERVER_NAME_LABEL from synapse.metrics import SERVER_NAME_LABEL
from synapse.types import ISynapseReactor, StrSequence from synapse.types import ISynapseReactor, StrSequence
@ -1036,6 +1040,7 @@ class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
Report a max size exceed error and disconnect the first time this is called. Report a max size exceed error and disconnect the first time this is called.
""" """
if not self.deferred.called: if not self.deferred.called:
with PreserveLoggingContext():
self.deferred.errback(BodyExceededMaxSize()) self.deferred.errback(BodyExceededMaxSize())
# Close the connection (forcefully) since all the data will get # Close the connection (forcefully) since all the data will get
# discarded anyway. # discarded anyway.
@ -1135,6 +1140,7 @@ class _MultipartParserProtocol(protocol.Protocol):
logger.warning( logger.warning(
"Exception encountered writing file data to stream: %s", e "Exception encountered writing file data to stream: %s", e
) )
with PreserveLoggingContext():
self.deferred.errback() self.deferred.errback()
self.file_length += end - start self.file_length += end - start
@ -1147,6 +1153,7 @@ class _MultipartParserProtocol(protocol.Protocol):
self.total_length += len(incoming_data) self.total_length += len(incoming_data)
if self.max_length is not None and self.total_length >= self.max_length: if self.max_length is not None and self.total_length >= self.max_length:
with PreserveLoggingContext():
self.deferred.errback(BodyExceededMaxSize()) self.deferred.errback(BodyExceededMaxSize())
# Close the connection (forcefully) since all the data will get # Close the connection (forcefully) since all the data will get
# discarded anyway. # discarded anyway.
@ -1157,6 +1164,7 @@ class _MultipartParserProtocol(protocol.Protocol):
self.parser.write(incoming_data) self.parser.write(incoming_data)
except Exception as e: except Exception as e:
logger.warning("Exception writing to multipart parser: %s", e) logger.warning("Exception writing to multipart parser: %s", e)
with PreserveLoggingContext():
self.deferred.errback() self.deferred.errback()
return return
@ -1167,8 +1175,10 @@ class _MultipartParserProtocol(protocol.Protocol):
if reason.check(ResponseDone): if reason.check(ResponseDone):
self.multipart_response.length = self.file_length self.multipart_response.length = self.file_length
with PreserveLoggingContext():
self.deferred.callback(self.multipart_response) self.deferred.callback(self.multipart_response)
else: else:
with PreserveLoggingContext():
self.deferred.errback(reason) self.deferred.errback(reason)
@ -1193,6 +1203,7 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
try: try:
self.stream.write(data) self.stream.write(data)
except Exception: except Exception:
with PreserveLoggingContext():
self.deferred.errback() self.deferred.errback()
return return
@ -1201,6 +1212,7 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
# connection. dataReceived might be called again if data was received # connection. dataReceived might be called again if data was received
# in the meantime. # in the meantime.
if self.max_size is not None and self.length >= self.max_size: if self.max_size is not None and self.length >= self.max_size:
with PreserveLoggingContext():
self.deferred.errback(BodyExceededMaxSize()) self.deferred.errback(BodyExceededMaxSize())
# Close the connection (forcefully) since all the data will get # Close the connection (forcefully) since all the data will get
# discarded anyway. # discarded anyway.
@ -1213,6 +1225,7 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
return return
if reason.check(ResponseDone): if reason.check(ResponseDone):
with PreserveLoggingContext():
self.deferred.callback(self.length) self.deferred.callback(self.length)
elif reason.check(PotentialDataLoss): elif reason.check(PotentialDataLoss):
# This applies to requests which don't set `Content-Length` or a # This applies to requests which don't set `Content-Length` or a
@ -1222,8 +1235,10 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
# behavior is expected of some servers (like YouTube), let's ignore it. # behavior is expected of some servers (like YouTube), let's ignore it.
# Stolen from https://github.com/twisted/treq/pull/49/files # Stolen from https://github.com/twisted/treq/pull/49/files
# http://twistedmatrix.com/trac/ticket/4840 # http://twistedmatrix.com/trac/ticket/4840
with PreserveLoggingContext():
self.deferred.callback(self.length) self.deferred.callback(self.length)
else: else:
with PreserveLoggingContext():
self.deferred.errback(reason) self.deferred.errback(reason)

View File

@ -41,6 +41,8 @@ from twisted.internet.protocol import ClientFactory, connectionDone
from twisted.python.failure import Failure from twisted.python.failure import Failure
from twisted.web import http from twisted.web import http
from synapse.logging.context import PreserveLoggingContext
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -176,6 +178,7 @@ class HTTPProxiedClientFactory(protocol.ClientFactory):
def clientConnectionFailed(self, connector: IConnector, reason: Failure) -> None: def clientConnectionFailed(self, connector: IConnector, reason: Failure) -> None:
logger.debug("Connection to proxy failed: %s", reason) logger.debug("Connection to proxy failed: %s", reason)
if not self.on_connection.called: if not self.on_connection.called:
with PreserveLoggingContext():
self.on_connection.errback(reason) self.on_connection.errback(reason)
if isinstance(self.wrapped_factory, ClientFactory): if isinstance(self.wrapped_factory, ClientFactory):
return self.wrapped_factory.clientConnectionFailed(connector, reason) return self.wrapped_factory.clientConnectionFailed(connector, reason)
@ -183,6 +186,7 @@ class HTTPProxiedClientFactory(protocol.ClientFactory):
def clientConnectionLost(self, connector: IConnector, reason: Failure) -> None: def clientConnectionLost(self, connector: IConnector, reason: Failure) -> None:
logger.debug("Connection to proxy lost: %s", reason) logger.debug("Connection to proxy lost: %s", reason)
if not self.on_connection.called: if not self.on_connection.called:
with PreserveLoggingContext():
self.on_connection.errback(reason) self.on_connection.errback(reason)
if isinstance(self.wrapped_factory, ClientFactory): if isinstance(self.wrapped_factory, ClientFactory):
return self.wrapped_factory.clientConnectionLost(connector, reason) return self.wrapped_factory.clientConnectionLost(connector, reason)
@ -238,6 +242,7 @@ class HTTPConnectProtocol(protocol.Protocol):
self.http_setup_client.connectionLost(reason) self.http_setup_client.connectionLost(reason)
if not self.connected_deferred.called: if not self.connected_deferred.called:
with PreserveLoggingContext():
self.connected_deferred.errback(reason) self.connected_deferred.errback(reason)
def proxyConnected(self, _: Union[None, "defer.Deferred[None]"]) -> None: def proxyConnected(self, _: Union[None, "defer.Deferred[None]"]) -> None:
@ -245,6 +250,7 @@ class HTTPConnectProtocol(protocol.Protocol):
assert self.transport is not None assert self.transport is not None
self.wrapped_protocol.makeConnection(self.transport) self.wrapped_protocol.makeConnection(self.transport)
with PreserveLoggingContext():
self.connected_deferred.callback(self.wrapped_protocol) self.connected_deferred.callback(self.wrapped_protocol)
# Get any pending data from the http buf and forward it to the original protocol # Get any pending data from the http buf and forward it to the original protocol
@ -303,6 +309,7 @@ class HTTPConnectSetupClient(http.HTTPClient):
def handleEndHeaders(self) -> None: def handleEndHeaders(self) -> None:
logger.debug("End Headers") logger.debug("End Headers")
with PreserveLoggingContext():
self.on_connected.callback(None) self.on_connected.callback(None)
def handleResponse(self, body: bytes) -> None: def handleResponse(self, body: bytes) -> None:

View File

@ -45,6 +45,7 @@ from synapse.api.errors import Codes, cs_error
from synapse.http.server import finish_request, respond_with_json from synapse.http.server import finish_request, respond_with_json
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.logging.context import ( from synapse.logging.context import (
PreserveLoggingContext,
defer_to_threadpool, defer_to_threadpool,
make_deferred_yieldable, make_deferred_yieldable,
run_in_background, run_in_background,
@ -753,6 +754,7 @@ class ThreadedFileSender:
self.wakeup_event.set() self.wakeup_event.set()
if not self.deferred.called: if not self.deferred.called:
with PreserveLoggingContext():
self.deferred.errback( self.deferred.errback(
ConsumerRequestedStopError("Consumer asked us to stop producing") ConsumerRequestedStopError("Consumer asked us to stop producing")
) )
@ -809,6 +811,7 @@ class ThreadedFileSender:
self.consumer = None self.consumer = None
if not self.deferred.called: if not self.deferred.called:
with PreserveLoggingContext():
self.deferred.errback(failure) self.deferred.errback(failure)
def _finish(self) -> None: def _finish(self) -> None:
@ -823,4 +826,5 @@ class ThreadedFileSender:
self.consumer = None self.consumer = None
if not self.deferred.called: if not self.deferred.called:
with PreserveLoggingContext():
self.deferred.callback(None) self.deferred.callback(None)

View File

@ -813,6 +813,7 @@ def timeout_deferred(
# will have errbacked new_d, but in case it hasn't, errback it now. # will have errbacked new_d, but in case it hasn't, errback it now.
if not new_d.called: if not new_d.called:
with PreserveLoggingContext():
new_d.errback(defer.TimeoutError("Timed out after %gs" % (timeout,))) new_d.errback(defer.TimeoutError("Timed out after %gs" % (timeout,)))
# We don't track these calls since they are short. # We don't track these calls since they are short.
@ -840,10 +841,12 @@ def timeout_deferred(
def success_cb(val: _T) -> None: def success_cb(val: _T) -> None:
if not new_d.called: if not new_d.called:
with PreserveLoggingContext():
new_d.callback(val) new_d.callback(val)
def failure_cb(val: Failure) -> None: def failure_cb(val: Failure) -> None:
if not new_d.called: if not new_d.called:
with PreserveLoggingContext():
new_d.errback(val) new_d.errback(val)
deferred.addCallbacks(success_cb, failure_cb) deferred.addCallbacks(success_cb, failure_cb)
@ -946,6 +949,7 @@ def delay_cancellation(awaitable: Awaitable[T]) -> Awaitable[T]:
# propagating. we then `unpause` it once the wrapped deferred completes, to # propagating. we then `unpause` it once the wrapped deferred completes, to
# propagate the exception. # propagate the exception.
new_deferred.pause() new_deferred.pause()
with PreserveLoggingContext():
new_deferred.errback(Failure(CancelledError())) new_deferred.errback(Failure(CancelledError()))
deferred.addBoth(lambda _: new_deferred.unpause()) deferred.addBoth(lambda _: new_deferred.unpause())
@ -978,15 +982,6 @@ class AwakenableSleeper:
"""Sleep for the given number of milliseconds, or return if the given """Sleep for the given number of milliseconds, or return if the given
`name` is explicitly woken up. `name` is explicitly woken up.
""" """
# Create a deferred that gets called in N seconds
sleep_deferred: "defer.Deferred[None]" = defer.Deferred()
call = self._clock.call_later(
delay_ms / 1000,
sleep_deferred.callback,
None,
)
# Create a deferred that will get called if `wake` is called with # Create a deferred that will get called if `wake` is called with
# the same `name`. # the same `name`.
stream_set = self._streams.setdefault(name, set()) stream_set = self._streams.setdefault(name, set())
@ -996,13 +991,14 @@ class AwakenableSleeper:
try: try:
# Wait for either the delay or for `wake` to be called. # Wait for either the delay or for `wake` to be called.
await make_deferred_yieldable( await make_deferred_yieldable(
defer.DeferredList( timeout_deferred(
[sleep_deferred, notify_deferred], deferred=stop_cancellation(notify_deferred),
fireOnOneCallback=True, timeout=delay_ms / 1000,
fireOnOneErrback=True, clock=self._clock,
consumeErrors=True,
) )
) )
except defer.TimeoutError:
pass
finally: finally:
# Clean up the state # Clean up the state
curr_stream_set = self._streams.get(name) curr_stream_set = self._streams.get(name)
@ -1011,10 +1007,6 @@ class AwakenableSleeper:
if len(curr_stream_set) == 0: if len(curr_stream_set) == 0:
self._streams.pop(name) self._streams.pop(name)
# Cancel the sleep if we were woken up
if call.active():
call.cancel()
class DeferredEvent: class DeferredEvent:
"""Like threading.Event but for async code""" """Like threading.Event but for async code"""

View File

@ -39,6 +39,7 @@ from prometheus_client import Gauge
from twisted.internet import defer from twisted.internet import defer
from twisted.python.failure import Failure from twisted.python.failure import Failure
from synapse.logging.context import PreserveLoggingContext
from synapse.metrics import SERVER_NAME_LABEL from synapse.metrics import SERVER_NAME_LABEL
from synapse.util.async_helpers import ObservableDeferred from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
@ -514,6 +515,7 @@ class CacheMultipleEntries(CacheEntry[KT, VT]):
cache._completed_callback(value, self, key) cache._completed_callback(value, self, key)
if self._deferred: if self._deferred:
with PreserveLoggingContext():
self._deferred.callback(result) self._deferred.callback(result)
def error_bulk( def error_bulk(
@ -524,4 +526,5 @@ class CacheMultipleEntries(CacheEntry[KT, VT]):
cache._error_callback(failure, self, key) cache._error_callback(failure, self, key)
if self._deferred: if self._deferred:
with PreserveLoggingContext():
self._deferred.errback(failure) self._deferred.errback(failure)