import uuid

from datetime import timedelta

from django.conf import settings as django_settings
from django.db.models import Sum, Count
from django.utils import timezone
from rest_framework import status, viewsets
from rest_framework.decorators import action
from rest_framework.permissions import AllowAny, IsAuthenticated
from apps.core.permissions import IsGymOwner, IsGymStaff, IsGymOwnerOrManager, get_gym_owner
from rest_framework.response import Response
from rest_framework.views import APIView

import logging

from .models import GymPackage, Invoice, InvoiceItem, Payment, PaymentGatewayConfig, Subscription, SubscriptionPackage
from .mpesa import MpesaClient, decrypt_config, encrypt_config, normalize_phone
from .serializers import (
    CreateSubscriptionSerializer,
    GymPackageSerializer,
    InvoiceSerializer,
    InvoiceItemSerializer,
    MpesaConfigSerializer,
    PaymentSerializer,
    STKPushRequestSerializer,
    SubscriptionPackageSerializer,
    SubscriptionSerializer,
)

logger = logging.getLogger(__name__)


class PaymentConfigView(APIView):
    """Return public payment config (Paystack public key, currency)."""

    permission_classes = [AllowAny]

    def get(self, request):
        return Response({
            "paystack_public_key": django_settings.PAYSTACK_PUBLIC_KEY,
            "currency": django_settings.PAYSTACK_CURRENCY,
        })


class GymPackageViewSet(viewsets.ModelViewSet):
    """CRUD for gym membership packages (Gold, Silver, etc.)."""
    serializer_class = GymPackageSerializer
    permission_classes = [IsGymStaff]

    def get_queryset(self):
        from django.db.models import Q

        qs = GymPackage.objects.all()
        user = self.request.user
        if not user.is_superuser:
            gym_owner = get_gym_owner(user)
            if gym_owner:
                qs = qs.filter(gym=gym_owner)
            else:
                qs = qs.none()
        active = self.request.query_params.get("active")
        if active is not None:
            qs = qs.filter(is_active=active.lower() == "true")
        return qs

    def perform_create(self, serializer):
        gym_owner = get_gym_owner(self.request.user)
        serializer.save(gym=gym_owner)


class SubscriptionPackageViewSet(viewsets.ModelViewSet):
    serializer_class = SubscriptionPackageSerializer
    queryset = SubscriptionPackage.objects.all()

    def get_permissions(self):
        if self.action in ("list", "retrieve"):
            return [AllowAny()]
        return [IsGymStaff()]

    def get_queryset(self):
        qs = SubscriptionPackage.objects.all()
        if self.action == "list" and not (
            self.request.user.is_authenticated and self.request.user.is_staff
        ):
            qs = qs.filter(is_active=True)
        return qs


class CreateSubscriptionView(APIView):
    """Record a subscription after Paystack payment."""

    permission_classes = [IsAuthenticated]

    def post(self, request):
        serializer = CreateSubscriptionSerializer(data=request.data)
        serializer.is_valid(raise_exception=True)
        data = serializer.validated_data

        package = SubscriptionPackage.objects.filter(
            id=data["package_id"]
        ).first()

        # Deactivate any existing active/expired subscriptions
        Subscription.objects.filter(
            user=request.user, status__in=["active", "expired"]
        ).update(status="cancelled")

        # Calculate expires_at based on billing cycle
        paid_at = data["paid_at"]
        cycle = data["billing_cycle"]
        amount = data["amount"]
        cycle_days = {"monthly": 30, "quarterly": 90, "yearly": 365}
        # Free packages never expire
        expires_at = None if amount == 0 else paid_at + timedelta(days=cycle_days.get(cycle, 30))

        sub = Subscription.objects.create(
            user=request.user,
            package=package,
            package_name=package.name if package else "Unknown",
            amount=data["amount"],
            billing_cycle=data["billing_cycle"],
            payment_reference=data["payment_reference"],
            paid_at=paid_at,
            expires_at=expires_at,
        )

        return Response(
            SubscriptionSerializer(sub).data,
            status=status.HTTP_201_CREATED,
        )


