from bz2 import BZ2File
import csv
import email.utils
from datetime import datetime, timedelta
from gzip import GzipFile
from ipaddress import ip_network
import os
from pathlib import Path
from lcconfig import LCConfigParser
from lcrequests import Request
import orjson
import tempfile
from typing import NamedTuple
from urllib.parse import urlparse
import urllib3
from ..exceptionlist import DNS_EXCEPTION_LIST, EXCEPTION_SUFFIXES, SKIP_PATTERNS_MP_DOMAINS

from ..settings import (
    DALMATIAN_CONFIG_FILE,
    logging,
    IPSET_DIR,
    REDWOOD_CATEGORY_DIR,
    RPZ_DIR,
    SUBSCRIPTIONS_DIR,
)

logger = logging.getLogger(__name__)
SCHEMAS = ('https:', 'http:', '//', 'ftp:', 'ftps:')


class Patterns(NamedTuple):
    """
    Load URLs from feeds into the various sets.

    Track URLs that consist only of valid hostnames,
    so they can be used either in Redwood, IPSets, or RPZs.

    Also track ALL hostnames from ALL lines in the feed,
    since they might be used in ACL categories for tagging.
    """
    urls: set[str]
    urls_hostname_only: set[str]
    all_hostnames: set[str]


def rpz_pattern(pattern: str):
    """
    Return a valid RPZ or hostname or IP pattern.

    >>> rpz_pattern('10.3.5.190')
    32.190.5.3.10.rpz-ip
    >>> rpz_pattern('10.3.5.0/24')
    24.0.5.3.10.rpz-ip
    >>> rpz_pattern('phishingsite.net')
    phishingsite.net
    """
    try:
        ip_network(pattern)
        reversed_ip = reverse_ip(pattern)
        return f'{reversed_ip}.rpz-ip'
    except ValueError:
        pass

    return pattern


def reverse_ip(ip: str) -> str:
    """
    Reverse map IP address / IP network, adding cidr mask for RPZ compatibility.

    >>> reverse_ip('10.3.5.190')
    32.190.5.3.10
    >>> reverse_ip('10.3.5.0/24')
    24.0.5.3.10
    """
    if '/' not in ip[-3:]:
        ip = f'{ip}/32'
    return '.'.join(reversed(ip.strip().replace('/', '.').split('.')))


def download_url(name: str, url: str, headers):
    """
    Download URL and return contents, extracted if compressed.
    """
    try:
        req = Request(url, stream=True, headers=headers)
        resp = req.get()
    except ConnectionError:
        logger.error(f'Connection failed! Unable to download updates for {name}.')
        return False

    if resp.status != 200:
        logger.error(f'Status {resp.status_code}. Unable download updates for {name}.')
        return False

    content_type = resp.headers.get('Content-Type') or ''

    if content_type.endswith('gzip') or url.endswith('gz'):
        with GzipFile(fileobj=resp) as uncompressed:
            return uncompressed.read()

    if content_type.endswith('bzip2') or url.endswith('bz2'):
        with BZ2File(resp) as uncompressed:
            return uncompressed.read()

    return resp.read()


class Loader:

    def __init_subclass__(cls, **kwargs: dict) -> None:
        """
        Include all child subclasses, including subclasses of subclasses
        to make an easy lookup dict, to retrieve the right subclass for
        based on origin property.
        """
        super().__init_subclass__(**kwargs)
        if cls.origin in cls.Subclasses:
            raise ValueError(f'{cls.__class__.__class__} already specified for {cls.origin}')
        cls.Subclasses[cls.origin] = cls

    @classmethod
    def getClass(cls, origin):
        """Return the correct class for the specified origin"""
        return cls.Subclasses.get(origin, cls)

    def __str__(self):
        return f'{self.__class__.__name__}("{self.name}")'

    def __repr__(self):
        return self.__str__()


