from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography import x509
from cryptography.x509.oid import NameOID
from cryptography.hazmat.primitives import hashes
from datetime import datetime, timedelta
import os
import uuid

from django.conf import settings
from django.core.exceptions import ValidationError
from django.test import SimpleTestCase

from console_base.validators import *
from console_base.utils.encryption import save_x509_pem, save_private_key


# -------------------------------------------------------------------------
class TestValidateKey(SimpleTestCase):
    @classmethod
    def setUpClass(cls):
        super().setUpClass()

        cls.invalidkey = f'{settings.TMP_DIR}/invalidkey.key'
        cls.validkey = f'{settings.TMP_DIR}/validkey.key'
        with open(cls.invalidkey, 'wb') as key:
            key.write(b'not a valid key')

        key = rsa.generate_private_key(
            public_exponent=65537,
            key_size=2048,
            backend=default_backend(),
        )

        save_private_key(key, cls.validkey)

    @classmethod
    def tearDownClass(cls):
        try:
            os.remove(cls.invalidkey)
        except FileNotFoundError:
            pass

        try:
            os.remove(cls.validkey)
        except FileNotFoundError:
            pass

        super().tearDownClass()

    def test_invalid_key(self):
        with self.assertRaises(ValidationError):
            validate_key(self.invalidkey)

    def test_valid_key(self):
        self.assertIsNone(validate_key(self.validkey))


# -------------------------------------------------------------------------
class TestValidateCert(SimpleTestCase):
    @classmethod
    def setUpClass(cls):
        super().setUpClass()
        cls.invalidcert = f'{settings.TMP_DIR}/invalidcert.pem'
        cls.validcert = f'{settings.TMP_DIR}/validcert.pem'
        with open(cls.invalidcert, 'wb') as cert:
            cert.write(b'not a valid cert')

        key = rsa.generate_private_key(
            public_exponent=65537,
            key_size=2048,
            backend=default_backend(),
        )

        hostname = settings.DEFAULT_HOSTNAME
        cert = x509.CertificateBuilder(
            public_key=key,
            issuer_name=x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, hostname)]),
            subject_name=x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, hostname)]),
            serial_number=(int(uuid.uuid4())),
            not_valid_before=datetime.today() - timedelta(days=1),
            not_valid_after=datetime.today() + timedelta(days=365 * 10),
        )
        cert = cert.sign(
            private_key=key,
            algorithm=hashes.SHA256(),
            backend=default_backend(),
        )

        save_x509_pem(cert, cls.validcert)

    @classmethod
    def tearDownClass(cls):
        try:
            os.remove(cls.invalidcert)
        except FileNotFoundError:
            pass

        try:
            os.remove(cls.validcert)
        except FileNotFoundError:
            pass

        super().tearDownClass()

    def test_invalid_cert(self):
        with self.assertRaises(ValidationError):
            validate_certificate(self.invalidcert)

    def test_valid_cert(self):
        self.assertIsNone(validate_certificate(self.validcert))