class MySubscriptionView(APIView):
    """Return the current user's active subscription + usage counts."""

    permission_classes = [IsAuthenticated]

    def get(self, request):
        from apps.members.models import MemberProfile
        from apps.staff.models import StaffProfile

        owner = get_gym_owner(request.user)
        if not owner:
            owner = request.user

        sub = (
            Subscription.objects.filter(user=owner, status__in=["active", "expired"])
            .select_related("package")
            .order_by("-created_at")
            .first()
        )

        if not sub:
            return Response(
                {"detail": "No subscription found."},
                status=status.HTTP_404_NOT_FOUND,
            )

        pkg = sub.package
        member_limit = pkg.member_limit if pkg else -1
        staff_limit = pkg.staff_accounts if pkg else -1

        current_members = MemberProfile.objects.filter(user__gym=owner).count()
        current_staff = StaffProfile.objects.filter(user__gym=owner, is_active=True).count()

        # Auto-fix: if paid plan has no expires_at, calculate from paid_at + billing cycle
        expires_at = sub.expires_at
        if not expires_at and sub.paid_at and float(sub.amount) > 0:
            cycle_days = {"monthly": 30, "quarterly": 90, "yearly": 365}
            days = cycle_days.get(sub.billing_cycle, 30)
            expires_at = sub.paid_at + timedelta(days=days)
            sub.expires_at = expires_at
            sub.save(update_fields=["expires_at"])

        return Response({
            "id": str(sub.id),
            "package_name": sub.package_name,
            "billing_cycle": sub.billing_cycle,
            "amount": float(sub.amount),
            "currency": sub.currency,
            "paid_at": sub.paid_at.isoformat() if sub.paid_at else None,
            "expires_at": expires_at.isoformat() if expires_at else None,
            "status": sub.status,
            "member_limit": member_limit,
            "staff_accounts": staff_limit,
            "current_member_count": current_members,
            "current_staff_count": current_staff,
            "package_id": str(pkg.id) if pkg else None,
            "features": pkg.features if pkg else [],
        })


