# Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations

from typing import ClassVar, Literal, Mapping
from enum import Enum
from json.decoder import JSONDecodeError
from urllib.parse import quote as urllib_quote, urljoin as urllib_join
import asyncio
import inspect
import json
import logging
import platform
import time

from aiohttp import ClientResponse, ClientSession, __version__ as aiohttp_version
from aiohttp.client_exceptions import ClientError, ContentTypeError
from yarl import URL

from mautrix import __optional_imports__, __version__ as mautrix_version
from mautrix.errors import MatrixConnectionError, MatrixRequestError, make_request_error
from mautrix.util.async_body import AsyncBody, async_iter_bytes
from mautrix.util.logging import TraceLogger
from mautrix.util.opt_prometheus import Counter

if __optional_imports__:
    # Safe to import, but it's not actually needed, so don't force-import the whole types module.
    from mautrix.types import JSON, DeviceID, UserID

API_CALLS = Counter(
    name="bridge_matrix_api_calls",
    documentation="The number of Matrix client API calls made",
    labelnames=("method",),
)
API_CALLS_FAILED = Counter(
    name="bridge_matrix_api_calls_failed",
    documentation="The number of Matrix client API calls which failed",
    labelnames=("method",),
)


class APIPath(Enum):
    """
    The known Matrix API path prefixes.
    These don't start with a slash so they can be used nicely with yarl.
    """

    CLIENT = "_matrix/client"
    MEDIA = "_matrix/media"
    SYNAPSE_ADMIN = "_synapse/admin"

    def __repr__(self):
        return self.value

    def __str__(self):
        return self.value


class Method(Enum):
    """A HTTP method."""

    GET = "GET"
    POST = "POST"
    PUT = "PUT"
    DELETE = "DELETE"
    PATCH = "PATCH"

    def __repr__(self):
        return self.value

    def __str__(self):
        return self.value


class PathBuilder:
    """
    A utility class to build API paths.

    Examples:
        >>> from mautrix.api import Path
        >>> room_id = "!foo:example.com"
        >>> event_id = "$bar:example.com"
        >>> str(Path.v3.rooms[room_id].event[event_id])
        "_matrix/client/v3/rooms/%21foo%3Aexample.com/event/%24bar%3Aexample.com"
    """

    def __init__(self, path: str | APIPath = "") -> None:
        self.path: str = str(path)

    def __str__(self) -> str:
        return self.path

    def __repr__(self):
        return self.path

    def __getattr__(self, append: str) -> PathBuilder:
        if append is None:
            return self
        return PathBuilder(f"{self.path}/{append}")

    def raw(self, append: str) -> PathBuilder:
        """
        Directly append a string to the path.

        Args:
            append: The string to append.
        """
        if append is None:
            return self
        return PathBuilder(self.path + append)

    def __eq__(self, other: PathBuilder | str) -> bool:
        return other.path == self.path if isinstance(other, PathBuilder) else other == self.path

    @staticmethod
    def _quote(string: str) -> str:
        return urllib_quote(string, safe="")

    def __getitem__(self, append: str | int) -> PathBuilder:
        if append is None:
            return self
        return PathBuilder(f"{self.path}/{self._quote(str(append))}")

    def replace(self, find: str, replace: str) -> PathBuilder:
        return PathBuilder(self.path.replace(find, replace))


ClientPath = PathBuilder(APIPath.CLIENT)
ClientPath.__doc__ = """
A path builder with the standard client prefix ( ``/_matrix/client``, :attr:`APIPath.CLIENT`).
"""
Path = PathBuilder(APIPath.CLIENT)
Path.__doc__ = """A shorter alias for :attr:`ClientPath`"""
MediaPath = PathBuilder(APIPath.MEDIA)
MediaPath.__doc__ = """
A path builder with the standard media prefix (``/_matrix/media``, :attr:`APIPath.MEDIA`)

Examples:
    >>> from mautrix.api import MediaPath
    >>> str(MediaPath.v3.config)
    "_matrix/media/v3/config"
"""
SynapseAdminPath = PathBuilder(APIPath.SYNAPSE_ADMIN)
SynapseAdminPath.__doc__ = """
A path builder for synapse-specific admin API paths
(``/_synapse/admin``, :attr:`APIPath.SYNAPSE_ADMIN`)

Examples:
    >>> from mautrix.api import SynapseAdminPath
    >>> user_id = "@user:example.com"
    >>> str(SynapseAdminPath.v1.users[user_id]/login)
    "_synapse/admin/v1/users/%40user%3Aexample.com/login"
"""

