#!/usr/bin/env python3
"""
FastAI remote runner — friend-side WebSocket client.

Подключается к нашему backend'у по WS, авторизуется runner-токеном,
держит соединение и отвечает на heartbeat. На этом этапе — только handshake
и ping/pong; обработка query() придёт следующим шагом.

Запуск:
    FASTAI_URL=wss://your-fastai.example.com \\
    FASTAI_RUNNER_TOKEN=<token-from-ui> \\
    python3 runner.py

Зависимости: websockets>=12.

Behaviour:
- При starup: WS connect + hello handshake.
- Heartbeat: ping каждые HEARTBEAT_SEC, ждём pong в течение PONG_TIMEOUT.
- Если сервер закрыл соединение или сеть упала: reconnect с
  экспоненциальным backoff (1s → 2s → 4s → ... cap 60s), jitter ±20%.
- SIGINT / SIGTERM → graceful close.
"""
from __future__ import annotations

import asyncio
import json
import logging
import os
import random
import signal
import sys
import time
from typing import Any, Optional
from urllib.parse import urlparse, urlunparse

try:
    import websockets
    from websockets.exceptions import ConnectionClosed, InvalidStatusCode
except ImportError:
    print("[FATAL] нужно установить websockets: pip install 'websockets>=12'", file=sys.stderr)
    sys.exit(2)

try:
    from claude_agent_sdk import query as _sdk_query
except ImportError:
    print("[FATAL] нужно установить claude-agent-sdk: pip install claude-agent-sdk", file=sys.stderr)
    sys.exit(2)

# Sibling-модуль protocol (runner/protocol.py). Импортируем относительно
# чтобы скрипт работал и как `python runner.py`, и как `python -m runner.runner`.
try:
    from . import protocol as _rprot
except ImportError:
    # запуск напрямую: добавим директорию скрипта в sys.path
    sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
    import protocol as _rprot  # type: ignore


HEARTBEAT_SEC = float(os.environ.get("FASTAI_HEARTBEAT_SEC", "20"))
PONG_TIMEOUT = float(os.environ.get("FASTAI_PONG_TIMEOUT", "10"))
WORKDIR_BASE = os.path.expanduser(
    os.environ.get("FASTAI_WORKDIR_BASE", "~/fastai-runner/workdir")
)
BACKOFF_MIN = 1.0
BACKOFF_MAX = 60.0


def _build_ws_url(base: str, token: str) -> str:
    """http(s)://host/... → ws(s)://host/ws/runner?token=..."""
    p = urlparse(base)
    if p.scheme in ("http", "ws"):
        scheme = "ws"
    elif p.scheme in ("https", "wss"):
        scheme = "wss"
    else:
        raise ValueError(f"unsupported scheme in FASTAI_URL: {p.scheme!r}")
    return urlunparse((scheme, p.netloc, "/ws/runner", "", f"token={token}", ""))


