from aspen_crypto.keys import (
    get_key_password,
    generate_kid,
    FacadePrivateKey,
    FacadePublicKey,
    PrivateKey,
    PublicKey,
)
from django.conf import settings
from django.db import models
from django.db.models import Q
from django.utils.translation import gettext_lazy as tr
from django.utils.encoding import force_str, force_bytes

from typing import Optional

from console_base.theme import ICONS
from console_base.models import BaseUUIDPKModel, LCDateTimeField, LCTextField, NameCITextField
from console_base.models.fernet import EncryptedBinaryField

from .choices import STATUS, SUBJECT
from .fields import KeyField
from .managers import EncryptionQuerySet
from .validators import validate_public_key


class EncryptionKey(BaseUUIDPKModel):
    """
    Store a public key or public/private key pair and associate it to a particular user.

    Implements the same concept as the OpenSSH ``~/.ssh/authorized_keys`` file on a Unix system.
    """

    css_icon = ICONS.TLS
    STATUS = STATUS

    user = models.ForeignKey(  # type: ignore[var-annotated]
        settings.AUTH_USER_MODEL,
        verbose_name=tr("User"),
        related_name="encryption_keys",
        on_delete=models.CASCADE,
    )

    # Key text in either PEM or OpenSSH format.
    key = KeyField(
        tr("Encryption Key"),
        help_text=tr("The user's RSA / Ed25519 / ECDSA public key"),
        validators=[validate_public_key],
    )
    kid = LCTextField(
        help_text=tr('Paserk Key ID derived from the Public Key'),
    )

    # Private Key bytes in either PEM or OpenSSH format.
    # Just because this field is on the model doesn't mean we should use it.
    # Included here for fast backup / restores of cloud systems, which to
    # accomplish means we need to store the private key on disk anyway.
    private_key = EncryptedBinaryField(
        tr("Private Key"),
        null=True,
        blank=True,
        help_text=tr("The user's RSA / Ed25519 / ECDSA private key"),
    )

    # Info describing the key. What system is authenticating with the key etc.
    name = NameCITextField(tr("Name"))

    # Date and time that key was last used for authenticating a request.
    last_used_on = LCDateTimeField(tr("Last Used On"), null=True, blank=True)
    revocation_date = LCDateTimeField(tr("Revocation Date"), null=True, blank=True)
    status = LCTextField(tr("Status"), choices=STATUS.choices, default=STATUS.active)
    subject = LCTextField(tr("Subject"), choices=SUBJECT.choices, default=SUBJECT.Any)

    objects = EncryptionQuerySet.as_manager()

    class Meta:
        verbose_name = tr("Encryption Key")
        verbose_name_plural = tr("Encryption Keys")
        constraints = [
            models.UniqueConstraint(
                fields=('user', 'name'),
                name='unique_key_name_per_user',
                condition=Q(revocation_date__isnull=True),
            ),
            models.UniqueConstraint(
                fields=('key',),
                name='encryption_keys_must_be_unique',
            ),
            models.UniqueConstraint(
                fields=('kid',),
                name='public_key_ids_must_be_unique',
            ),
        ]

    def __str__(self):
        return f'{self.__class__.__name__}({self.name!r})'

    def get_key(self) -> FacadePublicKey:
        key_bytes = force_bytes(self.key)
        exc, key = PublicKey.load_serialized_public_key(key_bytes)
        if key is None:
            if exc is None:  # pragma: no cover
                raise ValueError("Failed to load key")
            raise exc
        return key

    def get_private_key(self) -> Optional[FacadePrivateKey]:
        openssh_banner = b'-----BEGIN OPENSSH PRIVATE KEY-----'
        password = get_key_password(
            self.cid,
            settings.PRIVATE_KEY_PASSWORD_PREFIX,
            self.private_key.startswith(openssh_banner),
        )
        if not password and self.PRIVATE_KEY_PASSWORD_PREFIX:
            return None
        return PrivateKey.load_pem(self.private_key, password=password)

    def save(self, *args, **kwargs) -> None:
        if not self.name:
            key_parts = force_str(self.key).split(" ")
            if len(key_parts) == 3:
                self.name = key_parts.pop()
        if not self.kid:
            self.kid = generate_kid(self.key)
        super().save(*args, **kwargs)
