# Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations

from collections import defaultdict
from contextlib import asynccontextmanager
from datetime import timedelta

from asyncpg import UniqueViolationError

from mautrix.client.state_store import SyncStore
from mautrix.client.state_store.asyncpg import PgStateStore
from mautrix.errors import GroupSessionWithheldError
from mautrix.types import (
    CrossSigner,
    CrossSigningUsage,
    DeviceID,
    DeviceIdentity,
    EventID,
    IdentityKey,
    RoomID,
    RoomKeyWithheldCode,
    SessionID,
    SigningKey,
    SyncToken,
    TOFUSigningKey,
    TrustState,
    UserID,
)
from mautrix.util.async_db import Database, Scheme
from mautrix.util.logging import TraceLogger

from ... import InboundGroupSession, OlmAccount, OutboundGroupSession, RatchetSafety, Session
from ..abstract import CryptoStore, StateStore
from .upgrade import upgrade_table

try:
    from sqlite3 import IntegrityError, sqlite_version_info as sqlite_version

    from aiosqlite import Cursor
except ImportError:
    Cursor = None
    sqlite_version = (0, 0, 0)

    class IntegrityError(Exception):
        pass


class PgCryptoStateStore(PgStateStore, StateStore):
    """
    This class ensures that the PgStateStore in the client module implements the StateStore
    methods needed by the crypto module.
    """


