from annoy import AnnoyIndex
from lchttp import yaml_safe_dump, yaml_safe_load
from pathlib import Path
from urllib.parse import urlparse
import validators
from typing import List, Optional, Union

from ivylink.images import PDQHash, string_to_vector
from ivylink.typehints import DihedralHash, FastDihedralHash, SearchResults
from ivylink import settings


class DateMixin:
    data_file: str = ''

    def exists(self) -> bool:
        return Path(self.data_file).is_file()

    def modified(self) -> float:
        """
        Return the timestamp when file was modified,
        to be used in comparing files for staleness.
        """
        try:
            return Path(self.data_file).stat().st_ctime
        except FileNotFoundError:
            return 0.0


class ImageIndex(DateMixin):
    """
    Class for nearest neighbor lookups of Image Hashes.
    """

    def __init__(
        self,
        hash_database: Optional['HashDatabase'] = None,
        n_trees: int = 32,
        n_jobs: int = -1,
    ):
        self._annoy_idx = None
        self.hash_database = hash_database or HashDatabase()
        self.n_trees = n_trees
        self.n_jobs = n_jobs

    def search(
        self,
        vectors: Union[DihedralHash, FastDihedralHash],
        items: int = 1,
        distance: int = 30,
        fast: bool = True,
    ) -> List[SearchResults]:
        """
        Check if any of vector orientations in the dihedral hash are found in the index.

        If fast=True, then check r0 and fh orientations only, and return on first match.
        """
        results = []
        if fast:
            vectors = FastDihedralHash(vectors.r0, vectors.fh)

        for vector in vectors:
            indexes, distances = self.annoy_index().get_nns_by_vector(
                vector,
                items,
                include_distances=True,
            )

            for i, hash_distance in enumerate(distances):
                if hash_distance <= distance:
                    results.append(SearchResults(indexes[i], hash_distance))

            if results and fast:
                break

        return results

    def annoy_index(self, prefault=False) -> AnnoyIndex:
        """
        Get AnnoyIndex object for nearest neighbor lookups.
        """
        if not self._annoy_idx:
            self._annoy_idx = AnnoyIndex(PDQHash.hash_length, 'hamming')
            assert self._annoy_idx is not None

            # Load the index into a mmap file. If the index
            # doesn't exist, refresh from the database.
            if not self.exists() or self.is_stale():
                self.refresh_from_database()
            else:
                self._annoy_idx.load(self.data_file, prefault=prefault)

        return self._annoy_idx

    def is_stale(self) -> bool:
        """
        Index is stale when the Index List file is newer than
        the AnnoyIndex file.
        """
        if not self.hash_database.exists():
            return False

        return self.hash_database.modified() > self.modified()

    def count(self) -> int:
        return self.annoy_index().get_n_items()

    def refresh_from_database(self, prefault=False) -> AnnoyIndex:
        """
        Load all hashes from the database into the index.
        """
        all_entries = []
        for _, hashes in self.hash_database.hashes().items():
            all_entries.extend(hashes.keys())

        hash_index = self.annoy_index()
        hash_index.unload()
        hash_index.unbuild()

        for i, hash_string in enumerate(all_entries):
            vector = string_to_vector(hash_string, 'bool', PDQHash.hash_length, 'base64')
            hash_index.add_item(i, vector)

        hash_index.build(n_trees=self.n_trees, n_jobs=self.n_jobs)
        hash_index.save(self.data_file, prefault=prefault)

        return hash_index

    def __str__(self) -> str:
        return f'{self.__class__.__name__} containing {self.count()} vectors'

    def __repr__(self) -> str:
        return self.__str__()


class HashDatabase(DateMixin):
    """
    Manage the Image Hash file containing all the image hashes, grouped by the
    site they came from, and the original URL / file name.
    """
    data_file = ''

    def __init__(self):
        self._hashes = {}
        self._entry_count = 0

    def hashes(self) -> dict:
        """
        Return dictionary of all the hashes in the database file.
        """
        try:
            if not self._hashes:
                with open(self.data_file) as hdb:
                    saved_hashes = yaml_safe_load(hdb.read())
                    if saved_hashes:
                        self._hashes = saved_hashes
                        self.count()
            return self._hashes
        except FileNotFoundError:
            return {}

    def add(self, location: str, img_hash: str) -> None:
        """
        Add a hash to the database.

        location: File path or URL source of the image
        img_hash: base64-encoded string representation of the hash.
        """
        if validators.url(location):
            source = urlparse(location).hostname
        elif Path(location).is_file():
            source = 'local-file'
        else:
            raise SystemExit('Invalid image source!')

        hashes = self.hashes()

        try:
            hashes[source][img_hash] = location
        except KeyError:
            hashes[source] = {img_hash: location}

        self._hashes = hashes
        self._entry_count += 1

    def pop(self, img_hash: str) -> None:
        """
        Remove the specified image hash from all locations in the database file.
        """
        for _, hashes in self.hashes().items():
            try:
                hashes.pop(img_hash)
                if self._entry_count > 0:
                    self._entry_count -= 1
            except KeyError:
                continue

    def save(self) -> None:
        """
        Save entries to database file.
        """
        if not self._hashes:
            return print('No hashes to save!')
        with open(self.data_file, 'w') as hdf:
            yaml_safe_dump(self.hashes(), hdf)

    def count(self) -> int:
        """
        Total number of hashes in the database file.
        """
        if self._entry_count <= 0:
            for _, hashes in self.hashes().items():
                self._entry_count += len(hashes)
        return self._entry_count

    def clear(self) -> None:
        """
        Clear cached values so next entries are reloaded.
        """
        self._hashes = {}
        self._entry_count = 0

    def __str__(self) -> str:
        return f'{self.__class__.__name__} containing {self.count()} image hashes'

    def __repr__(self) -> str:
        return self.__str__()


class AllowHashDatabase(HashDatabase):
    """
    Database to track hashes images that should be Allowed.
    This will be primarily used to correct images that are
    incorrectly blocked from AI scanning.
    """
    data_file = settings.ALLOW_HASH_DATABASE_FILE


class AllowImageIndex(ImageIndex):
    """
    Check this index before performing AI scanning of images.
    """
    data_file = settings.ALLOW_ANNOY_INDEX_FILE

    def __init__(self, hash_database=None, n_trees=32, n_jobs=-1):
        super().__init__(
            hash_database=hash_database or AllowHashDatabase(),
            n_trees=n_trees,
            n_jobs=n_jobs,
        )


class BlockHashDatabase(HashDatabase):
    """
    Database to track hashes images that should be Blocked.
    """
    data_file = settings.BLOCK_HASH_DATABASE_FILE


class BlockImageIndex(ImageIndex):
    """
    Check this index for all Image requests to see if the image hash
    matches one that's been specifically flagged to block.

    Frequently accessed images should be hashed as it's a much
    more efficient operation any AI content scanning.
    """
    data_file = settings.BLOCK_ANNOY_INDEX_FILE

    def __init__(self, hash_database=None, n_trees=32, n_jobs=-1):
        super().__init__(
            hash_database=hash_database or BlockHashDatabase(),
            n_trees=n_trees,
            n_jobs=n_jobs,
        )
