Fix bad deferred logcontext handling (#19180)

These aren't really something personally experienced but I just went
around the codebase looking for all of the Deferred `.callback`,
`.errback`, and `.cancel` and wrapped them with
`PreserveLoggingContext()`

Spawning from wanting to solve
https://github.com/element-hq/synapse/issues/19165 but unconfirmed
whether this has any effect.

To explain the fix, see the [*Deferred
callbacks*](3b59ac3b69/docs/log_contexts.md (deferred-callbacks))
section of our logcontext docs for more info (specifically using
solution 2).
This commit is contained in:
Eric Eastwood 2025-11-14 11:21:15 -06:00 committed by GitHub
parent 8da8d4b4f5
commit edc0de9fa0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 81 additions and 49 deletions

1
changelog.d/19180.misc Normal file
View File

@ -0,0 +1 @@
Fix bad deferred logcontext handling across the codebase.

View File

@ -184,9 +184,7 @@ class WorkerLocksHandler:
locks: Collection[WaitingLock | WaitingMultiLock],
) -> None:
for lock in locks:
deferred = lock.deferred
if not deferred.called:
deferred.callback(None)
lock.release_lock()
self._clock.call_later(
0,
@ -215,6 +213,12 @@ class WaitingLock:
lambda: start_active_span("WaitingLock.lock")
)
def release_lock(self) -> None:
"""Release the lock (by resolving the deferred)"""
if not self.deferred.called:
with PreserveLoggingContext():
self.deferred.callback(None)
async def __aenter__(self) -> None:
self._lock_span.__enter__()
@ -298,6 +302,12 @@ class WaitingMultiLock:
lambda: start_active_span("WaitingLock.lock")
)
def release_lock(self) -> None:
"""Release the lock (by resolving the deferred)"""
if not self.deferred.called:
with PreserveLoggingContext():
self.deferred.callback(None)
async def __aenter__(self) -> None:
self._lock_span.__enter__()

View File

