# type: ignore
try:
    from pybase64 import b64encode
except (ImportError, ModuleNotFoundError):
    from base64 import b64encode
import logging
from http import HTTPStatus
from threading import Event
from typing import Any, Callable
from urllib.parse import quote, urlparse

from ..http import Request, Response
from .http_server import start_http_server, stop_http_server

_logger = logging.getLogger(__name__)


class OAuthError(Exception):
    def __init__(
        self,
        status_code: HTTPStatus,
        error: str,
        error_description: str | None = None,
    ):
        self.status_code = status_code
        self.error = error
        self.error_description = error_description

    def __str__(self) -> str:
        return f'{self.status_code}  - {self.error} : {self.error_description}'


class ServiceInformation:
    def __init__(
        self,
        authorize_service: str | None,
        token_service: str | None,
        client_id: str,
        client_secret: str,
        scopes: list,
        verify: bool = True,
    ):
        self.authorize_service = authorize_service
        self.token_service = token_service
        self.client_id = client_id
        self.client_secret = client_secret
        self.scopes = scopes
        auth = f'{self.client_id}:{self.client_secret}'.encode()
        self.auth = b64encode(auth).decode()
        self.verify = verify


class AuthorizeResponseCallback(dict):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.response = Event()

    def wait(self, timeout: float | None = None):
        self.response.wait(timeout)

    def register_parameters(self, parameters: dict):
        self.update(parameters)
        self.response.set()


class AuthorizationContext:
    def __init__(self, state: str, port: int, host: str):
        self.state = state
        self.results = AuthorizeResponseCallback()
        self.server = start_http_server(port, host, self.results.register_parameters)


