import logging
from typing import TYPE_CHECKING

from django.core.exceptions import FieldError
from django.db.models import QuerySet
from drf_spectacular.utils import extend_schema
from rest_framework.decorators import action
from rest_framework.request import Request
from rest_framework.response import Response

logger = logging.getLogger(__name__)
if TYPE_CHECKING:
    from rest_framework.generics import GenericAPIView

    TenantQueryMixinBase = GenericAPIView
else:
    TenantQueryMixinBase = object


# ---------------------------------------------------------------------------
class TenantQueryMixin(TenantQueryMixinBase):
    """
    User is always required for all API queries to check for Company
    relationships to the records being retrieved.
    """

    def get_queryset(self) -> QuerySet:
        """
        All queries must provide user value to tenant.

        Alter queryset for is_active queries rather than filters so
        record count value is relevant total of Active / Inactive records
        """

        try:
            qs = super().get_queryset().tenant(user=self.request.user)  # type: ignore[attr-defined]
        except AttributeError:
            qs = super().get_queryset()
            logger.info('%s is not a tenanted model; are you sure this is correct!?!', qs.model)

        return qs

    @extend_schema(exclude=True)
    @action(detail=False)
    def count(self, request: Request) -> Response:
        """
        Return count of all records matching queryset
        """
        queryset = self.filter_queryset(self.get_queryset())
        return Response({'count': queryset.count()})

    @extend_schema(exclude=True)
    @action(detail=False)
    def canonical_ids(self, request: Request) -> Response:
        """
        Return list of Canonical IDs of all records matching queryset.
        Used for Confirming that Publisher and Subscriber server are synced.
        """
        queryset = self.filter_queryset(self.get_queryset())
        try:
            results = list(queryset.values_list('cid', flat=True))
        except FieldError:
            logger.info('%s has no "cid" field', queryset.model)
            results = []
        return Response({'count': len(results), 'results': results})


__all__ = [
    'TenantQueryMixin',
]
