import base64
import hashlib
import logging
from datetime import datetime

import requests
from cryptography.fernet import Fernet
from django.conf import settings as django_settings

logger = logging.getLogger(__name__)

SANDBOX_BASE_URL = "https://sandbox.safaricom.co.ke"
PRODUCTION_BASE_URL = "https://api.safaricom.co.ke"

# Sensitive fields that get encrypted at rest
_SENSITIVE_FIELDS = {"consumer_key", "consumer_secret", "passkey"}


def _get_fernet() -> Fernet:
    """Derive a Fernet key from Django SECRET_KEY."""
    key = hashlib.sha256(django_settings.SECRET_KEY.encode()).digest()
    return Fernet(base64.urlsafe_b64encode(key))


def encrypt_config(config: dict) -> dict:
    """Encrypt sensitive fields before storing in the DB."""
    f = _get_fernet()
    out = dict(config)
    for field in _SENSITIVE_FIELDS:
        if field in out and out[field]:
            out[field] = f.encrypt(out[field].encode()).decode()
    return out


def decrypt_config(config: dict) -> dict:
    """Decrypt sensitive fields after reading from the DB."""
    f = _get_fernet()
    out = dict(config)
    for field in _SENSITIVE_FIELDS:
        if field in out and out[field]:
            try:
                out[field] = f.decrypt(out[field].encode()).decode()
            except Exception:
                pass  # Field may not be encrypted (legacy data)
    return out


def normalize_phone(phone: str) -> str:
    """Normalize Kenyan phone number to 254XXXXXXXXX format."""
    phone = phone.strip().replace(" ", "").replace("-", "")
    if phone.startswith("+"):
        phone = phone[1:]
    if phone.startswith("0"):
        phone = "254" + phone[1:]
    if not phone.startswith("254"):
        phone = "254" + phone
    return phone


class MpesaClient:
    def __init__(self, consumer_key: str, consumer_secret: str, shortcode: str,
                 passkey: str, callback_url: str, environment: str = "sandbox"):
        self.consumer_key = consumer_key
        self.consumer_secret = consumer_secret
        self.shortcode = shortcode
        self.passkey = passkey
        self.callback_url = callback_url
        self.base_url = PRODUCTION_BASE_URL if environment == "production" else SANDBOX_BASE_URL

    def _get_access_token(self) -> str:
        url = f"{self.base_url}/oauth/v1/generate?grant_type=client_credentials"
        resp = requests.get(url, auth=(self.consumer_key, self.consumer_secret), timeout=30)
        resp.raise_for_status()
        return resp.json()["access_token"]

    def _generate_password(self, timestamp: str) -> str:
        data = f"{self.shortcode}{self.passkey}{timestamp}"
        return base64.b64encode(data.encode()).decode()

    def stk_push(self, phone: str, amount: int, account_ref: str, description: str) -> dict:
        token = self._get_access_token()
        timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
        password = self._generate_password(timestamp)

        payload = {
            "BusinessShortCode": self.shortcode,
            "Password": password,
            "Timestamp": timestamp,
            "TransactionType": "CustomerPayBillOnline",
            "Amount": amount,
            "PartyA": phone,
            "PartyB": self.shortcode,
            "PhoneNumber": phone,
            "CallBackURL": self.callback_url,
            "AccountReference": account_ref,
            "TransactionDesc": description,
        }

        url = f"{self.base_url}/mpesa/stkpush/v1/processrequest"
        resp = requests.post(
            url,
            json=payload,
            headers={"Authorization": f"Bearer {token}"},
            timeout=30,
        )
        resp.raise_for_status()
        return resp.json()

    def test_credentials(self) -> dict:
        """Verify credentials by attempting to fetch an access token."""
        token = self._get_access_token()
        return {"success": True, "access_token_preview": token[:10] + "..."}
