"""Pulls customers and subscriptions from Stripe."""

import os
import sys
from typing import Any, Optional

import stripe

PAID_STATUSES = {"active", "past_due", "unpaid", "canceled", "paused"}


def fetch_customers(config: dict) -> list[dict[str, Any]]:
    """Return one record per Stripe customer with first-paid timestamp, current
    subscription status, and (optionally) details of any subscription on the
    configured "deal" price.

    Customers without an email, or whose email domain is in `filters.internal_domains`,
    are skipped — they can't be joined to other sources or shouldn't appear in the funnel.
    "first_paid_at" is derived from the earliest subscription that ever reached a paid
    status (if the subscription had a trial, the trial end is used). Scanning subscriptions
    is dramatically faster than scanning invoices for accounts with long billing history.
    """
    stripe.api_key = os.environ["STRIPE_API_KEY"]

    stripe_config = config.get("stripe", {}) or {}
    deal_price_id = stripe_config.get("deal_price_id")
    internal_domains = {
        d.lower() for d in config.get("filters", {}).get("internal_domains", [])
    }

    print("Fetching subscriptions...", file=sys.stderr, flush=True)
    first_paid, latest_status, deal_subs = _scan_subscriptions(deal_price_id)
    print(
        f"  {len(first_paid)} customers have at least one paid subscription. "
        f"{len(deal_subs)} have a deal subscription.",
        file=sys.stderr,
        flush=True,
    )

    print("Fetching customers...", file=sys.stderr, flush=True)
    records: list[dict[str, Any]] = []
    skipped_no_email = 0
    skipped_internal = 0
    for customer in stripe.Customer.list(limit=100).auto_paging_iter():
        if not customer.email:
            skipped_no_email += 1
            continue
        if _is_internal(customer.email, internal_domains):
            skipped_internal += 1
            continue
        deal = deal_subs.get(customer.id)
        records.append(
            {
                "id": customer.id,
                "email": customer.email,
                "created": customer.created,
                "first_paid_at": first_paid.get(customer.id),
                "subscription_status": latest_status.get(customer.id),
                "deal_started_at": deal["started_at"] if deal else None,
                "deal_status": deal["status"] if deal else None,
                "deal_trial_end": deal["trial_end"] if deal else None,
            }
        )
    print(
        f"  {len(records)} customers fetched, "
        f"{skipped_no_email} skipped (no email), "
        f"{skipped_internal} skipped (internal domain).",
        file=sys.stderr,
        flush=True,
    )
    return records


def _is_internal(email: str, internal_domains: set[str]) -> bool:
    domain = email.rsplit("@", 1)[-1].strip().lower()
    return domain in internal_domains


def _scan_subscriptions(
    deal_price_id: Optional[str],
) -> tuple[dict[str, int], dict[str, str], dict[str, dict[str, Any]]]:
    """Single pass over all subscriptions. Returns:
    - first_paid: customer_id -> earliest subscription start that reached a paid status
    - latest_status: customer_id -> status of their most recently created subscription
    - deal_subs: customer_id -> {started_at, status, trial_end} of their most recent
      subscription on the deal price (only populated when `deal_price_id` is set).
    """
    first_paid: dict[str, int] = {}
    latest_created: dict[str, int] = {}
    latest_status: dict[str, str] = {}
    deal_subs: dict[str, dict[str, Any]] = {}

    for sub in stripe.Subscription.list(status="all", limit=100).auto_paging_iter():
        cid = sub.customer
        if cid is None:
            continue

        if cid not in latest_created or sub.created > latest_created[cid]:
            latest_created[cid] = sub.created
            latest_status[cid] = sub.status

        if sub.status in PAID_STATUSES:
            paid_started = sub.start_date or sub.created
            if sub.trial_end and sub.trial_end > paid_started:
                paid_started = sub.trial_end
            prior = first_paid.get(cid)
            if prior is None or paid_started < prior:
                first_paid[cid] = paid_started

        if deal_price_id and _has_price(sub, deal_price_id):
            started = sub.start_date or sub.created
            existing = deal_subs.get(cid)
            if existing is None or started > existing["started_at"]:
                deal_subs[cid] = {
                    "started_at": started,
                    "status": sub.status,
                    "trial_end": sub.trial_end,
                }

    return first_paid, latest_status, deal_subs


def _has_price(sub, price_id: str) -> bool:
    items = getattr(sub, "items", None)
    if items is None:
        return False
    data = getattr(items, "data", None) or []
    for item in data:
        price = getattr(item, "price", None)
        if price is not None and getattr(price, "id", None) == price_id:
            return True
    return False


if __name__ == "__main__":
    import json

    import yaml
    from dotenv import load_dotenv

    load_dotenv()
    with open("config.yaml") as f:
        config = yaml.safe_load(f)
    customers = fetch_customers(config)
    print(f"\nFetched {len(customers)} customers.")
    deal_count = sum(1 for c in customers if c["deal_started_at"])
    print(f"Of those, {deal_count} have a deal subscription.")
    if customers:
        print("Sample record:")
        print(json.dumps(customers[0], indent=2, default=str))