class RunnerClient:
    def __init__(self, url: str, log: logging.Logger):
        self.url = url
        self.log = log
        self.runner_id: Optional[str] = None
        self.stop_event = asyncio.Event()
        self._last_pong_at: float = 0.0
        # qid → asyncio.Task запущенного _run_query; используется
        # для отмены при получении query_cancel и cleanup при disconnect.
        self._active_queries: dict[str, asyncio.Task] = {}
        # Параметры anthropic-proxy, присланные сервером в hello. Используются
        # для подмены ANTHROPIC_BASE_URL / CLAUDE_CODE_OAUTH_TOKEN в env CLI-
        # subprocess'а перед запуском SDK query — так настоящий OAuth-токен
        # пользователя никогда не оседает на friend-side.
        self.proxy_base_url: Optional[str] = None
        self.proxy_token: Optional[str] = None
        # Сериализация патча os.environ — конкурентные query будут идти
        # последовательно. Для текущего MVP (одна сессия = один query) это
        # достаточно; параллельность можно вернуть когда SDK научится
        # принимать env как параметр.
        self._sdk_env_lock = asyncio.Lock()

    async def run_forever(self) -> None:
        """Главный цикл: connect → loop → reconnect с backoff."""
        backoff = BACKOFF_MIN
        while not self.stop_event.is_set():
            try:
                await self._connect_once()
                # чистый close сервером → reconnect без задержки
                backoff = BACKOFF_MIN
            except InvalidStatusCode as e:
                # 4401 = bad token, 4403 = forbidden, etc — фатально, нет смысла перевыкручиваться
                if 4400 <= e.status_code < 4500:
                    self.log.error("auth/protocol error %s — abort: %s", e.status_code, e)
                    return
                self.log.warning("HTTP %s on handshake: %s", e.status_code, e)
            except (OSError, ConnectionClosed) as e:
                self.log.warning("connection lost: %s", e)
            except Exception as e:
                self.log.exception("unexpected error in connection loop: %s", e)

            if self.stop_event.is_set():
                return
            sleep_for = backoff * (1 + random.uniform(-0.2, 0.2))
            self.log.info("reconnect in %.1fs", sleep_for)
            try:
                await asyncio.wait_for(self.stop_event.wait(), timeout=sleep_for)
                return
            except asyncio.TimeoutError:
                pass
            backoff = min(backoff * 2, BACKOFF_MAX)

    async def _connect_once(self) -> None:
        """Одна попытка connect+hello+heartbeat-loop. Возврат = clean close."""
        self.log.info("connecting to %s", _redact_token(self.url))
        async with websockets.connect(self.url, max_size=None) as ws:
            # ждём hello
            hello_raw = await asyncio.wait_for(ws.recv(), timeout=15)
            hello = json.loads(hello_raw)
            if hello.get("type") != "hello":
                raise RuntimeError(f"expected hello, got {hello!r}")
            self.runner_id = hello.get("runner_id")
            self.proxy_base_url = hello.get("proxy_base_url") or None
            self.proxy_token = hello.get("proxy_token") or None
            self.log.info(
                "connected: runner_id=%s server_time=%s proxy=%s",
                self.runner_id, hello.get("server_time"),
                "on" if (self.proxy_base_url and self.proxy_token) else "off",
            )
            self._last_pong_at = time.time()

            # параллельно: heartbeat sender и message receiver
            send_task = asyncio.create_task(self._heartbeat_sender(ws))
            recv_task = asyncio.create_task(self._receiver(ws))
            try:
                done, pending = await asyncio.wait(
                    [send_task, recv_task],
                    return_when=asyncio.FIRST_COMPLETED,
                )
                for t in pending:
                    t.cancel()
                for t in done:
                    exc = t.exception()
                    if exc:
                        raise exc
            finally:
                for t in (send_task, recv_task):
                    if not t.done():
                        t.cancel()
                        try:
                            await t
                        except (asyncio.CancelledError, Exception):
                            pass
                # Отменяем все live query'и на этом соединении: сервер всё равно
                # уже выкинул их по disconnect-bail-out'у, а локальный SDK висит.
                for qid, qt in list(self._active_queries.items()):
                    if not qt.done():
                        qt.cancel()
                        try:
                            await qt
                        except (asyncio.CancelledError, Exception):
                            pass
                self._active_queries.clear()

    async def _heartbeat_sender(self, ws) -> None:
        """Каждые HEARTBEAT_SEC отправляем ping, проверяем что pong не опаздывает."""
        while not self.stop_event.is_set():
            await asyncio.sleep(HEARTBEAT_SEC)
            # pong timeout check
            if time.time() - self._last_pong_at > HEARTBEAT_SEC + PONG_TIMEOUT:
                self.log.warning("pong timeout — closing")
                await ws.close(code=4001, reason="pong timeout")
                return
            await ws.send(json.dumps({"type": "ping", "ts": int(time.time())}))

    async def _receiver(self, ws) -> None:
        async for raw in ws:
            try:
                msg = json.loads(raw)
            except Exception:
                self.log.debug("non-json message: %r", raw[:200])
                continue
            mtype = msg.get("type")
            if mtype == "pong":
                self._last_pong_at = time.time()
            elif mtype == "query_start":
                qid = msg.get("qid")
                prompt = msg.get("prompt") or ""
                opts_dict = msg.get("options") or {}
                session_id = msg.get("session_id") or ""
                if not qid:
                    self.log.warning("query_start without qid, dropping")
                    continue
                if qid in self._active_queries:
                    self.log.warning("query_start for active qid=%s, ignoring duplicate", qid)
                    continue
                task = asyncio.create_task(self._run_query(ws, qid, prompt, opts_dict, session_id))
                self._active_queries[qid] = task
                # cleanup при завершении
                task.add_done_callback(lambda t, qid=qid: self._active_queries.pop(qid, None))
            elif mtype == "query_cancel":
                qid = msg.get("qid")
                task = self._active_queries.get(qid)
                if task and not task.done():
                    self.log.info("cancelling query qid=%s", qid)
                    task.cancel()
            else:
                self.log.debug("unhandled message type=%r", mtype)

    async def _run_query(self, ws, qid: str, prompt: str, opts_dict: dict, session_id: str = "") -> None:
        """Запускает claude_agent_sdk.query() локально и стримит результаты обратно
        серверу через query_message-фреймы. По завершении — query_end, на ошибке —
        query_error. CancelledError пробрасывается тихо (юзер нажал stop)."""
        self.log.info("query_start qid=%s prompt_len=%d session=%s", qid, len(prompt), session_id or "-")
        try:
            try:
                options = _rprot.decode_options(opts_dict)
            except Exception as e:
                await self._safe_send(ws, _rprot.dump(_rprot.frame_query_error(
                    qid, kind="OptionsDecodeError", error=str(e),
                )))
                return

            # cwd-hint от бэкенда (например /home/fastai/app/reports) на NJ
            # не существует. Подменяем на изолированный per-session workdir
            # под WORKDIR_BASE и создаём папку, если её ещё нет.
            session_dir = session_id or qid
            local_cwd = os.path.join(WORKDIR_BASE, session_dir)
            try:
                os.makedirs(local_cwd, exist_ok=True)
                # Подкаталоги, которые ждёт системный промпт бэкенда
                # (./reports и ./tmp). CLI пишет туда относительно cwd.
                os.makedirs(os.path.join(local_cwd, "reports"), exist_ok=True)
                os.makedirs(os.path.join(local_cwd, "tmp"), exist_ok=True)
            except OSError as e:
                await self._safe_send(ws, _rprot.dump(_rprot.frame_query_error(
                    qid, kind="WorkdirCreateError", error=str(e),
                )))
                return
            options.cwd = local_cwd
            self.log.info("query cwd patched qid=%s cwd=%s", qid, local_cwd)

            # SDK запускает Claude CLI subprocess и наследует env родителя.
            # Под локом подменяем ANTHROPIC_BASE_URL + токены на наш прокси,
            # после завершения query — восстанавливаем прежние значения.
            async with self._sdk_env_lock:
                env_restore: dict[str, Optional[str]] = {}
                if self.proxy_base_url and self.proxy_token and session_id:
                    base = self.proxy_base_url.rstrip("/") + "/" + session_id
                    for k, v in (
                        ("ANTHROPIC_BASE_URL", base),
                        ("CLAUDE_CODE_OAUTH_TOKEN", self.proxy_token),
                        ("ANTHROPIC_AUTH_TOKEN", self.proxy_token),
                    ):
                        env_restore[k] = os.environ.get(k)
                        os.environ[k] = v

                try:
                    await self._run_sdk_query(ws, qid, prompt, options)
                finally:
                    for k, prev in env_restore.items():
                        if prev is None:
                            os.environ.pop(k, None)
                        else:
                            os.environ[k] = prev
        except asyncio.CancelledError:
            # уже залогали и отправили error внутри _run_sdk_query
            return

    async def _run_sdk_query(self, ws, qid: str, prompt: str, options: Any) -> None:
        """Внутренний цикл: SDK → query_message-фреймы. Вынесен из _run_query
        чтобы env-patch жил в одном месте и не дублировался."""
        try:
            try:
                async for sdk_msg in _sdk_query(prompt=prompt, options=options):
                    try:
                        frame = _rprot.frame_query_message(qid, sdk_msg)
                    except Exception as e:
                        # один битый message не должен валить весь query
                        self.log.warning("failed to encode msg for qid=%s: %s", qid, e)
                        continue
                    await ws.send(_rprot.dump(frame))
                await self._safe_send(ws, _rprot.dump(_rprot.frame_query_end(qid)))
                self.log.info("query_end qid=%s", qid)
            except asyncio.CancelledError:
                self.log.info("query cancelled qid=%s", qid)
                # сообщаем серверу что query прервана; затем пробрасываем выше
                await self._safe_send(ws, _rprot.dump(_rprot.frame_query_error(
                    qid, kind="CancelledError", error="query cancelled",
                )))
                raise
            except Exception as e:
                self.log.exception("query failed qid=%s", qid)
                await self._safe_send(ws, _rprot.dump(_rprot.frame_query_error(
                    qid, kind=type(e).__name__, error=str(e),
                )))
        except asyncio.CancelledError:
            # уже залогали и отправили error выше — просто завершаемся
            return

    async def _safe_send(self, ws, payload: str) -> None:
        """Отправка с подавлением ошибок (WS уже мог закрыться)."""
        try:
            await ws.send(payload)
        except Exception as e:
            self.log.debug("send failed: %s", e)