class AdminSubscriptionView(APIView):
    """Superuser: get/update/create subscription for a specific gym owner."""

    permission_classes = [IsAuthenticated]

    def _check_superuser(self, request):
        if not request.user.is_superuser:
            return Response({"detail": "Not authorized."}, status=status.HTTP_403_FORBIDDEN)
        return None

    def get(self, request, user_id):
        denied = self._check_superuser(request)
        if denied:
            return denied

        from django.contrib.auth import get_user_model
        from apps.members.models import MemberProfile
        from apps.staff.models import StaffProfile

        User = get_user_model()
        try:
            owner = User.objects.get(pk=user_id)
        except User.DoesNotExist:
            return Response({"detail": "User not found."}, status=status.HTTP_404_NOT_FOUND)

        # Prefer active/expired subs over cancelled ones
        sub = (
            Subscription.objects.filter(user=owner)
            .exclude(status="cancelled")
            .select_related("package")
            .order_by("-created_at")
            .first()
        )
        # Fallback to any subscription if none are active/expired
        if not sub:
            sub = (
                Subscription.objects.filter(user=owner)
                .select_related("package")
                .order_by("-created_at")
                .first()
            )

        if not sub:
            return Response({
                "subscription": None,
                "current_member_count": MemberProfile.objects.filter(user__gym=owner).count(),
                "current_staff_count": StaffProfile.objects.filter(user__gym=owner, is_active=True).count(),
            })

        pkg = sub.package

        # Auto-fix: if paid plan has no expires_at, calculate it from paid_at + billing cycle
        expires_at = sub.expires_at
        if not expires_at and sub.paid_at and float(sub.amount) > 0:
            cycle_days = {"monthly": 30, "quarterly": 90, "yearly": 365}
            days = cycle_days.get(sub.billing_cycle, 30)
            expires_at = sub.paid_at + timedelta(days=days)
            # Persist the fix
            sub.expires_at = expires_at
            sub.save(update_fields=["expires_at"])

        return Response({
            "subscription": {
                "id": str(sub.id),
                "package_name": sub.package_name,
                "package_id": str(pkg.id) if pkg else None,
                "billing_cycle": sub.billing_cycle,
                "amount": float(sub.amount),
                "currency": sub.currency,
                "status": sub.status,
                "paid_at": sub.paid_at.isoformat() if sub.paid_at else None,
                "expires_at": expires_at.isoformat() if expires_at else None,
                "member_limit": pkg.member_limit if pkg else -1,
                "staff_accounts": pkg.staff_accounts if pkg else -1,
                "features": pkg.features if pkg else [],
            },
            "current_member_count": MemberProfile.objects.filter(user__gym=owner).count(),
            "current_staff_count": StaffProfile.objects.filter(user__gym=owner, is_active=True).count(),
        })

    def post(self, request, user_id):
        """Assign/upgrade/change subscription for a gym owner."""
        denied = self._check_superuser(request)
        if denied:
            return denied

        from django.contrib.auth import get_user_model
        User = get_user_model()
        try:
            owner = User.objects.get(pk=user_id)
        except User.DoesNotExist:
            return Response({"detail": "User not found."}, status=status.HTTP_404_NOT_FOUND)

        package_id = request.data.get("package_id")
        billing_cycle = request.data.get("billing_cycle", "monthly")

        if not package_id:
            return Response({"detail": "package_id is required."}, status=status.HTTP_400_BAD_REQUEST)

        package = SubscriptionPackage.objects.filter(id=package_id).first()
        if not package:
            return Response({"detail": "Package not found."}, status=status.HTTP_404_NOT_FOUND)

        # Deactivate existing subscriptions
        Subscription.objects.filter(
            user=owner, status__in=["active", "expired"]
        ).update(status="cancelled")

        # Calculate expiry
        now = timezone.now()
        cycle_days = {"monthly": 30, "quarterly": 90, "yearly": 365}
        is_free = package.price == 0
        expires_at = None if is_free else now + timedelta(days=cycle_days.get(billing_cycle, 30))

        amount = float(package.price)
        if billing_cycle == "yearly":
            amount = round(amount * 12 * 0.8, 2)
        elif billing_cycle == "quarterly":
            amount = round(amount * 3, 2)

        sub = Subscription.objects.create(
            user=owner,
            package=package,
            package_name=package.name,
            amount=amount,
            currency=package.currency,
            billing_cycle=billing_cycle,
            payment_reference="admin_assigned",
            payment_method="admin",
            status="active",
            paid_at=now,
            expires_at=expires_at,
        )

        return Response(SubscriptionSerializer(sub).data, status=status.HTTP_201_CREATED)

    def patch(self, request, user_id):
        """Update subscription status (cancel, reactivate, extend)."""
        denied = self._check_superuser(request)
        if denied:
            return denied

        from django.contrib.auth import get_user_model
        User = get_user_model()
        try:
            owner = User.objects.get(pk=user_id)
        except User.DoesNotExist:
            return Response({"detail": "User not found."}, status=status.HTTP_404_NOT_FOUND)

        sub = (
            Subscription.objects.filter(user=owner)
            .exclude(status="cancelled")
            .order_by("-created_at")
            .first()
        )
        if not sub:
            return Response({"detail": "No subscription found."}, status=status.HTTP_404_NOT_FOUND)

        action_type = request.data.get("action")

        if action_type == "cancel":
            sub.status = "cancelled"
            sub.save(update_fields=["status"])
        elif action_type == "reactivate":
            sub.status = "active"
            # Extend expiry from now
            if sub.expires_at:
                cycle_days = {"monthly": 30, "quarterly": 90, "yearly": 365}
                sub.expires_at = timezone.now() + timedelta(days=cycle_days.get(sub.billing_cycle, 30))
            sub.save(update_fields=["status", "expires_at"])
        elif action_type == "extend":
            days = int(request.data.get("days", 30))
            if sub.expires_at:
                base = max(sub.expires_at, timezone.now())
                sub.expires_at = base + timedelta(days=days)
            sub.status = "active"
            sub.save(update_fields=["status", "expires_at"])
        else:
            return Response({"detail": "Invalid action."}, status=status.HTTP_400_BAD_REQUEST)

        return Response(SubscriptionSerializer(sub).data)


