import time, json, statistics, threading
from collections import deque, defaultdict
from typing import Optional, Dict, Tuple
from flask import Blueprint, Response, request, g, has_request_context
from flask_login import login_required
from sqlalchemy import event
from sqlalchemy.engine import Engine

import logging
logger = logging.getLogger("monitor")

class Monitoring:
    def __init__(
        self,
        app,
        *,
        window_size: int = 10_000,
        slow_query_threshold: float = 0.5,   # seconds
        slow_route_threshold: float = 2.0,   # seconds
        register_endpoint: bool = True,
        endpoint_url: str = "/metrics",
        filter_static: bool = True,
        sample_rate: float = 1.0,
        require_login: bool = True,
    ):
        self.app = app
        self.window_size = window_size
        self.slow_query_threshold = slow_query_threshold
        self.slow_route_threshold = slow_route_threshold
        self.filter_static = filter_static
        self.sample_rate = max(0.0, min(1.0, float(sample_rate)))
        self.endpoint_url = endpoint_url
        self.require_login = require_login

        # state
        self._lock = threading.Lock()
        self.sql_times = deque(maxlen=window_size)
        self.route_times = deque(maxlen=window_size)
        self.user_route_times: Dict[str, deque] = defaultdict(lambda: deque(maxlen=window_size))  # user -> [times]
        self.user_route_detail: Dict[Tuple[str, str], deque] = defaultdict(lambda: deque(maxlen=window_size))  # (user, route) -> [times]

        self._sql_attached = False
        self._register_flask_hooks()

        if register_endpoint:
            bp = Blueprint("metrics_monitor", __name__)

            def metrics_json():
                return self._metrics_endpoint()

            view = metrics_json
            if require_login:
                view = login_required(view)
            bp.add_url_rule(self.endpoint_url, view_func=view, methods=["GET"])
            app.register_blueprint(bp)

    # --------- helpers ----------
    @staticmethod
    def _now() -> float:
        return time.perf_counter()

    @staticmethod
    def _percentile(data, p):
        n = len(data)
        if n == 0:
            return None
        if n == 1:
            return data[0]
        return statistics.quantiles(sorted(data), n=100, method="inclusive")[p-1]

    @staticmethod
    def _route_key() -> str:
        rule = getattr(request, "url_rule", None)
        pattern = rule.rule if rule else request.path
        return f"{request.method} {pattern}"

    @staticmethod
    def _is_static_path(path: str) -> bool:
        path = (path or "")
        return (
            path == "/favicon.ico"
            or "/static/" in path
            or path.endswith(".css")
            or path.endswith(".js")
            or path.endswith(".png")
            or path.endswith(".jpg")
            or path.endswith(".jpeg")
            or path.endswith(".gif")
            or path.endswith(".svg")
            or path.endswith(".woff")
            or path.endswith(".woff2")
            or path.endswith(".map")
        )

    @staticmethod
    def _user_label() -> str:
        """
        Беремо підготовлені значення з g (їх встановлює ваш before_request),
        без доступу до current_user, щоб НЕ тригерити user_loader.
        """
        try:
            ctx = getattr(g, "_k2_ctx", None)
            if ctx:
                uid = getattr(ctx, 'user_id', None)
                login = getattr(g, "login", None)
            if uid:
                return str(uid)
            return "anonymous"
        except Exception:
            return "anonymous"

    def _calc_stats(self, dq: deque) -> dict:
        data = list(dq)
        if not data:
            return {}
        return {
            "count": len(data),
            "mean": statistics.fmean(data),
            "p50": statistics.median(data),
            "p95": self._percentile(data, 95),
            "p99": self._percentile(data, 99),
        }

    # --------- flask hooks ----------
    def _register_flask_hooks(self):
        app = self.app

        @app.before_request
        def _metrics_before_request():
            # окремий таймер; не чіпає вашу g.start_time
            setattr(g, "_metrics_route_start", self._now())
            # ВАЖЛИВО: тут НЕ звертаємось до current_user.
            # user id/login готує ваш окремий before_request (де це безпечно).

        @app.after_request
        def _metrics_after_request(response):
            setattr(g, "_metrics_status_code", getattr(response, "status_code", 200))
            return response

        @app.teardown_request
        def _metrics_teardown_request(exc):
            start = getattr(g, "_metrics_route_start", None)
            if start is None:
                return

            # фільтр статики
            if self.filter_static and self._is_static_path(getattr(request, "path", "")):
                return

            # не міряємо сам ендпоінт метрик
            try:
                if request.path == self.endpoint_url:
                    return
            except Exception:
                pass

            duration = self._now() - start
            user = self._user_label()
            route = self._route_key()
            status = getattr(g, "_metrics_status_code", 500 if exc else 200)
            req_id = request.headers.get("X-Request-ID", "-")

            # семплінг
            if self.sample_rate < 1.0:
                import random
                if random.random() >= self.sample_rate:
                    return

            with self._lock:
                self.route_times.append(duration)
                self.user_route_times[user].append(duration)
                self.user_route_detail[(user, route)].append(duration)

            if duration > self.slow_route_threshold:
                logger.warning(
                    "[SLOW ROUTE] req_id=%s user=%s | %s status=%s took %.3fs",
                    req_id, user, route, status, duration
                )

    # --------- sqlalchemy listeners ----------
    def attach_sqlalchemy(self, engine: Optional[Engine] = None):
        """
        Підвісити listeners до конкретного Engine.
        Викликайте ПІСЛЯ ініціалізації SQLAlchemy (коли є engine).
        """
        if self._sql_attached:
            return

        target = (engine or Engine)

        @event.listens_for(target, "before_cursor_execute")
        def _metrics_before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
            context._metrics_query_start = self._now()

        @event.listens_for(target, "after_cursor_execute")
        def _metrics_after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
            start = getattr(context, "_metrics_query_start", None)
            total = self._now() - start if start is not None else 0.0

            # Якщо це статичний запит — навіть не логувати SQL (щоб не торкатись user)
            if has_request_context():
                if self.filter_static and self._is_static_path(getattr(request, "path", "")):
                    return

            with self._lock:
                self.sql_times.append(total)

            route = user = req_id = "-"
            if has_request_context():
                route = f"{request.method} {request.path}"
                user = self._user_label()  # без current_user
                req_id = request.headers.get("X-Request-ID", "-")

            if total > self.slow_query_threshold:
                stmt = (statement or "").replace("\n", " ")
                MAX_LEN = 750
                if len(stmt) > MAX_LEN:
                    stmt = stmt[:MAX_LEN] + " …(truncated)"
                logger.warning(
                    "[SLOW SQL] %.3fs | user=%s | route=%s | req_id=%s | %s | params=%s",
                    total, user, route, req_id, stmt, "<masked>"
                )

        self._sql_attached = True

    # --------- /metrics endpoint ----------
    def _metrics_endpoint(self):
        # query params
        sort_by = request.args.get("sort", "p95")       # p95|p99|mean|count
        limit = int(request.args.get("limit", "20"))
        pretty = request.args.get("pretty", "0") == "1"
        units = request.args.get("units", "ms")         # ms|s

        with self._lock:
            sql_stats = self._calc_stats(self.sql_times)
            route_stats = self._calc_stats(self.route_times)

            # users summary
            users_summary = []
            for u, times in self.user_route_times.items():
                s = self._calc_stats(times)
                users_summary.append({
                    "user": u,
                    "count": s.get("count"),
                    "mean": s.get("mean"),
                    "p50": s.get("p50"),
                    "p95": s.get("p95"),
                    "p99": s.get("p99"),
                })

            # details per (user, route)
            details = []
            for (u, route), times in self.user_route_detail.items():
                s = self._calc_stats(times)
                method, path = route.split(" ", 1) if " " in route else ("", route)
                details.append({
                    "user": u,
                    "method": method,
                    "route": path,
                    "count": s.get("count"),
                    "mean": s.get("mean"),
                    "p50": s.get("p50"),
                    "p95": s.get("p95"),
                    "p99": s.get("p99"),
                })

        # sort & limit
        key_map = {
            "p95":  lambda r: (r["p95"] or 0),
            "p99":  lambda r: (r["p99"] or 0),
            "mean": lambda r: (r["mean"] or 0),
            "count":lambda r: (r["count"] or 0),
        }
        details.sort(key=key_map.get(sort_by, key_map["p95"]), reverse=True)
        details = details[:limit]

        # unit conversion
        def _to_ms(x):
            return None if x is None else float(x) * 1000.0
        def _round_s(x):
            return None if x is None else round(float(x), 3)

        def _convert_stats(stats: dict) -> dict:
            if not stats:
                return {}
            res = dict(stats)
            if units == "ms":
                for k in ("mean","p50","p95","p99"):
                    res[k] = _to_ms(res.get(k))
            else:
                for k in ("mean","p50","p95","p99"):
                    res[k] = _round_s(res.get(k))
            return res

        def _convert_row(row: dict) -> dict:
            out = dict(row)
            if units == "ms":
                for k in ("mean","p50","p95","p99"):
                    out[k] = _to_ms(out.get(k))
            else:
                for k in ("mean","p50","p95","p99"):
                    out[k] = _round_s(out.get(k))
            return out

        payload = {
            "meta": {
                "generated_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
                "window_size": self.window_size,
                "units": units,
                "thresholds": {
                    "slow_query_seconds": self.slow_query_threshold,
                    "slow_route_seconds": self.slow_route_threshold,
                },
                "sort": sort_by,
                "limit": limit,
                "endpoint_url": self.endpoint_url,
                "require_login": self.require_login,
            },
            "overview": {
                "sql_queries": _convert_stats(sql_stats),
                "http_routes": _convert_stats(route_stats),
                "users": [_convert_row(u) for u in users_summary],
            },
            "details": [_convert_row(r) for r in details],
        }

        if pretty:
            body = json.dumps(payload, ensure_ascii=False, indent=2)
            return Response(body, mimetype="application/json")
        else:
            body = json.dumps(payload, ensure_ascii=False, separators=(",", ":"))
            return Response(body, mimetype="application/json")
