"""
IvyLink Test Cases. Run entire suite by:

python3 -m unittest

Run individual test class by:

python3 -m unittest ivylink.tests.test_<name>
"""
from pathlib import Path
import tempfile
from unittest import TestCase
from ivylink import settings
from ivylink.images import (
    BlockHashDatabase,
    BlockImageIndex,
    PDQHash,
    string_to_vector,
)
from ivylink.typehints import DihedralHash, SearchResults


class Setup:

    @classmethod
    def setUpClass(cls) -> None:
        tmp_dir = tempfile.gettempdir()
        super().setUpClass()
        settings.BLOCK_ANNOY_INDEX_FILE = f'{tmp_dir}/image_hashes.ann'
        settings.BLOCK_HASH_DATABASE_FILE = f'{tmp_dir}/image_hashes.yml'
        BlockHashDatabase.data_file = settings.BLOCK_HASH_DATABASE_FILE
        BlockImageIndex.data_file = settings.BLOCK_ANNOY_INDEX_FILE

    def tearDown(self) -> None:
        for test_file in (
                settings.BLOCK_ANNOY_INDEX_FILE,
                settings.BLOCK_HASH_DATABASE_FILE,
        ):
            Path(test_file).unlink(missing_ok=True)


class TestHashDatabase(Setup, TestCase):

    def test_emtpy_hash_database(self):

        hash_database = BlockHashDatabase()
        self.assertEqual(hash_database.count(), 0)
        self.assertEqual(hash_database.hashes(), {})

    def test_add_and_save_entries(self):
        hash_database = BlockHashDatabase()

        img_url = 'https://draw.bridge/static/images/favicon_default.ico'
        img_hash = 'HLEH4eHGthYcuVpb48alLTw5UvDDxatLalo8LTwvQ9I='

        hash_data = {
            'draw.bridge': {
                img_hash: img_url
            },
        }

        hash_database.add(location=img_url, img_hash=img_hash)
        self.assertDictEqual(hash_database.hashes(), hash_data)
        self.assertEqual(hash_database.count(), 1)

        img_url = 'https://upload.wikimedia.org/wikipedia/commons/thumb/f/f5/Aquascope_%2836186033336%29.jpg/507px-Aquascope_%2836186033336%29.jpg'
        img_hash = 'dEyPGrhwU+buTN7plZJrZJ2M9LgKyVV7raUCwdITP4M='
        hash_database.add(location=img_url, img_hash=img_hash)

        hash_data['upload.wikimedia.org'] = {img_hash: img_url}
        self.assertDictEqual(hash_database.hashes(), hash_data)
        self.assertEqual(hash_database.count(), 2)

        img_url = 'https://upload.wikimedia.org/wikipedia/commons/thumb/a/a4/Porto_de_Puerto_del_Carmen-3.jpg/507px-Porto_de_Puerto_del_Carmen-3.jpg'
        img_hash = '4EkfEMfgEH8+xuL4Bz69A/B5Q8b/kBy8IkXD89/6MQQ='
        hash_database.add(location=img_url, img_hash=img_hash)

        hash_data['upload.wikimedia.org'][img_hash] = img_url
        self.assertDictEqual(hash_database.hashes(), hash_data)
        self.assertEqual(hash_database.count(), 3)

        # no errors should be raised!
        hash_database.save()

        # clear caches so that file is re-read from disk
        hash_database.clear()
        self.assertEqual(hash_database._hashes, {})

        self.assertDictEqual(hash_database.hashes(), hash_data)

        # popping a hash should remove it from the database file
        hash_database.pop(img_hash)
        self.assertEqual(hash_database.count(), 2)

        hash_database.save()
        hash_database.clear()
        self.assertEqual(hash_database.count(), 2)


class TestImageIndex(Setup, TestCase):

    def load_hash_database(self):
        hash_database = BlockHashDatabase()

        # the hashes in this list must remain in this order!
        for location, img_hash in (
            ('https://upload.wikimedia.org/wikipedia/commons/thumb/f/f5/Aquascope_%2836186033336%29.jpg/507px-Aquascope_%2836186033336%29.jpg',
             'dEyPGrhwU+buTN7plZJrZJ2M9LgKyVV7raUCwdITP4M='),
            ('https://upload.wikimedia.org/wikipedia/commons/thumb/a/a4/Porto_de_Puerto_del_Carmen-3.jpg/507px-Porto_de_Puerto_del_Carmen-3.jpg',
             '4EkfEMfgEH8+xuL4Bz69A/B5Q8b/kBy8IkXD89/6MQQ='),
            ('https://draw.bridge/static/images/favicon_default.ico',
             'HLEH4eHGthYcuVpb48alLTw5UvDDxatLalo8LTwvQ9I='),
        ):
            hash_database.add(location, img_hash)

        hash_database.save()

        return hash_database

    def test_emtpy_hash_index(self):
        image_index = BlockImageIndex()
        self.assertFalse(image_index.is_stale())
        self.assertEqual(image_index.count(), 0)

    def test_load_image_index(self):
        hash_database = self.load_hash_database()
        image_index = BlockImageIndex(hash_database)

        image_index.refresh_from_database()
        self.assertEqual(image_index.count(), 3)

        exact_match_vector = string_to_vector(
            hash_string='HLEH4eHGthYcuVpb48alLTw5UvDDxatLalo8LTwvQ9I=',
            dtype=PDQHash.dtype,
            hash_length=PDQHash.hash_length,
            hash_format=PDQHash.hash_format,
        )

        partial_match_vector = string_to_vector(
            hash_string='cE+qGc+QQG08R/b4h7yPg/hZC2ZDkry8NAXD88/6MQQ=',
            dtype=PDQHash.dtype,
            hash_length=PDQHash.hash_length,
            hash_format=PDQHash.hash_format,
        )

        results = image_index.search(DihedralHash(*[exact_match_vector for _ in range(8)]))
        self.assertListEqual(results, [SearchResults(2, 0.0)])

        # Partial match should not be returned when distance is too narrow.
        partial_match = DihedralHash(*[partial_match_vector for _ in range(8)])
        results = image_index.search(partial_match)
        self.assertListEqual(results, [], msg='Partial match should be further than 30')

        results = image_index.search(partial_match, distance=100)
        self.assertListEqual(results, [SearchResults(1, 50.0)])