class SubscriptionStatsView(APIView):
    """Platform subscription stats for superadmin dashboard."""

    permission_classes = [IsAuthenticated]

    def get(self, request):
        if not request.user.is_superuser:
            return Response(
                {"detail": "Not authorized."},
                status=status.HTTP_403_FORBIDDEN,
            )

        active_subs = Subscription.objects.filter(status="active")
        total_revenue = active_subs.aggregate(total=Sum("amount"))["total"] or 0
        active_count = active_subs.count()

        # Per-package breakdown
        breakdown = (
            active_subs.values("package_name")
            .annotate(count=Count("id"), revenue=Sum("amount"))
            .order_by("-revenue")
        )

        # Recent subscriptions
        recent = SubscriptionSerializer(
            Subscription.objects.select_related("user").order_by("-created_at")[:10],
            many=True,
        ).data

        return Response({
            "active_count": active_count,
            "total_revenue": float(total_revenue),
            "breakdown": list(breakdown),
            "recent": recent,
        })


def _generate_invoice_number():
    """Generate a unique invoice number like INV-20260214-XXXX."""
    today = timezone.now().strftime("%Y%m%d")
    suffix = uuid.uuid4().hex[:4].upper()
    return f"INV-{today}-{suffix}"


class InvoiceViewSet(viewsets.ModelViewSet):
    serializer_class = InvoiceSerializer
    permission_classes = [IsGymOwnerOrManager]

    def get_queryset(self):
        from django.db.models import Q

        user = self.request.user
        qs = Invoice.objects.select_related("member__user", "location").prefetch_related("items")
        if not user.is_superuser:
            gym_owner = get_gym_owner(user)
            owner = gym_owner if (gym_owner and gym_owner != user) else user
            qs = qs.filter(member__user__gym=owner)

        s = self.request.query_params.get("status")
        if s:
            qs = qs.filter(status=s)
        member = self.request.query_params.get("member")
        if member:
            qs = qs.filter(member_id=member)
        return qs

    def perform_create(self, serializer):
        items_data = self.request.data.get("items", [])
        invoice = serializer.save(invoice_number=_generate_invoice_number())

        for item in items_data:
            quantity = int(item.get("quantity", 1))
            unit_price = float(item.get("unit_price", 0))
            InvoiceItem.objects.create(
                invoice=invoice,
                description=item.get("description", ""),
                quantity=quantity,
                unit_price=unit_price,
                total_price=quantity * unit_price,
                item_type=item.get("item_type", "other"),
            )

        # Recalculate totals
        subtotal = sum(float(i.total_price) for i in invoice.items.all())
        invoice.subtotal = subtotal
        invoice.total_amount = subtotal + float(invoice.tax_amount) - float(invoice.discount_amount)
        invoice.save()

    @action(detail=True, methods=["post"], url_path="mark-paid")
    def mark_paid(self, request, pk=None):
        invoice = self.get_object()
        invoice.status = "paid"
        invoice.paid_at = timezone.now()
        invoice.save()
        return Response(InvoiceSerializer(invoice).data)

    @action(detail=True, methods=["post"], url_path="mark-sent")
    def mark_sent(self, request, pk=None):
        invoice = self.get_object()
        invoice.status = "sent"
        invoice.save()
        return Response(InvoiceSerializer(invoice).data)

    @action(detail=True, methods=["post"])
    def cancel(self, request, pk=None):
        invoice = self.get_object()
        invoice.status = "cancelled"
        invoice.save()
        return Response(InvoiceSerializer(invoice).data)


def _extend_membership_if_applicable(payment):
    """Extend member's membership when a payment has package+duration in gateway_response."""
    import json as _json

    gr = payment.gateway_response
    if not isinstance(gr, dict):
        return
    package_id = gr.get("package_id")
    plan_duration = gr.get("plan_duration")
    if not package_id or not plan_duration:
        return
    if payment.status not in ("completed", "pending"):
        return

    package = GymPackage.objects.filter(id=package_id).first()
    if not package:
        return

    plans = package.plans if isinstance(package.plans, list) else []
    plan = next((p for p in plans if p.get("duration") == plan_duration), None)
    plan_price = float(plan.get("price", 0)) if plan else 0

    duration_days_map = {
        "daily": 1, "1_month": 30, "3_months": 90,
        "6_months": 180, "12_months": 365,
    }
    days = duration_days_map.get(plan_duration, 30)

    member = payment.member
    notes = {}
    try:
        if member.notes:
            notes = _json.loads(member.notes)
    except (_json.JSONDecodeError, TypeError):
        pass

    # Determine start date: extend from current expiry if still active
    today = timezone.now().date()
    start_date = today
    current_start = notes.get("lastPaymentStart")
    current_days = notes.get("lastPaymentDays")
    if current_start and current_days:
        from datetime import date as date_type
        try:
            current_expiry = date_type.fromisoformat(current_start) + timedelta(days=int(current_days))
            if current_expiry > today:
                start_date = current_expiry
        except (ValueError, TypeError):
            pass

    notes.update({
        "packageId": str(package.id),
        "packageName": package.name,
        "planId": str(package.id),
        "planName": package.name,
        "planDuration": plan_duration,
        "planPrice": plan_price,
        "paymentMethod": payment.payment_method,
        "lastPaymentStart": str(start_date),
        "lastPaymentDays": days,
        "lastPaymentId": str(payment.id),
        "lastPaymentDate": str(today),
    })
    member.notes = _json.dumps(notes)
    member.status = "active"
    member.save(update_fields=["notes", "status"])

    # Also update user active flag
    if hasattr(member.user, "is_active_member"):
        member.user.is_active_member = True
        member.user.save(update_fields=["is_active_member"])


