import arrow
from datetime import datetime
import logging
from typing import TYPE_CHECKING

from django.db import connection
from django.db.models import Model
from django.utils.timezone import make_naive
from psqlextra.partitioning.time_partition_size import PostgresTimePartitionUnit
from .typehints import Partition

if TYPE_CHECKING:
    from psqlextra.backend.schema import PostgresSchemaEditor

logger = logging.getLogger(__name__)
DATE_FMT = '%Y%m%d'


def partition_list(model: type[Model]) -> list[Partition]:
    """
    Return the list of partitions of this model.

    partitions = [
        'reporter_summarywebviews_240801_240831',
        'reporter_summarywebviews_240901_240930',
        'reporter_summarywebviews_241001_241031',
        'reporter_summarywebviews_241101_241130',
    ]
    """
    parent_table_name = model._meta.db_table

    sql = """
    SELECT inhrelid::regclass AS child
    FROM   pg_catalog.pg_inherits
    WHERE  inhparent = %s::regclass;
    """

    partition_tables: list[Partition] = []
    with connection.cursor() as cursor:
        cursor.execute(sql, [f'public.{parent_table_name}'])

        for partition_table in cursor.fetchall():
            partition_name = partition_table[0]
            try:
                start, end = partition_name.removeprefix(f'{parent_table_name}_').split('_')
                start_date = datetime.strptime(start, DATE_FMT)
                end_date = datetime.strptime(end, DATE_FMT)
                partition_tables.append(Partition(partition_name, start_date, end_date))
            except (IndexError, ValueError):
                partition_tables.append(Partition(partition_name, None, None))

        try:
            partition_tables.sort(key=lambda pt: pt.start_date if pt.start_date else pt.name)
        except TypeError:
            pass

    return partition_tables


def create_range_partitions(
    model: type[Model],
    interval: PostgresTimePartitionUnit,
    date_start: datetime | None = None,
    future_partition_count: int = 1,
) -> None:
    """
    Create Range Partition for this model.

    :param model: The Django model for which to create partitions.
    :param interval: The range type (weeks,months,years).
    :param date_start: Date that the first partition should include.
    :param future_partition_count: Number of partitions to make beyond the current date.
    """
    time_interval = str(interval.value)
    now = arrow.now()
    start = (arrow.get(date_start) if date_start else now).floor(time_interval)  # type: ignore[arg-type]
    final_partition_date = now.shift(**{time_interval: future_partition_count})

    while start <= final_partition_date:
        end = start.ceil(time_interval)  # type: ignore[arg-type]
        start_value = start.strftime(DATE_FMT)

        # Shift by extra day because partition constraint for upper bound is "<" rather than "<="
        # Partition constraint: ((date IS NOT NULL) AND (date >= '240813'::text) AND (date < '240821'::text))
        range_end = end.shift(days=1)
        schema_editor: 'PostgresSchemaEditor' = connection.schema_editor()  # type: ignore[assignment]
        try:
            schema_editor.add_range_partition(
                model=model,
                name=f'{start_value}_{end.strftime(DATE_FMT)}',
                from_values=start_value,
                to_values=range_end.strftime(DATE_FMT),
            )
        except Exception as e:
            name = model._meta.model_name
            logger.error('Partitioning %s %s-%s failed with error: %s', name, start, end, e)

        start = start.shift(**{time_interval: 1})


def drop_partitions(model: type[Model], retain_months: int = 2) -> list[Partition]:
    """
    Drop partitions of the specified model.

    :param model: The Django model for which to drop partitions.
    :param retain_months: The number of historical full months for which to retain partitions.
    """
    dropped = []
    partitions = partition_list(model)

    if retain_months:
        retain_date = make_naive(
            arrow.now().shift(months=-abs(retain_months)).floor('month').datetime
        )
        scrub_partitions = []
        for partition in partitions:
            if partition.start_date and partition.start_date < retain_date:
                scrub_partitions.append(partition)
    else:
        scrub_partitions = partitions

    for partition_table in scrub_partitions:
        if delete_partition(model, partition_table.name):
            dropped.append(partition_table)

    return dropped


def delete_partition(model: type[Model], name: str) -> bool:
    """
    Delete the specified partition from the model.
    """
    try:
        schema_editor: 'PostgresSchemaEditor' = connection.schema_editor()  # type: ignore[assignment]
        partition_name = name.removeprefix(f'{model._meta.db_table}_')
        schema_editor.delete_partition(model=model, name=partition_name)
        return True
    except Exception as e:
        logger.error('Partition %s deletion failed. %s', name, e)

    return False