_req_id = 0


def _next_global_req_id() -> int:
    global _req_id
    _req_id += 1
    return _req_id


class HTTPAPI:
    """HTTPAPI is a simple asyncio Matrix API request sender."""

    default_ua: ClassVar[str] = (
        f"mautrix-python/{mautrix_version} aiohttp/{aiohttp_version} "
        f"Python/{platform.python_version()}"
    )
    """
    The default value for the ``User-Agent`` header.

    You should prepend your program name and version here before creating any HTTPAPI instances
    in order to have proper user agents for all requests.
    """
    global_default_retry_count: ClassVar[int] = 0
    """The default retry count to use if an instance-specific value is not passed."""

    base_url: URL
    """The base URL of the homeserver's client-server API to use."""
    token: str
    """The access token to use in requests."""
    log: TraceLogger
    """The :class:`logging.Logger` instance to log requests with."""
    session: ClientSession
    """The aiohttp ClientSession instance to make requests with."""
    txn_id: int | None
    """A counter used for generating transaction IDs."""
    default_retry_count: int
    """The default retry count to use if a custom value is not passed to :meth:`request`"""

    as_user_id: UserID | None
    """An optional user ID to set as the user_id query parameter for appservice requests."""
    as_device_id: DeviceID | None
    """
    An optional device ID to set as the user_id query parameter for appservice requests (MSC3202).
    """

    def __init__(
        self,
        base_url: URL | str,
        token: str = "",
        *,
        client_session: ClientSession = None,
        default_retry_count: int = None,
        txn_id: int = 0,
        log: TraceLogger | None = None,
        loop: asyncio.AbstractEventLoop | None = None,
        as_user_id: UserID | None = None,
        as_device_id: UserID | None = None,
    ) -> None:
        """
        Args:
            base_url: The base URL of the homeserver's client-server API to use.
            token: The access token to use.
            client_session: The aiohttp client session to use.
            txn_id: The outgoing transaction ID to start with.
            log: The :class:`logging.Logger` instance to log requests with.
            default_retry_count: Default number of retries to do when encountering network errors.
            as_user_id: An optional user ID to set as the user_id query parameter for
                appservice requests.
            as_device_id: An optional device ID to set as the user_id query parameter for
                appservice requests (MSC3202).
        """
        self.base_url = URL(base_url)
        self.token = token
        self.log = log or logging.getLogger("mau.http")
        self.session = client_session or ClientSession(
            loop=loop, headers={"User-Agent": self.default_ua}
        )
        self.as_user_id = as_user_id
        self.as_device_id = as_device_id
        if txn_id is not None:
            self.txn_id = txn_id
        if default_retry_count is not None:
            self.default_retry_count = default_retry_count
        else:
            self.default_retry_count = self.global_default_retry_count

    async def _send(
        self,
        method: Method,
        url: URL,
        content: bytes | bytearray | str | AsyncBody,
        query_params: dict[str, str],
        headers: dict[str, str],
    ) -> tuple[JSON, ClientResponse]:
        request = self.session.request(
            str(method), url, data=content, params=query_params, headers=headers
        )
        async with request as response:
            if response.status < 200 or response.status >= 300:
                errcode = unstable_errcode = message = None
                try:
                    response_data = await response.json()
                    errcode = response_data["errcode"]
                    message = response_data["error"]
                    unstable_errcode = response_data.get("org.matrix.msc3848.unstable.errcode")
                except (JSONDecodeError, ContentTypeError, KeyError):
                    pass
                raise make_request_error(
                    http_status=response.status,
                    text=await response.text(),
                    errcode=errcode,
                    message=message,
                    unstable_errcode=unstable_errcode,
                )
            return await response.json(), response

    def _log_request(
        self,
        method: Method,
        url: URL,
        content: str | bytes | bytearray | AsyncBody | None,
        orig_content,
        query_params: dict[str, str],
        headers: dict[str, str],
        req_id: int,
        sensitive: bool,
    ) -> None:
        if not self.log:
            return
        if isinstance(content, (bytes, bytearray)):
            log_content = f"<{len(content)} bytes>"
        elif inspect.isasyncgen(content):
            size = headers.get("Content-Length", None)
            log_content = f"<{size} async bytes>" if size else f"<stream with unknown length>"
        elif sensitive:
            log_content = f"<{len(content)} sensitive bytes>"
        else:
            log_content = content
        as_user = query_params.get("user_id", None)
        level = 5 if url.path.endswith("/v3/sync") else 10
        self.log.log(
            level,
            f"req #{req_id}: {method} {url} {log_content}".strip(" "),
            extra={
                "matrix_http_request": {
                    "req_id": req_id,
                    "method": str(method),
                    "url": str(url),
                    "content": (
                        orig_content
                        if isinstance(orig_content, (dict, list)) and not sensitive
                        else log_content
                    ),
                    "user": as_user,
                }
            },
        )

    def _log_request_done(
        self, path: PathBuilder | str, req_id: int, duration: float, status: int
    ) -> None:
        level = 5 if path == Path.v3.sync else 10
        duration_str = f"{duration * 1000:.1f}ms" if duration < 1 else f"{duration:.3f}s"
        path_without_prefix = f"/{path}".replace("/_matrix/client", "")
        self.log.log(
            level,
            f"req #{req_id} ({path_without_prefix}) completed in {duration_str} "
            f"with status {status}",
        )

    def _full_path(self, path: PathBuilder | str) -> str:
        path = str(path)
        if path and path[0] == "/":
            path = path[1:]
        base_path = self.base_url.raw_path
        if base_path[-1] != "/":
            base_path += "/"
        return urllib_join(base_path, path)

    def log_download_request(self, url: URL, query_params: dict[str, str]) -> int:
        req_id = _next_global_req_id()
        self._log_request(Method.GET, url, None, None, query_params, {}, req_id, False)
        return req_id

    def log_download_request_done(
        self, url: URL, req_id: int, duration: float, status: int
    ) -> None:
        self._log_request_done(url.path.removeprefix("/_matrix/media/"), req_id, duration, status)

    async def request(
        self,
        method: Method,
        path: PathBuilder | str,
        content: dict | list | bytes | bytearray | str | AsyncBody | None = None,
        headers: dict[str, str] | None = None,
        query_params: Mapping[str, str] | None = None,
        retry_count: int | None = None,
        metrics_method: str = "",
        min_iter_size: int = 25 * 1024 * 1024,
        sensitive: bool = False,
    ) -> JSON:
        """
        Make a raw Matrix API request.

        Args:
            method: The HTTP method to use.
            path: The full API endpoint to call (including the _matrix/... prefix)
            content: The content to post as a dict/list (will be serialized as JSON)
                     or bytes/str (will be sent as-is).
            headers: A dict of HTTP headers to send. If the headers don't contain ``Content-Type``,
                     it'll be set to ``application/json``. The ``Authorization`` header is always
                     overridden if :attr:`token` is set.
            query_params: A dict of query parameters to send.
            retry_count: Number of times to retry if the homeserver isn't reachable.
                         Defaults to :attr:`default_retry_count`.
            metrics_method: Name of the method to include in Prometheus timing metrics.
            min_iter_size: If the request body is larger than this value, it will be passed to
                           aiohttp as an async iterable to stop it from copying the whole thing
                           in memory.
            sensitive: If True, the request content will not be logged.

        Returns:
            The parsed response JSON.
        """
        headers = headers or {}
        if self.token:
            headers["Authorization"] = f"Bearer {self.token}"
        query_params = query_params or {}
        if isinstance(query_params, dict):
            query_params = {k: v for k, v in query_params.items() if v is not None}
        if self.as_user_id:
            query_params["user_id"] = self.as_user_id
        if self.as_device_id:
            query_params["org.matrix.msc3202.device_id"] = self.as_device_id
            query_params["device_id"] = self.as_device_id

        if method != Method.GET:
            content = content or {}
            if "Content-Type" not in headers:
                headers["Content-Type"] = "application/json"
            orig_content = content
            is_json = headers.get("Content-Type", None) == "application/json"
            if is_json and isinstance(content, (dict, list)):
                content = json.dumps(content)
        else:
            orig_content = content = None
        full_url = self.base_url.with_path(self._full_path(path), encoded=True)
        req_id = _next_global_req_id()

        if retry_count is None:
            retry_count = self.default_retry_count
        if inspect.isasyncgen(content):
            # Can't retry with non-static body
            retry_count = 0
        do_fake_iter = content and hasattr(content, "__len__") and len(content) > min_iter_size
        if do_fake_iter:
            headers["Content-Length"] = str(len(content))
        backoff = 4
        log_url = full_url.with_query(query_params)
        while True:
            self._log_request(
                method, log_url, content, orig_content, query_params, headers, req_id, sensitive
            )
            API_CALLS.labels(method=metrics_method).inc()
            req_content = async_iter_bytes(content) if do_fake_iter else content
            start = time.monotonic()
            try:
                resp_data, resp = await self._send(
                    method, full_url, req_content, query_params, headers or {}
                )
                self._log_request_done(path, req_id, time.monotonic() - start, resp.status)
                return resp_data
            except MatrixRequestError as e:
                API_CALLS_FAILED.labels(method=metrics_method).inc()
                if retry_count > 0 and e.http_status in (502, 503, 504):
                    self.log.warning(
                        f"Request #{req_id} failed with HTTP {e.http_status}, "
                        f"retrying in {backoff} seconds"
                    )
                else:
                    self._log_request_done(path, req_id, time.monotonic() - start, e.http_status)
                    raise
            except ClientError as e:
                API_CALLS_FAILED.labels(method=metrics_method).inc()
                if retry_count > 0:
                    self.log.warning(
                        f"Request #{req_id} failed with {e}, retrying in {backoff} seconds"
                    )
                else:
                    raise MatrixConnectionError(str(e)) from e
            except Exception:
                API_CALLS_FAILED.labels(method=metrics_method).inc()
                raise
            await asyncio.sleep(backoff)
            backoff *= 2
            retry_count -= 1

    def get_txn_id(self) -> str:
        """Get a new unique transaction ID."""
        self.txn_id += 1
        return f"mautrix-python_{time.time_ns()}_{self.txn_id}"

    def get_download_url(
        self,
        mxc_uri: str,
        download_type: Literal["download", "thumbnail"] = "download",
        file_name: str | None = None,
        authenticated: bool = False,
    ) -> URL:
        """
        Get the full HTTP URL to download a ``mxc://`` URI.

        Args:
            mxc_uri: The MXC URI whose full URL to get.
            download_type: The type of download ("download" or "thumbnail").
            file_name: Optionally, a file name to include in the download URL.
            authenticated: Whether to use the new authenticated download endpoint in Matrix v1.11.

        Returns:
            The full HTTP URL.

        Raises:
            ValueError: If `mxc_uri` doesn't begin with ``mxc://``.

        Examples:
            >>> api = HTTPAPI(base_url="https://matrix-client.matrix.org", ...)
            >>> api.get_download_url("mxc://matrix.org/pqjkOuKZ1ZKRULWXgz2IVZV6")
            "https://matrix-client.matrix.org/_matrix/media/v3/download/matrix.org/pqjkOuKZ1ZKRULWXgz2IVZV6"
            >>> api.get_download_url("mxc://matrix.org/pqjkOuKZ1ZKRULWXgz2IVZV6", file_name="hello.png")
            "https://matrix-client.matrix.org/_matrix/media/v3/download/matrix.org/pqjkOuKZ1ZKRULWXgz2IVZV6/hello.png"
        """
        server_name, media_id = self.parse_mxc_uri(mxc_uri)
        if authenticated:
            url = self.base_url / str(APIPath.CLIENT) / "v1" / "media"
        else:
            url = self.base_url / str(APIPath.MEDIA) / "v3"
        url = url / download_type / server_name / media_id
        if file_name:
            url /= file_name
        return url

    @staticmethod
    def parse_mxc_uri(mxc_uri: str) -> tuple[str, str]:
        """
        Parse a ``mxc://`` URI.

        Args:
            mxc_uri: The MXC URI to parse.

        Returns:
            A tuple containing the server and media ID of the MXC URI.

        Raises:
            ValueError: If `mxc_uri` doesn't begin with ``mxc://``.
        """
        if mxc_uri.startswith("mxc://"):
            server_name, media_id = mxc_uri[6:].split("/")
            return server_name, media_id
        else:
            raise ValueError("MXC URI did not begin with `mxc://`")