def _extend_guest_pass_if_applicable(payment):
    """Extend a guest pass when an M-Pesa payment completes for a guest pass renewal."""
    gr = payment.gateway_response
    if not isinstance(gr, dict):
        return
    guest_pass_id = gr.get("guest_pass_id")
    if not guest_pass_id:
        return

    from apps.members.models import GuestPass
    guest_pass = GuestPass.objects.filter(id=guest_pass_id).first()
    if not guest_pass:
        return

    renewal_days = int(gr.get("renewal_days", 1))
    today = timezone.now().date()
    guest_pass.visit_date = today
    guest_pass.days = renewal_days
    guest_pass.checked_in = False
    guest_pass.checked_in_at = None
    guest_pass.payment_status = "completed"
    guest_pass.payment_method = "mpesa"
    guest_pass.amount_paid = payment.amount
    guest_pass.save()


class PaymentViewSet(viewsets.ModelViewSet):
    serializer_class = PaymentSerializer
    permission_classes = [IsGymOwnerOrManager]

    def get_queryset(self):
        from django.db.models import Q

        user = self.request.user
        qs = Payment.objects.select_related("member__user", "invoice", "location")
        if not user.is_superuser:
            gym_owner = get_gym_owner(user)
            owner = gym_owner if (gym_owner and gym_owner != user) else user
            qs = qs.filter(member__user__gym=owner)

        s = self.request.query_params.get("status")
        if s:
            qs = qs.filter(status=s)
        method = self.request.query_params.get("payment_method")
        if method:
            qs = qs.filter(payment_method=method)
        member = self.request.query_params.get("member")
        if member:
            qs = qs.filter(member_id=member)
        return qs

    def perform_create(self, serializer):
        payment = serializer.save()
        # If payment is completed and linked to an invoice, mark invoice as paid
        if payment.status == "completed" and payment.invoice:
            invoice = payment.invoice
            if invoice.status != "paid":
                invoice.status = "paid"
                invoice.paid_at = timezone.now()
                invoice.save()

        # If a package was selected, extend the member's membership
        _extend_membership_if_applicable(payment)

    @action(detail=True, methods=["post"], url_path="mark-completed")
    def mark_completed(self, request, pk=None):
        payment = self.get_object()
        payment.status = "completed"
        payment.processed_at = timezone.now()
        payment.save()

        # Also mark linked invoice as paid
        if payment.invoice and payment.invoice.status != "paid":
            payment.invoice.status = "paid"
            payment.invoice.paid_at = timezone.now()
            payment.invoice.save()

        # Extend membership if package was selected
        _extend_membership_if_applicable(payment)

        return Response(PaymentSerializer(payment).data)


# ---------------------------------------------------------------------------
# M-Pesa Daraja Integration Views
# ---------------------------------------------------------------------------

def _get_mpesa_config(user):
    """Get the M-Pesa PaymentGatewayConfig for the gym owner."""
    gym_owner = get_gym_owner(user)
    return PaymentGatewayConfig.objects.filter(gateway="mpesa", gym=gym_owner).first()


def _build_mpesa_client(config_obj) -> MpesaClient:
    """Build an MpesaClient from a PaymentGatewayConfig instance."""
    c = decrypt_config(config_obj.config)
    return MpesaClient(
        consumer_key=c.get("consumer_key", ""),
        consumer_secret=c.get("consumer_secret", ""),
        shortcode=c.get("shortcode", ""),
        passkey=c.get("passkey", ""),
        callback_url=c.get("callback_url", ""),
        environment=c.get("environment", "sandbox"),
    )


