"""Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT."""
# @generated-id: 3263d7502030

import re
import json
from dataclasses import dataclass, asdict
from typing import (
    Any,
    Callable,
    Generic,
    TypeVar,
    Optional,
    Generator,
    AsyncGenerator,
    Tuple,
)
import httpx

T = TypeVar("T")


class EventStream(Generic[T]):
    # Holds a reference to the SDK client to avoid it being garbage collected
    # and cause termination of the underlying httpx client.
    client_ref: Optional[object]
    response: httpx.Response
    generator: Generator[T, None, None]
    _closed: bool

    def __init__(
        self,
        response: httpx.Response,
        decoder: Callable[[str], T],
        sentinel: Optional[str] = None,
        client_ref: Optional[object] = None,
        data_required: bool = True,
    ):
        self.response = response
        self.generator = stream_events(
            response, decoder, sentinel, data_required=data_required
        )
        self.client_ref = client_ref
        self._closed = False

    def __iter__(self):
        return self

    def __next__(self):
        if self._closed:
            raise StopIteration
        return next(self.generator)

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self._closed = True
        self.response.close()


class EventStreamAsync(Generic[T]):
    # Holds a reference to the SDK client to avoid it being garbage collected
    # and cause termination of the underlying httpx client.
    client_ref: Optional[object]
    response: httpx.Response
    generator: AsyncGenerator[T, None]
    _closed: bool

    def __init__(
        self,
        response: httpx.Response,
        decoder: Callable[[str], T],
        sentinel: Optional[str] = None,
        client_ref: Optional[object] = None,
        data_required: bool = True,
    ):
        self.response = response
        self.generator = stream_events_async(
            response, decoder, sentinel, data_required=data_required
        )
        self.client_ref = client_ref
        self._closed = False

    def __aiter__(self):
        return self

    async def __anext__(self):
        if self._closed:
            raise StopAsyncIteration
        return await self.generator.__anext__()

    async def __aenter__(self):
        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        self._closed = True
        await self.response.aclose()


@dataclass
class ServerEvent:
    id: Optional[str] = None
    event: Optional[str] = None
    data: Any = None
    retry: Optional[int] = None


MESSAGE_BOUNDARIES = [
    b"\r\n\r\n",
    b"\r\n\r",
    b"\r\n\n",
    b"\r\r\n",
    b"\n\r\n",
    b"\r\r",
    b"\n\r",
    b"\n\n",
]

UTF8_BOM = b"\xef\xbb\xbf"


async def stream_events_async(
    response: httpx.Response,
    decoder: Callable[[str], T],
    sentinel: Optional[str] = None,
    data_required: bool = True,
) -> AsyncGenerator[T, None]:
    buffer = bytearray()
    position = 0
    event_id: Optional[str] = None
    async for chunk in response.aiter_bytes():
        if len(buffer) == 0 and chunk.startswith(UTF8_BOM):
            chunk = chunk[len(UTF8_BOM) :]
        buffer += chunk
        for i in range(position, len(buffer)):
            char = buffer[i : i + 1]
            seq: Optional[bytes] = None
            if char in [b"\r", b"\n"]:
                for boundary in MESSAGE_BOUNDARIES:
                    seq = _peek_sequence(i, buffer, boundary)
                    if seq is not None:
                        break
            if seq is None:
                continue

            block = buffer[position:i]
            position = i + len(seq)
            event, discard, event_id = _parse_event(
                raw=block,
                decoder=decoder,
                sentinel=sentinel,
                event_id=event_id,
                data_required=data_required,
            )
            if event is not None:
                yield event
            if discard:
                await response.aclose()
                return

        if position > 0:
            buffer = buffer[position:]
            position = 0

    event, discard, _ = _parse_event(
        raw=buffer,
        decoder=decoder,
        sentinel=sentinel,
        event_id=event_id,
        data_required=data_required,
    )
    if event is not None:
        yield event


def stream_events(
    response: httpx.Response,
    decoder: Callable[[str], T],
    sentinel: Optional[str] = None,
    data_required: bool = True,
) -> Generator[T, None, None]:
    buffer = bytearray()
    position = 0
    event_id: Optional[str] = None
    for chunk in response.iter_bytes():
        if len(buffer) == 0 and chunk.startswith(UTF8_BOM):
            chunk = chunk[len(UTF8_BOM) :]
        buffer += chunk
        for i in range(position, len(buffer)):
            char = buffer[i : i + 1]
            seq: Optional[bytes] = None
            if char in [b"\r", b"\n"]:
                for boundary in MESSAGE_BOUNDARIES:
                    seq = _peek_sequence(i, buffer, boundary)
                    if seq is not None:
                        break
            if seq is None:
                continue

            block = buffer[position:i]
            position = i + len(seq)
            event, discard, event_id = _parse_event(
                raw=block,
                decoder=decoder,
                sentinel=sentinel,
                event_id=event_id,
                data_required=data_required,
            )
            if event is not None:
                yield event
            if discard:
                response.close()
                return

        if position > 0:
            buffer = buffer[position:]
            position = 0

    event, discard, _ = _parse_event(
        raw=buffer,
        decoder=decoder,
        sentinel=sentinel,
        event_id=event_id,
        data_required=data_required,
    )
    if event is not None:
        yield event


def _parse_event(
    *,
    raw: bytearray,
    decoder: Callable[[str], T],
    sentinel: Optional[str] = None,
    event_id: Optional[str] = None,
    data_required: bool = True,
) -> Tuple[Optional[T], bool, Optional[str]]:
    block = raw.decode()
    lines = re.split(r"\r?\n|\r", block)
    publish = False
    event = ServerEvent()
    data = ""
    for line in lines:
        if not line:
            continue

        delim = line.find(":")
        if delim == 0:
            continue

        field = line
        value = ""
        if delim > 0:
            field = line[0:delim]
            value = line[delim + 1 :] if delim < len(line) - 1 else ""
            if len(value) and value[0] == " ":
                value = value[1:]

        if field == "event":
            event.event = value
            publish = True
        elif field == "data":
            data += value + "\n"
            publish = True
        elif field == "id":
            publish = True
            if "\x00" not in value:
                event_id = value
        elif field == "retry":
            if value.isdigit():
                event.retry = int(value)
            publish = True

    event.id = event_id

    if sentinel and data == f"{sentinel}\n":
        return None, True, event_id

    # Skip data-less events when data is required
    if not data and publish and data_required:
        return None, False, event_id

    if data:
        data = data[:-1]
        try:
            event.data = json.loads(data)
        except json.JSONDecodeError:
            event.data = data

    out = None
    if publish:
        out_dict = {
            k: v
            for k, v in asdict(event).items()
            if v is not None or (k == "data" and data)
        }
        out = decoder(json.dumps(out_dict))

    return out, False, event_id


def _peek_sequence(position: int, buffer: bytearray, sequence: bytes):
    if len(sequence) > (len(buffer) - position):
        return None

    for i, seq in enumerate(sequence):
        if buffer[position + i] != seq:
            return None

    return sequence