class CredentialManager:
    def __init__(self, service_information: ServiceInformation):
        self.service_information = service_information
        self.authorization_code_context = None
        self.refresh_token = None
        self._session = None

        if not service_information.verify:
            from urllib3.exceptions import InsecureRequestWarning
            import warnings

            warnings.filterwarnings(
                'ignore', 'Unverified HTTPS request is being made.*', InsecureRequestWarning
            )

    @staticmethod
    def _handle_bad_response(response: Response):
        try:
            error = response.json()
            raise OAuthError(
                HTTPStatus(response.status_code), error.get('error'), error.get('error_description')
            )
        except BaseException as ex:
            if type(ex) is not OAuthError:
                _logger.exception(
                    '_handle_bad_response - error while getting error as json - %s - %s',
                    type(ex),
                    str(ex),
                )
                raise OAuthError(HTTPStatus(response.status_code), 'unknown_error', response.text)
            else:
                raise

    def generate_authorize_url(self, redirect_uri: str, state: str, **kwargs) -> str:
        parameters = dict(
            client_id=self.service_information.client_id,
            redirect_uri=redirect_uri,
            response_type='code',
            scope=' '.join(self.service_information.scopes),
            state=state,
            **kwargs,
        )
        return '%s?%s' % (
            self.service_information.authorize_service,
            '&'.join('%s=%s' % (k, quote(v, safe='~()*!.\'')) for k, v in parameters.items()),
        )

    def init_authorize_code_process(self, redirect_uri: str, state: str = '', **kwargs) -> str:
        uri_parsed = urlparse(redirect_uri)
        if uri_parsed.scheme == 'https':
            raise NotImplementedError("Redirect uri cannot be secured")
        elif uri_parsed.port == '' or uri_parsed.port is None:
            _logger.warning('You should use a port above 1024 for redirect uri server')
            port = 80
        else:
            port = int(uri_parsed.port)
        if uri_parsed.hostname != 'localhost' and uri_parsed.hostname != '127.0.0.1':
            _logger.warning(
                'Remember to put %s in your hosts config to point to loop back address',
                uri_parsed.hostname,
            )
        self.authorization_code_context = AuthorizationContext(state, port, uri_parsed.hostname)
        return self.generate_authorize_url(redirect_uri, state, **kwargs)

    def wait_and_terminate_authorize_code_process(self, timeout: float | None = None) -> str:
        if self.authorization_code_context is None:
            raise Exception('Authorization code not started')
        else:
            try:
                self.authorization_code_context.results.wait(timeout)
                error = self.authorization_code_context.results.get('error', None)
                error_description = self.authorization_code_context.results.get(
                    'error_description', ''
                )
                code = self.authorization_code_context.results.get('code', None)
                state = self.authorization_code_context.results.get('state', None)
                if error is not None:
                    raise OAuthError(HTTPStatus.UNAUTHORIZED, error, error_description)  # noqa
                elif state != self.authorization_code_context.state:
                    _logger.warning('State received does not match the one that was sent')
                    raise OAuthError(
                        HTTPStatus.INTERNAL_SERVER_ERROR,  # noqa
                        'invalid_state',
                        'State returned does not match: Sent(%s) <> Got(%s)'
                        % (self.authorization_code_context.state, state),
                    )
                elif code is None:
                    raise OAuthError(
                        HTTPStatus.INTERNAL_SERVER_ERROR,  # noqa
                        'no_code',
                        'No code returned',
                    )
                else:
                    return code
            finally:
                stop_http_server(self.authorization_code_context.server)
                self.authorization_code_context = None

    def init_with_authorize_code(self, redirect_uri: str, code: str, **kwargs):
        self._token_request(self._grant_code_request(code, redirect_uri, **kwargs), True)

    def init_with_user_credentials(self, login: str, password: str):
        self._token_request(self._grant_password_request(login, password), True)

    def init_with_client_credentials(self):
        self._token_request(self._grant_client_credentials_request(), False)

    def init_with_token(self, refresh_token: str):
        self._token_request(self._grant_refresh_token_request(refresh_token), False)
        if self.refresh_token is None:
            self.refresh_token = refresh_token

    def _grant_code_request(self, code: str, redirect_uri: str, **kwargs) -> dict:
        return dict(
            grant_type='authorization_code',
            code=code,
            scope=' '.join(self.service_information.scopes),
            redirect_uri=redirect_uri,
            **kwargs,
        )

    def _grant_password_request(self, login: str, password: str) -> dict:
        return dict(
            grant_type='password',
            username=login,
            scope=' '.join(self.service_information.scopes),
            password=password,
        )

    def _grant_client_credentials_request(self) -> dict:
        return dict(
            grant_type="client_credentials",
            scope=' '.join(self.service_information.scopes),
        )

    def _grant_refresh_token_request(self, refresh_token: str) -> dict:
        return dict(
            grant_type="refresh_token",
            scope=' '.join(self.service_information.scopes),
            refresh_token=refresh_token,
        )

    def _refresh_token(self):
        payload = self._grant_refresh_token_request(self.refresh_token)
        try:
            self._token_request(payload, False)
        except OAuthError as err:
            if err.status_code == HTTPStatus.UNAUTHORIZED:
                _logger.debug('refresh_token - unauthorized - cleaning token')
                self._session = None
                self.refresh_token = None
            raise err

    def _token_request(self, request_parameters: dict, refresh_token_mandatory: bool):
        headers = self._token_request_headers(request_parameters['grant_type'])
        headers['Authorization'] = f'Basic {self.service_information.auth}'
        headers['Content-Type'] = 'application/x-www-form-urlencoded'

        request = Request(
            self.service_information.token_service,
            verify=self.service_information.verify,
        )
        response = request.post(body=request_parameters, headers=headers)

        if response.status_code != HTTPStatus.OK:
            CredentialManager._handle_bad_response(response)
        else:
            _logger.debug(response.text)
            self._process_token_response(response.json(), refresh_token_mandatory)

    def _process_token_response(self, token_response: dict, refresh_token_mandatory: bool):
        self.refresh_token = (
            token_response['refresh_token']
            if refresh_token_mandatory
            else token_response.get('refresh_token')
        )
        self._access_token = token_response['access_token']

    @property
    def _access_token(self) -> str | None:
        authorization_header = (
            self._session._headers.get('Authorization') if self._session is not None else None
        )
        if authorization_header is not None:
            return authorization_header[len('Bearer ') :]
        else:
            return None

    @_access_token.setter
    def _access_token(self, access_token: str):
        if self._session is None:
            self._session = Request(url='')
            self._session.verify = self.service_information.verify
            self._session.trust_env = False
        if access_token is not None and len(access_token) > 0:
            self._session._headers.update(dict(Authorization=f'Bearer {access_token}'))

    def get(self, url: str, params: dict | None = None, **kwargs) -> Response:
        kwargs['params'] = params
        return self._bearer_request(self._get_session(url).get, **kwargs)

    def post(self, url: str, body=None, data=None, **kwargs) -> Response:
        kwargs['body'] = body
        kwargs['data'] = data
        return self._bearer_request(self._get_session(url).post, **kwargs)

    def put(self, url: str, body=None, data=None, **kwargs) -> Response:
        kwargs['body'] = body
        kwargs['data'] = data
        return self._bearer_request(self._get_session(url).put, **kwargs)

    def patch(self, url: str, body=None, data=None, **kwargs) -> Response:
        kwargs['body'] = body
        kwargs['data'] = data
        return self._bearer_request(self._get_session(url).patch, **kwargs)

    def delete(self, url: str, **kwargs) -> Response:
        return self._bearer_request(self._get_session(url).delete, **kwargs)

    def _get_session(self, url) -> Request:
        if self._session is None:
            raise OAuthError(HTTPStatus.UNAUTHORIZED, 'no_token', "no token provided")  # noqa
        self._session.url = url
        return self._session

    def _bearer_request(self, method: Callable[[Any], Response], **kwargs) -> Response:
        headers = kwargs.get('headers', None)
        if headers is None:
            headers = dict()
            kwargs['headers'] = headers
        _logger.debug("_bearer_request on %s - %s", method.__name__, self._session.url)
        response = method(**kwargs)  # noqa
        if self.refresh_token is not None and self._is_token_expired(response):
            self._refresh_token()
            return method(**kwargs)  # noqa
        else:
            return response

    @staticmethod
    def _token_request_headers(grant_type: str) -> dict:  # noqa
        return dict()

    @staticmethod
    def _is_token_expired(response: Response) -> bool:
        if response.status_code == HTTPStatus.UNAUTHORIZED:
            try:
                json_data = response.json()
                return json_data.get('error') == 'invalid_token'
            except ValueError:
                return False
        else:
            return False


__all__ = (
    'CredentialManager',
    'AuthorizationContext',
    'AuthorizeResponseCallback',
    'OAuthError',
    'ServiceInformation',
)