class MpesaConfigView(APIView):
    """GET / PUT the gym's M-Pesa Daraja credentials."""

    permission_classes = [IsGymOwner]

    def get(self, request):
        config_obj = _get_mpesa_config(request.user)
        if not config_obj:
            return Response({
                "configured": False,
                "consumer_key": "",
                "consumer_secret": "",
                "shortcode": "",
                "passkey": "",
            })

        c = decrypt_config(config_obj.config)
        secret = c.get("consumer_secret", "")
        passkey = c.get("passkey", "")

        return Response({
            "configured": True,
            "consumer_key": c.get("consumer_key", ""),
            "consumer_secret": f"****{secret[-4:]}" if len(secret) >= 4 else "****",
            "shortcode": c.get("shortcode", ""),
            "passkey": f"****{passkey[-4:]}" if len(passkey) >= 4 else "****",
        })

    def put(self, request):
        serializer = MpesaConfigSerializer(data=request.data)
        serializer.is_valid(raise_exception=True)
        data = serializer.validated_data

        # Always use production environment and fixed callback URL
        data["environment"] = "production"
        data["callback_url"] = "https://api.isoftke.co.ke/api/billing/mpesa/callback/"

        gym_owner = get_gym_owner(request.user)
        encrypted_data = encrypt_config(data)
        config_obj, created = PaymentGatewayConfig.objects.get_or_create(
            gateway="mpesa",
            gym=gym_owner,
            defaults={"is_active": True, "config": encrypted_data},
        )
        if not created:
            config_obj.config = encrypted_data
            config_obj.is_active = True
            config_obj.save()

        return Response({"detail": "M-Pesa configuration saved."})


class MpesaTestConnectionView(APIView):
    """Test M-Pesa credentials by fetching an OAuth token."""

    permission_classes = [IsGymOwner]

    def post(self, request):
        serializer = MpesaConfigSerializer(data=request.data)
        serializer.is_valid(raise_exception=True)
        data = serializer.validated_data

        client = MpesaClient(
            consumer_key=data["consumer_key"],
            consumer_secret=data["consumer_secret"],
            shortcode=data["shortcode"],
            passkey=data["passkey"],
            callback_url="",
            environment="production",
        )

        try:
            result = client.test_credentials()
            return Response({"detail": "Connection successful.", **result})
        except Exception as e:
            logger.warning("M-Pesa test connection failed: %s", e)
            return Response(
                {"detail": f"Connection failed: {str(e)}"},
                status=status.HTTP_400_BAD_REQUEST,
            )


class MpesaSTKPushView(APIView):
    """Initiate an M-Pesa STK Push and create a pending Payment."""

    permission_classes = [IsGymOwnerOrManager]

    def post(self, request):
        from apps.members.models import MemberProfile

        serializer = STKPushRequestSerializer(data=request.data)
        serializer.is_valid(raise_exception=True)
        data = serializer.validated_data

        config_obj = _get_mpesa_config(request.user)
        if not config_obj:
            return Response(
                {"detail": "M-Pesa is not configured. Go to Settings to set up your Daraja credentials."},
                status=status.HTTP_400_BAD_REQUEST,
            )

        try:
            member = MemberProfile.objects.get(id=data["member_id"])
        except MemberProfile.DoesNotExist:
            return Response({"detail": "Member not found."}, status=status.HTTP_404_NOT_FOUND)

        phone = normalize_phone(data["phone_number"])
        amount = data["amount"]
        invoice = None
        if data.get("invoice_id"):
            invoice = Invoice.objects.filter(id=data["invoice_id"]).first()

        client = _build_mpesa_client(config_obj)

        try:
            result = client.stk_push(
                phone=phone,
                amount=amount,
                account_ref=f"GYM-{str(member.id)[:8].upper()}",
                description="Gym membership payment",
            )
        except Exception as e:
            logger.error("M-Pesa STK Push failed: %s", e)
            return Response(
                {"detail": f"STK Push failed: {str(e)}"},
                status=status.HTTP_502_BAD_GATEWAY,
            )

        checkout_request_id = result.get("CheckoutRequestID", "")

        # Store package info alongside the Daraja response
        gateway_response = dict(result)
        pkg_id = str(data.get("package_id", "")) if data.get("package_id") else ""
        plan_dur = data.get("plan_duration", "")
        if pkg_id:
            pkg = GymPackage.objects.filter(id=pkg_id).first()
            gateway_response["package_id"] = pkg_id
            gateway_response["plan_duration"] = plan_dur
            gateway_response["notes"] = f"{pkg.name} - {plan_dur}" if pkg else ""

        payment = Payment.objects.create(
            member=member,
            invoice=invoice,
            amount=amount,
            currency="KES",
            payment_method="mpesa",
            status="pending",
            gateway_transaction_id=checkout_request_id,
            gateway_response=gateway_response,
        )

        return Response({
            "detail": "STK Push sent. Waiting for member to confirm.",
            "payment_id": str(payment.id),
            "checkout_request_id": checkout_request_id,
        }, status=status.HTTP_201_CREATED)


