"""
Base DRF encryption-related serializers
"""

from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.serialization import (
    load_pem_private_key,
    load_pem_public_key,
    load_ssh_public_key,
)

try:
    from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes
except ImportError:
    from cryptography.hazmat.primitives.asymmetric.types import (
        PRIVATE_KEY_TYPES as PrivateKeyTypes,
    )

from typing import Any, TYPE_CHECKING
from django.conf import settings
from rest_framework import exceptions
from rest_framework.serializers import CharField
from .serializers import LCSerializer

PublicKeyCharFieldMixinBase = CharField if TYPE_CHECKING else object


# ---------------------------------------------------------------------------
class PublicKeyCharFieldMixin(PublicKeyCharFieldMixinBase):  # type: ignore
    pem_loaders: tuple

    def to_internal_value(self, data: str) -> str:
        key_bytes = data.encode()
        for i, loader in enumerate(self.pem_loaders):
            try:
                loader(key_bytes)
                return data
            except Exception:
                continue

        raise exceptions.ValidationError(self.error_messages['invalid'])


# ---------------------------------------------------------------------------
class PublicKeyField(PublicKeyCharFieldMixin, CharField):
    """
    Accept either OpenSSH or PEM Public Key.
    """

    default_error_messages = {"invalid": "Enter a valid OpenSSH or PEM Public Key"}
    pem_loaders = (
        load_pem_public_key,
        load_ssh_public_key,
    )


# ---------------------------------------------------------------------------
class OpenSshPublicKeyField(PublicKeyCharFieldMixin, CharField):
    default_error_messages = {"invalid": "Enter a valid OpenSSH Public Key"}
    pem_loaders = (load_ssh_public_key,)


# ---------------------------------------------------------------------------
class CertificateSigningRequestField(PublicKeyCharFieldMixin, CharField):
    default_error_messages = {"invalid": "Enter a valid Certificate Signing Request"}
    pem_loaders = (x509.load_pem_x509_csr,)


# ---------------------------------------------------------------------------
class PEMCertificateField(PublicKeyCharFieldMixin, CharField):
    default_error_messages = {"invalid": "Enter a valid PEM Certificate"}
    pem_loaders = (x509.load_pem_x509_certificate,)


# ---------------------------------------------------------------------------
class PEMPublicKeyField(PublicKeyCharFieldMixin, CharField):
    default_error_messages = {"invalid": "Enter a valid PEM Public Key"}
    pem_loaders = (load_pem_public_key,)


# ---------------------------------------------------------------------------
class PEMPrivateKeyField(CharField):
    default_error_messages = {"invalid": "Enter a valid PEM Private Key"}

    def to_internal_value(self, data: str) -> PrivateKeyTypes:  # type: ignore[override]
        try:
            return load_pem_private_key(
                data=data.encode('utf8'),
                password=None,
                backend=default_backend(),
            )
        except Exception:
            raise exceptions.ValidationError(self.error_messages['invalid']) from None

    def to_representation(self, value: PrivateKeyTypes) -> str:  # type: ignore[override]
        key = value.private_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PrivateFormat.TraditionalOpenSSL,
            encryption_algorithm=serialization.NoEncryption(),
        )
        return key.decode('utf8')


# ---------------------------------------------------------------------------
class SavePem:
    def create(self):
        return

    def save(self, **kwargs: Any) -> str:
        file_path = f'{settings.FILES_DIR}/{kwargs["name"]}'
        with open(file_path, 'wb') as sf:
            sf.write(kwargs['data'].strip())

        return file_path


# ---------------------------------------------------------------------------
class PEMCertificateSerializer(SavePem, LCSerializer):  # type: ignore
    extension = 'pem'

    name = CharField()
    pem = PEMCertificateField()

    def save(self, **kwargs: Any) -> str:
        name = f"{self.validated_data['name']}.{self.extension}"
        pem = self.validated_data['pem']
        return super().save(name=name, data=pem.encode())


# ---------------------------------------------------------------------------
class PEMCertificateSigningRequestSerializer(PEMCertificateSerializer):
    extension = 'csr'

    pem = CertificateSigningRequestField()


# ---------------------------------------------------------------------------
class PEMPrivateKeySerializer(SavePem, LCSerializer):  # type: ignore
    name = CharField()
    key = PEMPrivateKeyField()

    def save(self, **kwargs: Any) -> str:
        name = f"{self.validated_data['name']}.key"
        key = self.validated_data['key']
        key = key.private_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PrivateFormat.TraditionalOpenSSL,
            encryption_algorithm=serialization.NoEncryption(),
        )
        return super().save(name=name, data=key)


__all__ = (
    'CertificateSigningRequestField',
    'OpenSshPublicKeyField',
    'PEMCertificateField',
    'PEMPublicKeyField',
    'PEMPrivateKeyField',
    'PEMCertificateSerializer',
    'PEMCertificateSigningRequestSerializer',
    'PEMPrivateKeySerializer',
)
