Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Claim fallback keys in bulk #16570

Merged
merged 9 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/16570.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve the performance of claiming encryption keys.
10 changes: 10 additions & 0 deletions synapse/storage/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,16 @@ def execute(self, sql: str, parameters: SQLQueryParameters = ()) -> None:
self._do_execute(self.txn.execute, sql, parameters)

def executemany(self, sql: str, *args: Any) -> None:
"""Repeatedly execute the same piece of SQL with different parameters.

See https://peps.python.org/pep-0249/#executemany. Note in particular that

> Use of this method for an operation which produces one or more result sets
> constitutes undefined behavior

so you can't use this for e.g. a SELECT, an UPDATE ... RETURNING, or a
DELETE FROM... RETURNING.
"""
# TODO: we should add a type for *args here. Looking at Cursor.executemany
# and DBAPI2 it ought to be Sequence[_Parameter], but we pass in
# Iterable[Iterable[Any]] in execute_batch and execute_values above, which mypy
Expand Down
60 changes: 60 additions & 0 deletions synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Mapping,
Optional,
Sequence,
Set,
Tuple,
Union,
cast,
Expand Down Expand Up @@ -1260,6 +1261,65 @@ async def claim_e2e_fallback_keys(
Returns:
A map of user ID -> a map device ID -> a map of key ID -> JSON.
"""
if isinstance(self.database_engine, PostgresEngine):
return await self.db_pool.runInteraction(
"_claim_e2e_fallback_keys_bulk",
self._claim_e2e_fallback_keys_bulk_txn,
query_list,
db_autocommit=True,
)
# Use an UPDATE FROM... RETURNING combined with a VALUES block to do
# everything in one query. Note: this is also supported in SQLite 3.33.0,
# (see https://www.sqlite.org/lang_update.html#update_from), but we do not
# have an equivalent of psycopg2's execute_values to do this in one query.
else:
return await self._claim_e2e_fallback_keys_simple(query_list)

def _claim_e2e_fallback_keys_bulk_txn(
self,
txn: LoggingTransaction,
query_list: Iterable[Tuple[str, str, str, bool]],
) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
"""Efficient implementation of claim_e2e_fallback_keys for Postgres.

Safe to autocommit: this is a single query.
"""
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}

sql = """
WITH claims(user_id, device_id, algorithm, mark_as_used) AS (
VALUES ?
)
UPDATE e2e_fallback_keys_json k
SET used = used OR mark_as_used
FROM claims
WHERE (k.user_id, k.device_id, k.algorithm) = (claims.user_id, claims.device_id, claims.algorithm)
RETURNING k.user_id, k.device_id, k.algorithm, k.key_id, k.key_json;
Comment on lines +1290 to +1297
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The form WITH ... UPDATE ... is non-standard, according to https://www.postgresql.org/docs/11/sql-update.html#id-1.9.3.182.10:

This command conforms to the SQL standard, except that the FROM and RETURNING clauses are PostgreSQL extensions, as is the ability to use WITH with UPDATE.

"""
claimed_keys = cast(
List[Tuple[str, str, str, str, str]],
txn.execute_values(sql, query_list),
)

seen_user_device: Set[Tuple[str, str]] = set()
for user_id, device_id, algorithm, key_id, key_json in claimed_keys:
device_results = results.setdefault(user_id, {}).setdefault(device_id, {})
device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json)

if (user_id, device_id) in seen_user_device:
continue
seen_user_device.add((user_id, device_id))
self._invalidate_cache_and_stream(
txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
)

return results

async def _claim_e2e_fallback_keys_simple(
self,
query_list: Iterable[Tuple[str, str, str, bool]],
) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
Comment on lines +1318 to +1321
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ouch -- this was doing 2 queries for each item in the input list.

"""Naive, inefficient implementation of claim_e2e_fallback_keys for SQLite."""
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for user_id, device_id, algorithm, mark_as_used in query_list:
row = await self.db_pool.simple_select_one(
Expand Down
Loading