"""
Based on from https://github.com/crgwbr/asymmetric-jwt-auth
Copyright (c) 2021, Craig Weber <crgwbr@gmail.com>

Permission to use, copy, modify, and/or distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT,
INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR
OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
OF THIS SOFTWARE.
"""

try:
    from pybase64 import urlsafe_b64encode
except (ImportError, ModuleNotFoundError):
    from base64 import urlsafe_b64encode
from cryptography.hazmat.primitives.asymmetric import rsa, ec, ed25519

try:
    from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes, PublicKeyTypes
except ImportError:
    from cryptography.hazmat.primitives.asymmetric.types import (
        PRIVATE_KEY_TYPES as PrivateKeyTypes,
        PUBLIC_KEY_TYPES as PublicKeyTypes,
    )
from cryptography.hazmat.primitives import serialization
import hashlib
from lclazy import LazyLoader
import os
import struct
from typing import Generic, List, Literal, TypeVar, Tuple, TYPE_CHECKING, Union

if TYPE_CHECKING:
    from pyseto import KeyInterface
    import pyseto
else:
    pyseto = LazyLoader('pyseto', globals(), 'pyseto')

FacadePrivateKey = Union["RSAPrivateKey", "EllipticCurvePrivateKey", "Ed25519PrivateKey"]
FacadePublicKey = Union["RSAPublicKey", "EllipticCurvePublicKey", "Ed25519PublicKey"]

PrivateKeyType = TypeVar(
    "PrivateKeyType",
    rsa.RSAPrivateKey,
    ec.EllipticCurvePrivateKey,
    ed25519.Ed25519PrivateKey,
)
PublicKeyType = TypeVar(
    "PublicKeyType",
    rsa.RSAPublicKey,
    ec.EllipticCurvePublicKey,
    ed25519.Ed25519PublicKey,
)
PurposeType = Literal['public', 'private']


# ---------------------------------------------------------------------------
class PublicKey(Generic[PublicKeyType]):
    """Represents a public key"""

    _key: PublicKeyType

    @staticmethod
    def from_cryptography_pubkey(pubkey: PublicKeyTypes) -> FacadePublicKey:
        if isinstance(pubkey, rsa.RSAPublicKey):
            return RSAPublicKey(pubkey)
        if isinstance(pubkey, ec.EllipticCurvePublicKey):
            return EllipticCurvePublicKey(pubkey)
        if isinstance(pubkey, ed25519.Ed25519PublicKey):
            return Ed25519PublicKey(pubkey)
        raise TypeError(f"Unknown key type: {pubkey}")

    @classmethod
    def load_pem(cls, pem: bytes) -> FacadePublicKey:
        """
        Load a PEM-format public key
        """
        public_key = serialization.load_pem_public_key(pem)
        return cls.from_cryptography_pubkey(public_key)

    @classmethod
    def load_openssh(cls, key: bytes) -> FacadePublicKey:
        """
        Load an openssh-format public key
        """
        public_key = serialization.load_ssh_public_key(key)
        return cls.from_cryptography_pubkey(public_key)

    @classmethod
    def load_serialized_public_key(
        cls, key: bytes
    ) -> Tuple[Exception | None, FacadePublicKey | None]:
        """
        Load a PEM or openssh format public key
        """
        for loader in (cls.load_pem, cls.load_openssh):
            try:
                return None, loader(key)
            except Exception as e:
                exc = e
        return exc, None

    @property
    def as_pem(self) -> bytes:
        """
        Get the public key as a PEM-formatted byte string
        """
        pem_bytes = self._key.public_bytes(
            serialization.Encoding.PEM, serialization.PublicFormat.SubjectPublicKeyInfo
        )
        return pem_bytes

    def as_paseto(self, version: int, purpose: PurposeType = 'public') -> 'KeyInterface':
        """
        Return the public key in Paseto format.
        """
        try:
            return pyseto.Key.new(version=version, purpose=purpose, key=self.as_pem)
        except AttributeError:
            pass
        raise NotImplementedError('Paseto keys not supported in this class')

    def to_paserk_id(self, version: int) -> str:
        return self.as_paseto(version).to_paserk_id()

    @property
    def fingerprint(self) -> str:
        """
        Get a sha256 fingerprint of the key.
        """
        return hashlib.sha256(self.as_pem).hexdigest()

    @property
    def allowed_algorithms(self) -> List[str]:  # pragma: no cover
        """
        Return a list of allowed JWT algorithms for this key, in order of most to least preferred.
        """
        raise NotImplementedError("Subclass does not implement allowed_algorithms method")