@ -77,7 +77,11 @@ from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_u
from synapse.http.proxyagent import ProxyAgent
from synapse.http.replicationagent import ReplicationAgent
from synapse.http.types import QueryParams
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.context import (
PreserveLoggingContext,
make_deferred_yieldable,
run_in_background,
)
from synapse.logging.opentracing import set_tag, start_active_span, tags
from synapse.metrics import SERVER_NAME_LABEL
from synapse.types import ISynapseReactor, StrSequence
@ -1036,7 +1040,8 @@ class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
Report a max size exceed error and disconnect the first time this is called.
"""
if not self.deferred.called:
self.deferred.errback(BodyExceededMaxSize())
with PreserveLoggingContext():
self.deferred.errback(BodyExceededMaxSize())
# Close the connection (forcefully) since all the data will get
# discarded anyway.
assert self.transport is not None
@ -1135,7 +1140,8 @@ class _MultipartParserProtocol(protocol.Protocol):
logger.warning(
"Exception encountered writing file data to stream: %s", e
)
self.deferred.errback()
with PreserveLoggingContext():
self.deferred.errback()
self.file_length += end - start
callbacks: "multipart.MultipartCallbacks" = {
@ -1147,7 +1153,8 @@ class _MultipartParserProtocol(protocol.Protocol):
self.total_length += len(incoming_data)
if self.max_length is not None and self.total_length >= self.max_length:
self.deferred.errback(BodyExceededMaxSize())
with PreserveLoggingContext():
self.deferred.errback(BodyExceededMaxSize())
# Close the connection (forcefully) since all the data will get
# discarded anyway.
assert self.transport is not None
@ -1157,7 +1164,8 @@ class _MultipartParserProtocol(protocol.Protocol):
self.parser.write(incoming_data)
except Exception as e:
logger.warning("Exception writing to multipart parser: %s", e)
self.deferred.errback()
with PreserveLoggingContext():
self.deferred.errback()
return
def connectionLost(self, reason: Failure = connectionDone) -> None:
@ -1167,9 +1175,11 @@ class _MultipartParserProtocol(protocol.Protocol):
if reason.check(ResponseDone):
self.multipart_response.length = self.file_length
self.deferred.callback(self.multipart_response)
with PreserveLoggingContext():
self.deferred.callback(self.multipart_response)
else:
self.deferred.errback(reason)
with PreserveLoggingContext():
self.deferred.errback(reason)
class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
@ -1193,7 +1203,8 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
try:
self.stream.write(data)
except Exception:
self.deferred.errback()
with PreserveLoggingContext():
self.deferred.errback()
return
self.length += len(data)
@ -1201,7 +1212,8 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
# connection. dataReceived might be called again if data was received
# in the meantime.
if self.max_size is not None and self.length >= self.max_size:
self.deferred.errback(BodyExceededMaxSize())
with PreserveLoggingContext():
self.deferred.errback(BodyExceededMaxSize())
# Close the connection (forcefully) since all the data will get
# discarded anyway.
assert self.transport is not None
@ -1213,7 +1225,8 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
return
if reason.check(ResponseDone):
self.deferred.callback(self.length)
with PreserveLoggingContext():
self.deferred.callback(self.length)
elif reason.check(PotentialDataLoss):
# This applies to requests which don't set `Content-Length` or a
# `Transfer-Encoding` in the response because in this case the end of the
@ -1222,9 +1235,11 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
# behavior is expected of some servers (like YouTube), let's ignore it.
# Stolen from https://github.com/twisted/treq/pull/49/files
# http://twistedmatrix.com/trac/ticket/4840
self.deferred.callback(self.length)
with PreserveLoggingContext():
self.deferred.callback(self.length)
else:
self.deferred.errback(reason)
with PreserveLoggingContext():
self.deferred.errback(reason)
def read_body_with_max_size(

View File

@ -41,6 +41,8 @@ from twisted.internet.protocol import ClientFactory, connectionDone
from twisted.python.failure import Failure
from twisted.web import http
from synapse.logging.context import PreserveLoggingContext
logger = logging.getLogger(__name__)
@ -176,14 +178,16 @@ class HTTPProxiedClientFactory(protocol.ClientFactory):
def clientConnectionFailed(self, connector: IConnector, reason: Failure) -> None:
logger.debug("Connection to proxy failed: %s", reason)
if not self.on_connection.called:
self.on_connection.errback(reason)
with PreserveLoggingContext():
self.on_connection.errback(reason)
if isinstance(self.wrapped_factory, ClientFactory):
return self.wrapped_factory.clientConnectionFailed(connector, reason)
def clientConnectionLost(self, connector: IConnector, reason: Failure) -> None:
logger.debug("Connection to proxy lost: %s", reason)
if not self.on_connection.called:
self.on_connection.errback(reason)
with PreserveLoggingContext():
self.on_connection.errback(reason)
if isinstance(self.wrapped_factory, ClientFactory):
return self.wrapped_factory.clientConnectionLost(connector, reason)
@ -238,14 +242,16 @@ class HTTPConnectProtocol(protocol.Protocol):
self.http_setup_client.connectionLost(reason)
if not self.connected_deferred.called:
self.connected_deferred.errback(reason)
with PreserveLoggingContext():
self.connected_deferred.errback(reason)
def proxyConnected(self, _: Union[None, "defer.Deferred[None]"]) -> None:
self.wrapped_connection_started = True
assert self.transport is not None
self.wrapped_protocol.makeConnection(self.transport)
self.connected_deferred.callback(self.wrapped_protocol)
with PreserveLoggingContext():
self.connected_deferred.callback(self.wrapped_protocol)
# Get any pending data from the http buf and forward it to the original protocol
buf = self.http_setup_client.clearLineBuffer()
@ -303,7 +309,8 @@ class HTTPConnectSetupClient(http.HTTPClient):
def handleEndHeaders(self) -> None:
logger.debug("End Headers")
self.on_connected.callback(None)
with PreserveLoggingContext():
self.on_connected.callback(None)
def handleResponse(self, body: bytes) -> None:
pass

View File

@ -45,6 +45,7 @@ from synapse.api.errors import Codes, cs_error
from synapse.http.server import finish_request, respond_with_json
from synapse.http.site import SynapseRequest
from synapse.logging.context import (
PreserveLoggingContext,
defer_to_threadpool,
make_deferred_yieldable,
run_in_background,
@ -753,9 +754,10 @@ class ThreadedFileSender:
self.wakeup_event.set()
if not self.deferred.called:
self.deferred.errback(
ConsumerRequestedStopError("Consumer asked us to stop producing")
)
with PreserveLoggingContext():
self.deferred.errback(
ConsumerRequestedStopError("Consumer asked us to stop producing")
)
async def start_read_loop(self) -> None:
"""This is the loop that drives reading/writing"""
@ -809,7 +811,8 @@ class ThreadedFileSender:
self.consumer = None
if not self.deferred.called:
self.deferred.errback(failure)
with PreserveLoggingContext():
self.deferred.errback(failure)
def _finish(self) -> None:
"""Called when we have finished writing (either on success or
@ -823,4 +826,5 @@ class ThreadedFileSender:
self.consumer = None
if not self.deferred.called:
self.deferred.callback(None)
with PreserveLoggingContext():
self.deferred.callback(None)

View File

@ -813,7 +813,8 @@ def timeout_deferred(
# will have errbacked new_d, but in case it hasn't, errback it now.
if not new_d.called:
new_d.errback(defer.TimeoutError("Timed out after %gs" % (timeout,)))
with PreserveLoggingContext():
new_d.errback(defer.TimeoutError("Timed out after %gs" % (timeout,)))
# We don't track these calls since they are short.
delayed_call = clock.call_later(
@ -840,11 +841,13 @@ def timeout_deferred(
def success_cb(val: _T) -> None:
if not new_d.called:
new_d.callback(val)
with PreserveLoggingContext():
new_d.callback(val)
def failure_cb(val: Failure) -> None:
if not new_d.called:
new_d.errback(val)
with PreserveLoggingContext():
new_d.errback(val)
deferred.addCallbacks(success_cb, failure_cb)
@ -946,7 +949,8 @@ def delay_cancellation(awaitable: Awaitable[T]) -> Awaitable[T]:
# propagating. we then `unpause` it once the wrapped deferred completes, to
# propagate the exception.
new_deferred.pause()
new_deferred.errback(Failure(CancelledError()))
with PreserveLoggingContext():
new_deferred.errback(Failure(CancelledError()))
deferred.addBoth(lambda _: new_deferred.unpause())
@ -978,15 +982,6 @@ class AwakenableSleeper:
"""Sleep for the given number of milliseconds, or return if the given
`name` is explicitly woken up.
"""
# Create a deferred that gets called in N seconds
sleep_deferred: "defer.Deferred[None]" = defer.Deferred()
call = self._clock.call_later(
delay_ms / 1000,
sleep_deferred.callback,
None,
)
# Create a deferred that will get called if `wake` is called with
# the same `name`.
stream_set = self._streams.setdefault(name, set())
@ -996,13 +991,14 @@ class AwakenableSleeper:
try:
# Wait for either the delay or for `wake` to be called.
await make_deferred_yieldable(
defer.DeferredList(
[sleep_deferred, notify_deferred],
fireOnOneCallback=True,
fireOnOneErrback=True,
consumeErrors=True,
timeout_deferred(
deferred=stop_cancellation(notify_deferred),
timeout=delay_ms / 1000,
clock=self._clock,
)
)
except defer.TimeoutError:
pass
finally:
# Clean up the state
curr_stream_set = self._streams.get(name)
@ -1011,10 +1007,6 @@ class AwakenableSleeper:
if len(curr_stream_set) == 0:
self._streams.pop(name)
# Cancel the sleep if we were woken up
if call.active():
call.cancel()
class DeferredEvent:
"""Like threading.Event but for async code"""

View File

@ -39,6 +39,7 @@ from prometheus_client import Gauge
from twisted.internet import defer
from twisted.python.failure import Failure
from synapse.logging.context import PreserveLoggingContext
from synapse.metrics import SERVER_NAME_LABEL
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.lrucache import LruCache
@ -514,7 +515,8 @@ class CacheMultipleEntries(CacheEntry[KT, VT]):
cache._completed_callback(value, self, key)
if self._deferred:
self._deferred.callback(result)
with PreserveLoggingContext():
self._deferred.callback(result)
def error_bulk(
self, cache: DeferredCache[KT, VT], keys: Collection[KT], failure: Failure
@ -524,4 +526,5 @@ class CacheMultipleEntries(CacheEntry[KT, VT]):
cache._error_callback(failure, self, key)
if self._deferred:
self._deferred.errback(failure)
with PreserveLoggingContext():
self._deferred.errback(failure)