# Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations

from typing import Awaitable, Callable, Optional, cast
import functools
import inspect
import logging

from mautrix.util.logging import TraceLogger

from .. import async_db
from .connection import LoggingConnection
from .errors import UnsupportedDatabaseVersion
from .scheme import Scheme

Upgrade = Callable[[LoggingConnection, Scheme], Awaitable[Optional[int]]]
UpgradeWithoutScheme = Callable[[LoggingConnection], Awaitable[Optional[int]]]


async def noop_upgrade(_: LoggingConnection, _2: Scheme) -> None:
    pass


def _wrap_upgrade(fn: UpgradeWithoutScheme | Upgrade) -> Upgrade:
    params = inspect.signature(fn).parameters
    if len(params) == 1:
        _wrapped: UpgradeWithoutScheme = cast(UpgradeWithoutScheme, fn)

        @functools.wraps(_wrapped)
        async def _wrapper(conn: LoggingConnection, _: Scheme) -> Optional[int]:
            return await _wrapped(conn)

        return _wrapper
    else:
        return fn


class UpgradeTable:
    upgrades: list[Upgrade]
    allow_unsupported: bool
    database_name: str
    version_table_name: str
    log: TraceLogger

    def __init__(
        self,
        allow_unsupported: bool = False,
        version_table_name: str = "version",
        database_name: str = "database",
        log: logging.Logger | TraceLogger | None = None,
    ) -> None:
        self.upgrades = []
        self.allow_unsupported = allow_unsupported
        self.version_table_name = version_table_name
        self.database_name = database_name
        self.log = log or logging.getLogger("mau.db.upgrade")

    def register(
        self,
        _outer_fn: Upgrade | UpgradeWithoutScheme | None = None,
        *,
        index: int = -1,
        description: str = "",
        transaction: bool = True,
        upgrades_to: int | Upgrade | None = None,
    ) -> Upgrade | Callable[[Upgrade | UpgradeWithoutScheme], Upgrade]:
        if isinstance(index, str):
            description = index
            index = -1

        def actually_register(fn: Upgrade | UpgradeWithoutScheme) -> Upgrade:
            fn = _wrap_upgrade(fn)
            fn.__mau_db_upgrade_description__ = description
            fn.__mau_db_upgrade_transaction__ = transaction
            fn.__mau_db_upgrade_destination__ = (
                upgrades_to
                if not upgrades_to or isinstance(upgrades_to, int)
                else _wrap_upgrade(upgrades_to)
            )
            if index == -1 or index == len(self.upgrades):
                self.upgrades.append(fn)
            else:
                if len(self.upgrades) <= index:
                    self.upgrades += [noop_upgrade] * (index - len(self.upgrades) + 1)
                self.upgrades[index] = fn
            return fn

        return actually_register(_outer_fn) if _outer_fn else actually_register

    async def _save_version(self, conn: LoggingConnection, version: int) -> None:
        self.log.trace(f"Saving current version (v{version}) to database")
        await conn.execute(f"DELETE FROM {self.version_table_name}")
        await conn.execute(f"INSERT INTO {self.version_table_name} (version) VALUES ($1)", version)

    async def upgrade(self, db: async_db.Database) -> None:
        await db.execute(
            f"""CREATE TABLE IF NOT EXISTS {self.version_table_name} (
                version INTEGER PRIMARY KEY
            )"""
        )
        row = await db.fetchrow(f"SELECT version FROM {self.version_table_name} LIMIT 1")
        version = row["version"] if row else 0

        if len(self.upgrades) < version:
            unsupported_version_error = UnsupportedDatabaseVersion(
                self.database_name, version, len(self.upgrades)
            )
            if not self.allow_unsupported:
                raise unsupported_version_error
            else:
                self.log.warning(str(unsupported_version_error))
                return
        elif len(self.upgrades) == version:
            self.log.debug(f"Database at v{version}, not upgrading")
            return

        async with db.acquire() as conn:
            while version < len(self.upgrades):
                old_version = version
                upgrade = self.upgrades[version]
                new_version = (
                    getattr(upgrade, "__mau_db_upgrade_destination__", None) or version + 1
                )
                if callable(new_version):
                    new_version = await new_version(conn, db.scheme)
                desc = getattr(upgrade, "__mau_db_upgrade_description__", None)
                suffix = f": {desc}" if desc else ""
                self.log.debug(
                    f"Upgrading {self.database_name} from v{old_version} to v{new_version}{suffix}"
                )
                if getattr(upgrade, "__mau_db_upgrade_transaction__", True):
                    async with conn.transaction():
                        version = await upgrade(conn, db.scheme) or new_version
                        await self._save_version(conn, version)
                else:
                    version = await upgrade(conn, db.scheme) or new_version
                    await self._save_version(conn, version)
                if version != new_version:
                    self.log.warning(
                        f"Upgrading {self.database_name} actually went from v{old_version} "
                        f"to v{version}"
                    )


upgrade_tables: dict[str, UpgradeTable] = {}


def register_upgrade_table_parent_module(name: str) -> None:
    upgrade_tables[name] = UpgradeTable()


def _find_upgrade_table(fn: Upgrade) -> UpgradeTable:
    try:
        module = fn.__module__
    except AttributeError as e:
        raise ValueError(
            "Registering upgrades without an UpgradeTable requires the function "
            "to have the __module__ attribute."
        ) from e
    parts = module.split(".")
    used_parts = []
    last_error = None
    for part in parts:
        used_parts.append(part)
        try:
            return upgrade_tables[".".join(used_parts)]
        except KeyError as e:
            last_error = e
    raise KeyError(
        "Registering upgrades without an UpgradeTable requires you to register a parent "
        "module with register_upgrade_table_parent_module first."
    ) from last_error


def register_upgrade(index: int = -1, description: str = "") -> Callable[[Upgrade], Upgrade]:
    def actually_register(fn: Upgrade) -> Upgrade:
        return _find_upgrade_table(fn).register(fn, index=index, description=description)

    return actually_register