# ---------------------------------------------------------------------------
class RSAPublicKey(PublicKey):
    """Represents an RSA public key"""

    def __init__(self, key: rsa.RSAPublicKey):
        self._key = key

    @property
    def as_jwk(self) -> dict:
        """
        Return the public key in JWK format
        """
        public_numbers = self._key.public_numbers()
        return {
            "kty": "RSA",
            "use": "sig",
            "alg": self.allowed_algorithms[0],
            "kid": self.fingerprint,
            "n": long_to_base64(public_numbers.n),
            "e": long_to_base64(public_numbers.e),
        }

    @property
    def allowed_algorithms(self) -> List[str]:
        return [
            "RS512",
            "RS384",
            "RS256",
        ]


# ---------------------------------------------------------------------------
class EllipticCurvePublicKey(PublicKey):
    """Represents an EllipticCurve public key"""

    def __init__(self, key: ec.EllipticCurvePublicKey):
        self._key = key

    @property
    def allowed_algorithms(self) -> List[str]:
        return [
            "SECT571R1",
            "SECT409R1",
            "SECT283R1",
            "SECT233R1",
            "SECT163R2",
            "SECT571K1",
            "SECT409K1",
            "SECT283K1",
            "SECT233K1",
            "SECT163K1",
            "SECP521R1",
            "SECP384R1",
            "SECP256R1",
            "SECP256K1",
            "SECP224R1",
            "SECP192R1",
            "BrainpoolP256R1",
            "BrainpoolP384R1",
            "BrainpoolP512R1",
        ]


# ---------------------------------------------------------------------------
class Ed25519PublicKey(PublicKey):
    """Represents an Ed25519 public key"""

    def __init__(self, key: ed25519.Ed25519PublicKey):
        self._key = key

    @property
    def allowed_algorithms(self) -> List[str]:
        return [
            "EdDSA",
        ]


# ---------------------------------------------------------------------------
class PrivateKey(Generic[PrivateKeyType]):
    """Represents a private key"""

    _key: PrivateKeyType

    @staticmethod
    def from_cryptography_private_key(private_key: PrivateKeyTypes) -> FacadePrivateKey:
        if isinstance(private_key, rsa.RSAPrivateKey):
            return RSAPrivateKey(private_key)
        if isinstance(private_key, ec.EllipticCurvePrivateKey):
            return EllipticCurvePrivateKey(private_key)
        if isinstance(private_key, ed25519.Ed25519PrivateKey):
            return Ed25519PrivateKey(private_key)
        raise TypeError("Unknown key type")

    @classmethod
    def load_pem_from_file(
        cls,
        filepath: os.PathLike,
        password: bytes | None = None,
    ) -> FacadePrivateKey:
        """
        Load a PEM-format private key from disk.
        """
        with open(filepath, "rb") as fh:
            key_bytes = fh.read()
        return cls.load_pem(key_bytes, password=password)

    @classmethod
    def load_pem(cls, pem: bytes, password: bytes | None = None) -> FacadePrivateKey:
        """
        Load a PEM-format private key
        """
        try:
            private_key = serialization.load_pem_private_key(pem, password=password)
        except ValueError:
            private_key = serialization.load_ssh_private_key(pem, password)

        return cls.from_cryptography_private_key(private_key)

    @property
    def as_pem(self) -> bytes:
        pem_bytes = self._key.private_bytes(
            serialization.Encoding.PEM,
            serialization.PrivateFormat.PKCS8,
            serialization.NoEncryption(),
        )
        return pem_bytes

    def as_paseto(self, version: int, purpose: PurposeType = 'public') -> 'KeyInterface':
        """
        Return the private key in Paseto format.
        """
        try:
            return pyseto.Key.new(version=version, purpose=purpose, key=self.as_pem)
        except AttributeError:
            pass
        raise NotImplementedError('Paseto keys not supported in this class')

    @property
    def public_key(self) -> FacadePublicKey:  # pragma: no cover
        raise NotImplementedError()

    def encryption(self, password: bytes = b'') -> serialization.KeySerializationEncryption:
        if not password:
            return serialization.NoEncryption()
        return serialization.BestAvailableEncryption(password)


