# Copyright 2025 Daytona Platforms Inc.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import re

import websockets
from websockets.sync.client import connect

from daytona_toolbox_api_client import (
    CodeRunRequest,
    Command,
    CreateSessionRequest,
    ExecuteRequest,
    ProcessApi,
    PtyCreateRequest,
    PtyResizeRequest,
    PtySessionInfo,
    Session,
    SessionSendInputRequest,
)

from .._utils.errors import intercept_errors
from .._utils.otel_decorator import with_instrumentation
from .._utils.stream import std_demux_stream
from .._utils.timeout import http_timeout
from ..common.charts import parse_chart
from ..common.process import (
    CodeRunParams,
    ExecuteResponse,
    ExecutionArtifacts,
    OutputHandler,
    SessionCommandLogsResponse,
    SessionExecuteRequest,
    SessionExecuteResponse,
)
from ..common.pty import PtySize
from ..handle.pty_handle import PtyHandle


class Process:
    """Handles process and code execution within a Sandbox."""

    def __init__(
        self,
        language: str,
        api_client: ProcessApi,
    ):
        """Initialize a new Process instance.

        Args:
            api_client (ProcessApi): API client for process operations.
        """
        self._language: str = language
        self._api_client: ProcessApi = api_client

    @intercept_errors(message_prefix="Failed to execute command: ")
    @with_instrumentation()
    def exec(
        self,
        command: str,
        cwd: str | None = None,
        env: dict[str, str] | None = None,
        timeout: int | None = None,
    ) -> ExecuteResponse:
        """Execute a shell command in the Sandbox.

        Args:
            command (str): Shell command to execute.
            cwd (str | None): Working directory for command execution. If not
                specified, uses the sandbox working directory.
            env (dict[str, str] | None): Environment variables to set for the command.
            timeout (int | None): Maximum time in seconds to wait for the command
                to complete.

        Returns:
            ExecuteResponse: Command execution results containing:
                - exit_code: The command's exit status
                - result: Standard output from the command
                - artifacts: ExecutionArtifacts object containing `stdout` (same as result)
                and `charts` (matplotlib charts metadata)

        Example:
            ```python
            # Simple command
            response = sandbox.process.exec("echo 'Hello'")
            print(response.artifacts.stdout)  # Prints: Hello

            # Command with working directory
            result = sandbox.process.exec("ls", cwd="workspace/src")

            # Command with timeout
            result = sandbox.process.exec("sleep 10", timeout=5)
            ```
        """
        execute_request = ExecuteRequest(command=command, cwd=cwd, timeout=timeout, envs=env)

        response = self._api_client.execute_command(
            request=execute_request,
            _request_timeout=http_timeout(timeout + 5 if timeout else None),
        )

        result = response.result or ""
        artifacts = ExecutionArtifacts(stdout=result, charts=[])

        return ExecuteResponse.model_construct(
            exit_code=(
                response.exit_code if response.exit_code is not None else response.additional_properties.get("code")
            ),
            result=result,
            artifacts=artifacts,
            additional_properties=response.additional_properties,
        )

    @with_instrumentation()
    def code_run(
        self,
        code: str,
        params: CodeRunParams | None = None,
        timeout: int | None = None,
    ) -> ExecuteResponse:
        """Executes code in the Sandbox using the appropriate language runtime.

        Args:
            code (str): Code to execute.
            params (CodeRunParams | None): Parameters for code execution.
            timeout (int | None): Maximum time in seconds to wait for the code
                to complete.

        Returns:
            ExecuteResponse: Code execution result containing:
                - exit_code: The execution's exit status
                - result: Standard output from the code
                - artifacts: ExecutionArtifacts object containing `stdout` (same as result)
                and `charts` (matplotlib charts metadata)

        Example:
            ```python
            # Run Python code
            response = sandbox.process.code_run('''
                x = 10
                y = 20
                print(f"Sum: {x + y}")
            ''')
            print(response.artifacts.stdout)  # Prints: Sum: 30
            ```

            Matplotlib charts are automatically detected and returned in the `charts` field
            of the `ExecutionArtifacts` object.
            ```python
            code = '''
            import matplotlib.pyplot as plt
            import numpy as np

            x = np.linspace(0, 10, 30)
            y = np.sin(x)

            plt.figure(figsize=(8, 5))
            plt.plot(x, y, 'b-', linewidth=2)
            plt.title('Line Chart')
            plt.xlabel('X-axis (seconds)')
            plt.ylabel('Y-axis (amplitude)')
            plt.grid(True)
            plt.show()
            '''

            response = sandbox.process.code_run(code)
            chart = response.artifacts.charts[0]

            print(f"Type: {chart.type}")
            print(f"Title: {chart.title}")
            if chart.type == ChartType.LINE and isinstance(chart, LineChart):
                print(f"X Label: {chart.x_label}")
                print(f"Y Label: {chart.y_label}")
                print(f"X Ticks: {chart.x_ticks}")
                print(f"X Tick Labels: {chart.x_tick_labels}")
                print(f"X Scale: {chart.x_scale}")
                print(f"Y Ticks: {chart.y_ticks}")
                print(f"Y Tick Labels: {chart.y_tick_labels}")
                print(f"Y Scale: {chart.y_scale}")
                print("Elements:")
                for element in chart.elements:
                    print(f"Label: {element.label}")
                    print(f"Points: {element.points}")
            ```
        """
        code_run_params = params or CodeRunParams()
        code_run_request = CodeRunRequest(
            code=code,
            language=self._language,
            argv=code_run_params.argv,
            envs=code_run_params.env,
            timeout=timeout,
        )

        response = self._api_client.code_run(
            request=code_run_request,
            _request_timeout=http_timeout(timeout + 5 if timeout else None),
        )

        stdout = response.result or ""
        charts = []
        if response.artifacts and response.artifacts.charts:
            charts = [parse_chart(chart) for chart in response.artifacts.charts]
        artifacts = ExecutionArtifacts(stdout=stdout, charts=charts)

        # TODO: Remove model_construct once everything is migrated to pydantic # pylint: disable=fixme
        return ExecuteResponse.model_construct(
            exit_code=(
                response.exit_code if response.exit_code is not None else response.additional_properties.get("code")
            ),
            result=stdout,
            artifacts=artifacts,
            additional_properties=response.additional_properties,
        )

    @intercept_errors(message_prefix="Failed to create session: ")
    @with_instrumentation()
    def create_session(self, session_id: str) -> None:
        """Creates a new long-running background session in the Sandbox.

        Sessions are background processes that maintain state between commands, making them ideal for
        scenarios requiring multiple related commands or persistent environment setup. You can run
        long-running commands and monitor process status.

        Args:
            session_id (str): Unique identifier for the new session.

        Example:
            ```python
            # Create a new session
            session_id = "my-session"
            sandbox.process.create_session(session_id)
            session = sandbox.process.get_session(session_id)
            # Do work...
            sandbox.process.delete_session(session_id)
            ```
        """
        request = CreateSessionRequest(session_id=session_id)
        self._api_client.create_session(request=request)

    @intercept_errors(message_prefix="Failed to get session: ")
    def get_session(self, session_id: str) -> Session:
        """Gets a session in the Sandbox.

        Args:
            session_id (str): Unique identifier of the session to retrieve.

        Returns:
            Session: Session information including:
                - session_id: The session's unique identifier
                - commands: List of commands executed in the session

        Example:
            ```python
            session = sandbox.process.get_session("my-session")
            for cmd in session.commands:
                print(f"Command: {cmd.command}")
            ```
        """
        return self._api_client.get_session(session_id=session_id)

    @intercept_errors(message_prefix="Failed to get sandbox entrypoint session: ")
    def get_entrypoint_session(self) -> Session:
        """Gets the sandbox entrypoint session.

        Returns:
            Session: Entrypoint session information including:
                - session_id: The entrypoint session's unique identifier
                - commands: List of commands executed in the entrypoint session

        Example:
            ```python
            session = sandbox.process.get_entrypoint_session()
            for cmd in session.commands:
                print(f"Command: {cmd.command}")
            ```
        """
        return self._api_client.get_entrypoint_session()

    @intercept_errors(message_prefix="Failed to get session command: ")
    @with_instrumentation()
    def get_session_command(self, session_id: str, command_id: str) -> Command:
        """Gets information about a specific command executed in a session.

        Args:
            session_id (str): Unique identifier of the session.
            command_id (str): Unique identifier of the command.

        Returns:
            Command: Command information including:
                - id: The command's unique identifier
                - command: The executed command string
                - exit_code: Command's exit status (if completed)

        Example:
            ```python
            cmd = sandbox.process.get_session_command("my-session", "cmd-123")
            if cmd.exit_code == 0:
                print(f"Command {cmd.command} completed successfully")
            ```
        """
        return self._api_client.get_session_command(session_id=session_id, command_id=command_id)

    @intercept_errors(message_prefix="Failed to execute session command: ")
    @with_instrumentation()
    def execute_session_command(
        self,
        session_id: str,
        req: SessionExecuteRequest,
        timeout: int | None = None,
    ) -> SessionExecuteResponse:
        """Executes a command in the session.

        Args:
            session_id (str): Unique identifier of the session to use.
            req (SessionExecuteRequest): Command execution request containing:
                - command: The command to execute
                - run_async: Whether to execute asynchronously

        Returns:
            SessionExecuteResponse: Command execution results containing:
                - cmd_id: Unique identifier for the executed command
                - output: Combined command output (stdout and stderr) (if synchronous execution)
                - stdout: Standard output from the command
                - stderr: Standard error from the command
                - exit_code: Command exit status (if synchronous execution)

        Example:
            ```python
            # Execute commands in sequence, maintaining state
            session_id = "my-session"

            # Change directory
            req = SessionExecuteRequest(command="cd /workspace")
            sandbox.process.execute_session_command(session_id, req)

            # Create a file
            req = SessionExecuteRequest(command="echo 'Hello' > test.txt")
            sandbox.process.execute_session_command(session_id, req)

            # Read the file
            req = SessionExecuteRequest(command="cat test.txt")
            result = sandbox.process.execute_session_command(session_id, req)
            print(f"Command stdout: {result.stdout}")
            print(f"Command stderr: {result.stderr}")
            ```
        """
        response = self._api_client.session_execute_command(
            session_id=session_id,
            request=req,
            _request_timeout=http_timeout(timeout + 5 if timeout else None),
        )

        return SessionExecuteResponse.model_construct(
            cmd_id=response.cmd_id,
            output=response.output,
            stdout=response.stdout or "",
            stderr=response.stderr or "",
            exit_code=response.exit_code,
            additional_properties=response.additional_properties,
        )

    @intercept_errors(message_prefix="Failed to get session command logs: ")
    @with_instrumentation()
    def get_session_command_logs(self, session_id: str, command_id: str) -> SessionCommandLogsResponse:
        """Get the logs for a command executed in a session.

        Args:
            session_id (str): Unique identifier of the session.
            command_id (str): Unique identifier of the command.

        Returns:
            SessionCommandLogsResponse: Command logs including:
                - output: Combined command output (stdout and stderr)
                - stdout: Standard output from the command
                - stderr: Standard error from the command

        Example:
            ```python
            logs = sandbox.process.get_session_command_logs(
                "my-session",
                "cmd-123"
            )
            print(f"Command stdout: {logs.stdout}")
            print(f"Command stderr: {logs.stderr}")
            ```
        """
        response = self._api_client.get_session_command_logs(session_id=session_id, command_id=command_id)

        return SessionCommandLogsResponse(output=response.output, stdout=response.stdout, stderr=response.stderr)

    @intercept_errors(message_prefix="Failed to get session command logs: ")
    async def get_session_command_logs_async(
        self, session_id: str, command_id: str, on_stdout: OutputHandler[str], on_stderr: OutputHandler[str]
    ) -> None:
        """Asynchronously retrieves and processes the logs for a command executed in a session as they become available.

        Accepts both sync and async callbacks. Async callbacks are awaited.
        Blocking synchronous operations inside callbacks may cause WebSocket
        disconnections — use async callbacks and async libraries to avoid this.

        Args:
            session_id (str): Unique identifier of the session.
            command_id (str): Unique identifier of the command.
            on_stdout (OutputHandler[str]): Callback function to handle stdout log chunks as they arrive.
            on_stderr (OutputHandler[str]): Callback function to handle stderr log chunks as they arrive.

        Example:
            ```python
            await sandbox.process.get_session_command_logs_async(
                "my-session",
                "cmd-123",
                lambda log: print(f"[STDOUT]: {log}"),
                lambda log: print(f"[STDERR]: {log}"),
            )
            ```
        """

        _, url, headers, *_ = self._api_client._get_session_command_logs_serialize(
            session_id=session_id,
            command_id=command_id,
            follow=True,
            _request_auth=None,
            _content_type=None,
            _headers=None,
            _host_index=None,
        )

        url = re.sub(r"^http", "ws", url)

        async with websockets.connect(url, additional_headers=headers) as ws:
            await std_demux_stream(ws, on_stdout, on_stderr)

    @intercept_errors(message_prefix="Failed to get entrypoint logs: ")
    @with_instrumentation()
    def get_entrypoint_logs(self) -> SessionCommandLogsResponse:
        """Get the logs for the entrypoint session.

        Returns:
            SessionCommandLogsResponse: Command logs including:
                - output: Combined command output (stdout and stderr)
                - stdout: Standard output from the command
                - stderr: Standard error from the command

        Example:
            ```python
            logs = sandbox.process.get_entrypoint_logs()
            print(f"Command stdout: {logs.stdout}")
            print(f"Command stderr: {logs.stderr}")
            ```
        """
        response = self._api_client.get_entrypoint_logs()

        return SessionCommandLogsResponse(output=response.output, stdout=response.stdout, stderr=response.stderr)

    @intercept_errors(message_prefix="Failed to get entrypoint logs: ")
    async def get_entrypoint_logs_async(self, on_stdout: OutputHandler[str], on_stderr: OutputHandler[str]) -> None:
        """Asynchronously retrieves and processes the logs for the entrypoint session as they become available.

        Args:
            on_stdout OutputHandler[str]: Callback function to handle stdout log chunks as they arrive.
            on_stderr OutputHandler[str]: Callback function to handle stderr log chunks as they arrive.

        Example:
            ```python
            await sandbox.process.get_entrypoint_logs_async(
                lambda log: print(f"[STDOUT]: {log}"),
                lambda log: print(f"[STDERR]: {log}"),
            )
            ```
        """

        _, url, headers, *_ = self._api_client._get_entrypoint_logs_serialize(
            follow=True,
            _request_auth=None,
            _content_type=None,
            _headers=None,
            _host_index=None,
        )

        url = re.sub(r"^http", "ws", url)

        async with websockets.connect(url, additional_headers=headers) as ws:
            await std_demux_stream(ws, on_stdout, on_stderr)

    @intercept_errors(message_prefix="Failed to send session command input: ")
    def send_session_command_input(self, session_id: str, command_id: str, data: str) -> None:
        """Sends input data to a command executed in a session.

        Args:
            session_id (str): Unique identifier of the session.
            command_id (str): Unique identifier of the command.
            data (str): Input data to send.
        """
        self._api_client.send_input(
            session_id=session_id, command_id=command_id, request=SessionSendInputRequest(data=data)
        )

    @intercept_errors(message_prefix="Failed to list sessions: ")
    @with_instrumentation()
    def list_sessions(self) -> list[Session]:
        """Lists all sessions in the Sandbox.

        Returns:
            list[Session]: List of all sessions in the Sandbox.

        Example:
            ```python
            sessions = sandbox.process.list_sessions()
            for session in sessions:
                print(f"Session {session.session_id}:")
                print(f"  Commands: {len(session.commands)}")
            ```
        """
        return self._api_client.list_sessions()

    @intercept_errors(message_prefix="Failed to delete session: ")
    @with_instrumentation()
    def delete_session(self, session_id: str) -> None:
        """Terminates and removes a session from the Sandbox, cleaning up any resources
        associated with it.

        Args:
            session_id (str): Unique identifier of the session to delete.

        Example:
            ```python
            # Create and use a session
            sandbox.process.create_session("temp-session")
            # ... use the session ...

            # Clean up when done
            sandbox.process.delete_session("temp-session")
            ```
        """
        self._api_client.delete_session(session_id=session_id)

    @intercept_errors(message_prefix="Failed to create PTY session: ")
    @with_instrumentation()
    def create_pty_session(
        self,
        id: str,
        cwd: str | None = None,
        envs: dict[str, str] | None = None,
        pty_size: PtySize | None = None,
    ) -> PtyHandle:
        """Creates a new PTY (pseudo-terminal) session in the Sandbox.

        Creates an interactive terminal session that can execute commands and handle user input.
        The PTY session behaves like a real terminal, supporting features like command history.

        Args:
            id: Unique identifier for the PTY session. Must be unique within the Sandbox.
            cwd: Working directory for the PTY session. Defaults to the sandbox's working directory.
            env: Environment variables to set in the PTY session. These will be merged with
                the Sandbox's default environment variables.
            pty_size: Terminal size configuration. Defaults to 80x24 if not specified.

        Returns:
            PtyHandle: Handle for managing the created PTY session. Use this to send input,
                           receive output, resize the terminal, and manage the session lifecycle.

        Raises:
            DaytonaError: If the PTY session creation fails or the session ID is already in use.
        """
        response = self._api_client.create_pty_session(
            request=PtyCreateRequest(
                id=id,
                cwd=cwd,
                envs=envs,
                cols=pty_size.cols if pty_size else None,
                rows=pty_size.rows if pty_size else None,
                lazy_start=True,
            ),
        )

        return self.connect_pty_session(
            response.session_id,
        )

    @intercept_errors(message_prefix="Failed to connect PTY session: ")
    @with_instrumentation()
    def connect_pty_session(
        self,
        session_id: str,
    ) -> PtyHandle:
        """Connects to an existing PTY session in the Sandbox.

        Establishes a WebSocket connection to an existing PTY session, allowing you to
        interact with a previously created terminal session.

        Args:
            session_id: Unique identifier of the PTY session to connect to.

        Returns:
            PtyHandle: Handle for managing the connected PTY session.

        Raises:
            DaytonaError: If the PTY session doesn't exist or connection fails.
        """
        _, url, headers, *_ = self._api_client._connect_pty_session_serialize(
            session_id=session_id,
            _request_auth=None,
            _content_type=None,
            _headers=None,
            _host_index=None,
        )
        url = re.sub(r"^http", "ws", url)

        ws = connect(url, additional_headers=headers)

        # Create resize and kill handlers
        def resize_handler(pty_size: PtySize) -> PtySessionInfo:
            return self.resize_pty_session(session_id, pty_size)

        def kill_handler() -> None:
            self.kill_pty_session(session_id)

        handle = PtyHandle(
            ws,
            session_id=session_id,
            handle_resize=resize_handler,
            handle_kill=kill_handler,
        )
        handle.wait_for_connection()
        return handle

    @intercept_errors(message_prefix="Failed to list PTY sessions: ")
    @with_instrumentation()
    def list_pty_sessions(self) -> list[PtySessionInfo]:
        """Lists all PTY sessions in the Sandbox.

        Retrieves information about all PTY sessions in this Sandbox.

        Returns:
            list[PtySessionInfo]: List of PTY session information objects containing
                                details about each session's state, creation time, and configuration.

        Example:
            ```python
            # List all PTY sessions
            sessions = sandbox.process.list_pty_sessions()

            for session in sessions:
                print(f"Session ID: {session.id}")
                print(f"Active: {session.active}")
                print(f"Created: {session.created_at}")
            ```
        """
        return (self._api_client.list_pty_sessions()).sessions

    @intercept_errors(message_prefix="Failed to get PTY session info: ")
    @with_instrumentation()
    def get_pty_session_info(self, session_id: str) -> PtySessionInfo:
        """Gets detailed information about a specific PTY session.

        Retrieves comprehensive information about a PTY session including its current state,
        configuration, and metadata.

        Args:
            session_id: Unique identifier of the PTY session to retrieve information for.

        Returns:
            PtySessionInfo: Detailed information about the PTY session including ID, state,
                           creation time, working directory, environment variables, and more.

        Raises:
            DaytonaError: If the PTY session doesn't exist.

        Example:
            ```python
            # Get details about a specific PTY session
            session_info = sandbox.process.get_pty_session_info("my-session")

            print(f"Session ID: {session_info.id}")
            print(f"Active: {session_info.active}")
            print(f"Working Directory: {session_info.cwd}")
            print(f"Terminal Size: {session_info.cols}x{session_info.rows}")
            ```
        """
        return self._api_client.get_pty_session(session_id=session_id)

    @intercept_errors(message_prefix="Failed to kill PTY session: ")
    @with_instrumentation()
    def kill_pty_session(self, session_id: str) -> None:
        """Kills a PTY session and terminates its associated process.

        Forcefully terminates the PTY session and cleans up all associated resources.
        This will close any active connections and kill the underlying shell process.
        This operation is irreversible. Any unsaved work in the terminal session will be lost.

        Args:
            session_id: Unique identifier of the PTY session to kill.

        Raises:
            DaytonaError: If the PTY session doesn't exist or cannot be killed.

        Example:
            ```python
            # Kill a specific PTY session
            sandbox.process.kill_pty_session("my-session")

            # Verify the session no longer exists
            pty_sessions = sandbox.process.list_pty_sessions()
            for pty_session in pty_sessions:
                print(f"PTY session: {pty_session.id}")
            ```
        """
        _ = self._api_client.delete_pty_session(session_id=session_id)

    @intercept_errors(message_prefix="Failed to resize PTY session: ")
    @with_instrumentation()
    def resize_pty_session(self, session_id: str, pty_size: PtySize) -> PtySessionInfo:
        """Resizes a PTY session's terminal dimensions.

        Changes the terminal size of an active PTY session. This is useful when the
        client terminal is resized or when you need to adjust the display for different
        output requirements.

        Args:
            session_id: Unique identifier of the PTY session to resize.
            pty_size: New terminal dimensions containing the desired columns and rows.

        Returns:
            PtySessionInfo: Updated session information reflecting the new terminal size.

        Raises:
            DaytonaError: If the PTY session doesn't exist or resize operation fails.

        Example:
            ```python
            from daytona.common.pty import PtySize

            # Resize a PTY session to a larger terminal
            new_size = PtySize(rows=40, cols=150)
            updated_info = sandbox.process.resize_pty_session("my-session", new_size)

            print(f"Terminal resized to {updated_info.cols}x{updated_info.rows}")

            # You can also use the PtyHandle's resize method
            pty_handle.resize(new_size)
            ```
        """
        return self._api_client.resize_pty_session(
            session_id=session_id,
            request=PtyResizeRequest(cols=pty_size.cols, rows=pty_size.rows),
        )
