mirror of
				https://github.com/element-hq/synapse.git
				synced 2025-11-04 00:01:22 -05:00 
			
		
		
		
	Rewrite store_server_verify_key to store several keys at once (#5234)
Storing server keys hammered the database a bit. This replaces the implementation which stored a single key, with one which can do many updates at once.
This commit is contained in:
		
							parent
							
								
									85d1e03b9d
								
							
						
					
					
						commit
						2e052110ee
					
				
							
								
								
									
										1
									
								
								changelog.d/5234.misc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								changelog.d/5234.misc
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1 @@
 | 
			
		||||
Rewrite store_server_verify_key to store several keys at once.
 | 
			
		||||
@ -453,10 +453,11 @@ class Keyring(object):
 | 
			
		||||
            raise_from(KeyLookupError("Remote server returned an error"), e)
 | 
			
		||||
 | 
			
		||||
        keys = {}
 | 
			
		||||
        added_keys = []
 | 
			
		||||
 | 
			
		||||
        responses = query_response["server_keys"]
 | 
			
		||||
        time_now_ms = self.clock.time_msec()
 | 
			
		||||
 | 
			
		||||
        for response in responses:
 | 
			
		||||
        for response in query_response["server_keys"]:
 | 
			
		||||
            if (
 | 
			
		||||
                u"signatures" not in response
 | 
			
		||||
                or perspective_name not in response[u"signatures"]
 | 
			
		||||
@ -492,21 +493,13 @@ class Keyring(object):
 | 
			
		||||
            )
 | 
			
		||||
            server_name = response["server_name"]
 | 
			
		||||
 | 
			
		||||
            added_keys.extend(
 | 
			
		||||
                (server_name, key_id, key) for key_id, key in processed_response.items()
 | 
			
		||||
            )
 | 
			
		||||
            keys.setdefault(server_name, {}).update(processed_response)
 | 
			
		||||
 | 
			
		||||
        yield logcontext.make_deferred_yieldable(
 | 
			
		||||
            defer.gatherResults(
 | 
			
		||||
                [
 | 
			
		||||
                    run_in_background(
 | 
			
		||||
                        self.store_keys,
 | 
			
		||||
                        server_name=server_name,
 | 
			
		||||
                        from_server=perspective_name,
 | 
			
		||||
                        verify_keys=response_keys,
 | 
			
		||||
                    )
 | 
			
		||||
                    for server_name, response_keys in keys.items()
 | 
			
		||||
                ],
 | 
			
		||||
                consumeErrors=True,
 | 
			
		||||
            ).addErrback(unwrapFirstError)
 | 
			
		||||
        yield self.store.store_server_verify_keys(
 | 
			
		||||
            perspective_name, time_now_ms, added_keys
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        defer.returnValue(keys)
 | 
			
		||||
@ -519,6 +512,7 @@ class Keyring(object):
 | 
			
		||||
            if requested_key_id in keys:
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            time_now_ms = self.clock.time_msec()
 | 
			
		||||
            try:
 | 
			
		||||
                response = yield self.client.get_json(
 | 
			
		||||
                    destination=server_name,
 | 
			
		||||
@ -548,12 +542,13 @@ class Keyring(object):
 | 
			
		||||
                requested_ids=[requested_key_id],
 | 
			
		||||
                response_json=response,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            yield self.store.store_server_verify_keys(
 | 
			
		||||
                server_name,
 | 
			
		||||
                time_now_ms,
 | 
			
		||||
                ((server_name, key_id, key) for key_id, key in response_keys.items()),
 | 
			
		||||
            )
 | 
			
		||||
            keys.update(response_keys)
 | 
			
		||||
 | 
			
		||||
        yield self.store_keys(
 | 
			
		||||
            server_name=server_name, from_server=server_name, verify_keys=keys
 | 
			
		||||
        )
 | 
			
		||||
        defer.returnValue({server_name: keys})
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
@ -650,32 +645,6 @@ class Keyring(object):
 | 
			
		||||
 | 
			
		||||
        defer.returnValue(response_keys)
 | 
			
		||||
 | 
			
		||||
    def store_keys(self, server_name, from_server, verify_keys):
 | 
			
		||||
        """Store a collection of verify keys for a given server
 | 
			
		||||
        Args:
 | 
			
		||||
            server_name(str): The name of the server the keys are for.
 | 
			
		||||
            from_server(str): The server the keys were downloaded from.
 | 
			
		||||
            verify_keys(dict): A mapping of key_id to VerifyKey.
 | 
			
		||||
        Returns:
 | 
			
		||||
            A deferred that completes when the keys are stored.
 | 
			
		||||
        """
 | 
			
		||||
        # TODO(markjh): Store whether the keys have expired.
 | 
			
		||||
        return logcontext.make_deferred_yieldable(
 | 
			
		||||
            defer.gatherResults(
 | 
			
		||||
                [
 | 
			
		||||
                    run_in_background(
 | 
			
		||||
                        self.store.store_server_verify_key,
 | 
			
		||||
                        server_name,
 | 
			
		||||
                        server_name,
 | 
			
		||||
                        key.time_added,
 | 
			
		||||
                        key,
 | 
			
		||||
                    )
 | 
			
		||||
                    for key_id, key in verify_keys.items()
 | 
			
		||||
                ],
 | 
			
		||||
                consumeErrors=True,
 | 
			
		||||
            ).addErrback(unwrapFirstError)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@defer.inlineCallbacks
 | 
			
		||||
def _handle_key_deferred(verify_request):
 | 
			
		||||
 | 
			
		||||
@ -84,38 +84,51 @@ class KeyStore(SQLBaseStore):
 | 
			
		||||
 | 
			
		||||
        return self.runInteraction("get_server_verify_keys", _txn)
 | 
			
		||||
 | 
			
		||||
    def store_server_verify_key(
 | 
			
		||||
        self, server_name, from_server, time_now_ms, verify_key
 | 
			
		||||
    ):
 | 
			
		||||
        """Stores a NACL verification key for the given server.
 | 
			
		||||
    def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys):
 | 
			
		||||
        """Stores NACL verification keys for remote servers.
 | 
			
		||||
        Args:
 | 
			
		||||
            server_name (str): The name of the server.
 | 
			
		||||
            from_server (str): Where the verification key was looked up
 | 
			
		||||
            time_now_ms (int): The time now in milliseconds
 | 
			
		||||
            verify_key (nacl.signing.VerifyKey): The NACL verify key.
 | 
			
		||||
            from_server (str): Where the verification keys were looked up
 | 
			
		||||
            ts_added_ms (int): The time to record that the key was added
 | 
			
		||||
            verify_keys (iterable[tuple[str, str, nacl.signing.VerifyKey]]):
 | 
			
		||||
                keys to be stored. Each entry is a triplet of
 | 
			
		||||
                (server_name, key_id, key).
 | 
			
		||||
        """
 | 
			
		||||
        key_id = "%s:%s" % (verify_key.alg, verify_key.version)
 | 
			
		||||
 | 
			
		||||
        # XXX fix this to not need a lock (#3819)
 | 
			
		||||
        def _txn(txn):
 | 
			
		||||
            self._simple_upsert_txn(
 | 
			
		||||
                txn,
 | 
			
		||||
                table="server_signature_keys",
 | 
			
		||||
                keyvalues={"server_name": server_name, "key_id": key_id},
 | 
			
		||||
                values={
 | 
			
		||||
                    "from_server": from_server,
 | 
			
		||||
                    "ts_added_ms": time_now_ms,
 | 
			
		||||
                    "verify_key": db_binary_type(verify_key.encode()),
 | 
			
		||||
                },
 | 
			
		||||
        key_values = []
 | 
			
		||||
        value_values = []
 | 
			
		||||
        invalidations = []
 | 
			
		||||
        for server_name, key_id, verify_key in verify_keys:
 | 
			
		||||
            key_values.append((server_name, key_id))
 | 
			
		||||
            value_values.append(
 | 
			
		||||
                (
 | 
			
		||||
                    from_server,
 | 
			
		||||
                    ts_added_ms,
 | 
			
		||||
                    db_binary_type(verify_key.encode()),
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
            # invalidate takes a tuple corresponding to the params of
 | 
			
		||||
            # _get_server_verify_key. _get_server_verify_key only takes one
 | 
			
		||||
            # param, which is itself the 2-tuple (server_name, key_id).
 | 
			
		||||
            txn.call_after(
 | 
			
		||||
                self._get_server_verify_key.invalidate, ((server_name, key_id),)
 | 
			
		||||
            )
 | 
			
		||||
            invalidations.append((server_name, key_id))
 | 
			
		||||
 | 
			
		||||
        return self.runInteraction("store_server_verify_key", _txn)
 | 
			
		||||
        def _invalidate(res):
 | 
			
		||||
            f = self._get_server_verify_key.invalidate
 | 
			
		||||
            for i in invalidations:
 | 
			
		||||
                f((i, ))
 | 
			
		||||
            return res
 | 
			
		||||
 | 
			
		||||
        return self.runInteraction(
 | 
			
		||||
            "store_server_verify_keys",
 | 
			
		||||
            self._simple_upsert_many_txn,
 | 
			
		||||
            table="server_signature_keys",
 | 
			
		||||
            key_names=("server_name", "key_id"),
 | 
			
		||||
            key_values=key_values,
 | 
			
		||||
            value_names=(
 | 
			
		||||
                "from_server",
 | 
			
		||||
                "ts_added_ms",
 | 
			
		||||
                "verify_key",
 | 
			
		||||
            ),
 | 
			
		||||
            value_values=value_values,
 | 
			
		||||
        ).addCallback(_invalidate)
 | 
			
		||||
 | 
			
		||||
    def store_server_keys_json(
 | 
			
		||||
        self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes
 | 
			
		||||
 | 
			
		||||
@ -192,8 +192,18 @@ class KeyringTestCase(unittest.HomeserverTestCase):
 | 
			
		||||
        kr = keyring.Keyring(self.hs)
 | 
			
		||||
 | 
			
		||||
        key1 = signedjson.key.generate_signing_key(1)
 | 
			
		||||
        r = self.hs.datastore.store_server_verify_key(
 | 
			
		||||
            "server9", "", time.time() * 1000, signedjson.key.get_verify_key(key1)
 | 
			
		||||
        key1_id = "%s:%s" % (key1.alg, key1.version)
 | 
			
		||||
 | 
			
		||||
        r = self.hs.datastore.store_server_verify_keys(
 | 
			
		||||
            "server9",
 | 
			
		||||
            time.time() * 1000,
 | 
			
		||||
            [
 | 
			
		||||
                (
 | 
			
		||||
                    "server9",
 | 
			
		||||
                    key1_id,
 | 
			
		||||
                    signedjson.key.get_verify_key(key1),
 | 
			
		||||
                ),
 | 
			
		||||
            ],
 | 
			
		||||
        )
 | 
			
		||||
        self.get_success(r)
 | 
			
		||||
        json1 = {}
 | 
			
		||||
 | 
			
		||||
@ -31,23 +31,32 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
 | 
			
		||||
    def test_get_server_verify_keys(self):
 | 
			
		||||
        store = self.hs.get_datastore()
 | 
			
		||||
 | 
			
		||||
        d = store.store_server_verify_key("server1", "from_server", 0, KEY_1)
 | 
			
		||||
        self.get_success(d)
 | 
			
		||||
        d = store.store_server_verify_key("server1", "from_server", 0, KEY_2)
 | 
			
		||||
        key_id_1 = "ed25519:key1"
 | 
			
		||||
        key_id_2 = "ed25519:KEY_ID_2"
 | 
			
		||||
        d = store.store_server_verify_keys(
 | 
			
		||||
            "from_server",
 | 
			
		||||
            10,
 | 
			
		||||
            [
 | 
			
		||||
                ("server1", key_id_1, KEY_1),
 | 
			
		||||
                ("server1", key_id_2, KEY_2),
 | 
			
		||||
            ],
 | 
			
		||||
        )
 | 
			
		||||
        self.get_success(d)
 | 
			
		||||
 | 
			
		||||
        d = store.get_server_verify_keys(
 | 
			
		||||
            [
 | 
			
		||||
                ("server1", "ed25519:key1"),
 | 
			
		||||
                ("server1", "ed25519:key2"),
 | 
			
		||||
                ("server1", "ed25519:key3"),
 | 
			
		||||
            ]
 | 
			
		||||
            [("server1", key_id_1), ("server1", key_id_2), ("server1", "ed25519:key3")]
 | 
			
		||||
        )
 | 
			
		||||
        res = self.get_success(d)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(len(res.keys()), 3)
 | 
			
		||||
        self.assertEqual(res[("server1", "ed25519:key1")].version, "key1")
 | 
			
		||||
        self.assertEqual(res[("server1", "ed25519:key2")].version, "key2")
 | 
			
		||||
        res1 = res[("server1", key_id_1)]
 | 
			
		||||
        self.assertEqual(res1, KEY_1)
 | 
			
		||||
        self.assertEqual(res1.version, "key1")
 | 
			
		||||
 | 
			
		||||
        res2 = res[("server1", key_id_2)]
 | 
			
		||||
        self.assertEqual(res2, KEY_2)
 | 
			
		||||
        # version comes from the ID it was stored with
 | 
			
		||||
        self.assertEqual(res2.version, "KEY_ID_2")
 | 
			
		||||
 | 
			
		||||
        # non-existent result gives None
 | 
			
		||||
        self.assertIsNone(res[("server1", "ed25519:key3")])
 | 
			
		||||
@ -60,9 +69,14 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
 | 
			
		||||
        key_id_1 = "ed25519:key1"
 | 
			
		||||
        key_id_2 = "ed25519:key2"
 | 
			
		||||
 | 
			
		||||
        d = store.store_server_verify_key("srv1", "from_server", 0, KEY_1)
 | 
			
		||||
        self.get_success(d)
 | 
			
		||||
        d = store.store_server_verify_key("srv1", "from_server", 0, KEY_2)
 | 
			
		||||
        d = store.store_server_verify_keys(
 | 
			
		||||
            "from_server",
 | 
			
		||||
            0,
 | 
			
		||||
            [
 | 
			
		||||
                ("srv1", key_id_1, KEY_1),
 | 
			
		||||
                ("srv1", key_id_2, KEY_2),
 | 
			
		||||
            ],
 | 
			
		||||
        )
 | 
			
		||||
        self.get_success(d)
 | 
			
		||||
 | 
			
		||||
        d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
 | 
			
		||||
@ -81,7 +95,9 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
 | 
			
		||||
        new_key_2 = signedjson.key.get_verify_key(
 | 
			
		||||
            signedjson.key.generate_signing_key("key2")
 | 
			
		||||
        )
 | 
			
		||||
        d = store.store_server_verify_key("srv1", "from_server", 10, new_key_2)
 | 
			
		||||
        d = store.store_server_verify_keys(
 | 
			
		||||
            "from_server", 10, [("srv1", key_id_2, new_key_2)]
 | 
			
		||||
        )
 | 
			
		||||
        self.get_success(d)
 | 
			
		||||
 | 
			
		||||
        d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user