from lcrequests import LogCabinApiSession
from lcrequests.exceptions import LCHTTPError
import msgspec
from http import HTTPStatus
from urllib.parse import urlencode
from typing import Any, Literal, overload, TypeVar, TYPE_CHECKING

from ..redwood.categories import (
    add_scoring_category_scores,
    add_classifier_rule_analysis,
    prune_child_categories,
    prune_score_analysis_categories,
    load_category_conf_files,
    calculate_total_score,
)
from ..redwood.tally import (
    add_rating_scores,
    calculate_tally_stats,
    collect_tally_rules,
    collect_rule_types,
)
from ..settings import REDWOOD_API
from ..typehints.redwood_api_types import (
    ApiError,
    ApiFunction,
    ApiMode,
    ClassifyText,
)
from ..typehints.extended_types import (
    ClassifyUrlAnalyzedResponse,
    ClassifyUrlVerboseResponse,
    ClassifyTallyResponse,
    EXTENDED_RESPONSES,
)
from ..typehints import (
    CATEGORY_CONFIGS,
)

TALLY_RESPONSES = TypeVar(
    "TALLY_RESPONSES",
    ClassifyTallyResponse,
    ClassifyUrlAnalyzedResponse,
)


# -------------------------------------------------------------------------
def classify_tally(
    tally: str,
    category_configs: CATEGORY_CONFIGS | None = None,
) -> ApiError | ClassifyTallyResponse:
    """
    Classify logline tally string by Redwood.

    :param tally: Tally column from logline for Redwood to classify
    :param category_configs: dict of required values from category.conf
    """
    response = classify(tally, ApiFunction.Tally, ApiMode.Normal)

    if isinstance(response, ApiError):
        return response

    # /analyze-tally doesn't have "categories" key, so we need to add it.
    response.categories = add_scoring_category_scores(response)
    response.stats = calculate_tally_stats(response)

    return perform_extended_analysis(response, category_configs, verbose=False)


# -------------------------------------------------------------------------
def classify_tally_json(
    tally: dict[str, Any],
    category_configs: CATEGORY_CONFIGS | None = None,
) -> ApiError | ClassifyTallyResponse:
    """
    Wrapper for Block Page tally parsing which is JSON.
    """
    tally_string = ", ".join([f"{k} {v}" for k, v in tally.items()])
    return classify_tally(tally_string, category_configs)


# -------------------------------------------------------------------------
def classify_text(text: str) -> ApiError | ClassifyText:
    """
    Classify text value by Redwood.

    :param text: Text corpus for Redwood to classify
    """
    response = classify(text, ApiFunction.Text, ApiMode.Verbose)

    if isinstance(response, ApiError):
        return response

    return prune_child_categories(response)


# -------------------------------------------------------------------------
def classify_analyze_url(
    url: str,
    category_configs: CATEGORY_CONFIGS | None = None,
    verbose: bool = False,
) -> ApiError | ClassifyUrlAnalyzedResponse:
    """
    Classify URL in `analyze-score` mode.
    """
    response = classify(url, ApiFunction.URL, ApiMode.Analyze)
    if isinstance(response, ApiError):
        return response

    response = perform_extended_analysis(response, category_configs, verbose)
    response.rules = collect_tally_rules(response)

    return response


def perform_extended_analysis(
    response: TALLY_RESPONSES,
    category_configs: CATEGORY_CONFIGS | None = None,
    verbose: bool = False,
) -> TALLY_RESPONSES:
    """
    Fill out the additional keys to get the Extended Classification Response.
    """
    category_configs = category_configs or load_category_conf_files()
    prune_child_categories(response, category_configs)
    if not verbose:
        prune_score_analysis_categories(response)

    response.classifierAnalysis = add_classifier_rule_analysis(response, category_configs)
    response.total_score, response.total_phrase_score = calculate_total_score(response)
    response.ratings = add_rating_scores(response)
    response.rule_types = collect_rule_types(response)  # type: ignore[assignment]

    return response


# -------------------------------------------------------------------------
def classify_url_verbose(url: str) -> ApiError | ClassifyUrlVerboseResponse:
    """
    Classify URL in `verbose` mode.
    """
    response = classify(url, ApiFunction.URL, ApiMode.Verbose)
    if isinstance(response, ApiError):
        return response

    response.total_score = sum(response.categories.values())

    return response


# -------------------------------------------------------------------------
if TYPE_CHECKING:

    @overload
    def classify(
        body: str, function: Literal[ApiFunction.Tally], mode: Literal[ApiMode.Normal]
    ) -> ClassifyTallyResponse: ...

    @overload
    def classify(
        body: str, function: Literal[ApiFunction.Text], mode: Literal[ApiMode.Verbose]
    ) -> ClassifyText: ...

    @overload
    def classify(
        body: str, function: Literal[ApiFunction.URL], mode: Literal[ApiMode.Verbose]
    ) -> ClassifyUrlVerboseResponse: ...

    @overload
    def classify(
        body: str, function: Literal[ApiFunction.URL], mode: Literal[ApiMode.Analyze]
    ) -> ClassifyUrlAnalyzedResponse: ...


def classify(body: str, function: ApiFunction, mode: ApiMode) -> EXTENDED_RESPONSES:
    """
    Helper method to call the various classifying methods of the Redwood API.

    :param body: The body (url / text / tally) to be classified.
    :param function: The classification function to perform.
    :param mode: The type / mode of classifying
    """

    if is_classifying_url(function):
        if is_verbose_mode(mode):
            ResponseType = ClassifyUrlVerboseResponse
        else:
            ResponseType = ClassifyUrlAnalyzedResponse  # type: ignore[assignment]
    elif is_classifying_text(function):
        ResponseType = ClassifyText  # type: ignore[assignment]
    elif is_classifying_tally(function):
        ResponseType = ClassifyTallyResponse  # type: ignore[assignment]
    else:
        raise ValueError(f"{function} is an invalid classification function")

    base_url = f"{REDWOOD_API}/{function.value}/{mode}".strip("/")
    url = f"{base_url}?{urlencode({function.name.lower(): body})}"

    with LogCabinApiSession(url, timeout=7) as session:
        try:
            response = session.get()
        except LCHTTPError as e:
            response = e.response
        except Exception as e:
            return ApiError(status_code=HTTPStatus.INTERNAL_SERVER_ERROR, message=str(e))

    if not response.ok:
        return ApiError(status_code=response.status_code, message=response.text)

    # Try to decode classification response. But it may not always succeed,
    # such as when a tally classification request in which no category contains
    # any of the tally rules.
    try:
        return msgspec.json.decode(response.content, type=ResponseType)
    except msgspec.ValidationError:
        try:
            message = response.json()
        except Exception:
            message = response.content
        return ApiError(status_code=HTTPStatus.UNPROCESSABLE_ENTITY, message=message)


def is_classifying_url(function: ApiFunction) -> bool:
    return function == ApiFunction.URL


def is_classifying_text(function: ApiFunction) -> bool:
    return function == ApiFunction.Text


def is_classifying_tally(function: ApiFunction) -> bool:
    return function == ApiFunction.Tally


def is_verbose_mode(mode: ApiMode) -> bool:
    return mode == ApiMode.Verbose


__all__ = (
    "classify_tally",
    "classify_tally_json",
    "classify_text",
    "classify_analyze_url",
    "classify_url_verbose",
)
