"""
Group-level statistics service.

Fetches anonymous, aggregate statistics for all events of a given group and
caches the results server-side:
  - Events that are NOT currently published → cached indefinitely (they won't
    change without a new publication cycle).
  - Events that ARE currently published → cached for PUBLISHED_CACHE_TTL seconds
    (4 hours) so live data stays reasonably fresh.

Only "safe" / anonymous data (counts, distributions) is stored; no PII.
"""

import logging
import math
import time
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING

from src.services.api_service import fetch_event_tn_data
from src.services.campflow import get_campflow_client

if TYPE_CHECKING:
    from src.group_settings.models import GroupConfig

logger = logging.getLogger(__name__)

# ─── Cache configuration ─────────────────────────────────────────────────────
PUBLISHED_CACHE_TTL = 4 * 3600  # 4 hours in seconds
# Non-published events are stored forever (until the process restarts) because
# their data cannot change from the API while they're not live.

# Cache structure:
#   group_email → { list_id → EventStatsCacheEntry }
_group_stats_cache: Dict[str, Dict[str, Dict[str, Any]]] = {}


def _get_cached_entry(group_email: str, list_id: str) -> Optional[Dict[str, Any]]:
    return _group_stats_cache.get(group_email, {}).get(list_id)


def _store_cached_entry(group_email: str, list_id: str, entry: Dict[str, Any]) -> None:
    if group_email not in _group_stats_cache:
        _group_stats_cache[group_email] = {}
    _group_stats_cache[group_email][list_id] = entry


def _is_entry_valid(entry: Dict[str, Any], is_published: bool) -> bool:
    """Return True if the cached entry is still fresh enough to use."""
    if entry is None:
        return False
    if not is_published:
        # Non-published events: cache never expires
        return True
    # Published events: respect TTL
    age = time.time() - entry.get("cached_at", 0)
    return age < PUBLISHED_CACHE_TTL


# ─── Anonymous stats extraction ───────────────────────────────────────────────

def _extract_anonymous_event_stats(
    tn_list: List[Dict[str, Any]],
    event_info: Dict[str, Any],
) -> Dict[str, Any]:
    """
    Extract only aggregate / anonymous statistics from a participant list.
    No names, no emails, no addresses.
    """
    total_tn = len(tn_list)

    # ── Ticket / ÖPNV stats ──────────────────────────────────────────────────
    # Look for a column that contains "ticket" or "oepnv" in its (already cleaned) name
    ticket_col = None
    if tn_list:
        ticket_col = next(
            (k for k in tn_list[0].keys() if "ticket" in k.lower() or "oepnv" in k.lower()),
            None,
        )

    ticket_counts: Dict[str, int] = {}
    # Track both ticket and no-ticket counts per age group
    age_groups = {
        "<=14": {"ticket": 0, "no_ticket": 0},
        "15-25": {"ticket": 0, "no_ticket": 0},
        ">=26": {"ticket": 0, "no_ticket": 0}
    }
    
    has_ticket_data = ticket_col is not None
    ticket_count_total = 0

    if has_ticket_data:
        for tn in tn_list:
            age = tn.get("age_at_event")
            # Determine bracket
            bracket = None
            if age is not None:
                if age <= 14: bracket = "<=14"
                elif 15 <= age <= 25: bracket = "15-25"
                else: bracket = ">=26"

            ticket = tn.get(ticket_col)
            if isinstance(ticket, list) and ticket:
                ticket = ticket[0]
            
            if ticket:
                ticket_str = str(ticket)
                ticket_counts[ticket_str] = ticket_counts.get(ticket_str, 0) + 1
                
                # Check if it's actually a ticket
                if ticket_str.lower() not in ["kein ticket", "kein", "nein", "no", "none", "false"]:
                    ticket_count_total += 1
                    if bracket:
                        age_groups[bracket]["ticket"] += 1
                else:
                    if bracket:
                        age_groups[bracket]["no_ticket"] += 1
            else:
                # No value in ticket column = No ticket
                if bracket:
                    age_groups[bracket]["no_ticket"] += 1
        
        # Only keep events that actually have some ticket data
        if not any(v > 0 for v in ticket_counts.values()):
            has_ticket_data = False

    ticket_percentage = (ticket_count_total / total_tn * 100) if total_tn > 0 else 0

    return {
        "list_id": event_info.get("list_id"),
        "title": event_info.get("title", "Unbekannt"),
        "start_date": event_info.get("start_date"),
        "end_date": event_info.get("end_date"),
        "published": event_info.get("published", False),
        "total_participants": total_tn,
        "has_ticket_data": has_ticket_data,
        "ticket_counts": ticket_counts,
        "ticket_count_total": ticket_count_total,
        "ticket_percentage": round(ticket_percentage, 1),
        "age_groups": age_groups,
    }