class Retriever(Loader):
    """
    Download and save data from Subscription URL.
    """
    Subclasses = {}
    origin = ''

    def __init__(self, name: str, urls: list[str], interval: int, **kwargs):
        self.name = name
        self.urls = urls
        self.interval = int(interval)
        self._headers = kwargs.get('headers', {})
        self.last_modified = None

    @property
    def headers(self):
        """
        Override on subclasses to customize headers
        """
        return self._headers

    @property
    def directory(self):
        """
        Return the directory location
        """
        return f'{SUBSCRIPTIONS_DIR}/{self.name}'

    @property
    def combined_file(self) -> str:
        """
        Return file name of raw subscription data, containing all
        the data of all the URLs that the subscription consists of.
        """
        if self.name.startswith('mp_domains'):
            return f'{tempfile.gettempdir()}/mp_malicious_domains.list'
        return f'{self.directory}/subscription.list'

    def is_subscription_stale(self) -> bool:
        """
        Check whether combined data file is older than current time,
        minus the next update check interval.
        """
        path = Path(self.combined_file)
        if not path.exists():
            open(self.combined_file, 'w').close()
            self.last_modified = datetime(1970, 1, 1, 0, 0)
            return True

        self.last_modified = datetime.fromtimestamp(path.stat().st_mtime)
        return datetime.now() - timedelta(minutes=self.interval) > self.last_modified

    def update_modified_date(self) -> None:
        """
        Update the modified date of the combined file, so that if an update check
        has _no_ new updates, that we don't check again until an entire update
        interval has elapsed
        """
        os.utime(self.combined_file)
        path = Path(self.combined_file)
        self.last_modified = datetime.fromtimestamp(path.stat().st_mtime)

    def remote_file_newer(self, url: str) -> bool:
        """
        Make head request to see if remote file is newer than the local file.
        """
        headers = {'If-Modified-Since': email.utils.format_datetime(self.last_modified)}
        headers.update(self.headers)
        try:
            req = Request(url, verify=False, headers=headers)
            resp = req.head()
        except ConnectionError:
            logger.error(f'Connection failed! Unable to check for updates for {self.name}.')
            return False

        if not resp.ok:
            logger.error(f'Status {resp.status_code}. Unable check for updates for {self.name}.')
            return False

        last_modified = resp.headers.get('Last-Modified')
        return last_modified and datetime.strptime(
            last_modified, '%a, %d %b %Y %H:%M:%S GMT'
        ) < self.last_modified

    def update(self) -> datetime:
        """
        Download updated all URL contents and save to file.
        """
        if not self.is_subscription_stale():
            next_check = self.last_modified + timedelta(minutes=self.interval)
            logger.info(f'Skipping update - {self.name} not stale until {next_check}')
            return self.last_modified

        os.makedirs(self.directory, exist_ok=True)

        all_files = []
        updated_files = []

        for url in self.urls:
            url_file_name = os.path.basename(url)
            local_file = f'{self.directory}/{url_file_name}'
            all_files.append(local_file)

            content = download_url(self.name, url, self.headers)
            if not content:
                continue

            updated_files.append(local_file)
            with open(f'{self.directory}/{url_file_name}', 'wb') as uf:
                uf.write(content)

        if not updated_files:
            self.update_modified_date()
            logger.info(f'No updates available for {self.name}')
            return self.last_modified

        # regenerate combined file from all files, and not just changed files
        # as some files might not have been updated this round.
        with open(self.combined_file, 'wb') as cf:
            for file in all_files:
                with open(file, 'rb') as af:
                    cf.write(af.read())

        logger.info(f'Successfully downloaded updates for {self.name}')

        return self.last_modified


class MalwarePatrolRetriever(Retriever):
    origin = 'malware-patrol'

    @property
    def headers(self):
        headers = super().headers
        cfg = LCConfigParser()
        cfg.read(DALMATIAN_CONFIG_FILE)
        data = cfg.as_typed_dict(self.origin)
        headers |= urllib3.make_headers(basic_auth=f'{data["username"]}:{data["password"]}')
        return headers


