from copy import copy
import json
import msgspec
from pathlib import Path
from unittest import TestCase
from redwoodctl.redwood.categories import (
    prune_child_categories,
    sort_categories,
    load_category_conf_files,
)
from redwoodctl.typehints.categories import CategoryConf
from redwoodctl.typehints import (
    ClassifierCategoryStat,
    ClassifyUrlAnalyzedResponse,
    RatingName,
    RedwoodAction,
)
from redwoodctl.settings import REDWOOD_CATEGORY_DIR
from .fixtures.categories import (
    cnn,
    training_without_parent_cat,
)
from ..settings import (
    LOW_CONFIDENCE_SCORE_THRESHOLD,
    MED_CONFIDENCE_SCORE_THRESHOLD,
    HIGH_CONFIDENCE_SCORE_THRESHOLD,
    PHRASE_SCORING_SKEWED_RATIO,
    PHRASE_SCORING_LOW_CONFIDENCE_RATIO,
    PHRASE_SCORING_MED_CONFIDENCE_RATIO,
    PHRASE_SCORING_HIGH_CONFIDENCE_RATIO,
)


class TestClassifierCategoryStat(TestCase):
    def test_single_request_rule_type_found(self):
        single_rule_type = ClassifierCategoryStat(
            score=300,
            rating=RatingName.SILT,
            action=RedwoodAction.Allow,
            domain_score=300,
            domain_rules=["bank"],
            ip_score=0,
            ip_rules=[],
            phrase_score=0,
            phrase_rules=[],
            regex_score=0,
            regex_rules=["/bank\\b/h"],
        )
        self.assertEqual(single_rule_type.phrase_confidence(), 0.0)
        self.assertEqual(single_rule_type.confidence(), 0.5)

    def test_two_request_rule_types_found(self):
        two_rule_types = ClassifierCategoryStat(
            score=750,
            rating=RatingName.SILT,
            action=RedwoodAction.Allow,
            domain_score=300,
            domain_rules=["bank"],
            ip_score=0,
            ip_rules=[],
            phrase_score=0,
            phrase_rules=[],
            regex_score=450,
            regex_rules=["/bank\\b/h"],
        )
        self.assertEqual(two_rule_types.phrase_confidence(), 0.0)
        self.assertEqual(two_rule_types.confidence(), 0.75)

    def test_three_request_rule_types_found(self):
        three_rule_types = ClassifierCategoryStat(
            score=1925,
            rating=RatingName.SILT,
            action=RedwoodAction.Allow,
            domain_score=300,
            domain_rules=["bank"],
            ip_score=250,
            ip_rules=["203.15.17.90"],
            phrase_score=625,
            phrase_rules=list(range(3)),
            regex_score=0,
            regex_rules=[],
        )
        self.assertEqual(three_rule_types.phrase_confidence(), PHRASE_SCORING_SKEWED_RATIO)
        self.assertEqual(
            three_rule_types.confidence(),
            PHRASE_SCORING_HIGH_CONFIDENCE_RATIO,
            msg="When 3 rule types are found, confidence is High even if phrasing is skewed.",
        )

    def test_skewed_phrase_scoring(self):
        skewed_phrase = ClassifierCategoryStat(
            score=900,
            rating=RatingName.SILT,
            action=RedwoodAction.Allow,
            domain_score=0,
            domain_rules=[],
            ip_score=0,
            ip_rules=[],
            phrase_score=900,
            phrase_rules=list(range(3)),
            regex_score=0,
            regex_rules=[],
        )
        self.assertTrue(skewed_phrase.phrase_scoring_skewed())
        self.assertEqual(skewed_phrase.phrase_confidence(), PHRASE_SCORING_SKEWED_RATIO)

        many_phrases = copy(skewed_phrase)
        many_phrases.phrase_rules = list(range(25))
        self.assertFalse(many_phrases.phrase_scoring_skewed())

    def test_low_phrase_scoring_confidence(self):
        score = LOW_CONFIDENCE_SCORE_THRESHOLD + 5
        medium_confidence = ClassifierCategoryStat(
            score=score,
            rating=RatingName.SILT,
            action=RedwoodAction.Allow,
            domain_score=0,
            domain_rules=[],
            ip_score=0,
            ip_rules=[],
            phrase_score=score,
            phrase_rules=list(range(10)),
            regex_score=0,
            regex_rules=[],
        )
        self.assertFalse(medium_confidence.phrase_scoring_skewed())
        self.assertEqual(medium_confidence.phrase_confidence(), PHRASE_SCORING_LOW_CONFIDENCE_RATIO)

    def test_med_phrase_scoring_confidence(self):
        score = MED_CONFIDENCE_SCORE_THRESHOLD + 5
        medium_confidence = ClassifierCategoryStat(
            score=score,
            rating=RatingName.SILT,
            action=RedwoodAction.Allow,
            domain_score=0,
            domain_rules=[],
            ip_score=0,
            ip_rules=[],
            phrase_score=score,
            phrase_rules=list(range(20)),
            regex_score=0,
            regex_rules=[],
        )
        self.assertFalse(medium_confidence.phrase_scoring_skewed())
        self.assertEqual(medium_confidence.phrase_confidence(), PHRASE_SCORING_MED_CONFIDENCE_RATIO)

    def test_high_phrase_scoring_confidence(self):
        score = HIGH_CONFIDENCE_SCORE_THRESHOLD + 5
        high_confidence = ClassifierCategoryStat(
            score=score,
            rating=RatingName.SILT,
            action=RedwoodAction.Allow,
            domain_score=0,
            domain_rules=[],
            ip_score=0,
            ip_rules=[],
            phrase_score=score,
            phrase_rules=list(range(30)),
            regex_score=0,
            regex_rules=[],
        )
        self.assertFalse(high_confidence.phrase_scoring_skewed())
        self.assertEqual(high_confidence.phrase_confidence(), PHRASE_SCORING_HIGH_CONFIDENCE_RATIO)


