# Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations

from typing import Any, AsyncContextManager
from contextlib import asynccontextmanager
from contextvars import ContextVar
import asyncio
import logging
import os
import re
import sqlite3

from yarl import URL
import aiosqlite

from .connection import LoggingConnection
from .database import Database
from .scheme import Scheme
from .upgrade import UpgradeTable

POSITIONAL_PARAM_PATTERN = re.compile(r"\$(\d+)")


in_transaction = ContextVar("in_transaction", default=False)


class TxnConnection(aiosqlite.Connection):
    def __init__(self, path: str, **kwargs) -> None:
        def connector() -> sqlite3.Connection:
            return sqlite3.connect(
                path, detect_types=sqlite3.PARSE_DECLTYPES, isolation_level=None, **kwargs
            )

        super().__init__(connector, iter_chunk_size=64)

    @asynccontextmanager
    async def transaction(self) -> None:
        if in_transaction.get():
            yield
            return
        await self.execute("BEGIN TRANSACTION")
        token = in_transaction.set(True)
        try:
            yield
        except Exception:
            await self.rollback()
            raise
        else:
            await self.commit()
        finally:
            in_transaction.reset(token)

    def __execute(self, query: str, *args: Any):
        query = POSITIONAL_PARAM_PATTERN.sub(r"?\1", query)
        return super().execute(query, args)

    async def execute(
        self, query: str, *args: Any, timeout: float | None = None
    ) -> aiosqlite.Cursor:
        return await self.__execute(query, *args)

    async def executemany(
        self, query: str, *args: Any, timeout: float | None = None
    ) -> aiosqlite.Cursor:
        query = POSITIONAL_PARAM_PATTERN.sub(r"?\1", query)
        return await super().executemany(query, *args)

    async def fetch(
        self, query: str, *args: Any, timeout: float | None = None
    ) -> list[sqlite3.Row]:
        async with self.__execute(query, *args) as cursor:
            return list(await cursor.fetchall())

    async def fetchrow(
        self, query: str, *args: Any, timeout: float | None = None
    ) -> sqlite3.Row | None:
        async with self.__execute(query, *args) as cursor:
            return await cursor.fetchone()

    async def fetchval(
        self, query: str, *args: Any, column: int = 0, timeout: float | None = None
    ) -> Any:
        row = await self.fetchrow(query, *args)
        if row is None:
            return None
        return row[column]


class SQLiteDatabase(Database):
    scheme = Scheme.SQLITE
    _parent: SQLiteDatabase | None
    _pool: asyncio.Queue[TxnConnection]
    _stopped: bool
    _conns: int
    _init_commands: list[str]

    def __init__(
        self,
        url: URL,
        upgrade_table: UpgradeTable,
        db_args: dict[str, Any] | None = None,
        log: logging.Logger | None = None,
        owner_name: str | None = None,
        ignore_foreign_tables: bool = True,
    ) -> None:
        super().__init__(
            url,
            db_args=db_args,
            upgrade_table=upgrade_table,
            log=log,
            owner_name=owner_name,
            ignore_foreign_tables=ignore_foreign_tables,
        )
        self._parent = None
        self._path = url.path
        self._pool = asyncio.Queue(self._db_args.pop("min_size", 1))
        self._db_args.pop("max_size", None)
        self._stopped = False
        self._conns = 0
        self._init_commands = self._add_missing_pragmas(self._db_args.pop("init_commands", []))

    @staticmethod
    def _add_missing_pragmas(init_commands: list[str]) -> list[str]:
        has_foreign_keys = False
        has_journal_mode = False
        has_synchronous = False
        has_busy_timeout = False
        for cmd in init_commands:
            if "PRAGMA" not in cmd:
                continue
            if "foreign_keys" in cmd:
                has_foreign_keys = True
            elif "journal_mode" in cmd:
                has_journal_mode = True
            elif "synchronous" in cmd:
                has_synchronous = True
            elif "busy_timeout" in cmd:
                has_busy_timeout = True
        if not has_foreign_keys:
            init_commands.append("PRAGMA foreign_keys = ON")
        if not has_journal_mode:
            init_commands.append("PRAGMA journal_mode = WAL")
        if not has_synchronous and "PRAGMA journal_mode = WAL" in init_commands:
            init_commands.append("PRAGMA synchronous = NORMAL")
        if not has_busy_timeout:
            init_commands.append("PRAGMA busy_timeout = 5000")
        return init_commands

    def override_pool(self, db: Database) -> None:
        assert isinstance(db, SQLiteDatabase)
        self._parent = db

    async def start(self) -> None:
        if self._parent:
            await super().start()
            return
        if self._conns:
            raise RuntimeError("database pool has already been started")
        elif self._stopped:
            raise RuntimeError("database pool can't be restarted")
        self.log.debug(f"Connecting to {self.url}")
        self.log.debug(f"Database connection init commands: {self._init_commands}")
        if os.path.exists(self._path):
            if not os.access(self._path, os.W_OK):
                self.log.warning("Database file doesn't seem writable")
        elif not os.access(os.path.dirname(os.path.abspath(self._path)), os.W_OK):
            self.log.warning("Database file doesn't exist and directory doesn't seem writable")
        for _ in range(self._pool.maxsize):
            conn = await TxnConnection(self._path, **self._db_args)
            if self._init_commands:
                cur = await conn.cursor()
                for command in self._init_commands:
                    self.log.trace("Executing init command: %s", command)
                    await cur.execute(command)
                await conn.commit()
            conn.row_factory = sqlite3.Row
            self._pool.put_nowait(conn)
            self._conns += 1
        await super().start()

    async def stop(self) -> None:
        if self._parent:
            return
        self._stopped = True
        while self._conns > 0:
            conn = await self._pool.get()
            self._conns -= 1
            await conn.close()

    def acquire_direct(self) -> AsyncContextManager[LoggingConnection]:
        if self._parent:
            return self._parent.acquire()
        return self._acquire()

    @asynccontextmanager
    async def _acquire(self) -> LoggingConnection:
        if self._stopped:
            raise RuntimeError("database pool has been stopped")
        conn = await self._pool.get()
        try:
            yield LoggingConnection(self.scheme, conn, self.log)
        finally:
            self._pool.put_nowait(conn)


Database.schemes["sqlite"] = SQLiteDatabase
Database.schemes["sqlite3"] = SQLiteDatabase