class ConfigWriter(Loader):
    """
    Take raw data from a Subscription and generate conf files for the various services,
    such as RPZs, IPSets and Redwood Categories.

    Default is <firehol>.netset URL from https://iplists.firehol.org
    """
    Subclasses = {}
    origin = ''

    def __init__(self, **kwargs):
        self.name: str = kwargs.get('name')
        if not self.name:
            raise ValueError('Name is required!')

        self.interval: int = kwargs.get('interval', 500)
        self.ipset: bool = kwargs.get('ipset', False)
        self.redwood: bool = kwargs.get('redwood', False)
        self.redwood_category: str = kwargs.get('redwood_category') or self.name
        self.redwood_domain_acl_category: str = kwargs.get('redwood_domain_acl_category') or ''
        self.rpz: bool = kwargs.get('rpz', False)
        self._patterns = None

    def save(self):
        """
        Save config files to disk in correct format.
        """
        self.generate_ipset_file()
        self.generate_redwood_category_file()
        self.generate_rpz_file()

    @property
    def combined_file(self):
        return f'{SUBSCRIPTIONS_DIR}/{self.name}/subscription.list'

    def application_file_is_stale(self, location: str) -> bool:
        """
        Check if the application file (Redwood Category, RPZ)
        is older than the latest feed update.
        """
        if not (application_file := Path(location)).exists():
            return True

        feed_last_updated = datetime.fromtimestamp(Path(self.combined_file).stat().st_mtime)
        return feed_last_updated > datetime.fromtimestamp(application_file.stat().st_mtime)

    def symlink(self, src, dst):
        try:
            os.symlink(src, dst)
        except FileExistsError:
            pass

    def load_patterns_file(self) -> set[str]:
        """
        Override on subclasses to perform any pattern normalization / cleanup.
        """
        with open(self.combined_file) as f:
            return set(
                l.strip() for l in f.readlines() if l.strip() and not l.startswith(('#', ';'))
            )

    def pattern_lines(self) -> Patterns:
        """
        Extract lines from file and extract URLs and hostnames.

        https://sites.google.com/view/adrianoferriani/
        https://sites.google.com/view/orange-facture1/admin
        https://d.oftde.top/
        https://form.jotform.com/221295519236559
        http://interadtivebrokers.com
        """
        if not self._patterns:
            all_hostnames = set()
            hostnames_only = set()
            urls = set()
            for url in self.load_patterns_file():
                url = url.strip().rstrip('/').lower()
                parseable_url = f'//{url}' if not url.startswith(SCHEMAS) else url
                try:
                    parsed = urlparse(parseable_url)
                except ValueError as e:
                    logger.error('Unable to parse URL: %s. %s', url, e)
                    continue
                hostname = parsed.netloc.split(':')[0]

                exclude_hostname = self._exclude_hostname(hostname)
                if not exclude_hostname:
                    all_hostnames.add(hostname)

                hostname_only = ((not parsed.path or parsed.path == '/') and not parsed.fragment
                                 and not parsed.query and not parsed.params)

                if hostname_only:
                    if not exclude_hostname:
                        hostnames_only.add(hostname)
                        urls.add(hostname)
                else:
                    # The pattern has a full path; don't extract the hostname,
                    # since RPZs and IPSets affect the entire hostname.
                    schema_idx = parseable_url.find('//', 0, 9) + 2
                    urls.add(parseable_url[schema_idx:])

            self._patterns = Patterns(
                urls=urls,
                urls_hostname_only=hostnames_only,
                all_hostnames=all_hostnames,
            )

        return self._patterns

    def _exclude_hostname(self, hostname: str) -> bool:
        """
        Run checks to see if this hostname should be excluded
        Override on subclasses to extend functionality.
        """
        return hostname in DNS_EXCEPTION_LIST or hostname.endswith(EXCEPTION_SUFFIXES)

    def rpz_pattern_lines(self) -> list[str]:
        """
        Get all RPZ-compatible lines from the subscription file,
        deduplicated and cleaned of exception entries.

        Override on subclasses to customize behavior.
        """
        hostnames = self.pattern_lines().urls_hostname_only
        rpz_lines = list(hostnames.difference(DNS_EXCEPTION_LIST))
        rpz_lines.sort()  # sort to minimize file diffs for rsync
        return rpz_lines

    def redwood_pattern_lines(self) -> list[str]:
        """
        Redwood can handle full URLs, which is much more granular than
        the hostnames that RPZs and IPSets support, so the hostname-level
        exceptions aren't applied.

        Override on subclasses to customize behavior.
        """
        trimmed_urls = []
        for url in self.pattern_lines().urls:
            if url.startswith('www.'):
                trimmed_urls.append(url[4:])
            else:
                trimmed_urls.append(url)
        trimmed_urls.sort()  # sort to minimize file diffs for rsync
        return trimmed_urls

    def generate_ipset_file(self) -> None:
        """
        Generate IPSet config file & settings file in the Log Cabin console location.
        """
        if not self.ipset:
            return

        ipset_dir = f'{IPSET_DIR}/{self.name}/'
        os.makedirs(ipset_dir, exist_ok=True)
        self.symlink(self.combined_file, f'{ipset_dir}/{self.name}.ipset')
        self.symlink(f'{SUBSCRIPTIONS_DIR}/{self.name}/settings.yml', f'{ipset_dir}/settings.yml')
        logger.info(f'Updated IP Set files for subscription: {self.name}')

    def generate_rpz_file(self) -> None:
        """
        Generate RPZ zone file from expected netset data in the Log Cabin console location.
        Override if not IPSet formatted data.
        """
        if not self.rpz:
            return

        rpz_dir = f'{RPZ_DIR}/{self.name}'
        os.makedirs(rpz_dir, exist_ok=True)
        self.symlink(f'{SUBSCRIPTIONS_DIR}/{self.name}/settings.yml', f'{rpz_dir}/settings.yml')

        zone_file = f'{rpz_dir}/zone.conf'
        if not self.application_file_is_stale(zone_file):
            return

        with open(zone_file, 'w') as f:
            f.write(
                f'$TTL {self.interval}\n@ SOA {self.name}.compassfoundation.io. hostmaster.compassfoundation.io. 1 12h 15m 3w 2h\n  NS localhost.\n;\n\n'
            )
            f.write('\n'.join(f'{rpz_pattern(ip)} CNAME .' for ip in self.rpz_pattern_lines()))

        logger.info(f'Updated RPZ files for subscription: {self.name}')

    def generate_redwood_category_file(self) -> None:
        """
        Generate Redwood Category file from raw data. Override on subclasses as necessary.
        On systems where the DNS Recursor is running, Redwood categories will never fire
        as the RPZ Zone will fire first and redirect the user to the RPZ block page.

        But not all systems will be able to run DNS Recursor because of slow connections,
        so by creating Redwood Categories when possible, we offer them some protection
        by blocking HTTP requests via Redwood.
        """
        if not self.redwood:
            return

        category_dir = f'{REDWOOD_CATEGORY_DIR}/{self.redwood_category}'
        os.makedirs(category_dir, exist_ok=True)

        category_list_file = f'{category_dir}/{self.name}.list'
        if not self.application_file_is_stale(category_list_file):
            return

        with open(category_list_file, 'w') as f:
            f.write('# Automatically generated file. Do not edit!\n\ndefault 500\n\n')
            f.write('\n'.join(self.redwood_pattern_lines()))

        # If this subscription collects all the domains found in the feed,
        # save them to the specified ACL category. Such categories for the
        # use of Starlark-based deep URL analysis functions.
        if self.redwood_domain_acl_category:
            category_dir = f'{REDWOOD_CATEGORY_DIR}/{self.redwood_domain_acl_category}'
            os.makedirs(category_dir, exist_ok=True)
            with open(f'{category_dir}/{self.name}.list', 'w') as f:
                f.write('# Automatically generated file. Do not edit!\n\ndefault 500\n\n')
                domains = list(self.pattern_lines().all_hostnames)
                domains.sort()
                f.write('\n'.join(domains))

        logger.info(f'Updated Redwood Category files for subscription: {self.name}')


