from lchttp.json import json_dumps, json_loads
import msgspec
from django.db.backends.postgresql.psycopg_any import is_psycopg3
from netfields.fields import InetAddressField

if is_psycopg3:
    from netfields.psycopg3_types import Inet
else:
    from psycopg2.extras import Inet
    from psycopg2.extensions import adapt

from sequential_uuids.generators import uuid_time_nextval
from typing import Any, TYPE_CHECKING

from citext import CITextField
from django.forms.fields import JSONField as JSONFormField
from django.db.models.lookups import Transform
from django.db.models.fields.mixins import CheckFieldDefaultMixin
from django.db.models.fields.json import (
    KeyTransformFactory,
    DataContains,
    ContainedBy,
    HasKey,
    HasKeys,
    HasAnyKeys,
    JSONExact,
    JSONIContains,
)
from django.db.models import (
    Field,
    CharField,
    DateTimeField,
    EmailField,
    Max,
    IntegerField,
    JSONField,
    UUIDField,
    TextField,
)
from django.core import exceptions
from django.core.serializers.json import DjangoJSONEncoder
from django.core.validators import MaxValueValidator, MinValueValidator
from django.utils.translation import gettext_lazy as tr
from console_base.forms.fields import (
    IP4RangeAddressFormField,
    RemoteURLField as RemoteURLFormField,
)
from console_base.validators import (
    NoControlCharactersValidator,
    ProhibitNullCharactersValidator,
    RemoteURLValidator,
    validate_legal_punctuation,
    validate_no_spaces,
    validate_private_lan_ip,
)

if TYPE_CHECKING:
    LowerCaseBaseMixin = Field
else:
    LowerCaseBaseMixin = object


class LCJSONField(JSONField):
    """
    Use `lchttp.json_dumps` library for parsing JSON
    """

    def __init__(  # type: ignore[no-untyped-def]
        self,
        verbose_name: str | None = None,
        name: str | None = None,
        encoder: type[DjangoJSONEncoder] = DjangoJSONEncoder,
        decoder=None,
        **kwargs: Any,
    ) -> None:
        super().__init__(
            verbose_name=verbose_name,
            name=name,
            encoder=encoder,
            decoder=decoder,
            **kwargs,
        )

    def from_db_value(self, value, expression, connection):
        if value is None:
            return value
        try:
            loaded_data = json_loads(value)
            if isinstance(loaded_data, str):
                # if double-string encoded, perform double-loading
                # TODO - remove after Archives & Autofixes from 01/06/2024-01/20/2024 are obsolete
                return self.from_db_value(loaded_data, expression, connection)  # type: ignore[no-untyped-call]
            return loaded_data
        except Exception:
            pass
        return super().from_db_value(value, expression, connection)

    def validate(self, value, model_instance):
        super().validate(value, model_instance)
        try:
            json_dumps(value)
        except Exception:
            raise exceptions.ValidationError(
                self.error_messages['invalid'],
                code='invalid',
                params={'value': value},
            ) from None


class MsgspecStructField(CheckFieldDefaultMixin, Field):
    """
    Encode msgspec.Struct field as JSON and decode back into Struct.
    """

    empty_strings_allowed = False
    description = tr("A msgspec Struct object")
    default_error_messages = {
        "invalid": tr("Value must be valid msgspec Struct object."),
    }
    _default_hint = ("dict", "{}")

    def __init__(
        self,
        struct_type: type[msgspec.Struct],
        verbose_name: str | None = None,
        name: str | None = None,
        **kwargs: Any,
    ) -> None:
        if not issubclass(struct_type, msgspec.Struct):
            raise ValueError("The struct parameter must be a subclass of msgspec.Struct.")
        self.struct_type = struct_type

        super().__init__(verbose_name, name, **kwargs)

    def deconstruct(self) -> tuple[str, str, Any, Any]:
        name, path, args, kwargs = super().deconstruct()
        kwargs['struct_type'] = self.struct_type
        return name, path, args, kwargs

    def from_db_value(
        self,
        value: Any,
        expression: Any,
        connection: Any,
    ) -> msgspec.Struct | dict | None:
        if value is None or value == {}:
            return None
        try:
            return msgspec.json.decode(value, type=self.struct_type, strict=True)
        except msgspec.ValidationError:
            try:
                return msgspec.json.decode(value, type=self.struct_type, strict=False)
            except Exception:
                pass
        except (TypeError, msgspec.MsgspecError):
            pass

        return json_loads(value)

    def get_internal_type(self) -> str:
        return 'JSONField'

    def get_prep_value(self, value: Any) -> str | None:
        if value is None or isinstance(value, str):
            return value
        return msgspec.json.encode(value).decode()

    def get_transform(self, name: str) -> type[Transform] | KeyTransformFactory:  # type: ignore[override]
        transform = super().get_transform(name)
        if transform:
            return transform
        return KeyTransformFactory(name)

    def validate(self, value, model_instance):
        super().validate(value, model_instance)
        try:
            msgspec.json.encode(value)
        except (TypeError, msgspec.DecodeError):
            raise exceptions.ValidationError(
                self.error_messages['invalid'],
                code='invalid',
                params={'value': value},
            ) from None

    def value_to_string(self, obj):
        return self.value_from_object(obj)

    def formfield(self, **kwargs):  # type: ignore[override]
        return super().formfield(**{
            'form_class': JSONFormField,
            'struct_type': self.struct_type,
            **kwargs,
        })


