from __future__ import unicode_literals

import json
import logging
import re
from builtins import str
from collections import defaultdict
from pathlib import Path

from future.utils import iteritems, itervalues

from snips_inference_agl.common.dataset_utils import get_slot_name_mappings
from snips_inference_agl.common.log_utils import log_elapsed_time, log_result
from snips_inference_agl.common.utils import (
    check_persisted_path, deduplicate_overlapping_items, fitted_required,
    json_string, ranges_overlap, regex_escape,
    replace_entities_with_placeholders)
from snips_inference_agl.constants import (
    DATA, END, ENTITIES, ENTITY,
    INTENTS, LANGUAGE, RES_INTENT, RES_INTENT_NAME,
    RES_MATCH_RANGE, RES_SLOTS, RES_VALUE, SLOT_NAME, START, TEXT, UTTERANCES,
    RES_PROBA)
from snips_inference_agl.dataset import validate_and_format_dataset
from snips_inference_agl.dataset.utils import get_stop_words_whitelist
from snips_inference_agl.entity_parser.builtin_entity_parser import is_builtin_entity
from snips_inference_agl.exceptions import IntentNotFoundError, LoadingError
from snips_inference_agl.intent_parser.intent_parser import IntentParser
from snips_inference_agl.pipeline.configs import DeterministicIntentParserConfig
from snips_inference_agl.preprocessing import normalize_token, tokenize, tokenize_light
from snips_inference_agl.resources import get_stop_words
from snips_inference_agl.result import (empty_result, extraction_result,
                              intent_classification_result, parsing_result,
                              unresolved_slot)

WHITESPACE_PATTERN = r"\s*"

logger = logging.getLogger(__name__)


