from aspen_crypto.otp import generate_totp, verify_totp
import jwt
import logging

import time
from typing import Optional, Union

from . import keys, get_setting
from .nonce import get_nonce_backend
from .signals import invalid_signed_token
from .typehints import ClaimCreds

logger = logging.getLogger(__name__)


class Token:
    """
    Represents a JWT that's either been constructed by our code or has been
    verified to be valid.
    """

    username: str
    timestamp: int

    def __init__(
            self,
            cid: str = '',
            email: str = '',
            username: str = '',
            timestamp: Optional[int] = None,
            public_key_cid: str = '',
    ):
        self.cid = cid
        self.email = email
        self.username = username
        self.timestamp = int(time.time()) if timestamp is None else timestamp
        self.public_key_cid = public_key_cid

    def create_auth_header(self, private_key: keys.PrivateKey) -> str:
        """
        Create an HTTP Authorization header
        """

        auth_method = get_setting("AUTH_METHOD")
        token = self.sign(private_key)
        return f"{auth_method} {token}"

    def create_auto_auth_header(self, private_key: keys.PrivateKey, include_otp=False) -> str:
        """
        Create an HTTP Authorization header when public key accompanies request
        in either POST body or GET param.
        """

        auth_method = get_setting("AUTO_AUTH_METHOD")
        token = self.sign(private_key, include_otp=include_otp)
        return f"{auth_method} {token}"

    def sign(self, private_key: keys.PrivateKey, include_otp=False) -> str:
        """
        Create and return signed authentication JWT
        """
        public_key = private_key.public_key
        algorithm = public_key.allowed_algorithms[0]
        nonce = get_nonce_backend().generate_nonce()
        kid = public_key.fingerprint
        # Build and sign claim data
        token_data = {
            "time": self.timestamp,
            "nonce": nonce,
        }
        if self.public_key_cid:
            token_data['public_key_cid'] = self.public_key_cid
        if self.cid:
            token_data['cid'] = self.cid
        elif self.email:
            token_data['email'] = self.email
        elif self.username:
            token_data['username'] = self.username
        if include_otp:
            token_data['otp'] = generate_totp(public_key.as_pem.decode(), digits=8).now()

        headers = {
            "kid": kid,
        }
        token = jwt.encode(
            payload=token_data,
            key=private_key.as_pem,
            algorithm=algorithm,
            headers=headers,
        )
        return token


class UntrustedToken:
    """
    Represents a JWT received from user input (and not yet trusted)
    """

    token: str

    def __init__(self, token: str):
        """
        :param token: JWT claim
        """
        self.token = token
        self._token_data = {}
        self._valid_totp = None

    def token_data(self):
        if not self._token_data:
            try:
                self._token_data = jwt.decode(self.token, options={"verify_signature": False})
            except jwt.PyJWTError:
                pass
        return self._token_data

    def get_claimed_creds(self) -> ClaimCreds:
        """
        Given a JWT, get the username that it is claiming to be
        `without verifying that the signature is valid`.
        """
        data = {}

        unverified_data = jwt.decode(self.token, options={"verify_signature": False})
        for key in ('cid', 'email', 'username'):
            data[key] = unverified_data.get(key, '')

        return data

    def requested_public_key(self) -> str:
        """
        The token may request to have been signed by a specific PublicKey
        record to be considered valid, rather than allowing _any_ PublicKey
        that might be assigned to the specific user.
        """
        return self.token_data().get('public_key_cid', '')

    def totp_code(self) -> str:
        """
        Return the Time-based One-time Password if the token contains it.
        """
        return self.token_data().get('otp', '')

    def valid_otp(self, public_key) -> bool:
        """
        If the token contains a TOTP, it must be verified first.
        This step is primarily used for registering new PublicKey
        records automatically.
        """
        if self._valid_totp is None:

            otp_code = self.totp_code()
            if not otp_code:
                self._valid_totp = True
            else:
                key_text = public_key.as_pem.decode()
                self._valid_totp = verify_totp(key_text, otp=otp_code)
                if not self._valid_totp:
                    logger.error('TOTP: %s is invalid for public key: %s' % (otp_code, key_text))

        return self._valid_totp

    def verify(self, public_key: keys.PublicKey) -> Union[None, Token]:
        """
        Verify the validity of the given JWT using the given public key.
        """
        # valid_otp can be called separately to provide custom error message to user,
        # but call here to ensure it's included
        if not self.valid_otp(public_key):
            return None

        try:
            token_data = jwt.decode(
                jwt=self.token,
                key=public_key.as_pem.decode(),
                algorithms=public_key.allowed_algorithms,
            )
        except jwt.InvalidTokenError:
            logger.error("JWT failed verification")
            return None

        claimed_time = token_data.get("time", 0)
        claimed_nonce = token_data.get("nonce")

        for user_field in ('cid', 'email', 'username'):
            user_value = token_data.get(user_field)
            if user_value:
                break

        # Ensure fields aren't blank
        if not user_value or not claimed_time or not claimed_nonce:
            return None

        # Ensure time is within acceptable bounds
        current_time = time.time()
        timestamp_tolerance = get_setting("TIMESTAMP_TOLERANCE")
        min_time, max_time = (
            int(current_time - timestamp_tolerance),
            int(current_time + timestamp_tolerance),
        )
        if claimed_time < min_time or claimed_time > max_time:
            msg = f"Claimed time is outside {timestamp_tolerance} second range."
            msg = f"{msg} '{claimed_time}' outside '{min_time}-{max_time}'. User: {user_value!r}"
            invalid_signed_token.send(sender=self.__class__, message=msg)
            logger.error(msg)
            return None

        # Ensure nonce is unique
        nonce_backend = get_nonce_backend()
        if not nonce_backend.validate_nonce(
            user_value, claimed_time, claimed_nonce
        ):
            msg = f"Claimed nonce has already been used - {claimed_nonce!r}. User: {user_value!r}"
            invalid_signed_token.send(sender=self.__class__, message=msg)
            logger.error(msg)
            return None

        # If we've gotten this far, the token is valid
        nonce_backend.log_used_nonce(user_value, claimed_time, claimed_nonce)
        return Token(**{user_field: user_value, 'timestamp': claimed_time})