class PgCryptoStore(CryptoStore, SyncStore):
    upgrade_table = upgrade_table

    db: Database
    account_id: str
    pickle_key: str
    log: TraceLogger

    _sync_token: SyncToken | None
    _device_id: DeviceID | None
    _account: OlmAccount | None
    _olm_cache: dict[IdentityKey, dict[SessionID, Session]]

    def __init__(self, account_id: str, pickle_key: str, db: Database) -> None:
        self.db = db
        self.account_id = account_id
        self.pickle_key = pickle_key
        self.log = db.log

        self._sync_token = None
        self._device_id = DeviceID("")
        self._account = None
        self._olm_cache = defaultdict(lambda: {})

    @asynccontextmanager
    async def transaction(self) -> None:
        async with self.db.acquire() as conn, conn.transaction():
            yield

    async def delete(self) -> None:
        tables = ("crypto_account", "crypto_olm_session", "crypto_megolm_outbound_session")
        async with self.db.acquire() as conn, conn.transaction():
            for table in tables:
                await conn.execute(f"DELETE FROM {table} WHERE account_id=$1", self.account_id)

    async def get_device_id(self) -> DeviceID | None:
        q = "SELECT device_id FROM crypto_account WHERE account_id=$1"
        device_id = await self.db.fetchval(q, self.account_id)
        self._device_id = device_id or self._device_id
        return self._device_id

    async def put_device_id(self, device_id: DeviceID) -> None:
        q = "UPDATE crypto_account SET device_id=$1 WHERE account_id=$2"
        await self.db.fetchval(q, device_id, self.account_id)
        self._device_id = device_id

    async def put_next_batch(self, next_batch: SyncToken) -> None:
        self._sync_token = next_batch
        q = "UPDATE crypto_account SET sync_token=$1 WHERE account_id=$2"
        await self.db.execute(q, self._sync_token, self.account_id)

    async def get_next_batch(self) -> SyncToken:
        if self._sync_token is None:
            q = "SELECT sync_token FROM crypto_account WHERE account_id=$1"
            self._sync_token = await self.db.fetchval(q, self.account_id)
        return self._sync_token

    async def put_account(self, account: OlmAccount) -> None:
        self._account = account
        pickle = account.pickle(self.pickle_key)
        q = """
        INSERT INTO crypto_account (account_id, device_id, shared, sync_token, account)
        VALUES ($1, $2, $3, $4, $5)
        ON CONFLICT (account_id) DO UPDATE
            SET shared=excluded.shared, sync_token=excluded.sync_token, account=excluded.account
        """
        await self.db.execute(
            q,
            self.account_id,
            self._device_id or "",
            account.shared,
            self._sync_token or "",
            pickle,
        )

    async def get_account(self) -> OlmAccount:
        if self._account is None:
            q = "SELECT shared, account, device_id FROM crypto_account WHERE account_id=$1"
            row = await self.db.fetchrow(q, self.account_id)
            if row is not None:
                self._account = OlmAccount.from_pickle(
                    row["account"], passphrase=self.pickle_key, shared=row["shared"]
                )
        return self._account

    async def has_session(self, key: IdentityKey) -> bool:
        if len(self._olm_cache[key]) > 0:
            return True
        q = "SELECT session_id FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2"
        val = await self.db.fetchval(q, key, self.account_id)
        return val is not None

    async def get_sessions(self, key: IdentityKey) -> list[Session]:
        q = """
        SELECT session_id, session, created_at, last_encrypted, last_decrypted
        FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2
        ORDER BY last_decrypted DESC
        """
        rows = await self.db.fetch(q, key, self.account_id)
        sessions = []
        for row in rows:
            try:
                sess = self._olm_cache[key][row["session_id"]]
            except KeyError:
                sess = Session.from_pickle(
                    row["session"],
                    passphrase=self.pickle_key,
                    creation_time=row["created_at"],
                    last_encrypted=row["last_encrypted"],
                    last_decrypted=row["last_decrypted"],
                )
                self._olm_cache[key][SessionID(sess.id)] = sess
            sessions.append(sess)
        return sessions

    async def get_latest_session(self, key: IdentityKey) -> Session | None:
        q = """
        SELECT session_id, session, created_at, last_encrypted, last_decrypted
        FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2
        ORDER BY last_decrypted DESC LIMIT 1
        """
        row = await self.db.fetchrow(q, key, self.account_id)
        if row is None:
            return None
        try:
            return self._olm_cache[key][row["session_id"]]
        except KeyError:
            sess = Session.from_pickle(
                row["session"],
                passphrase=self.pickle_key,
                creation_time=row["created_at"],
                last_encrypted=row["last_encrypted"],
                last_decrypted=row["last_decrypted"],
            )
            self._olm_cache[key][SessionID(sess.id)] = sess
            return sess

    async def add_session(self, key: IdentityKey, session: Session) -> None:
        if session.id in self._olm_cache[key]:
            self.log.warning(f"Cache already contains Olm session with ID {session.id}")
        self._olm_cache[key][SessionID(session.id)] = session
        pickle = session.pickle(self.pickle_key)
        q = """
        INSERT INTO crypto_olm_session (
            session_id, sender_key, session, created_at, last_encrypted, last_decrypted, account_id
        ) VALUES ($1, $2, $3, $4, $5, $6, $7)
        """
        await self.db.execute(
            q,
            session.id,
            key,
            pickle,
            session.creation_time,
            session.last_encrypted,
            session.last_decrypted,
            self.account_id,
        )

    async def update_session(self, key: IdentityKey, session: Session) -> None:
        try:
            assert self._olm_cache[key][SessionID(session.id)] == session
        except (KeyError, AssertionError) as e:
            self.log.warning(
                f"Cached olm session with ID {session.id} "
                f"isn't equal to the one being saved to the database ({e})"
            )
        pickle = session.pickle(self.pickle_key)
        q = """
        UPDATE crypto_olm_session SET session=$1, last_encrypted=$2, last_decrypted=$3
        WHERE session_id=$4 AND account_id=$5
        """
        await self.db.execute(
            q, pickle, session.last_encrypted, session.last_decrypted, session.id, self.account_id
        )

    async def put_group_session(
        self,
        room_id: RoomID,
        sender_key: IdentityKey,
        session_id: SessionID,
        session: InboundGroupSession,
    ) -> None:
        pickle = session.pickle(self.pickle_key)
        forwarding_chains = ",".join(session.forwarding_chain)
        q = """
        INSERT INTO crypto_megolm_inbound_session (
            session_id, sender_key, signing_key, room_id, session, forwarding_chains,
            ratchet_safety, received_at, max_age, max_messages, is_scheduled, account_id
        ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
        ON CONFLICT (session_id, account_id) DO UPDATE
        SET withheld_code=NULL, withheld_reason=NULL, sender_key=excluded.sender_key,
            signing_key=excluded.signing_key, room_id=excluded.room_id, session=excluded.session,
            forwarding_chains=excluded.forwarding_chains, ratchet_safety=excluded.ratchet_safety,
            received_at=excluded.received_at, max_age=excluded.max_age,
            max_messages=excluded.max_messages, is_scheduled=excluded.is_scheduled
        """
        try:
            await self.db.execute(
                q,
                session_id,
                sender_key,
                session.signing_key,
                room_id,
                pickle,
                forwarding_chains,
                session.ratchet_safety.json(),
                session.received_at,
                int(session.max_age.total_seconds() * 1000) if session.max_age else None,
                session.max_messages,
                session.is_scheduled,
                self.account_id,
            )
        except (IntegrityError, UniqueViolationError):
            self.log.exception(f"Failed to insert megolm session {session_id}")

    async def get_group_session(
        self, room_id: RoomID, session_id: SessionID
    ) -> InboundGroupSession | None:
        q = """
        SELECT
            sender_key, signing_key, session, forwarding_chains, withheld_code,
            ratchet_safety, received_at, max_age, max_messages, is_scheduled
        FROM crypto_megolm_inbound_session
        WHERE room_id=$1 AND session_id=$2 AND account_id=$3
        """
        row = await self.db.fetchrow(q, room_id, session_id, self.account_id)
        if row is None:
            return None
        if row["withheld_code"] is not None:
            raise GroupSessionWithheldError(session_id, row["withheld_code"])
        forwarding_chain = row["forwarding_chains"].split(",") if row["forwarding_chains"] else []
        return InboundGroupSession.from_pickle(
            row["session"],
            passphrase=self.pickle_key,
            signing_key=row["signing_key"],
            sender_key=row["sender_key"],
            room_id=room_id,
            forwarding_chain=forwarding_chain,
            ratchet_safety=RatchetSafety.parse_json(row["ratchet_safety"] or "{}"),
            received_at=row["received_at"],
            max_age=timedelta(milliseconds=row["max_age"]) if row["max_age"] else None,
            max_messages=row["max_messages"],
            is_scheduled=row["is_scheduled"],
        )

    async def redact_group_session(
        self, room_id: RoomID, session_id: SessionID, reason: str
    ) -> None:
        q = """
        UPDATE crypto_megolm_inbound_session
        SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
        WHERE session_id=$3 AND account_id=$4 AND session IS NOT NULL
        """
        await self.db.execute(
            q,
            RoomKeyWithheldCode.BEEPER_REDACTED.value,
            f"Session redacted: {reason}",
            session_id,
            self.account_id,
        )

    async def redact_group_sessions(
        self, room_id: RoomID, sender_key: IdentityKey, reason: str
    ) -> list[SessionID]:
        if not room_id and not sender_key:
            raise ValueError("Either room_id or sender_key must be provided")
        q = """
        UPDATE crypto_megolm_inbound_session
        SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
        WHERE (room_id=$3 OR $3='') AND (sender_key=$4 OR $4='') AND account_id=$5
          AND session IS NOT NULL AND is_scheduled=false AND received_at IS NOT NULL
        RETURNING session_id
        """
        rows = await self.db.fetch(
            q,
            RoomKeyWithheldCode.BEEPER_REDACTED.value,
            f"Session redacted: {reason}",
            room_id,
            sender_key,
            self.account_id,
        )
        return [row["session_id"] for row in rows]

    async def redact_expired_group_sessions(self) -> list[SessionID]:
        if self.db.scheme == Scheme.SQLITE:
            q = """
            UPDATE crypto_megolm_inbound_session
            SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
            WHERE account_id=$3 AND session IS NOT NULL AND is_scheduled=false
              AND received_at IS NOT NULL and max_age IS NOT NULL
              AND unixepoch(received_at) + (2 * max_age / 1000) < unixepoch(date('now'))
            RETURNING session_id
            """
        elif self.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH):
            q = """
            UPDATE crypto_megolm_inbound_session
            SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
            WHERE account_id=$3 AND session IS NOT NULL AND is_scheduled=false
              AND received_at IS NOT NULL and max_age IS NOT NULL
              AND received_at + 2 * (max_age * interval '1 millisecond') < now()
            RETURNING session_id
            """
        else:
            raise RuntimeError(f"Unsupported dialect {self.db.scheme}")
        rows = await self.db.fetch(
            q,
            RoomKeyWithheldCode.BEEPER_REDACTED.value,
            f"Session redacted: expired",
            self.account_id,
        )
        return [row["session_id"] for row in rows]

    async def redact_outdated_group_sessions(self) -> list[SessionID]:
        q = """
        UPDATE crypto_megolm_inbound_session
        SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
        WHERE account_id=$3 AND session IS NOT NULL AND received_at IS NULL
        RETURNING session_id
        """
        rows = await self.db.fetch(
            q,
            RoomKeyWithheldCode.BEEPER_REDACTED.value,
            f"Session redacted: outdated",
            self.account_id,
        )
        return [row["session_id"] for row in rows]

    async def has_group_session(self, room_id: RoomID, session_id: SessionID) -> bool:
        q = """
        SELECT COUNT(session) FROM crypto_megolm_inbound_session
        WHERE room_id=$1 AND session_id=$2 AND account_id=$3 AND session IS NOT NULL
        """
        count = await self.db.fetchval(q, room_id, session_id, self.account_id)
        return count > 0

    async def add_outbound_group_session(self, session: OutboundGroupSession) -> None:
        pickle = session.pickle(self.pickle_key)
        max_age = int(session.max_age.total_seconds() * 1000)
        q = """
        INSERT INTO crypto_megolm_outbound_session (
            room_id, session_id, session, shared, max_messages, message_count,
            max_age, created_at, last_used, account_id
        ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
        ON CONFLICT (account_id, room_id) DO UPDATE
        SET session_id=excluded.session_id, session=excluded.session, shared=excluded.shared,
            max_messages=excluded.max_messages, message_count=excluded.message_count,
            max_age=excluded.max_age, created_at=excluded.created_at, last_used=excluded.last_used
        """
        await self.db.execute(
            q,
            session.room_id,
            session.id,
            pickle,
            session.shared,
            session.max_messages,
            session.message_count,
            max_age,
            session.creation_time,
            session.use_time,
            self.account_id,
        )

    async def update_outbound_group_session(self, session: OutboundGroupSession) -> None:
        pickle = session.pickle(self.pickle_key)
        q = """
        UPDATE crypto_megolm_outbound_session SET session=$1, message_count=$2, last_used=$3
        WHERE room_id=$4 AND session_id=$5 AND account_id=$6
        """
        await self.db.execute(
            q,
            pickle,
            session.message_count,
            session.use_time,
            session.room_id,
            session.id,
            self.account_id,
        )

    async def get_outbound_group_session(self, room_id: RoomID) -> OutboundGroupSession | None:
        q = """
        SELECT room_id, session_id, session, shared, max_messages, message_count, max_age,
               created_at, last_used
        FROM crypto_megolm_outbound_session WHERE room_id=$1 AND account_id=$2
        """
        row = await self.db.fetchrow(q, room_id, self.account_id)
        if row is None:
            return None
        return OutboundGroupSession.from_pickle(
            row["session"],
            passphrase=self.pickle_key,
            room_id=row["room_id"],
            shared=row["shared"],
            max_messages=row["max_messages"],
            message_count=row["message_count"],
            max_age=timedelta(milliseconds=row["max_age"]),
            use_time=row["last_used"],
            creation_time=row["created_at"],
        )

    async def remove_outbound_group_session(self, room_id: RoomID) -> None:
        q = "DELETE FROM crypto_megolm_outbound_session WHERE room_id=$1 AND account_id=$2"
        await self.db.execute(q, room_id, self.account_id)

    async def remove_outbound_group_sessions(self, rooms: list[RoomID]) -> None:
        if self.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH):
            q = """
            DELETE FROM crypto_megolm_outbound_session WHERE account_id=$1 AND room_id=ANY($2)
            """
            await self.db.execute(q, self.account_id, rooms)
        else:
            params = ",".join(["?"] * len(rooms))
            q = f"""
            DELETE FROM crypto_megolm_outbound_session WHERE account_id=? AND room_id IN ({params})
            """
            await self.db.execute(q, self.account_id, *rooms)

    _validate_message_index_query = """
    INSERT INTO crypto_message_index (sender_key, session_id, "index", event_id, timestamp)
    VALUES ($1, $2, $3, $4, $5)
    -- have to update something so that RETURNING * always returns the row
    ON CONFLICT (sender_key, session_id, "index") DO UPDATE SET sender_key=excluded.sender_key
    RETURNING *
    """

    async def validate_message_index(
        self,
        sender_key: IdentityKey,
        session_id: SessionID,
        event_id: EventID,
        index: int,
        timestamp: int,
    ) -> bool:
        if self.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH) or (
            # RETURNING was added in SQLite 3.35.0 https://www.sqlite.org/lang_returning.html
            self.db.scheme == Scheme.SQLITE
            and sqlite_version >= (3, 35)
        ):
            row = await self.db.fetchrow(
                self._validate_message_index_query,
                sender_key,
                session_id,
                index,
                event_id,
                timestamp,
            )
            return row["event_id"] == event_id and row["timestamp"] == timestamp
        else:
            row = await self.db.fetchrow(
                "SELECT event_id, timestamp FROM crypto_message_index "
                'WHERE sender_key=$1 AND session_id=$2 AND "index"=$3',
                sender_key,
                session_id,
                index,
            )
            if row is not None:
                return row["event_id"] == event_id and row["timestamp"] == timestamp
            await self.db.execute(
                "INSERT INTO crypto_message_index(sender_key, session_id, "
                '                                 "index", event_id, timestamp) '
                "VALUES ($1, $2, $3, $4, $5)",
                sender_key,
                session_id,
                index,
                event_id,
                timestamp,
            )
            return True

    async def get_devices(self, user_id: UserID) -> dict[DeviceID, DeviceIdentity] | None:
        q = "SELECT user_id FROM crypto_tracked_user WHERE user_id=$1"
        tracked_user_id = await self.db.fetchval(q, user_id)
        if tracked_user_id is None:
            return None
        q = """
        SELECT device_id, identity_key, signing_key, trust, deleted, name
        FROM crypto_device WHERE user_id=$1
        """
        rows = await self.db.fetch(q, user_id)
        result = {}
        for row in rows:
            result[row["device_id"]] = DeviceIdentity(
                user_id=user_id,
                device_id=row["device_id"],
                identity_key=row["identity_key"],
                signing_key=row["signing_key"],
                trust=TrustState(row["trust"]),
                deleted=row["deleted"],
                name=row["name"],
            )
        return result

    async def get_device(self, user_id: UserID, device_id: DeviceID) -> DeviceIdentity | None:
        q = """
        SELECT identity_key, signing_key, trust, deleted, name FROM crypto_device
        WHERE user_id=$1 AND device_id=$2
        """
        row = await self.db.fetchrow(q, user_id, device_id)
        if row is None:
            return None
        return DeviceIdentity(
            user_id=user_id,
            device_id=device_id,
            name=row["name"],
            identity_key=row["identity_key"],
            signing_key=row["signing_key"],
            trust=TrustState(row["trust"]),
            deleted=row["deleted"],
        )

    async def find_device_by_key(
        self, user_id: UserID, identity_key: IdentityKey
    ) -> DeviceIdentity | None:
        q = """
        SELECT device_id, signing_key, trust, deleted, name FROM crypto_device
        WHERE user_id=$1 AND identity_key=$2
        """
        row = await self.db.fetchrow(
            q,
            user_id,
            identity_key,
        )
        if row is None:
            return None
        return DeviceIdentity(
            user_id=user_id,
            device_id=row["device_id"],
            name=row["name"],
            identity_key=identity_key,
            signing_key=row["signing_key"],
            trust=TrustState(row["trust"]),
            deleted=row["deleted"],
        )

    async def put_devices(self, user_id: UserID, devices: dict[DeviceID, DeviceIdentity]) -> None:
        data = [
            (
                user_id,
                device_id,
                identity.identity_key,
                identity.signing_key,
                identity.trust,
                identity.deleted,
                identity.name,
            )
            for device_id, identity in devices.items()
        ]
        columns = [
            "user_id",
            "device_id",
            "identity_key",
            "signing_key",
            "trust",
            "deleted",
            "name",
        ]
        async with self.db.acquire() as conn, conn.transaction():
            q = """
            INSERT INTO crypto_tracked_user (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING
            """
            await conn.execute(q, user_id)
            await conn.execute("DELETE FROM crypto_device WHERE user_id=$1", user_id)
            if self.db.scheme == Scheme.POSTGRES:
                await conn.copy_records_to_table("crypto_device", records=data, columns=columns)
            else:
                q = """
                INSERT INTO crypto_device (
                    user_id, device_id, identity_key, signing_key, trust, deleted, name
                ) VALUES ($1, $2, $3, $4, $5, $6, $7)
                """
                await conn.executemany(q, data)

    async def filter_tracked_users(self, users: list[UserID]) -> list[UserID]:
        if self.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH):
            q = "SELECT user_id FROM crypto_tracked_user WHERE user_id = ANY($1)"
            rows = await self.db.fetch(q, users)
        else:
            params = ",".join(["?"] * len(users))
            q = f"SELECT user_id FROM crypto_tracked_user WHERE user_id IN ({params})"
            rows = await self.db.fetch(q, *users)
        return [row["user_id"] for row in rows]

    async def put_cross_signing_key(
        self, user_id: UserID, usage: CrossSigningUsage, key: SigningKey
    ) -> None:
        q = """
        INSERT INTO crypto_cross_signing_keys (user_id, usage, key, first_seen_key)
        VALUES ($1, $2, $3, $4)
        ON CONFLICT (user_id, usage) DO UPDATE SET key=excluded.key
        """
        try:
            await self.db.execute(q, user_id, usage.value, key, key)
        except Exception:
            self.log.exception(f"Failed to store cross-signing key {user_id}/{key}/{usage}")

    async def get_cross_signing_keys(
        self, user_id: UserID
    ) -> dict[CrossSigningUsage, TOFUSigningKey]:
        q = "SELECT usage, key, first_seen_key FROM crypto_cross_signing_keys WHERE user_id=$1"
        return {
            CrossSigningUsage(row["usage"]): TOFUSigningKey(
                key=SigningKey(row["key"]),
                first=SigningKey(row["first_seen_key"]),
            )
            for row in await self.db.fetch(q, user_id)
        }

    async def put_signature(
        self, target: CrossSigner, signer: CrossSigner, signature: str
    ) -> None:
        q = """
        INSERT INTO crypto_cross_signing_signatures (
            signed_user_id, signed_key, signer_user_id, signer_key, signature
        ) VALUES ($1, $2, $3, $4, $5)
        ON CONFLICT (signed_user_id, signed_key, signer_user_id, signer_key)
            DO UPDATE SET signature=excluded.signature
        """
        signed_user_id, signed_key = target
        signer_user_id, signer_key = signer
        try:
            await self.db.execute(
                q, signed_user_id, signed_key, signer_user_id, signer_key, signature
            )
        except Exception:
            self.log.exception(
                f"Failed to store signature from {signer_user_id}/{signer_key} "
                f"for {signed_user_id}/{signed_key}"
            )

    async def is_key_signed_by(self, target: CrossSigner, signer: CrossSigner) -> bool:
        q = """
        SELECT EXISTS(
            SELECT 1 FROM crypto_cross_signing_signatures
            WHERE signed_user_id=$1 AND signed_key=$2 AND signer_user_id=$3 AND signer_key=$4
        )
        """
        signed_user_id, signed_key = target
        signer_user_id, signer_key = signer
        return await self.db.fetchval(q, signed_user_id, signed_key, signer_user_id, signer_key)

    async def drop_signatures_by_key(self, signer: CrossSigner) -> int:
        signer_user_id, signer_key = signer
        q = "DELETE FROM crypto_cross_signing_signatures WHERE signer_user_id=$1 AND signer_key=$2"
        try:
            res = await self.db.execute(q, signer_user_id, signer_key)
        except Exception:
            self.log.exception(
                f"Failed to drop old signatures made by replaced key {signer_user_id}/{signer_key}"
            )
            return -1
        if Cursor is not None and isinstance(res, Cursor):
            return res.rowcount
        elif (
            isinstance(res, str)
            and res.startswith("DELETE ")
            and (intPart := res[len("DELETE ") :]).isdecimal()
        ):
            return int(intPart)
        return -1