class MpesaCallbackView(APIView):
    """Safaricom callback webhook — no authentication required."""

    permission_classes = [AllowAny]

    def post(self, request):
        body = request.data
        logger.info("M-Pesa callback received: %s", body)

        stk_callback = body.get("Body", {}).get("stkCallback", {})
        checkout_request_id = stk_callback.get("CheckoutRequestID", "")
        result_code = stk_callback.get("ResultCode")

        if not checkout_request_id:
            return Response({"ResultCode": 0, "ResultDesc": "Accepted"})

        payment = Payment.objects.filter(
            gateway_transaction_id=checkout_request_id,
            payment_method="mpesa",
        ).first()

        if not payment:
            logger.warning("No payment found for CheckoutRequestID: %s", checkout_request_id)
            return Response({"ResultCode": 0, "ResultDesc": "Accepted"})

        # Preserve original package info from STK push creation
        original_gr = payment.gateway_response if isinstance(payment.gateway_response, dict) else {}
        package_id = original_gr.get("package_id", "")
        plan_duration = original_gr.get("plan_duration", "")
        notes_text = original_gr.get("notes", "")

        payment.gateway_response = body
        # Re-attach package info so membership extension works
        if isinstance(payment.gateway_response, dict) and package_id:
            payment.gateway_response["package_id"] = package_id
            payment.gateway_response["plan_duration"] = plan_duration
            payment.gateway_response["notes"] = notes_text

        if result_code == 0:
            # Success — extract receipt number
            metadata_items = stk_callback.get("CallbackMetadata", {}).get("Item", [])
            receipt = ""
            for item in metadata_items:
                if item.get("Name") == "MpesaReceiptNumber":
                    receipt = item.get("Value", "")
                    break

            payment.status = "completed"
            payment.processed_at = timezone.now()
            if receipt:
                payment.gateway_transaction_id = receipt
            payment.save()

            # Mark linked invoice as paid
            if payment.invoice and payment.invoice.status != "paid":
                payment.invoice.status = "paid"
                payment.invoice.paid_at = timezone.now()
                payment.invoice.save()

            # Extend membership if package was selected
            _extend_membership_if_applicable(payment)

            # Extend guest pass if applicable
            _extend_guest_pass_if_applicable(payment)
        else:
            payment.status = "failed"
            payment.save()

        return Response({"ResultCode": 0, "ResultDesc": "Accepted"})


class MpesaPaymentStatusView(APIView):
    """Poll payment status by checkout request ID."""

    permission_classes = [IsGymStaff]

    def get(self, request, checkout_id):
        payment = Payment.objects.filter(
            gateway_transaction_id=checkout_id,
            payment_method="mpesa",
        ).first()

        if not payment:
            # Also check gateway_response for original checkout ID (receipt may have replaced it)
            payment = Payment.objects.filter(
                payment_method="mpesa",
                gateway_response__Body__stkCallback__CheckoutRequestID=checkout_id,
            ).first()

        if not payment:
            return Response({"detail": "Payment not found."}, status=status.HTTP_404_NOT_FOUND)

        return Response({
            "payment_id": str(payment.id),
            "status": payment.status,
            "receipt_number": payment.gateway_transaction_id if payment.status == "completed" else None,
            "amount": str(payment.amount),
            "member": str(payment.member_id),
        })
