from contextlib import closing
import logging
from psqlextra.manager import PostgresQuerySet
from uuid import UUID
from typing import Any, Optional, Sequence, TYPE_CHECKING, Union

from django.conf import settings
from django.core.exceptions import FieldError
from django.db import connection, transaction
from django.db.models import Model, Q, QuerySet
from django.db.utils import Error

from .typehints import CanonicalIDOrString
from .utils.models import lookup_field

if TYPE_CHECKING:
    from console_base.models import BaseUUIDPKModel

    QuerySetMixin = QuerySet
else:
    QuerySetMixin = object

logger = logging.getLogger(__name__)


# ---------------------------------------------------------------------------
def equal_or_null_qs(model: Model, field: str, values: Sequence | set) -> Q:
    """
    Get lookup field for a given model, and return a Q object.
    If the field is not required on the model, include NULL values.
    """
    field_name, required = lookup_field(model, field)
    values = tuple(values)

    if not field_name or len(values) == 1 and not values[0]:
        return Q()

    if not values and not required:
        # Example: user may have company relations but no policy relations
        # in which case, the company filter takes care of filtering by company,
        # and we only want policy relations that are null
        return Q((f'{field_name}__isnull', True))

    if len(values) == 1:
        field_param = (field_name, values[0])
    else:
        field_param = (f'{field_name}__in', values)

    if not required:
        return Q(field_param, (f'{field_name}__isnull', True), _connector=Q.OR)

    return Q(field_param)


# ---------------------------------------------------------------------------
def get_new_value(field: str, value: Any) -> Any:
    """
    Handle callable values and string/hex UUIDs
    """
    if callable(value):
        return value()

    if field not in ('cid', 'id') or not isinstance(value, str):
        return value

    try:
        return UUID(value)
    except (AttributeError, TypeError, ValueError):
        return value


# ---------------------------------------------------------------------------
def update_model_data(record: 'BaseUUIDPKModel', values: dict) -> tuple['BaseUUIDPKModel', bool]:
    """
    Update model data from the values dictionary, and recording
    whether model object changed or not.
    """
    for field, value in values.items():
        new_value = get_new_value(field, value)
        old_value = getattr(record, field)
        if old_value != new_value:
            setattr(record, field, new_value)

    return record, bool(record.get_dirty_fields())


# ---------------------------------------------------------------------------
class SearchQuerySet(QuerySetMixin):
    """
    Used for Select2 select field filters
    """

    def build_qs(self, qkey: str | list | tuple, queryvalue: str = '') -> Q:
        """
        :param qkey: string or list of field and lookup key, such as
            date__range or [date__range, created__range]
        :param queryvalue: Value to query, whether string, int, uuid
        :return:
        """
        if not queryvalue:
            return Q()

        if isinstance(qkey, str):
            return Q((qkey, queryvalue))

        field_params = [(qkey[0], queryvalue)]

        if len(qkey) == 1:
            return Q(*field_params)

        for key in qkey[1:]:
            field_params.append((key, queryvalue))

        return Q(*field_params, _connector=Q.OR)

    def search(self, **kwargs: Any) -> QuerySet:
        q = Q()

        for qkey, field in self.model().filter_fields():
            ff = self.build_qs(qkey=qkey, queryvalue=kwargs.get(field, ''))
            if ff:
                q &= ff

        return self.filter(q)