class TestCategory(TestCase):
    def test_load_category(self):
        if not Path(REDWOOD_CATEGORY_DIR).is_dir():
            raise ValueError(f"Redwood Category directory {REDWOOD_CATEGORY_DIR!r} is not found")

        category_config = load_category_conf_files()
        self.assertIn("auto", category_config)

        finance = category_config["finance"]
        expected_conf = CategoryConf(
            id="1e659a87-2646-6a34-a910-0010184a8fec",
            action=RedwoodAction.Allow,
            genre_id="IAB13",
            rating=RatingName.SILT,
            parent_multiplier=1.0,
        )
        self.assertEqual(finance, expected_conf)

    def test_sort_categories(self):
        categories = {"auto": 250, "ag": 350, "construction": 700}
        self.assertEqual(
            sort_categories(categories),
            {"construction": 700, "ag": 350, "auto": 250},
            msg="Category dicts must be sorted in descending order by score",
        )

    def test_prune_categories_retain_children_if_parent_not_present(self):
        tally = msgspec.json.decode(
            json.dumps(training_without_parent_cat),
            type=ClassifyUrlAnalyzedResponse,
        )

        all_categories = tally.categories
        self.assertIn("training/safety", all_categories)

        self.assertDictEqual(
            prune_child_categories(tally).categories,
            {
                "lawn": 235,
                "training/safety": 220,
            },
            msg="Child categories must be retained if parent category is absent",
        )

    def test_prune_child_categories(self):
        tally = msgspec.json.decode(json.dumps(cnn), type=ClassifyUrlAnalyzedResponse)

        all_categories = tally.categories
        self.assertIn("news/news_military", all_categories)

        self.assertDictEqual(
            prune_child_categories(tally).categories,
            {
                "news": 1570,
                "webmarketing": 512,
                "auctions": 315,
                "computer": 257,
            },
            msg="Child categories with same score as parent must be removed",
        )
