# Copyright 2025 Daytona Platforms Inc.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import json
import re

from websockets.exceptions import ConnectionClosed, ConnectionClosedOK
from websockets.sync.client import connect

from daytona_toolbox_api_client import CreateContextRequest, InterpreterApi, InterpreterContext

from .._utils.errors import intercept_errors
from ..common.code_interpreter import ExecutionError, ExecutionResult, OutputMessage
from ..common.errors import DaytonaConnectionError, DaytonaTimeoutError
from ..common.process import OutputHandler

WEBSOCKET_TIMEOUT_CODE = 4008


class CodeInterpreter:
    """Handles code interpretation and execution within a Sandbox. Currently supports only Python.

    This class provides methods to execute code in isolated interpreter contexts,
    manage contexts, and stream execution output via callbacks. If subsequent code executions
    are performed in the same context, the variables, imports, and functions defined in
    the previous execution will be available.

    For other languages, use the `code_run` method from the `Process` interface,
    or execute the appropriate command directly in the sandbox terminal.
    """

    def __init__(
        self,
        api_client: InterpreterApi,
    ):
        """Initialize a new CodeInterpreter instance.

        Args:
            api_client: API client for interpreter operations.
        """
        self._api_client: InterpreterApi = api_client

    @intercept_errors(message_prefix="Failed to run code: ")
    def run_code(
        self,
        code: str,
        *,
        context: InterpreterContext | None = None,
        on_stdout: OutputHandler[OutputMessage] | None = None,
        on_stderr: OutputHandler[OutputMessage] | None = None,
        on_error: OutputHandler[ExecutionError] | None = None,
        envs: dict[str, str] | None = None,
        timeout: int | None = None,
    ) -> ExecutionResult:
        """Execute Python code in the sandbox.

        By default, code runs in the default shared context which persists variables,
        imports, and functions across executions. To run in an isolated context,
        create a new context with `create_context()` and pass it as `context` argument.

        Args:
            code (str): Code to execute.
            context (InterpreterContext | None): Context to run code in. If not provided, uses default context.
            on_stdout (OutputHandler[OutputMessage] | None): Callback for stdout messages.
            on_stderr (OutputHandler[OutputMessage] | None): Callback for stderr messages.
            on_error (OutputHandler[ExecutionError] | None): Callback for execution errors
                (e.g., syntax errors, runtime errors).
            envs (dict[str, str] | None): Environment variables for this execution.
            timeout (int | None): Timeout in seconds. 0 means no timeout. Default is 10 minutes.

        Returns:
            ExecutionResult: Result object containing stdout, stderr and error if any.

        Raises:
            DaytonaTimeoutError: If execution times out.
            DaytonaError: If execution fails due to communication or other SDK errors.

        Examples:
            ```python
            def handle_stdout(msg: OutputMessage):
                print(f"STDOUT: {msg.output}", end="")

            def handle_stderr(msg: OutputMessage):
                print(f"STDERR: {msg.output}", end="")

            def handle_error(err: ExecutionError):
                print(f"ERROR: {err.name}: {err.value}")

            code = '''
            import sys
            import time
            for i in range(5):
                print(i)
                time.sleep(1)
            sys.stderr.write("Counting done!")
            '''
            result = sandbox.code_interpreter.run_code(
                code=code,
                on_stdout=handle_stdout,
                on_stderr=handle_stderr,
                on_error=handle_error,
                timeout=10
            )
            ```
        """
        _, url, headers, *_ = self._api_client._execute_interpreter_code_serialize(
            _request_auth=None,
            _content_type=None,
            _headers=None,
            _host_index=None,
        )
        url = re.sub(r"^http", "ws", url)

        result = ExecutionResult()

        try:
            with connect(url, additional_headers=headers) as websocket:
                # Send execution request as first message
                request: dict[str, str | int | dict[str, str]] = {"code": code}
                if context is not None:
                    request["contextId"] = context.id
                if envs is not None:
                    request["envs"] = envs
                if timeout is not None:
                    request["timeout"] = timeout

                websocket.send(json.dumps(request))

                # Process streaming chunks
                while True:
                    try:
                        message = websocket.recv()
                        chunk = json.loads(message)
                        chunk_type = chunk.get("type")

                        if chunk_type == "stdout":
                            stdout = chunk.get("text", "")
                            if on_stdout:
                                _ = on_stdout(OutputMessage(output=stdout))
                            result.stdout += stdout

                        elif chunk_type == "stderr":
                            stderr = chunk.get("text", "")
                            if on_stderr:
                                _ = on_stderr(OutputMessage(output=stderr))
                            result.stderr += stderr

                        elif chunk_type == "error":
                            error = ExecutionError(
                                name=chunk.get("name", ""),
                                value=chunk.get("value", ""),
                                traceback=chunk.get("traceback", ""),
                            )
                            result.error = error
                            if on_error:
                                _ = on_error(error)

                    except ConnectionClosed as e:
                        if isinstance(e, ConnectionClosedOK):
                            break
                        self._raise_from_ws_close(e)

        except ConnectionClosed as e:
            if isinstance(e, ConnectionClosedOK):
                return result
            self._raise_from_ws_close(e)

        return result

    @intercept_errors(message_prefix="Failed to create interpreter context: ")
    def create_context(
        self,
        cwd: str | None = None,
    ) -> InterpreterContext:
        """Create a new isolated interpreter context.

        Contexts provide isolated execution environments with their own global namespace.
        Variables, imports, and functions defined in one context don't affect others.

        Args:
            cwd (str | None): Working directory for the context. If not specified, uses sandbox working directory.

        Returns:
            InterpreterContext: The created context with its ID and metadata.

        Raises:
            DaytonaError: If context creation fails.

        Examples:
            ```python
            # Create isolated context
            ctx = sandbox.code_interpreter.create_context()

            # Execute code in this context
            sandbox.code_interpreter.run_code("x = 100", context=ctx)

            # Variable only exists in this context
            result = sandbox.code_interpreter.run_code("print(x)", context=ctx)  # OK

            # Won't see the variable in default context
            result = sandbox.code_interpreter.run_code("print(x)")  # NameError

            # Clean up
            sandbox.code_interpreter.delete_context(ctx)
            ```
        """
        return self._api_client.create_interpreter_context(request=CreateContextRequest(cwd=cwd))

    @intercept_errors(message_prefix="Failed to list interpreter contexts: ")
    def list_contexts(self) -> list[InterpreterContext]:
        """List all user-created interpreter contexts.

        The default context is not included in this list. Only contexts created
        via `create_context()` are returned.

        Returns:
            list[InterpreterContext]: List of context objects.

        Raises:
            DaytonaError: If listing fails.

        Examples:
            ```python
            contexts = sandbox.code_interpreter.list_contexts()
            for ctx in contexts:
                print(f"Context {ctx.id}: {ctx.language} at {ctx.cwd}")
            ```
        """
        return (self._api_client.list_interpreter_contexts()).contexts or []

    @intercept_errors(message_prefix="Failed to delete interpreter context: ")
    def delete_context(self, context: InterpreterContext) -> None:
        """Delete an interpreter context and shut down all associated processes.

        This permanently removes the context and all its state (variables, imports, etc.).
        The default context cannot be deleted.

        Args:
            context (InterpreterContext): Context to delete.

        Raises:
            DaytonaError: If deletion fails or context not found.

        Examples:
            ```python
            ctx = sandbox.code_interpreter.create_context()
            # ... use context ...
            sandbox.code_interpreter.delete_context(ctx)
            ```
        """
        _ = self._api_client.delete_interpreter_context(id=context.id)

    def _raise_from_ws_close(self, error: ConnectionClosed) -> None:
        """Raise the appropriate Daytona timeout or connection error from a websocket close event."""
        code = None
        reason = None
        if error.rcvd is not None:
            code = error.rcvd.code
            reason = error.rcvd.reason
        elif error.sent is not None:
            code = error.sent.code
            reason = error.sent.reason

        if code == WEBSOCKET_TIMEOUT_CODE:
            raise DaytonaTimeoutError(
                "Execution timed out: operation exceeded the configured `timeout`. Provide a larger value if needed."
            )

        detail = reason or "WebSocket connection closed unexpectedly"
        if code is not None:
            detail = f"{detail} (close code {code})"
        raise DaytonaConnectionError(detail)