MsgspecStructField.register_lookup(DataContains)
MsgspecStructField.register_lookup(ContainedBy)
MsgspecStructField.register_lookup(HasKey)
MsgspecStructField.register_lookup(HasKeys)
MsgspecStructField.register_lookup(HasAnyKeys)
MsgspecStructField.register_lookup(JSONExact)
MsgspecStructField.register_lookup(JSONIContains)


class LowerCaseBase(LowerCaseBaseMixin):
    def get_prep_value(self, value: str) -> str:
        """
        Convert text to lowercase for values like emails and usernames.
        """
        value = super().get_prep_value(value)
        # handle None values gracefully
        try:
            return value.lower()
        except AttributeError:
            return value


class LowerCaseTextField(LowerCaseBase, TextField):
    pass


class LCEmailField(LowerCaseBase, EmailField):
    """
    Ensure that email addresses are always saved in lower case.
    """

    pass


LCDateTimeField = DateTimeField


class LCTextField(TextField):
    """
    Add validators at Model level to standardize validation across
    forms and Rest Framework.
    """

    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super().__init__(*args, **kwargs)
        self.validators.append(ProhibitNullCharactersValidator())


class NameTextField(LCTextField):
    """
    Keep Control Characters out of name fields.
    Cannot be used for fields where regex strings are valid.
    """

    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super().__init__(*args, **kwargs)
        self.validators.append(NoControlCharactersValidator())


class NameCITextField(CITextField):
    """
    Keep Control Characters out of name fields.
    Cannot be used for fields where regex strings are valid.
    """

    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super().__init__(*args, **kwargs)
        self.validators.extend([
            NoControlCharactersValidator(),
            ProhibitNullCharactersValidator(),
        ])


class CodeTextField(NameTextField):
    """
    Enforce no spaces and legal punctuation.
    """

    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super().__init__(*args, **kwargs)
        self.validators.extend([
            validate_legal_punctuation,  # type: ignore[list-item]
            validate_no_spaces,  # type: ignore[list-item]
        ])


class UUID6Field(UUIDField):
    db_returning = True

    def __init__(self, *args: Any, **kwargs: Any) -> None:
        if (
            not kwargs.get('null', False)
            and not kwargs.get('blank', False)
            and 'default' not in kwargs
        ):
            kwargs['default'] = uuid_time_nextval
        super().__init__(*args, **kwargs)


class CanonicalIdField(UUID6Field):
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        if 'help_text' not in kwargs:
            kwargs['help_text'] = (
                'Canonical ID assigned to this record, locally or on external systems. '
                'Should not be changed, except to reconcile sync errors.'
            )
        super().__init__(*args, **kwargs)


class ComputerPortField(IntegerField):
    """
    Accept values from 1-65535
    """

    def __init__(self, *args: Any, **kwargs: Any) -> None:
        kwargs['validators'] = [MinValueValidator(1), MaxValueValidator(65535)]
        super().__init__(*args, **kwargs)


