import json
import msgspec
from unittest import TestCase
from redwoodctl.redwood.categories import (
    add_scoring_category_scores,
    add_classifier_rule_analysis,
    prune_child_categories,
    load_category_conf_files,
)
from redwoodctl.typehints.extended_types import (
    ClassifyTallyResponse,
    ClassifierCategoryStat,
    RatingStat,
)
from redwoodctl.typehints import RedwoodAction, RatingName
from redwoodctl.redwood.tally import add_rating_scores
from .fixtures import domain_rules_only, domain_phrases, reuters


class TestTally(TestCase):
    @classmethod
    def setUpClass(cls) -> None:
        cls.category_configs = load_category_conf_files()  # type: ignore[attr-defined]

    def test_domain_only_tally(self):
        response = msgspec.json.decode(json.dumps(domain_rules_only), type=ClassifyTallyResponse)

        # ------------------------------------------------------------------------
        # Add category scorers to tally dict
        response.categories = add_scoring_category_scores(response)
        scorers = {
            "lc_Cortana_block": 1500,
            "lc_matchallNone": 1500,
            "finance": 750,
        }
        self.assertDictEqual(response.categories, scorers)

        # ------------------------------------------------------------------------
        # Add classifier scores to tally dict
        response.classifierAnalysis = add_classifier_rule_analysis(response, self.category_configs)
        classifier = {
            "finance": ClassifierCategoryStat(
                score=750,
                rating=RatingName.SILT,
                action=RedwoodAction.Allow,
                domain_score=300,
                domain_rules=["bank"],
                ip_score=0,
                ip_rules=[],
                phrase_score=0,
                phrase_rules=[],
                regex_score=450,
                regex_rules=["/bank\\b/h"],
            ),
        }

        self.assertDictEqual(
            response.classifierAnalysis,
            classifier,
            msg="ACL and Misc categories must be removed",
        )

        # ------------------------------------------------------------------------
        # Add rating scores to tally dict
        response.ratings = add_rating_scores(response)
        self.assertDictEqual(
            response.ratings,
            {RatingName.SILT: RatingStat(total_score=750, phrase_score=0)},
        )

    def test_tally_with_phrases(self):
        response = msgspec.json.decode(json.dumps(domain_phrases), type=ClassifyTallyResponse)

        # ------------------------------------------------------------------------
        # Add category scorers to tally dict
        response.categories = add_scoring_category_scores(response)
        scorers = {
            "base_blanket_block": 3000,
            "lc_Apple_Configurator": 3000,
            "lc_CA_EmailAndUpdates": 3000,
            "lc_Silver_ML_Default": 3000,
            "lc_Silver_ML_EmailandUpdates": 3000,
            "lc_email_only": 3000,
            "lc_epmc_emailandupdates_73": 3000,
            "lc_Cortana_block": 1500,
            "lc_matchallNone": 1500,
            "pageassets": 400,
            "os": 375,
            "computer": 251,
        }
        self.assertDictEqual(response.categories, scorers)

        response.classifierAnalysis = add_classifier_rule_analysis(response, self.category_configs)

        # ------------------------------------------------------------------------
        # Add rating scores to tally dict
        response.ratings = add_rating_scores(response)

        expected_ratings = {
            RatingName.SILT: RatingStat(total_score=626, phrase_score=626.0, phrase_count=9),
        }

        self.assertDictEqual(response.ratings, expected_ratings)

    def test_tally_reuters(self):
        self.maxDiff = None
        response = msgspec.json.decode(json.dumps(reuters), type=ClassifyTallyResponse)

        # ------------------------------------------------------------------------
        # Add category scorers to tally dict
        response.categories = add_scoring_category_scores(response)

        scorers = {
            "news/news_politics": 2174,
            "news/news_military": 1824,
            "lc_Cortana_block": 1500,
            "lc_matchallNone": 1500,
            "politics": 1315,
            "news": 1249,
            "news/news_sports": 1139,
            "finance": 1055,
            "news/news_celebrity": 995,
            "news/news_fashion": 995,
            "news/news_weapons": 995,
            "news/news_medical": 930,
            "military": 925,
            "news/news_genealogy": 909,
            "news/news_humor": 850,
            "computer": 765,
            "legal": 725,
            "isptelecom": 693,
            "auto": 634,
            "content_music": 625,
            "state_sites": 533,
            "foods/beverages": 510,
            "woodworking": 435,
            "travel_utility": 384,
            "travel_countryprofiles": 320,
            "ag": 280,
            "airfare": 255,
            "lawn": 250,
            "finance_alternate": 230,
        }
        self.assertDictEqual(response.categories, scorers)

        response.classifierAnalysis = add_classifier_rule_analysis(response, self.category_configs)

        # ------------------------------------------------------------------------
        # Add rating scores to tally dict
        response.ratings = add_rating_scores(response)

        expected_ratings = {
            RatingName.PEBBLE: RatingStat(total_score=10321, phrase_score=10321.0, phrase_count=88),
            RatingName.STONE: RatingStat(total_score=2614, phrase_score=3114.0, phrase_count=27),
            RatingName.SILT: RatingStat(total_score=6519, phrase_score=6519.0, phrase_count=69),
            RatingName.ROCK: RatingStat(total_score=1990, phrase_score=1990.0, phrase_count=16),
            RatingName.SAND: RatingStat(total_score=550, phrase_score=550.0, phrase_count=6),
        }

        self.assertDictEqual(response.ratings, expected_ratings)

        # ------------------------------------------------------------------------
        # Prune duplicated child categories
        prune_child_categories(response)

        expected_pruned_categories = {
            "news/news_politics": 2174,
            "news/news_military": 1824,
            "lc_Cortana_block": 1500,
            "lc_matchallNone": 1500,
            "politics": 1315,
            "news": 1249,
            "finance": 1055,
            "military": 925,
            "computer": 765,
            "legal": 725,
            "isptelecom": 693,
            "auto": 634,
            "content_music": 625,
            "state_sites": 533,
            "foods/beverages": 510,
            "woodworking": 435,
            "travel_utility": 384,
            "travel_countryprofiles": 320,
            "ag": 280,
            "airfare": 255,
            "lawn": 250,
            "finance_alternate": 230,
        }

        self.assertDictEqual(response.categories, expected_pruned_categories)