# ---------------------------------------------------------------------------
class LCBaseQuerySet(SearchQuerySet, PostgresQuerySet):
    """
    'browse' method to be available on all models.
    'scid' method will be superfluous on some, and
    will return PK if no 'cid' field is defined on model
    """

    def active(self):
        try:
            return self.filter(is_active=True)
        except FieldError:
            return self.none()

    def inactive(self):
        try:
            # use 'exclude' to catch NULL or False values
            return self.exclude(is_active=True)
        except FieldError:
            return self.none()

    def browse(
        self, pk: Union[CanonicalIDOrString, 'BaseUUIDPKModel', int]
    ) -> Optional['BaseUUIDPKModel']:
        if not pk:
            return None
        if isinstance(pk, self.model):
            return pk
        return self.filter(pk=pk).first()

    def scid(
        self,
        cid: Union[CanonicalIDOrString, 'BaseUUIDPKModel'],
    ) -> Optional['BaseUUIDPKModel']:
        if not cid:
            return None
        if isinstance(cid, self.model):
            return cid
        try:
            return self.filter(cid=cid).first()
        except FieldError:
            return self.browse(cid)

    def count(self) -> int:
        """
        Override to use fuzzy count when no where clause is present.
        """
        if (
            not self.query.where
            and getattr(self.model, 'CACHE_COUNT_QUERIES', False)
            and not getattr(self, 'use_builtin_count', False)
        ):
            return self.fuzzy_count()

        return super().count()

    def fuzzy_count(self) -> int:
        """
        For more efficient `count` query, lookup
        the `reltuples` column in the pg_class table.

        If no value is found, the vacuum job might be stale, so
        use the builtin `count` method to query the actual table.
        """
        db_table = self.model._meta.db_table
        try:
            with closing(connection.cursor()) as cursor:
                cursor.execute(
                    """
                    SELECT sum(reltuples) 
                    FROM pg_class 
                    WHERE relkind='r' 
                    AND (relname=%(table)s OR relname ~ %(parent)s);
                    """,
                    {
                        'table': db_table,
                        'parent': f'{db_table}_d+_',
                    },
                )
                approximate_count = int(cursor.fetchone()[0])
        except (IndexError, TypeError, Error):
            approximate_count = 0

        return approximate_count if approximate_count > 0 else self.builtin_count()

    def builtin_count(self) -> int:
        """
        Helper method to ensure that builtin count is called.
        """
        self.use_builtin_count = True
        return self.count()

    def get_by_natural_key(self, name: str) -> Optional['BaseUUIDPKModel']:
        return self.filter(name=name).first()

    def update_if_changed_or_create(
        self,
        defaults: dict | None = None,
        **kwargs: Any,
    ) -> tuple['BaseUUIDPKModel', bool]:
        """
        A slight modification of the QuerySet.update_or_create method,
        which always saves the model every time, whether values changed
        or not. This makes the 'modified' field worthless for the
        purpose of tracking sync status.

        If there are no changes to the model, we don't want to save it
        to the database, so avoiding an update to the 'modified' field.
        """
        sync_mode = kwargs.pop(settings.SYNC_OPERATION_MODE, None)
        defaults = defaults or {}
        self._for_write = True
        with transaction.atomic(using=self.db):
            obj, created = self.select_for_update().get_or_create(defaults, **kwargs)
            if created:
                return obj, created

            obj, changed = update_model_data(record=obj, values=defaults)

            if changed:
                if sync_mode:
                    obj.save(**{settings.SYNC_OPERATION_MODE: sync_mode, 'using': self.db})
                else:
                    obj.save(using=self.db)

        return obj, False  # Updated, rather than created object


# ---------------------------------------------------------------------------
class LTreeBaseQuerySet(QuerySetMixin):
    """
    Query methods for handling models that have ltree path
    """

    def roots(self):
        """
        Get all objects that are at the root of the tree.
        """
        return self.filter(path__depth=1)

    def children(self, path):
        """
        Get all objects with a root / parent defined.
        """
        return self.filter(path__descendants=path, path__depth=path.count('.') + 2)

    def ancestors(self, path):
        return self.filter(path__ancestors=path)

    def ancestors_children(self, path):
        """
        Get all parents and children of path
        """
        if not path:
            return self.none()

        if '.' not in path:
            return self.filter(path__subpath=path)

        return self.filter(Q(path__ancestors=path, path__match=f'{path}.*', _connector=Q.OR))

    def family(self, path):
        """
        Get root parent, preceding parent and all leaf siblings.
        Do not get intermediate siblings as they should be considered unrelated.
        """
        if '.' not in path:
            return self.filter(path__match=path)

        root_parent = path[: path.index('.')]
        first_parent = path[: path.rindex('.')]
        return self.filter(
            Q(('path__match', root_parent), ('path__match', first_parent), _connector=Q.OR)
            | Q(path__descendants=first_parent, path__depth=first_parent.count('.') + 2)
        )
