from django.db.models import QuerySet
from console_base.utils import (
    query_cache_key,
    get_cache_query_count,
    set_cache_query_count,
)
from rest_framework_datatables.pagination import DatatablesLimitOffsetPagination


# ---------------------------------------------------------------------------
def CachedCountQueryset(queryset: QuerySet) -> QuerySet:
    """
    Return copy of queryset with queryset.count() wrapped to cache result
    """
    queryset = queryset._chain()  # type: ignore[attr-defined]
    real_count = queryset.count

    def count(qs: QuerySet) -> int:
        key = query_cache_key(str(qs.query))

        cached = get_cache_query_count(key)
        if cached:
            return cached

        value = real_count()
        if key:
            set_cache_query_count(key, value)
        return value

    queryset.count = count.__get__(queryset, type(queryset))  # type: ignore[method-assign]
    return queryset


# ---------------------------------------------------------------------------
class CachedCountLimitOffsetPagination(DatatablesLimitOffsetPagination):
    """
    Cache the count query to improve pagination performance, and modify
    pagination to always paginate even if no length param was defined.
    """

    def paginate_queryset(self, queryset, request, view=None):
        """
        From these sources:
            rest_framework_datatables.DatatablesLimitOffsetPagination
            rest_framework.LimitOffsetPagination

        Simplified to reduce duplicate property assignment that exists in the
        2 superclasses above, and modified to avoid expensive count queries
        when property is set on the view, and to cache the query count whenever
        possible the rest of the time.
        """
        if hasattr(queryset, 'count'):
            queryset = CachedCountQueryset(queryset)

        if request.accepted_renderer.format == 'datatables':
            self.is_datatable_request = True
            self.limit_query_param = 'length'
            self.offset_query_param = 'start'

            # Comment this out, so that requests with no length don't return entire table!!
            # if get_param(request, self.limit_query_param) is None:
            #     return None
            self.count, self.total_count = self.get_count_and_total_count(queryset, view)
        else:
            self.is_datatable_request = False
            self.count = self.get_count(queryset)

        self.limit = self.get_limit(request)
        if self.limit is None:
            return None

        self.offset = self.get_offset(request)
        self.request = request

        if self.count > self.limit and self.template is not None:
            self.display_page_controls = True

        if self.count == 0 or self.offset > self.count:
            return []

        return list(queryset[self.offset : self.offset + self.limit])
