from datetime import date, timedelta

from cryptography.fernet import Fernet

from django.conf import settings
from django.core.exceptions import FieldError, ImproperlyConfigured
from django.db import connection, models as dj_models
from django.test import TestCase, override_settings
from django.utils import timezone
from django.utils.encoding import force_bytes, force_str

import console_base.models.fernet as fields
from fernet.models import (
    EncryptedText,
    EncryptedChar,
    EncryptedEmail,
    EncryptedInt,
    EncryptedDate,
    EncryptedDateTime,
    EncryptedNullable,
)


class TestEncryptedField(TestCase):
    @override_settings(FERNET_KEYS=['psst-the-secret'])
    def test_key_from_settings(self):
        """If present, use settings.FERNET_KEYS."""
        f = fields.EncryptedTextField()
        self.assertEqual(f.keys, settings.FERNET_KEYS)

    @override_settings(FERNET_KEYS=[])
    def test_fallback_to_secret_key(self):
        """If no FERNET_KEY setting, use SECRET_KEY."""
        f = fields.EncryptedTextField()
        self.assertEqual(f.keys, [settings.SECRET_KEY])

    @override_settings(FERNET_KEYS=['key1', 'key2'])
    def test_key_rotation(self):
        """Can supply multiple `keys` for key rotation."""
        f = fields.EncryptedTextField()

        enc1 = Fernet(f.fernet_keys[0]).encrypt(b'enc1')
        enc2 = Fernet(f.fernet_keys[1]).encrypt(b'enc2')

        self.assertEqual(f.fernet.decrypt(enc1), b'enc1')
        self.assertEqual(f.fernet.decrypt(enc2), b'enc2')

    @override_settings(FERNET_USE_HKDF=False)
    def test_no_hkdf(self):
        """Can set FERNET_USE_HKDF=False to avoid HKDF."""
        settings.FERNET_USE_HKDF = False
        k1 = Fernet.generate_key()
        settings.FERNET_KEYS = [k1]
        f = fields.EncryptedTextField()
        fernet = Fernet(k1)

        self.assertEqual(fernet.decrypt(f.fernet.encrypt(b'foo')), b'foo')

    def test_not_allowed(self):
        for key in ['primary_key', 'db_index', 'unique']:
            with self.assertRaises(ImproperlyConfigured):
                fields.EncryptedIntegerField(**{key: True})

    def test_get_integer_field_validators(self):
        f = fields.EncryptedIntegerField()

        # Raises no error
        f.validators


@override_settings(USE_TZ=False)
class TestEncryptedFieldQueries(TestCase):
    @override_settings(USE_TZ=False)
    def test_insert(self):
        """Data stored in DB is actually encrypted."""

        today = timezone.now()
        in_three_days = today + timedelta(days=3)

        for model, vals in [
            (EncryptedText, ['foo', 'bar']),
            (EncryptedChar, ['one', 'two']),
            (EncryptedEmail, ['a@example.com', 'b@example.com']),
            (EncryptedInt, [1, 2]),
            (EncryptedDate, [date(2015, 2, 5), date(2015, 2, 8)]),
            (EncryptedDateTime, [today, in_three_days]),
        ]:
            field = model._meta.get_field('value')
            model.objects.create(value=vals[0])
            with connection.cursor() as cur:
                cur.execute('SELECT value FROM %s' % model._meta.db_table)
                data = [force_str(field.fernet.decrypt(force_bytes(r[0]))) for r in cur.fetchall()]
            self.assertEqual(list(map(field.to_python, data)), [vals[0]])

    def test_insert_and_select(self):
        """Data round-trips through insert and select."""
        today = timezone.now()
        in_three_days = today + timedelta(days=3)

        for model, vals in [
            (EncryptedText, ['foo', 'bar']),
            (EncryptedChar, ['one', 'two']),
            (EncryptedEmail, ['a@example.com', 'b@example.com']),
            (EncryptedInt, [1, 2]),
            (EncryptedDate, [date(2015, 2, 5), date(2015, 2, 8)]),
            (EncryptedDateTime, [today, in_three_days]),
        ]:
            model.objects.create(value=vals[0])
            found = model.objects.get()
            self.assertEqual(found.value, vals[0])

    def test_update_and_select(self):
        """Data round-trips through update and select."""
        today = timezone.now()
        in_three_days = today + timedelta(days=3)

        for model, vals in [
            (EncryptedText, ['foo', 'bar']),
            (EncryptedChar, ['one', 'two']),
            (EncryptedEmail, ['a@example.com', 'b@example.com']),
            (EncryptedInt, [1, 2]),
            (EncryptedDate, [date(2015, 2, 5), date(2015, 2, 8)]),
            (EncryptedDateTime, [today, in_three_days]),
        ]:
            model.objects.create(value=vals[0])
            model.objects.update(value=vals[1])
            found = model.objects.get()
            self.assertEqual(found.value, vals[1])

    def test_lookups_raise_field_error(self):
        """Lookups are not allowed (they cannot succeed)."""
        today = timezone.now()
        in_three_days = today + timedelta(days=3)

        for model, vals in [
            (EncryptedText, ['foo', 'bar']),
            (EncryptedChar, ['one', 'two']),
            (EncryptedEmail, ['a@example.com', 'b@example.com']),
            (EncryptedInt, [1, 2]),
            (EncryptedDate, [date(2015, 2, 5), date(2015, 2, 8)]),
            (EncryptedDateTime, [today, in_three_days]),
        ]:
            model.objects.create(value=vals[0])
            field_name = model._meta.get_field('value').__class__.__name__
            lookups = set(dj_models.Field.class_lookups) - set(['isnull'])

            for lookup in lookups:
                with self.assertRaises(FieldError) as fe:
                    model.objects.get(**{'value__' + lookup: vals[0]})

                exception_msg = str(fe.exception)
                self.assertIn(field_name, exception_msg)
                self.assertIn(lookup, exception_msg)
                self.assertIn('does not support lookups', exception_msg)

    def test_nullable(self):
        EncryptedNullable.objects.create(value=None)
        found = EncryptedNullable.objects.get()
        self.assertIsNone(found.value)

    def test_isnull_false_lookup(self):
        """isnull False lookup succeeds on nullable fields"""
        test_val = 3
        EncryptedNullable.objects.create(value=None)
        EncryptedNullable.objects.create(value=test_val)
        found = EncryptedNullable.objects.get(value__isnull=False)

        self.assertEqual(found.value, test_val)

    def test_isnull_true_lookup(self):
        """isnull True lookup succeeds on nullable fields"""
        test_val = 3
        EncryptedNullable.objects.create(value=None)
        EncryptedNullable.objects.create(value=test_val)
        found = EncryptedNullable.objects.get(value__isnull=True)

        self.assertIsNone(found.value)
