from functools import lru_cache
import grp
from pathlib import Path
import pwd
from string import punctuation
from subprocess import run
from system_env.environment import is_drawbridge_os
from typing import Any, Iterator

from django.conf import settings
from django.utils import timezone
from django.utils.regex_helper import _lazy_re_compile as re_compile
from datetime import timedelta
from lcutils import ValidIP

from pid import PidFile as PidFileBase, PidFileError

punc_tble = str.maketrans({p: ' ' for p in punctuation})
RPM_VERSION_REGEX = re_compile(r'-([\d\.]*\d*-\d*)')

# intuitively named decorator for methods that should only be run once
call_once = lru_cache(maxsize=None)


# ---------------------------------------------------------------------------
def package_version(name: str) -> str:
    """
    Get version of installed DEB or RPM package.

    # rpm -q pdns-recursor = pdns-recursor-4.2.1-1pdns.el6.x86_64
    >>> package_version('pdns-recursor')
    4.2.1-1
    # rpm -q redwood = redwood-1.1-42.x86_64
    >>> package_version('redwood')
    1.1-42
    # rpm -q logcabin = logcabin-1.9-715.noarch
    >>> package_version('redwood')
    1.9-715
    """
    if is_drawbridge_os():
        return deb_package_version(name)

    return rpm_package_version(name)


def deb_package_version(name: str) -> str:
    """
    Get package version on DEB systems.
    """
    try:
        query = (
            run(
                ['apt', 'show', name],
                capture_output=True,
            )
            .stdout.decode()
            .strip()
            .split('\n')[0:3]
        )
        package_details = dict([d.split(':') for d in query])
        return package_details['Version'].strip()
    except (IndexError, KeyError, ValueError):
        pass
    return ''


def rpm_package_version(name: str) -> str:
    """
    Get package version on RPM systems.
    """
    try:
        query = run(['rpm', '-q', name], capture_output=True).stdout.decode()
        if rpm := RPM_VERSION_REGEX.search(query):
            return rpm.group(1)
    except (FileNotFoundError, AttributeError, IndexError, TypeError):
        pass
    return ''


# ---------------------------------------------------------------------------
class PidFile(PidFileBase):
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        if 'piddir' not in kwargs:
            kwargs['piddir'] = settings.TMP_DIR
        super().__init__(*args, **kwargs)


# ---------------------------------------------------------------------------
class LCPidFile(PidFileBase):
    """
    PID file set with logcabin user and group
    """

    def __init__(self, *args: Any, **kwargs: Any) -> None:
        if 'gid' not in kwargs:
            kwargs['gid'] = grp.getgrnam(settings.CONSOLE_GID).gr_gid
        if 'uid' not in kwargs:
            kwargs['uid'] = pwd.getpwnam(settings.CONSOLE_UID).pw_uid
        super().__init__(*args, **kwargs)


# ---------------------------------------------------------------------------
def punctuation_to_spaces(string: str) -> str:
    """
    Replace all punctuation in the string with spaces
    """
    return string.translate(punc_tble).strip()


def remove_trailing_digits(word: str) -> str:
    """
    Return word, minus any trailing digits
    """

    if not word or ValidIP(word).valid_ip():
        return word

    try:
        int(word[-1])
    except ValueError:
        return word

    return remove_trailing_digits(word[:-1])


# -------------------------------------------------------------------
def valid_pid(pidfile: str, age: int = 12 * 60 * 60) -> bool:
    """
    Checks if PID file exists and is younger than age in seconds

    :param pidfile: Path to file
    :param age: Age in seconds
    """
    pf = Path(pidfile)
    if not pf.is_file():
        return False

    pid_creation = pf.lstat().st_mtime
    max_age = (timezone.now() - timedelta(seconds=age)).timestamp()
    if pid_creation > max_age:
        return True

    pf.unlink()
    return False


# -------------------------------------------------------------------
def chunks(mlist: list, length: int) -> Iterator[list[str]]:
    """Break list into sub-lists of length 'n'"""
    # For item i in a range that is a length of l,
    for i in range(0, len(mlist), length):
        yield mlist[i : i + length]


__all__ = (
    'call_once',
    'chunks',
    'PidFile',
    'PidFileError',
    'LCPidFile',
    'punctuation_to_spaces',
    're_compile',
    'remove_trailing_digits',
    'package_version',
    'valid_pid',
)
