# Taken from https://github.com/thorn-oss/perception/blob/main/perception/hashers/tools.py
# See LICENSE.txt file
import os
import io
import math
import pybase64
from lcrequests import Request
from PIL import Image
from typing import (
    Optional,
    Tuple,
    Union,
)

import numpy as np
import validators

ImageInputType = Union[str, np.ndarray, "Image.Image", io.BytesIO]

SIZES = {"float32": 32, "uint8": 8, "bool": 1}


def get_string_length(hash_length: int, dtype: str, hash_format="hex") -> int:
    """Compute the expected length of a hash string.

    Args:
        hash_length: The length of the hash vector
        dtype: The dtype of the vector
        hash_format: One of 'base64' or 'hex'

    Returns:
        The expected string length
    """
    hash_bytes = math.ceil(hash_length * SIZES[dtype] / 8)

    if hash_format == "base64":
        return int((4 * hash_bytes / 3) + 3) & ~3
    if hash_format == "hex":
        return 2 * hash_bytes
    raise NotImplementedError("Unknown hash format: " + hash_format)


def vector_to_string(vector: np.ndarray, dtype: str, hash_format: str) -> str:
    """Convert vector to hash.

    Args:
        vector: Input vector
    """
    # At times, a vector returned by a hasher is None (e.g., for hashes
    # that depend on the image not being featureless). In those cases,
    # we need to just return None, which is the least surprising outcome
    # because after all, the string representation of None is None.
    if vector is None:
        return ''
    if dtype == "uint8":
        vector_bytes = vector.astype("uint8")
    elif dtype == "float32":
        vector_bytes = vector.astype("float32")
    elif dtype == "bool":
        vector_bytes = np.packbits(vector.astype("bool"))
    else:
        raise NotImplementedError(f"Cannot convert hash of type {dtype}.")

    if hash_format == "base64":
        return pybase64.b64encode(vector_bytes.tobytes()).decode("utf-8")
    if hash_format == "hex":
        return vector_bytes.tobytes().hex()
    raise NotImplementedError(f"Cannot convert to string format: {hash_format}.")


def string_to_vector(
    hash_string: str,
    dtype: str,
    hash_length: int,
    hash_format: str,
    verify_length: bool = True,
) -> np.ndarray:
    """Convert hash back to vector.

    Args:
        hash_string: The input hash string
        dtype: The data type of the hash
        hash_length: The length of the hash vector
        hash_format: The input format of the hash (base64 or hex)
        verify_length: Whether to verify the string length
    """
    assert not verify_length or len(hash_string) == get_string_length(
        hash_length=hash_length, hash_format=hash_format, dtype=dtype
    ), "Incorrect string length for this hash format."
    if hash_format == "base64":
        vector_bytes = np.frombuffer(
            pybase64.b64decode(hash_string),
            dtype="uint8" if dtype in ["bool", "uint8"] else dtype,
        )
    elif hash_format == "hex":
        vector_bytes = np.frombuffer(
            bytearray.fromhex(hash_string),
            dtype="uint8" if dtype in ["bool", "uint8"] else dtype,
        )
    else:
        raise NotImplementedError(f"Cannot convert to string format: {hash_format}")
    if dtype == "uint8":
        return vector_bytes[:hash_length]
    if dtype == "float32":
        return vector_bytes[:hash_length]
    if dtype == "bool":
        return np.unpackbits(vector_bytes)[:hash_length].astype("bool")
    raise NotImplementedError(f"Cannot convert hash of type {dtype}.")


def hex_to_b64(hash_string: str, dtype: str, hash_length: int, verify_length: bool = True):
    """Convert a hex-encoded hash to base64.

    Args:
        hash_string: The input base64 hash string
        dtype: The data type of the hash
        hash_length: The length of the hash vector
        verify_length: Whether to verify the string length
    """
    return vector_to_string(
        string_to_vector(
            hash_string,
            hash_length=hash_length,
            hash_format="hex",
            dtype=dtype,
            verify_length=verify_length,
        ),
        dtype=dtype,
        hash_format="base64",
    )


def b64_to_hex(hash_string: str, dtype: str, hash_length: int, verify_length: bool = True):
    """Convert a base64-encoded hash to hex.

    Args:
        hash_string: The input hex hash string
        dtype: The data type of the hash
        hash_length: The length of the hash vector
        verify_length: Whether to verify the string length
    """
    return vector_to_string(
        string_to_vector(
            hash_string,
            hash_length=hash_length,
            hash_format="base64",
            dtype=dtype,
            verify_length=verify_length,
        ),
        dtype=dtype,
        hash_format="hex",
    )


def to_image_array(image: ImageInputType, require_color=True):
    if isinstance(image, np.ndarray):
        assert image.flags["C_CONTIGUOUS"], (
            "Provided arrays must be contiguous to avoid "
            "erroneous results when arrays are passed to "
            "underlying libraries. This can be achieved using"
            "np.ascontiguousarray(image)"
        )
        assert not require_color or (
            len(image.shape) == 3 and image.shape[-1] == 3
        ), "Provided images must be RGB images."
        return image
    return read(image)