class SequenceField(IntegerField):
    """
    Inspired by https://github.com/cordery/django-autosequence/blob/master/autosequence/fields.py

    AutoSequenceField is an AutoField that is available for non-primary keys and can be configured
    with unique_with to have separate sequences based on other model fields.

    Modified from upstream example to:
        * permit field to be editable, because sort order can be updated
        * remove unnecessary "start_at" parameter
        * set db_index by default

    :param unique_with: string or tuple of strings: name or names of attributes
        that this sequence will be unique with
    """

    def __init__(self, unique_with: str | tuple | None = None, *args: Any, **kwargs: Any) -> None:
        self.unique_with = unique_with or ()
        if isinstance(self.unique_with, str):
            self.unique_with = (self.unique_with,)

        if 'db_index' not in kwargs:
            kwargs['db_index'] = True

        if 'help_text' not in kwargs:
            kwargs['help_text'] = 'Numerical value to ensure consistent record sort order.'

        super().__init__(*args, **kwargs)

    def deconstruct(self):
        name, path, args, kwargs = super().deconstruct()

        if self.unique_with != ():
            kwargs['unique_with'] = self.unique_with

        return name, path, args, kwargs

    def pre_save(self, instance, add):
        if not add:
            return getattr(instance, self.attname)

        # Return preset sequence to allow for creating entries via bulk_create
        preset_sequence = getattr(instance, self.attname)
        if preset_sequence:
            return preset_sequence

        qs = self.model.objects.all()  # type: ignore[attr-defined]

        if self.unique_with:
            qs = qs.filter(**{field: getattr(instance, field) for field in self.unique_with})
        sequence = qs.aggregate(max=Max(self.attname))['max']
        if sequence:
            sequence += 1
        else:
            sequence = 1
        setattr(instance, self.attname, sequence)

        return sequence


# ----------------------------------------------------------------------
class IPRange(Inet):
    """
    Wrap a string for the IPRANGE type
    """

    if not is_psycopg3:

        def getquoted(self):
            obj = adapt(self.addr)
            if hasattr(obj, 'prepare'):
                obj.prepare(self._conn)
            return obj.getquoted() + b"::iprange"


# ----------------------------------------------------------------------
class IP4RangeAddressField(InetAddressField):
    """
    Address field for the `ip4r` Postgres extension module.
    """

    description = "PostgreSQL iprange field"

    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super().__init__(*args, **kwargs)
        # don't store prefix length, as Redwood doesn't use CIDR notation
        self.store_prefix_length = False

    def db_type(self, connection: Any) -> str:
        return 'iprange'

    def to_python(self, value):
        if isinstance(value, str) and value.count('-') == 1:
            value = value.split('-')

        if isinstance(value, list):
            return [self.to_python(v) for v in value]  # type: ignore[no-untyped-call]

        return super().to_python(value)

    def get_prep_value(self, value: Any) -> str | None:
        """
        Cast list to string if it's a list of IP Addresses
        """
        if not value:
            return None

        prep_val = self.to_python(value)  # type: ignore[no-untyped-call]
        if not isinstance(prep_val, list):
            prep_val = [prep_val]

        try:
            return '-'.join([str(ip) for ip in prep_val])
        except TypeError:
            return str(value)

    def get_db_prep_value(self, value, connection, prepared=False):
        """
        Override to prepare IPRange value.
        """
        model = getattr(self, 'model', None)
        if model and model._meta.get_field(self.name).get_internal_type() == 'ArrayField':
            is_array_field = True
        else:
            is_array_field = False

        if prepared is False and is_array_field is False:
            return self.get_prep_value(value)

        return IPRange(self.get_prep_value(value))

    def form_class(self):
        return IP4RangeAddressFormField

    def validate(self, value: Any, model_instance: Any) -> None:
        super().validate(value, model_instance)

        ips = [value] if not isinstance(value, list) else value
        for ip in ips:
            validate_private_lan_ip(ip)

        if len(ips) == 2 and ips[0] > ips[1]:
            raise exceptions.ValidationError('First IP may not be greater than second IP')


# ----------------------------------------------------------------------
class RemoteURLField(CharField):
    """
    Same as URLField, to validate only Remote URLs are permitted.
    """

    default_validators = [RemoteURLValidator()]
    description = tr("URL")

    def __init__(self, verbose_name: str | None = None, name: str | None = None, **kwargs: Any):
        kwargs.setdefault("max_length", 200)
        super().__init__(verbose_name, name, **kwargs)

    def deconstruct(self) -> tuple[str, str, Any, Any]:
        name, path, args, kwargs = super().deconstruct()
        if kwargs.get("max_length") == 200:
            del kwargs["max_length"]
        return name, path, args, kwargs

    def formfield(self, **kwargs):  # type: ignore[override]
        return super().formfield(**{
            "form_class": RemoteURLFormField,
            **kwargs,
        })


__all__ = (
    'LowerCaseTextField',
    'LCDateTimeField',
    'LCEmailField',
    'LCJSONField',
    'LCTextField',
    'MsgspecStructField',
    'NameCITextField',
    'CodeTextField',
    'NameTextField',
    'CanonicalIdField',
    'ComputerPortField',
    'SequenceField',
    'UUID6Field',
    'IP4RangeAddressField',
    'RemoteURLField',
)
