import ssl

import sniffio
from httpcore import (
    AsyncConnectionPool,
    Origin,
    AsyncConnectionInterface,
    Request,
    Response,
    default_ssl_context,
    AsyncHTTP11Connection,
    ConnectionNotAvailable,
)
from httpcore import AsyncNetworkStream
from httpcore._synchronization import AsyncLock
from python_socks import ProxyType, parse_proxy_url


class AsyncProxy(AsyncConnectionPool):
    def __init__(
        self,
        *,
        proxy_type: ProxyType,
        proxy_host: str,
        proxy_port: int,
        username=None,
        password=None,
        rdns=None,
        proxy_ssl: ssl.SSLContext = None,
        loop=None,
        **kwargs,
    ):
        self._proxy_type = proxy_type
        self._proxy_host = proxy_host
        self._proxy_port = proxy_port
        self._username = username
        self._password = password
        self._rdns = rdns
        self._proxy_ssl = proxy_ssl
        self._loop = loop

        super().__init__(**kwargs)

    def create_connection(self, origin: Origin) -> AsyncConnectionInterface:
        return AsyncProxyConnection(
            proxy_type=self._proxy_type,
            proxy_host=self._proxy_host,
            proxy_port=self._proxy_port,
            username=self._username,
            password=self._password,
            rdns=self._rdns,
            proxy_ssl=self._proxy_ssl,
            loop=self._loop,
            remote_origin=origin,
            ssl_context=self._ssl_context,
            keepalive_expiry=self._keepalive_expiry,
            http1=self._http1,
            http2=self._http2,
        )

    @classmethod
    def from_url(cls, url, **kwargs):
        proxy_type, host, port, username, password = parse_proxy_url(url)
        return cls(
            proxy_type=proxy_type,
            proxy_host=host,
            proxy_port=port,
            username=username,
            password=password,
            **kwargs,
        )


class AsyncProxyConnection(AsyncConnectionInterface):
    def __init__(
        self,
        *,
        proxy_type: ProxyType,
        proxy_host: str,
        proxy_port: int,
        username=None,
        password=None,
        rdns=None,
        proxy_ssl: ssl.SSLContext = None,
        loop=None,
        remote_origin: Origin,
        ssl_context: ssl.SSLContext,
        keepalive_expiry: float = None,
        http1: bool = True,
        http2: bool = False,
    ) -> None:

        if ssl_context is None:  # pragma: no cover
            ssl_context = default_ssl_context()

        self._proxy_type = proxy_type
        self._proxy_host = proxy_host
        self._proxy_port = proxy_port
        self._username = username
        self._password = password
        self._rdns = rdns
        self._proxy_ssl = proxy_ssl
        self._loop = loop

        self._remote_origin = remote_origin
        self._ssl_context = ssl_context
        self._keepalive_expiry = keepalive_expiry
        self._http1 = http1
        self._http2 = http2

        self._connect_lock = AsyncLock()
        self._connection = None
        self._connect_failed: bool = False

    async def handle_async_request(self, request: Request) -> Response:
        timeouts = request.extensions.get('timeout', {})
        timeout = timeouts.get('connect', None)

        try:
            async with self._connect_lock:
                if self._connection is None:
                    stream = await self._connect_via_proxy(
                        origin=self._remote_origin,
                        connect_timeout=timeout,
                    )

                    ssl_object = stream.get_extra_info("ssl_object")
                    http2_negotiated = (
                        ssl_object is not None and ssl_object.selected_alpn_protocol() == "h2"
                    )
                    if http2_negotiated or (self._http2 and not self._http1):
                        from httpcore import AsyncHTTP2Connection

                        self._connection = AsyncHTTP2Connection(
                            origin=self._remote_origin,
                            stream=stream,
                            keepalive_expiry=self._keepalive_expiry,
                        )
                    else:
                        self._connection = AsyncHTTP11Connection(
                            origin=self._remote_origin,
                            stream=stream,
                            keepalive_expiry=self._keepalive_expiry,
                        )
                elif not self._connection.is_available():  # pragma: no cover
                    raise ConnectionNotAvailable()
        except BaseException as exc:
            self._connect_failed = True
            raise exc

        return await self._connection.handle_async_request(request)

    async def _connect_via_proxy(self, origin, connect_timeout) -> AsyncNetworkStream:
        scheme, hostname, port = origin.scheme, origin.host, origin.port

        ssl_context = self._ssl_context if scheme == b'https' else None
        host = hostname.decode('ascii')  # ?

        return await self._open_stream(
            host=host,
            port=port,
            connect_timeout=connect_timeout,
            ssl_context=ssl_context,
        )

    async def _open_stream(self, host, port, connect_timeout, ssl_context):
        backend = sniffio.current_async_library()

        if backend == 'asyncio':
            return await self._open_aio_stream(host, port, connect_timeout, ssl_context)

        if backend == 'trio':
            return await self._open_trio_stream(host, port, connect_timeout, ssl_context)

        # Curio support has been dropped in httpcore 0.14.0
        # if backend == 'curio':
        #     return await self._open_curio_stream(host, port, connect_timeout, ssl_context)

        raise RuntimeError(f'Unsupported concurrency backend {backend!r}')  # pragma: no cover

    async def _open_aio_stream(self, host, port, connect_timeout, ssl_context):
        from httpcore._backends.anyio import AnyIOStream
        from python_socks.async_.anyio import Proxy

        proxy = Proxy.create(
            proxy_type=self._proxy_type,
            host=self._proxy_host,
            port=self._proxy_port,
            username=self._username,
            password=self._password,
            rdns=self._rdns,
            proxy_ssl=self._proxy_ssl,
        )

        proxy_stream = await proxy.connect(
            host,
            port,
            dest_ssl=ssl_context,
            timeout=connect_timeout,
        )

        return AnyIOStream(proxy_stream.anyio_stream)

    async def _open_trio_stream(self, host, port, connect_timeout, ssl_context):
        from httpcore._backends.trio import TrioStream
        from python_socks.async_.trio.v2 import Proxy

        proxy = Proxy.create(
            proxy_type=self._proxy_type,
            host=self._proxy_host,
            port=self._proxy_port,
            username=self._username,
            password=self._password,
            rdns=self._rdns,
            proxy_ssl=self._proxy_ssl,
        )

        proxy_stream = await proxy.connect(
            host,
            port,
            dest_ssl=ssl_context,
            timeout=connect_timeout,
        )

        return TrioStream(proxy_stream.trio_stream)

    async def aclose(self) -> None:
        if self._connection is not None:
            await self._connection.aclose()

    def can_handle_request(self, origin: Origin) -> bool:
        return origin == self._remote_origin

    def is_available(self) -> bool:
        if self._connection is None:  # pragma: no cover
            # return self._http2 and (self._remote_origin.scheme == b"https" or not self._http1)
            return False
        return self._connection.is_available()

    def has_expired(self) -> bool:
        if self._connection is None:
            return self._connect_failed
        return self._connection.has_expired()

    def is_idle(self) -> bool:
        if self._connection is None:
            return self._connect_failed
        return self._connection.is_idle()

    def is_closed(self) -> bool:
        if self._connection is None:
            return self._connect_failed
        return self._connection.is_closed()

    def info(self) -> str:  # pragma: no cover
        if self._connection is None:
            return "CONNECTION FAILED" if self._connect_failed else "CONNECTING"
        return self._connection.info()