@IntentParser.register("deterministic_intent_parser")
class DeterministicIntentParser(IntentParser):
    """Intent parser using pattern matching in a deterministic manner

    This intent parser is very strict by nature, and tends to have a very good
    precision but a low recall. For this reason, it is interesting to use it
    first before potentially falling back to another parser.
    """

    config_type = DeterministicIntentParserConfig

    def __init__(self, config=None, **shared):
        """The deterministic intent parser can be configured by passing a
        :class:`.DeterministicIntentParserConfig`"""
        super(DeterministicIntentParser, self).__init__(config, **shared)
        self._language = None
        self._slot_names_to_entities = None
        self._group_names_to_slot_names = None
        self._stop_words = None
        self._stop_words_whitelist = None
        self.slot_names_to_group_names = None
        self.regexes_per_intent = None
        self.entity_scopes = None

    @property
    def language(self):
        return self._language

    @language.setter
    def language(self, value):
        self._language = value
        if value is None:
            self._stop_words = None
        else:
            if self.config.ignore_stop_words:
                self._stop_words = get_stop_words(self.resources)
            else:
                self._stop_words = set()

    @property
    def slot_names_to_entities(self):
        return self._slot_names_to_entities

    @slot_names_to_entities.setter
    def slot_names_to_entities(self, value):
        self._slot_names_to_entities = value
        if value is None:
            self.entity_scopes = None
        else:
            self.entity_scopes = {
                intent: {
                    "builtin": {ent for ent in itervalues(slot_mapping)
                                if is_builtin_entity(ent)},
                    "custom": {ent for ent in itervalues(slot_mapping)
                               if not is_builtin_entity(ent)}
                }
                for intent, slot_mapping in iteritems(value)}

    @property
    def group_names_to_slot_names(self):
        return self._group_names_to_slot_names

    @group_names_to_slot_names.setter
    def group_names_to_slot_names(self, value):
        self._group_names_to_slot_names = value
        if value is not None:
            self.slot_names_to_group_names = {
                slot_name: group for group, slot_name in iteritems(value)}

    @property
    def patterns(self):
        """Dictionary of patterns per intent"""
        if self.regexes_per_intent is not None:
            return {i: [r.pattern for r in regex_list] for i, regex_list in
                    iteritems(self.regexes_per_intent)}
        return None

    @patterns.setter
    def patterns(self, value):
        if value is not None:
            self.regexes_per_intent = dict()
            for intent, pattern_list in iteritems(value):
                regexes = [re.compile(r"%s" % p, re.IGNORECASE)
                           for p in pattern_list]
                self.regexes_per_intent[intent] = regexes

    @property
    def fitted(self):
        """Whether or not the intent parser has already been trained"""
        return self.regexes_per_intent is not None

    @log_elapsed_time(
        logger, logging.INFO, "Fitted deterministic parser in {elapsed_time}")
    def fit(self, dataset, force_retrain=True):
        """Fits the intent parser with a valid Snips dataset"""
        logger.info("Fitting deterministic intent parser...")
        dataset = validate_and_format_dataset(dataset)
        self.load_resources_if_needed(dataset[LANGUAGE])
        self.fit_builtin_entity_parser_if_needed(dataset)
        self.fit_custom_entity_parser_if_needed(dataset)
        self.language = dataset[LANGUAGE]
        self.regexes_per_intent = dict()
        entity_placeholders = _get_entity_placeholders(dataset, self.language)
        self.slot_names_to_entities = get_slot_name_mappings(dataset)
        self.group_names_to_slot_names = _get_group_names_to_slot_names(
            self.slot_names_to_entities)
        self._stop_words_whitelist = get_stop_words_whitelist(
            dataset, self._stop_words)

        # Do not use ambiguous patterns that appear in more than one intent
        all_patterns = set()
        ambiguous_patterns = set()
        intent_patterns = dict()
        for intent_name, intent in iteritems(dataset[INTENTS]):
            patterns = self._generate_patterns(intent_name, intent[UTTERANCES],
                                               entity_placeholders)
            patterns = [p for p in patterns
                        if len(p) < self.config.max_pattern_length]
            existing_patterns = {p for p in patterns if p in all_patterns}
            ambiguous_patterns.update(existing_patterns)
            all_patterns.update(set(patterns))
            intent_patterns[intent_name] = patterns

        for intent_name, patterns in iteritems(intent_patterns):
            patterns = [p for p in patterns if p not in ambiguous_patterns]
            patterns = patterns[:self.config.max_queries]
            regexes = [re.compile(p, re.IGNORECASE) for p in patterns]
            self.regexes_per_intent[intent_name] = regexes
        return self

    @log_result(
        logger, logging.DEBUG, "DeterministicIntentParser result -> {result}")
    @log_elapsed_time(logger, logging.DEBUG, "Parsed in {elapsed_time}.")
    @fitted_required
    def parse(self, text, intents=None, top_n=None):
        """Performs intent parsing on the provided *text*

        Intent and slots are extracted simultaneously through pattern matching

        Args:
            text (str): input
            intents (str or list of str): if provided, reduces the scope of
                intent parsing to the provided list of intents
            top_n (int, optional): when provided, this method will return a
                list of at most top_n most likely intents, instead of a single
                parsing result.
                Note that the returned list can contain less than ``top_n``
                elements, for instance when the parameter ``intents`` is not
                None, or when ``top_n`` is greater than the total number of
                intents.

        Returns:
            dict or list: the most likely intent(s) along with the extracted
            slots. See :func:`.parsing_result` and :func:`.extraction_result`
            for the output format.

        Raises:
            NotTrained: when the intent parser is not fitted
        """
        if top_n is None:
            top_intents = self._parse_top_intents(text, top_n=1,
                                                  intents=intents)
            if top_intents:
                intent = top_intents[0][RES_INTENT]
                slots = top_intents[0][RES_SLOTS]
                if intent[RES_PROBA] <= 0.5:
                    # return None in case of ambiguity
                    return empty_result(text, probability=1.0)
                return parsing_result(text, intent, slots)
            return empty_result(text, probability=1.0)
        return self._parse_top_intents(text, top_n=top_n, intents=intents)

    def _parse_top_intents(self, text, top_n, intents=None):
        if isinstance(intents, str):
            intents = {intents}
        elif isinstance(intents, list):
            intents = set(intents)

        if top_n < 1:
            raise ValueError(
                "top_n argument must be greater or equal to 1, but got: %s"
                % top_n)

        def placeholder_fn(entity_name):
            return _get_entity_name_placeholder(entity_name, self.language)

        results = []

        for intent, entity_scope in iteritems(self.entity_scopes):
            if intents is not None and intent not in intents:
                continue
            builtin_entities = self.builtin_entity_parser.parse(
                text, scope=entity_scope["builtin"], use_cache=True)
            custom_entities = self.custom_entity_parser.parse(
                text, scope=entity_scope["custom"], use_cache=True)
            all_entities = builtin_entities + custom_entities
            mapping, processed_text = replace_entities_with_placeholders(
                text, all_entities, placeholder_fn=placeholder_fn)
            cleaned_text = self._preprocess_text(text, intent)
            cleaned_processed_text = self._preprocess_text(processed_text,
                                                           intent)
            for regex in self.regexes_per_intent[intent]:
                res = self._get_matching_result(text, cleaned_text, regex,
                                                intent)
                if res is None and cleaned_text != cleaned_processed_text:
                    res = self._get_matching_result(
                        text, cleaned_processed_text, regex, intent, mapping)

                if res is not None:
                    results.append(res)
                    break

        # In some rare cases there can be multiple ambiguous intents
        # In such cases, priority is given to results containing fewer slots
        weights = [1.0 / (1.0 + len(res[RES_SLOTS])) for res in results]
        total_weight = sum(weights)

        for res, weight in zip(results, weights):
            res[RES_INTENT][RES_PROBA] = weight / total_weight

        results = sorted(results, key=lambda r: -r[RES_INTENT][RES_PROBA])

        return results[:top_n]

    @fitted_required
    def get_intents(self, text):
        """Returns the list of intents ordered by decreasing probability

        The length of the returned list is exactly the number of intents in the
        dataset + 1 for the None intent
        """
        nb_intents = len(self.regexes_per_intent)
        top_intents = [intent_result[RES_INTENT] for intent_result in
                       self._parse_top_intents(text, top_n=nb_intents)]
        matched_intents = {res[RES_INTENT_NAME] for res in top_intents}
        for intent in self.regexes_per_intent:
            if intent not in matched_intents:
                top_intents.append(intent_classification_result(intent, 0.0))

        # The None intent is not included in the regex patterns and is thus
        # never matched by the deterministic parser
        top_intents.append(intent_classification_result(None, 0.0))
        return top_intents

    @fitted_required
    def get_slots(self, text, intent):
        """Extracts slots from a text input, with the knowledge of the intent

        Args:
            text (str): input
            intent (str): the intent which the input corresponds to

        Returns:
            list: the list of extracted slots

        Raises:
            IntentNotFoundError: When the intent was not part of the training
                data
        """
        if intent is None:
            return []

        if intent not in self.regexes_per_intent:
            raise IntentNotFoundError(intent)

        slots = self.parse(text, intents=[intent])[RES_SLOTS]
        if slots is None:
            slots = []
        return slots

    def _get_intent_stop_words(self, intent):
        whitelist = self._stop_words_whitelist.get(intent, set())
        return self._stop_words.difference(whitelist)

    def _preprocess_text(self, string, intent):
        """Replaces stop words and characters that are tokenized out by
            whitespaces"""
        tokens = tokenize(string, self.language)
        current_idx = 0
        cleaned_string = ""
        stop_words = self._get_intent_stop_words(intent)
        for token in tokens:
            if stop_words and normalize_token(token) in stop_words:
                token.value = "".join(" " for _ in range(len(token.value)))
            prefix_length = token.start - current_idx
            cleaned_string += "".join((" " for _ in range(prefix_length)))
            cleaned_string += token.value
            current_idx = token.end
        suffix_length = len(string) - current_idx
        cleaned_string += "".join((" " for _ in range(suffix_length)))
        return cleaned_string

    def _get_matching_result(self, text, processed_text, regex, intent,
                             entities_ranges_mapping=None):
        found_result = regex.match(processed_text)
        if found_result is None:
            return None
        parsed_intent = intent_classification_result(intent_name=intent,
                                                     probability=1.0)
        slots = []
        for group_name in found_result.groupdict():
            ref_group_name = group_name
            if "_" in group_name:
                ref_group_name = group_name.split("_")[0]
            slot_name = self.group_names_to_slot_names[ref_group_name]
            entity = self.slot_names_to_entities[intent][slot_name]
            rng = (found_result.start(group_name),
                   found_result.end(group_name))
            if entities_ranges_mapping is not None:
                if rng in entities_ranges_mapping:
                    rng = entities_ranges_mapping[rng]
                else:
                    shift = _get_range_shift(
                        rng, entities_ranges_mapping)
                    rng = {START: rng[0] + shift, END: rng[1] + shift}
            else:
                rng = {START: rng[0], END: rng[1]}
            value = text[rng[START]:rng[END]]
            parsed_slot = unresolved_slot(
                match_range=rng, value=value, entity=entity,
                slot_name=slot_name)
            slots.append(parsed_slot)
        parsed_slots = _deduplicate_overlapping_slots(slots, self.language)
        parsed_slots = sorted(parsed_slots,
                              key=lambda s: s[RES_MATCH_RANGE][START])
        return extraction_result(parsed_intent, parsed_slots)

    def _generate_patterns(self, intent, intent_utterances,
                           entity_placeholders):
        unique_patterns = set()
        patterns = []
        stop_words = self._get_intent_stop_words(intent)
        for utterance in intent_utterances:
            pattern = self._utterance_to_pattern(
                utterance, stop_words, entity_placeholders)
            if pattern not in unique_patterns:
                unique_patterns.add(pattern)
                patterns.append(pattern)
        return patterns

    def _utterance_to_pattern(self, utterance, stop_words,
                              entity_placeholders):
        from snips_nlu_utils import normalize

        slot_names_count = defaultdict(int)
        pattern = []
        for chunk in utterance[DATA]:
            if SLOT_NAME in chunk:
                slot_name = chunk[SLOT_NAME]
                slot_names_count[slot_name] += 1
                group_name = self.slot_names_to_group_names[slot_name]
                count = slot_names_count[slot_name]
                if count > 1:
                    group_name = "%s_%s" % (group_name, count)
                placeholder = entity_placeholders[chunk[ENTITY]]
                pattern.append(r"(?P<%s>%s)" % (group_name, placeholder))
            else:
                tokens = tokenize_light(chunk[TEXT], self.language)
                pattern += [regex_escape(t.lower()) for t in tokens
                            if normalize(t) not in stop_words]

        pattern = r"^%s%s%s$" % (WHITESPACE_PATTERN,
                                 WHITESPACE_PATTERN.join(pattern),
                                 WHITESPACE_PATTERN)
        return pattern

    @check_persisted_path
    def persist(self, path):
        """Persists the object at the given path"""
        path.mkdir()
        parser_json = json_string(self.to_dict())
        parser_path = path / "intent_parser.json"

        with parser_path.open(mode="w", encoding="utf8") as f:
            f.write(parser_json)
        self.persist_metadata(path)

    @classmethod
    def from_path(cls, path, **shared):
        """Loads a :class:`DeterministicIntentParser` instance from a path

        The data at the given path must have been generated using
        :func:`~DeterministicIntentParser.persist`
        """
        path = Path(path)
        model_path = path / "intent_parser.json"
        if not model_path.exists():
            raise LoadingError(
                "Missing deterministic intent parser metadata file: %s"
                % model_path.name)

        with model_path.open(encoding="utf8") as f:
            metadata = json.load(f)
        return cls.from_dict(metadata, **shared)

    def to_dict(self):
        """Returns a json-serializable dict"""
        stop_words_whitelist = None
        if self._stop_words_whitelist is not None:
            stop_words_whitelist = {
                intent: sorted(values)
                for intent, values in iteritems(self._stop_words_whitelist)}
        return {
            "config": self.config.to_dict(),
            "language_code": self.language,
            "patterns": self.patterns,
            "group_names_to_slot_names": self.group_names_to_slot_names,
            "slot_names_to_entities": self.slot_names_to_entities,
            "stop_words_whitelist": stop_words_whitelist
        }

    @classmethod
    def from_dict(cls, unit_dict, **shared):
        """Creates a :class:`DeterministicIntentParser` instance from a dict

        The dict must have been generated with
        :func:`~DeterministicIntentParser.to_dict`
        """
        config = cls.config_type.from_dict(unit_dict["config"])
        parser = cls(config=config, **shared)
        parser.patterns = unit_dict["patterns"]
        parser.language = unit_dict["language_code"]
        parser.group_names_to_slot_names = unit_dict[
            "group_names_to_slot_names"]
        parser.slot_names_to_entities = unit_dict["slot_names_to_entities"]
        if parser.fitted:
            whitelist = unit_dict.get("stop_words_whitelist", dict())
            # pylint:disable=protected-access
            parser._stop_words_whitelist = {
                intent: set(values) for intent, values in iteritems(whitelist)}
            # pylint:enable=protected-access
        return parser


