from aspen_crypto.keys import KeyDetails
from datetime import datetime, timezone
from lchttp.json import json_loads
from lclazy import LazyLoader
from lcrequests.tokens import PasetoFooter, PasetoPayload
import json
import logging
from typing import Optional, Union, TYPE_CHECKING

if TYPE_CHECKING:
    import pyseto
    from pyseto import KeyInterface
else:
    pyseto = LazyLoader('pyseto', globals(), 'pyseto')

from .choices import SUBJECT
from .loaders import (
    KeyLoaderBase,
    AspenLoader,
    DatabaseLoader,
    LocalKeyLoader,
    PemLoader,
)
from . import get_setting
from .nonce import get_nonce_backend
from .typehints import ClaimCreds, KeyData

logger = logging.getLogger(__name__)
_retrieve_key_from_aspen = None


# ---------------------------------------------------------------------------
class UnverifiedToken:
    """
    Wrapper class for convenience methods to
    perform builtin token verification.
    """

    def __init__(self, token: Union[bytes, str], request_data: dict | None = None):
        self.token_class = pyseto.Token.new(token)
        self.token_string = token
        self.request_data = request_data

        self._footer_payload: PasetoFooter | None = None
        self.paseto_version = int((token if isinstance(token, str) else token.decode())[1])
        self._public_key: Optional['KeyInterface'] = None
        self._verification_key: KeyLoaderBase | None = None

    def verify(self) -> 'AccessToken':
        """
        Perform builtin Paseto verification and return the verified token
        """
        token = pyseto.Paseto(
            leeway=get_setting('TIMESTAMP_TOLERANCE'),
        ).decode(
            keys=self.paseto_public_key(),
            token=self.token_string,
            deserializer=json,
        )

        self.verify_subject(payload=token.payload)

        return AccessToken(self._verification_key, token)

    def footer_payload(self) -> PasetoFooter:
        if self._footer_payload is None:
            try:
                self._footer_payload = json_loads(self.token_class.footer)
            except json.JSONDecodeError:
                logger.info(f'Unable to decode token footer: {self.token_class.footer}')
        return self._footer_payload

    def verification_key(self):
        """
        Lookup verification key from database or from Aspen.
        """
        if self._verification_key:
            return self._verification_key

        footer = self.footer_payload()
        cid = footer.get('cid') or ''
        kid = footer.get('kid') or ''

        for Loader in (
            LocalKeyLoader,
            DatabaseLoader,
            AspenLoader,
        ):
            key_loader = Loader(cid, kid)
            try:
                if key_loader.public_key():
                    self._verification_key = key_loader
                    return key_loader
            except pyseto.VerifyError:
                continue

        raise pyseto.VerifyError('No key found to verify token')

    def paseto_public_key(self) -> 'KeyInterface':
        """
        Lookup public key from database or from Aspen.
        """
        if not self._public_key:
            key_data = self.verification_key().public_key()
            self._public_key = pyseto.Key.new(
                version=self.paseto_version,
                purpose='public',
                key=key_data.key if isinstance(key_data, KeyData) else key_data.public_key.as_pem,
            )

        return self._public_key

    def verify_subject(self, payload: PasetoPayload):
        """
        The subject of the Key must correspond with the subject/source of the API call.
        """
        key_data = self.verification_key().key_data
        # If the Key couldn't be loaded from the database,
        # then subject verification isn't applicable.
        if isinstance(key_data, KeyDetails):
            return

        if key_data.subject == SUBJECT.Any:
            return

        subject = payload.get('sub', '')
        if subject and subject != key_data.subject:
            raise pyseto.VerifyError(f'Verification key is not valid for subject: {subject}')


# ---------------------------------------------------------------------------
class UnverifiedRegistrationToken(UnverifiedToken):
    """
    Perform extra verification on tokens when doing registration.
    When using pinned EncryptionKey records, the subject and
    subject_id must be provided!
    """

    def verification_key(self):
        """
        Prefer all other forms of loading the Verification Key, but
        fall back to checking the request data for `pem` and `cid`
        values to construct a verification key.
        """
        try:
            if vk := super().verification_key():
                return vk
        except pyseto.PysetoError:
            vk = PemLoader(data=self.request_data)
            if vk.public_key():
                self._verification_key = vk
                return self._verification_key

            raise

    def verify_subject(self, payload: PasetoPayload) -> None:
        """
        When using pinned EncryptionKey records, the subject must
        be provided!
        """
        super().verify_subject(payload)
        if not payload.get('sub', ''):
            raise pyseto.VerifyError('Registration token must contain valid Subject claim')

        if not payload.get('sub_cid', ''):
            raise pyseto.VerifyError('Registration token must contain valid Subject CID claim')


# ---------------------------------------------------------------------------
class AccessToken:
    def __init__(
        self,
        verification_key: KeyLoaderBase,
        token: pyseto.Token,
    ) -> None:
        """
        :param token: Paseto claim
        """
        self.verification_key = verification_key
        if isinstance(token.payload, dict):
            self.payload: PasetoPayload = token.payload
        else:
            raise pyseto.VerifyError('Token payload must be data object.')

        self.token = token
        self.footer = token.footer

    def verify(self) -> 'AccessToken':
        """
        Verify the validity of the given Paseto Token using the given public key.
        """
        expiration_ts = self.validate_exp()
        claimed_nonce = self.validate_nonce()
        self.validate_unique_nonce(expiration_ts, claimed_nonce)
        self.verification_key.update_last_used()

        return self

    def get_claimed_creds(self) -> ClaimCreds:
        """
        After verifying that the token is valid, get the User record
        from the token payload.
        """
        return ClaimCreds(
            user_cid=self.payload.get('user_cid', ''),
            email=self.payload.get('email', ''),
        )

    def validate_exp(self) -> int:
        """
        By the time this code has been reached, we know that the 'exp' value
        has passed builtin validation, but we _require_ all tokens to have
        the `exp` key present in the payload.
        """
        try:
            return int(datetime.fromisoformat(self.payload['exp']).timestamp())
        except (KeyError, ValueError):
            raise pyseto.VerifyError('Expiration value is required') from None

    def validate_nonce(self) -> str:
        """
        The payload must contain a nonce value!
        """
        try:
            return self.payload['nonce']
        except KeyError:
            raise pyseto.VerifyError('Nonce value is required.') from None

    def validate_unique_nonce(self, expiration_ts: int, nonce: str) -> None:
        """
        Ensure that nonce is unique to prevent replay attacks.
        """
        claimed_time = int(expiration_ts - datetime.now(timezone.utc).timestamp())

        nonce_backend = get_nonce_backend()
        if not nonce_backend.validate_nonce(claimed_time, nonce):
            logger.error("Claimed nonce has already been used - payload: %s", nonce)
            raise pyseto.VerifyError('Nonce has already been used.')

        # If we've gotten this far, the token is valid
        nonce_backend.log_used_nonce(claimed_time, nonce)