class FireholConfigWriter(ConfigWriter):
    origin = 'firehol'


class RpzConfigWriter(ConfigWriter):
    origin = 'rpz'


class PhishTankConfigWriter(ConfigWriter):
    origin = 'phishtank'

    def load_patterns_file(self) -> set[str]:
        """
        Parse CSV file and extract url from columns.

        phish_id,url,phish_detail_url,submission_time,verified,verification_time,online,target
        6689513,https://payee-failure.cc/hsbc/,http://www.phishtank.com/phish_detail.php?phish_id=6689513,2020-07-22T13:23:23+00:00,yes,2020-07-22T13:33:32+00:00,yes,"HSBC Group"
        6689510,https://o2.uk.inv80.com/?o2=2,http://www.phishtank.com/phish_detail.php?phish_id=6689510,2020-07-22T13:19:24+00:00,yes,2020-07-22T13:22:03+00:00,yes,"Telefónica UK"
        6689507,https://o2.uk.inv50.com/?o2=2,http://www.phishtank.com/phish_detail.php?phish_id=6689507,2020-07-22T13:16:39+00:00,yes,2020-07-22T13:22:03+00:00,yes,"Telefónica UK"
        """
        patterns = set()

        with open(self.combined_file, 'r') as cp:
            reader = csv.DictReader(cp)
            for row in reader:
                patterns.add(row['url'])

        return patterns