# ---------------------------------------------------------------------------
class RSAPrivateKey(PrivateKey[rsa.RSAPrivateKey]):
    """Represents an RSA private key"""

    pubkey_cls = RSAPublicKey

    @classmethod
    def generate(cls, size: int = 2048, public_exponent: int = 65537) -> "RSAPrivateKey":
        """
        Generate an RSA private key.
        """
        private = rsa.generate_private_key(public_exponent=public_exponent, key_size=size)
        return cls(private)

    def __init__(self, key: rsa.RSAPrivateKey):
        self._key = self.private_key = key

    @property
    def public_key(self) -> FacadePublicKey:
        public = self._key.public_key()
        return self.pubkey_cls(public)

    def private_bytes(self, password: bytes = b'') -> bytes:
        return self._key.private_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PrivateFormat.TraditionalOpenSSL,
            encryption_algorithm=self.encryption(password),
        )


# ---------------------------------------------------------------------------
class EllipticCurvePrivateKey(PrivateKey[ec.EllipticCurvePrivateKey]):
    """Represents an EllipticCurve private key"""

    pubkey_cls = EllipticCurvePublicKey

    @classmethod
    def generate(cls, curve: ec.EllipticCurve | None = None) -> "EllipticCurvePrivateKey":
        """
        Generate an EllipticCurve private key.
        """
        private = ec.generate_private_key(curve=curve or ec.SECP384R1())
        return cls(private)

    def __init__(self, key: ec.EllipticCurvePrivateKey):
        self._key = self.private_key = key

    @property
    def public_key(self) -> FacadePublicKey:
        public = self._key.public_key()
        return self.pubkey_cls(public)

    def private_bytes(self, password: bytes = b'') -> bytes:
        return self._key.private_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PrivateFormat.TraditionalOpenSSL,
            encryption_algorithm=self.encryption(password),
        )


# ---------------------------------------------------------------------------
class Ed25519PrivateKey(PrivateKey[ed25519.Ed25519PrivateKey]):
    """Represents an Ed25519 private key"""

    pubkey_cls = Ed25519PublicKey

    @classmethod
    def generate(cls) -> "Ed25519PrivateKey":
        """
        Generate an Ed25519 private key.
        """
        private = ed25519.Ed25519PrivateKey.generate()
        return cls(private)

    def __init__(self, key: ed25519.Ed25519PrivateKey):
        self._key = self.private_key = key

    @property
    def public_key(self) -> FacadePublicKey:
        public = self._key.public_key()
        return self.pubkey_cls(public)

    def private_bytes(self, password: bytes = b'') -> bytes:
        return self._key.private_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PrivateFormat.OpenSSH,
            encryption_algorithm=self.encryption(password),
        )


def long2intarr(long_int: int) -> List[int]:
    _bytes: List[int] = []
    while long_int:
        long_int, r = divmod(long_int, 256)
        _bytes.insert(0, r)
    return _bytes


def long_to_base64(n: int, m_len: int = 0) -> str:
    bys = long2intarr(n)
    if m_len:
        _len = m_len - len(bys)
        if _len:
            bys = [0] * _len + bys
    data = struct.pack(f"{len(bys)}B", *bys)
    if not len(data):
        data = b"\x00"
    s = urlsafe_b64encode(data).rstrip(b"=")
    return s.decode("ascii")


__all__ = (
    'FacadePrivateKey',
    'FacadePublicKey',
    'PublicKey',
    'RSAPublicKey',
    'Ed25519PublicKey',
    'EllipticCurvePrivateKey',
    'EllipticCurvePublicKey',
    'PrivateKey',
    'RSAPrivateKey',
    'Ed25519PrivateKey',
    'long2intarr',
    'long_to_base64',
)