def _redact_token(url: str) -> str:
    """Скрывает значение token=... в логах."""
    if "token=" not in url:
        return url
    head, _, tail = url.partition("token=")
    cutoff = tail.find("&")
    suffix = tail[cutoff:] if cutoff >= 0 else ""
    return f"{head}token=***{suffix}"


def _setup_logging() -> logging.Logger:
    level = os.environ.get("FASTAI_LOG_LEVEL", "INFO").upper()
    logging.basicConfig(
        level=level,
        format="%(asctime)s [%(levelname)s] %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )
    return logging.getLogger("fastai-runner")


def main() -> int:
    log = _setup_logging()

    base = os.environ.get("FASTAI_URL")
    token = os.environ.get("FASTAI_RUNNER_TOKEN")
    if not base or not token:
        log.error("FASTAI_URL и FASTAI_RUNNER_TOKEN обязательны")
        return 2

    try:
        url = _build_ws_url(base, token)
    except ValueError as e:
        log.error("bad FASTAI_URL: %s", e)
        return 2

    client = RunnerClient(url, log)

    loop = asyncio.new_event_loop()
    asyncio.set_event_loop(loop)

    def _on_signal(signame: str) -> None:
        log.info("received %s, shutting down", signame)
        client.stop_event.set()

    for sig in (signal.SIGINT, signal.SIGTERM):
        try:
            loop.add_signal_handler(sig, _on_signal, sig.name)
        except NotImplementedError:
            # Windows / restricted env — fallback
            signal.signal(sig, lambda *_: client.stop_event.set())

    try:
        loop.run_until_complete(client.run_forever())
    finally:
        loop.close()
    log.info("bye")
    return 0


if __name__ == "__main__":
    sys.exit(main())
