"""
Provides RestFramework auth backend for signed tokens.
"""
import base64
import binascii
import logging

from rest_framework import exceptions
from rest_framework.authentication import get_authorization_header, BaseAuthentication
from django.http import HttpRequest, HttpResponse
from django.utils.translation import gettext_lazy as _
from typing import Callable

from signed_jwt_auth import get_setting, keys
from signed_jwt_auth.tokens import UntrustedToken
from signed_jwt_auth.repos import get_user_repository, get_public_key_repositories
from signed_jwt_auth.typehints import UntrustedAuth, VerifiedAuth
logger = logging.getLogger(__name__)


class SignedTokenMixin:
    """
    Signed token based authentication for Django Rest Framework.
    """

    header_method_key = ''

    def authenticate(self, request: Callable[[HttpRequest], HttpResponse]):
        # Ensure this auth header was meant for us (it has the JWT auth method).
        auth = get_authorization_header(request).split()

        method = self.authenticate_header(request)
        if not auth or auth[0].upper() != method.encode():
            return None

        try:
            header_token = auth[1].decode()
        except IndexError:
            msg = _('Invalid signed token header. No credentials provided.')
            raise exceptions.AuthenticationFailed(msg)
        except UnicodeError:
            msg = _('Invalid signed token header. Token string contains invalid characters.')
            raise exceptions.AuthenticationFailed(msg)

        return self.authenticate_credentials(header_token, request)

    def lookup_user_record(self, token: str) -> UntrustedAuth:
        """
        Lookup user record from token claims
        """
        untrusted_token = UntrustedToken(token)
        lookup_creds = untrusted_token.get_claimed_creds()

        user = get_user_repository().get_user(lookup_creds)
        if not user:
            raise exceptions.AuthenticationFailed(_('User not found for provided token.'))
        if not user.is_active:
            raise exceptions.AuthenticationFailed(_('User inactive or deleted.'))

        return UntrustedAuth(user, untrusted_token)

    def authenticate_header(self, request):
        if not self.header_method_key:
            return
        return get_setting(self.header_method_key)


class SignedTokenAuthentication(SignedTokenMixin, BaseAuthentication):
    """
    Clients should authenticate by passing the token key in the "Authorization"
    HTTP header, prepended with the string "Token ".  For example:

        Authorization: SIGNED_JWT 401f7ac837da42b97f613d789819ff93537bee6a
    """
    header_method_key = 'AUTH_METHOD'

    def authenticate_credentials(self, token: str, request) -> VerifiedAuth:
        """
        Authenticate user by comparing signed token with any
        PublicKey records linked to the specified user.
        """
        user, untrusted_token = self.lookup_user_record(token)

        for repo in get_public_key_repositories():
            verified_token = repo.attempt_to_verify_token(user, untrusted_token)
            if verified_token:
                return VerifiedAuth(user, verified_token)

        raise exceptions.AuthenticationFailed(_('Token verification failed.'))


class SignedTokenByProvidedKey(SignedTokenMixin, BaseAuthentication):
    """
    Validates signed JWTs for Django Rest Framework, to be used in
    registration situations, such as the creation of the FIRST
    PublicKey for a user.

    The public key (verify_public_key) must be
    included in the POST body or GET param.
    (The public key must be base64 encoded if using the GET param.)

        Authorization: AUTO_JWT 401f7ac837da42b97f613d789819ff93537bee6a
    """
    header_method_key = 'AUTO_AUTH_METHOD'

    def authenticate_credentials(self, token: str, request) -> VerifiedAuth:
        """
        Get verification public key from request.data, POST body or GET query params,
        in that order. GET param must be base64-encoded.
        """
        param_key = get_setting('AUTO_AUTH_METHOD_PARAM')

        try:
            verify_public_key_bytes = request.data[param_key].encode()
        except (AttributeError, KeyError):
            verify_public_key_bytes = request.POST.get(param_key, '').encode()

        if not verify_public_key_bytes:
            verify_public_key = request.GET.get(param_key, '')
            try:
                # Add 2 equals signs as padding to avoid "Incorrect padding" error. Unnecessary
                # padding is stripped, and 2 is the most ever required to avoid this specific error.
                verify_public_key_bytes = base64.urlsafe_b64decode(f'{verify_public_key}==')
            except (binascii.Error, UnicodeDecodeError) as e:
                raise exceptions.AuthenticationFailed(
                    _('Unable to parse verification public key: %s') % e
                )

        if not verify_public_key_bytes:
            raise exceptions.AuthenticationFailed(_('No verification key was provided'))

        exc, public_key = keys.PublicKey.load_serialized_public_key(verify_public_key_bytes)

        if not public_key:
            logger.error('Key verification failed for key: %s' % verify_public_key_bytes.decode())
            if not exc:
                raise exceptions.AuthenticationFailed(_('No public key provided'))
            raise exceptions.AuthenticationFailed(_('Invalid public key'))

        user, untrusted_token = self.lookup_user_record(token)

        if not untrusted_token.valid_otp(public_key=public_key):
            raise exceptions.AuthenticationFailed(
                _("Token verification failed. Check your device clock.")
            )

        verified_token = untrusted_token.verify(public_key=public_key)
        if not verified_token:
            logger.error('Token verification failed for token: %s' % token)
            raise exceptions.AuthenticationFailed(_('Token verification failed'))

        return VerifiedAuth(user, verified_token)

    def lookup_user_record(self, token: str) -> UntrustedAuth:
        """
        Prevent authenticating new PublicKey records, when a user already
        has a PublicKey. They should sign JWTs with it, and perform key
        rotation with _existing_ key. To register more than one key
        """
        user, untrusted_token = super().lookup_user_record(token)
        if not untrusted_token.totp_code() and user.public_keys.exists():
            raise exceptions.AuthenticationFailed(
                _("JWT must be signed with one of the user's existing PublicKeys!")
            )
        return UntrustedAuth(user, untrusted_token)