def get_isometric_transforms(image: ImageInputType, require_color=True):
    image = to_image_array(image, require_color=require_color)
    return dict(
        r0=image,
        fv=np.ascontiguousarray(image[::-1, :]),
        fh=np.ascontiguousarray(image[:, ::-1]),
        r180=np.ascontiguousarray(image[::-1, ::-1]),
        r90=np.ascontiguousarray(image.transpose(1, 0, 2)[::-1, :, :]),
        r90fv=np.ascontiguousarray(image.transpose(1, 0, 2)),
        r90fh=np.ascontiguousarray(image.transpose(1, 0, 2)[::-1, ::-1]),
        r270=np.ascontiguousarray(image.transpose(1, 0, 2)[:, ::-1]),
    )


def get_isometric_dct_transforms(dct: np.ndarray):
    # pylint: disable=invalid-name
    T1 = np.empty_like(dct)
    T1[::2] = 1
    T1[1::2] = -1

    # pylint: disable=invalid-name
    T2 = np.empty_like(dct)
    T2[::2, ::2] = 1
    T2[1::2, 1::2] = 1
    T2[::2, 1::2] = -1
    T2[1::2, ::2] = -1
    return dict(
        r0=dct,
        fv=dct * T1,
        fh=dct * T1.T,
        r180=dct * T2,
        r90=dct.T * T1,
        r90fv=dct.T,
        r90fh=dct.T * T2,
        r270=dct.T * T1.T,
    )


def read(filepath_or_buffer: ImageInputType, timeout=None):
    """Read a file into an image object

    Args:
        filepath_or_buffer: The path to the file or any object
            with a `read` method (such as `io.BytesIO`)
        timeout: If filepath_or_buffer is a URL, the timeout to
            use for making the HTTP request.
    """
    if isinstance(filepath_or_buffer, Image.Image):
        filepath_or_buffer.thumbnail((128, 128), resample=Image.BICUBIC)
        return np.array(filepath_or_buffer.convert("RGB"))

    if isinstance(filepath_or_buffer, io.BytesIO):
        return read(Image.open(filepath_or_buffer))

    if isinstance(filepath_or_buffer, str):
        if validators.url(filepath_or_buffer):

            try:
                req = Request(filepath_or_buffer, timeout=timeout)
                resp = req.get()
                content_type = resp.headers.get('Content-Type') or ''
                if not content_type.startswith('image/'):
                    raise SystemExit(f'Not a supported image content type: {content_type!r}')
                return read(Image.open(io.BytesIO(resp.content)))
            except Exception:
                raise SystemExit(f'Unable to retrieve image - please confirm URL.')

        if not os.path.isfile(filepath_or_buffer):
            raise SystemExit(f"Could not find image at {filepath_or_buffer!r}")

        return read(Image.open(filepath_or_buffer))

    raise SystemExit(f"Unhandled filepath_or_buffer type: {type(filepath_or_buffer)}")


def unletterbox(image) -> Optional[Tuple[Tuple[int, int], Tuple[int, int]]]:
    """Return bounds of non-trivial region of image or None.

    Unletterboxing is cropping an image such that trivial edge regions
    are removed. Trivial in this context means that the majority of
    the values in that row or column are zero or very close to
    zero. This is why we don't use the terms "non-blank" or
    "non-empty."

    In order to do unletterboxing, this function returns bounds in the
    form (x1, x2), (y1, y2) where:

    - x1 is the index of the first column where over 10% of the pixels
      have means (average of R, G, B) > 2.
    - x2 is the index of the last column where over 10% of the pixels
      have means > 2.
    - y1 is the index of the first row where over 10% of the pixels
      have means > 2.
    - y2 is the index of the last row where over 10% of the pixels
      have means > 2.

    If there are zero columns or zero rows where over 10% of the
    pixels have means > 2, this function returns `None`.

    Note that in the case(s) of a single column and/or row of
    non-trivial pixels that it is possible for x1 = x2 and/or y1 = y2.

    Consider these examples to understand edge cases.  Given two
    images, `L` (entire left and bottom edges are 1, all other pixels
    0) and `U` (left, bottom and right edges 1, all other pixels 0),
    `unletterbox(L)` would return the bounds of the single bottom-left
    pixel and `unletterbox(U)` would return the bounds of the entire
    bottom row.

    Consider `U1` which is the same as `U` but with the bottom two
    rows all 1s. `unletterbox(U1)` returns the bounds of the bottom
    two rows.

    Args:
        image: The image from which to remove letterboxing.

    Returns:
        A pair of coordinates bounds of the form (x1, x2)
        and (y1, y2) representing the left, right, top, and
        bottom bounds.

    """
    # adj should be thought of as a boolean at each pixel indicating
    # whether or not that pixel is non-trivial (True) or not (False).
    adj = image.mean(axis=2) > 2

    if adj.all():
        return (0, image.shape[1] + 1), (0, image.shape[0] + 1)

    y = np.where(adj.sum(axis=1) > 0.1 * image.shape[1])[0]
    x = np.where(adj.sum(axis=0) > 0.1 * image.shape[0])[0]

    if len(y) == 0 or len(x) == 0:
        return None

    if len(y) == 1:
        y1 = y2 = y[0]
    else:
        y1, y2 = y[[0, -1]]
    if len(x) == 1:
        x1 = x2 = x[0]
    else:
        x1, x2 = x[[0, -1]]
    bounds = (x1, x2 + 1), (y1, y2 + 1)

    return bounds
