from celery import Task
from celery_singleton.exceptions import DuplicateTaskError
from celery_singleton.singleton import Singleton
from celery.result import AsyncResult
from console_base.utils import cache_key
from django.db import close_old_connections
import inspect
from lchttp.json import json_dumps
import logging
from .exceptions import TaskFailedExit

logger = logging.getLogger(__name__)


class CloseDbConnections:
    """
    Close database connections after task to try to avoid closed database connections.

    InterfaceError
    connection already closed
    """

    def after_return(self, status, retval, task_id, args, kwargs, einfo):
        close_old_connections()
        return super().after_return(status, retval, task_id, args, kwargs, einfo)


class LCTask(CloseDbConnections, Task):
    """
    Default class with logging of errors.
    """
    throws = (TaskFailedExit,)

    def on_failure(self, exc, task_id, args, kwargs, einfo):
        logger.exception(f'Task failed! ID: {task_id} / args: {args} / kwargs: {kwargs} - {exc}')


class LCTaskWithRetry(CloseDbConnections, Task):

    autoretry_for = (Exception, KeyError)
    retry_kwargs = {'max_retries': 3}
    retry_backoff = True
    retry_jitter = True


class LCSingleton(CloseDbConnections, Singleton):

    throws = (TaskFailedExit,)

    def generate_lock(self, task_name, task_args=None, task_kwargs=None):
        """
        Copy the parent method so we can use our own
        generate_lock method which handles UUIDs
        """
        unique_on = self.unique_on
        task_args = task_args or []
        task_kwargs = task_kwargs or {}

        if unique_on:
            if isinstance(unique_on, str):
                unique_on = [unique_on]
            sig = inspect.signature(self.run)
            bound = sig.bind(*task_args, **task_kwargs).arguments

            unique_args = []
            unique_kwargs = {key: bound[key] for key in unique_on}
        else:
            unique_args = task_args
            unique_kwargs = task_kwargs

        return generate_lock(
            task_name,
            unique_args,
            unique_kwargs,
            key_prefix=self.singleton_config.key_prefix,
        )

    def apply_async(
        self,
        args=None,
        kwargs=None,
        task_id=None,
        producer=None,
        link=None,
        link_error=None,
        shadow=None,
        **options
    ):
        """
        Enable clearing the lock of a pending task, cancelling the pending
        task itself, and then schedule a new Task to run immediately.
        """
        self.clear_pending_tasks = options.pop('clear_pending_tasks', False)
        if self.clear_pending_tasks:
            try:
                return super().apply_async(
                    args=args,
                    kwargs=kwargs,
                    producer=producer,
                    link=link,
                    link_error=link_error,
                    shadow=shadow,
                    **options,
                )
            except DuplicateTaskError as e:
                self.release_lock(task_args=args or [], task_kwargs=kwargs or {})
                AsyncResult(e.task_id).revoke()

        return super().apply_async(
            args=args,
            kwargs=kwargs,
            producer=producer,
            link=link,
            link_error=link_error,
            shadow=shadow,
            **options,
        )

    @property
    def _raise_on_duplicate(self):
        """
        When clearing tasks, we always need to raise on duplicates
        to ensure that future tasks are indeed cancelled and a new
        task created.
        """
        try:
            return super()._raise_on_duplicate or self.clear_pending_tasks
        except AttributeError:
            return False

    def on_failure(self, exc, task_id, args, kwargs, einfo):
        logger.exception(f'Task failed! ID: {task_id} / args: {args} / kwargs: {kwargs} - {exc}')


class LCSingletonWithRetry(LCSingleton):

    autoretry_for = (Exception, KeyError)
    retry_kwargs = {'max_retries': 3}
    retry_backoff = True
    retry_jitter = True


def generate_lock(task_name, task_args=None, task_kwargs=None, key_prefix="SINGLETONLOCK_"):
    byte_args = json_dumps(task_args or [], sort_keys=True)
    byte_kwargs = json_dumps(task_kwargs or {}, sort_keys=True)
    task_hash = cache_key(b'%s%s%s' % (task_name.encode('utf8'), byte_args, byte_kwargs))
    return f'{key_prefix}{task_hash}'


__all__ = (
    'LCTask',
    'LCTaskWithRetry',
    'LCSingleton',
    'LCSingletonWithRetry',
    'generate_lock',
)