class MalwarePatrolConfigWriter(ConfigWriter):
    origin = 'malware-patrol'

    @property
    def combined_file(self):
        if self.name.startswith('mp_domains'):
            return f'{tempfile.gettempdir()}/mp_malicious_domains.list'
        return super().combined_file

    def load_patterns_file(self):
        if self.name == 'mp_scam_domains':
            return self.load_scam_domains()
        if self.name.startswith('mp_domains'):
            return self.load_malicious_domains()
        return super().load_patterns_file()

    def load_malicious_domains(self):
        """
        The file is a list of JSON objects like so:

        {
            "domain": "00027fea-1142-4c83-add2-add56b9bfa18.filesusr.com",
            "detection_timestamp": "20200131211819",
            "domain_ranking": "0",
            "threat_type": "malware",
            "confidence": "100",
            "malware_classification": "trojan"
        }
        """
        patterns = set()
        threat_type_key = self.name.split('_')[-1]

        with open(self.combined_file, 'r') as cp:
            for row in orjson.loads(cp.read()):
                if not row or int(row['domain_ranking']) > 0:
                    continue

                if SKIP_PATTERNS_MP_DOMAINS.search(row['domain']):
                    continue

                if row['threat_type'] == threat_type_key:
                    patterns.add(row['domain'])

        return patterns

    def load_scam_domains(self) -> set[str]:
        """
        The file is a list of JSON objects like so:

        {
            '_id': {
                '$oid': '641281ea13a4a4f9c70d0120'
            },
            'domain': '10xbet.club',
            'score': 99,
            'created_at': 1678934506,
            'trustrules': None,
            'updated_at': '1730548117',
            'ssl': {
                'valid': True,
                'type': 'DV',
                'issuer': "Let's Encrypt"
            },
            'networkinfo': {
                'ip': {
                    'address': '217.21.77.124',
                    'isp': 'Hostinger International Limited',
                    'country': 'US'
                },
                'nameservers': [{
                    'hostname': 'ns2.dns-parking.com.',
                    'ip': {
                        'address': '162.159.25.42',
                        'isp': 'CloudFlare Inc.',
                        'country': 'US'
                    }
                }, {
                    'hostname': 'ns1.dns-parking.com.',
                    'ip': {
                        'address': '162.159.24.201',
                        'isp': 'CloudFlare Inc.',
                        'country': 'US'
                    }
                }]
            },
            'scamadviser_votes': {
                'source_url': 'https://www.scamadviser.com/check-website/10xbet.club',
                'count_legit': None,
                'count_scam': None,
                'count_fake': None
            },
            'status': '200 OK',
            'status_timestaamp': '1730549106',
            'domain_ranking': '0'
        }
        """
        patterns = set()

        with open(self.combined_file, 'r') as cp:
            for row in orjson.loads(cp.read()):
                if SKIP_PATTERNS_MP_DOMAINS.search(row['domain']):
                    continue
                patterns.add(row['domain'])

        return patterns


__all__ = (
    'ConfigWriter',
    'FireholConfigWriter',
    'PhishTankConfigWriter',
    'Retriever',
    'RpzConfigWriter',
    'MalwarePatrolRetriever',
    'MalwarePatrolConfigWriter',
)