def _get_range_shift(matched_range, ranges_mapping):
    shift = 0
    previous_replaced_range_end = None
    matched_start = matched_range[0]
    for replaced_range, orig_range in iteritems(ranges_mapping):
        if replaced_range[1] <= matched_start:
            if previous_replaced_range_end is None \
                    or replaced_range[1] > previous_replaced_range_end:
                previous_replaced_range_end = replaced_range[1]
                shift = orig_range[END] - replaced_range[1]
    return shift


def _get_group_names_to_slot_names(slot_names_mapping):
    slot_names = {slot_name for mapping in itervalues(slot_names_mapping)
                  for slot_name in mapping}
    return {"group%s" % i: name
            for i, name in enumerate(sorted(slot_names))}


def _get_entity_placeholders(dataset, language):
    return {
        e: _get_entity_name_placeholder(e, language)
        for e in dataset[ENTITIES]
    }


def _deduplicate_overlapping_slots(slots, language):
    def overlap(lhs_slot, rhs_slot):
        return ranges_overlap(lhs_slot[RES_MATCH_RANGE],
                              rhs_slot[RES_MATCH_RANGE])

    def sort_key_fn(slot):
        tokens = tokenize(slot[RES_VALUE], language)
        return -(len(tokens) + len(slot[RES_VALUE]))

    deduplicated_slots = deduplicate_overlapping_items(
        slots, overlap, sort_key_fn)
    return sorted(deduplicated_slots,
                  key=lambda slot: slot[RES_MATCH_RANGE][START])


def _get_entity_name_placeholder(entity_label, language):
    return "%%%s%%" % "".join(
        tokenize_light(entity_label, language)).upper()