# ─── Public API ──────────────────────────────────────────────────────────────

async def get_group_event_stats(
    group_config: "GroupConfig",
) -> List[Dict[str, Any]]:
    """
    Return a list of anonymous aggregate stats dicts, one per event that has
    participant data. Events with no participant data are excluded.

    Results are cached per-event according to the publication status.
    """
    group_email = group_config.group_email
    client = get_campflow_client(group_config)

    # Fetch the event list (already cached by CampflowClient for 5 min)
    events = await client.get_events()

    result: List[Dict[str, Any]] = []

    for event in events:
        list_id = event.get("list_id")
        if not list_id:
            continue

        is_published = bool(event.get("published", False))

        # ── Check cache ──────────────────────────────────────────────────────
        cached = _get_cached_entry(group_email, list_id)
        if cached and _is_entry_valid(cached, is_published):
            logger.debug("Cache hit for group=%s event=%s", group_email, list_id)
            result.append(cached["stats"])
            continue

        # ── Fetch fresh data ─────────────────────────────────────────────────
        logger.info("Fetching stats for group=%s event=%s", group_email, list_id)
        try:
            tn_list, _gf, event_info = await fetch_event_tn_data(list_id, group_config=group_config)
        except Exception as exc:
            logger.error("Failed to fetch data for event %s: %s", list_id, exc)
            continue

        if not event_info:
            continue

        # Merge published flag from the event list (more reliable than event_info)
        event_info["published"] = is_published

        if not tn_list:
            # No participants yet — skip this event entirely
            continue

        stats = _extract_anonymous_event_stats(tn_list, event_info)

        # ── Store in cache ───────────────────────────────────────────────────
        _store_cached_entry(group_email, list_id, {
            "stats": stats,
            "cached_at": time.time(),
        })

        result.append(stats)

    # Sort by start_date, oldest first
    result.sort(key=lambda e: e.get("start_date") or "")
    return result


# ─── Helper: compute mean + std from a list of numbers ───────────────────────

def _mean_std(values: List[float]) -> Tuple[Optional[float], Optional[float]]:
    n = len(values)
    if n == 0:
        return None, None
    mean = sum(values) / n
    if n == 1:
        return round(mean, 2), 0.0
    variance = sum((x - mean) ** 2 for x in values) / (n - 1)
    return round(mean, 2), round(math.sqrt(variance), 2)


def compute_cross_event_summary(events_stats: List[Dict[str, Any]]) -> Dict[str, Any]:
    """
    Compute cross-event summary statistics (mean ± std) from the per-event data.
    """
    tn_values = [e["total_participants"] for e in events_stats]
    tn_mean, tn_std = _mean_std(tn_values)

    # Aggregate ticket totals per event (only those with ticket data)
    ticket_totals = [
        sum(e["ticket_counts"].values())
        for e in events_stats
        if e.get("has_ticket_data")
    ]
    ticket_mean, ticket_std = _mean_std(ticket_totals)

    # Ticket percentage mean
    percentage_values = [
        e["ticket_percentage"]
        for e in events_stats
        if e.get("has_ticket_data")
    ]
    percentage_mean, percentage_std = _mean_std(percentage_values)

    # Age group totals and means (both ticket and no_ticket)
    age_keys = ["<=14", "15-25", ">=26"]
    age_group_stats = {}

    for k in age_keys:
        ticket_vals = [
            e["age_groups"][k]["ticket"]
            for e in events_stats
            if e.get("has_ticket_data")
        ]
        no_ticket_vals = [
            e["age_groups"][k]["no_ticket"]
            for e in events_stats
            if e.get("has_ticket_data")
        ]
        
        t_mean, t_std = _mean_std(ticket_vals)
        nt_mean, nt_std = _mean_std(no_ticket_vals)
        
        age_group_stats[k] = {
            "ticket": {"mean": t_mean, "std": t_std},
            "no_ticket": {"mean": nt_mean, "std": nt_std}
        }

    return {
        "tn_mean": tn_mean,
        "tn_std": tn_std,
        "ticket_mean": ticket_mean,
        "ticket_std": ticket_std,
        "percentage_mean": percentage_mean,
        "percentage_std": percentage_std,
        "total_events": len(events_stats),
        "age_group_stats": age_group_stats,
    }
