# Evaluation util functions for PropBank SRL.

import codecs
from collections import Counter
import operator
import os
from os.path import join
import subprocess
from Evaluator import conll09_utils

_SPAN_SRL_CONLL_EVAL_SCRIPT = "srl_scripts/run_conll_eval.sh"
_DEPENDENCY_SRL_CONLL_EVAL_SCRIPT = "srl_scripts/eval_09.pl"


def split_example_for_eval(example):
    """Split document-based samples into sentence-based samples for evaluation.
    Args:
      example:
    Returns:
      Tuple of (sentence, list of SRL relations)
    """
    sentences = example["sentences"]
    num_words = sum(len(s) for s in sentences)
    word_offset = 0
    samples = []
    for i, sentence in enumerate(sentences):
        srl_rels = {}
        ner_spans = []  # Unused.
        for r in example["srl"][i]:
            pred_id = r[0] - word_offset
            if pred_id not in srl_rels:
                srl_rels[pred_id] = []
            srl_rels[pred_id].append((r[1] - word_offset, r[2] - word_offset, r[3]))
        samples.append((sentence, srl_rels, ner_spans))
        word_offset += len(sentence)
    return samples


def evaluate_retrieval(span_starts, span_ends, span_scores, pred_starts, pred_ends, gold_spans,
                       text_length, evaluators, debugging=False):
    """
    Evaluation for unlabeled retrieval.
    Args:
      gold_spans: Set of tuples of (start, end).
    """
    if len(span_starts) > 0:
        sorted_starts, sorted_ends, sorted_scores = zip(*sorted(
            zip(span_starts, span_ends, span_scores),
            key=operator.itemgetter(2), reverse=True))
    else:
        sorted_starts = []
        sorted_ends = []
    for k, evaluator in evaluators.items():
        if k == -3:
            predicted_spans = set(zip(span_starts, span_ends)) & gold_spans
        else:
            if k == -2:
                predicted_starts = pred_starts
                predicted_ends = pred_ends
                if debugging:
                    print
                    "Predicted", zip(sorted_starts, sorted_ends, sorted_scores)[:len(gold_spans)]
                    print
                    "Gold", gold_spans
            # FIXME: scalar index error
            elif k == 0:
                is_predicted = span_scores > 0
                predicted_starts = span_starts[is_predicted]
                predicted_ends = span_ends[is_predicted]
            else:
                if k == -1:
                    num_predictions = len(gold_spans)
                else:
                    num_predictions = (k * text_length) / 100
                predicted_starts = sorted_starts[:num_predictions]
                predicted_ends = sorted_ends[:num_predictions]
            predicted_spans = set(zip(predicted_starts, predicted_ends))
        evaluator.update(gold_set=gold_spans, predicted_set=predicted_spans)


def _print_f1(total_gold, total_predicted, total_matched, message=""):
    precision = 100.0 * total_matched / total_predicted if total_predicted > 0 else 0
    recall = 100.0 * total_matched / total_gold if total_gold > 0 else 0
    f1 = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0
    print("{}: Precision: {:.4f}, Recall: {:.4f}, F1: {:.4f}".format(message, precision, recall, f1))
    return precision, recall, f1


def compute_span_f1(gold_data, predictions, task_name):
    assert len(gold_data) == len(predictions)
    total_gold = 0
    total_predicted = 0
    total_matched = 0
    total_unlabeled_matched = 0
    label_confusions = Counter()  # Counter of (gold, pred) label pairs.

    for i in range(len(gold_data)):
        gold = gold_data[i]
        pred = predictions[i]
        total_gold += len(gold)
        total_predicted += len(pred)
        for a0 in gold:
            for a1 in pred:
                if a0[0] == a1[0] and a0[1] == a1[1]:
                    total_unlabeled_matched += 1
                    label_confusions.update([(a0[2], a1[2]), ])
                    if a0[2] == a1[2]:
                        total_matched += 1
    prec, recall, f1 = _print_f1(total_gold, total_predicted, total_matched, task_name)
    ul_prec, ul_recall, ul_f1 = _print_f1(total_gold, total_predicted, total_unlabeled_matched,
                                          "Unlabeled " + task_name)
    return prec, recall, f1, ul_prec, ul_recall, ul_f1, label_confusions


def compute_unlabeled_span_f1(gold_data, predictions, task_name):
    assert len(gold_data) == len(predictions)
    total_gold = 0
    total_predicted = 0
    total_matched = 0
    total_unlabeled_matched = 0
    label_confusions = Counter()  # Counter of (gold, pred) label pairs.

    for i in range(len(gold_data)):
        gold = gold_data[i]
        pred = predictions[i]
        total_gold += len(gold)
        total_predicted += len(pred)
        for a0 in gold:
            for a1 in pred:
                if a0[0] == a1[0] and a0[1] == a1[1]:
                    total_unlabeled_matched += 1
                    label_confusions.update([(a0[2], a1[2]), ])
                    if a0[2] == a1[2]:
                        total_matched += 1
    prec, recall, f1 = _print_f1(total_gold, total_predicted, total_matched, task_name)
    ul_prec, ul_recall, ul_f1 = _print_f1(total_gold, total_predicted, total_unlabeled_matched,
                                          "Unlabeled " + task_name)
    return prec, recall, f1, ul_prec, ul_recall, ul_f1, label_confusions


def compute_dependency_f1(sentences, gold_srl, predictions, srl_conll_eval_path = None, disamb_rate = 0.95, use_gold = False):
    assert len(gold_srl) == len(predictions)

    ca = 0.0
    pa = 0.0
    cp = 0.0
    pp = 0.0
    ga = 0.0
    gp = 0.0

    # Compute unofficial F1 of SRL relations.
    for gold, prediction in zip(gold_srl, predictions):
        if gold is None:
            continue
        gp += len(gold)
        pp += len(prediction)
        for pred_id, gold_args in gold.items():
            filtered_gold_args = [a for a in gold_args if a[1] not in ["_"]]
            ga += len(filtered_gold_args)
        for pred_id, pred_args in prediction.items():
            filtered_args = [a for a in pred_args if a[1] not in ["_"]]
            pa += len(filtered_args)
        for pred_id, gold_args in gold.items():
            if pred_id not in prediction:
                continue
            cp += 1
            pred_args = prediction[pred_id]
            for a0 in gold_args:
                for a1 in pred_args:
                    if a0[0] == a1[0] and a0[1] == a1[1] and a1[1] not in ["_"]:
                        ca += 1

    if use_gold:
        assert pp == gp
    precision = (ca + cp * disamb_rate) / (pa + pp) if (pa + pp) > 0 else 0
    recall = (ca + cp * disamb_rate) / (ga + gp) if (ga + gp) > 0 else 0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
    print('SRL(unofficial) GA:', ga, 'GP:', gp, 'PA:', pa, 'PP:', pp, 'CA:', ca, 'CP:', cp)
    print("SRL(unofficial) Precision: {:.4f}, Recall: {:.4f}, F1: {:.4f}".format(precision * 100, recall * 100, f1 * 100))

    # # assert srl_conll_eval_path is not None
    #
    # # Prepare to compute official F1.
    # if not srl_conll_eval_path:
    #   #print "No gold conll_eval data provided. Recreating ..."
    #   gold_path = "./tmp/srl_pred_%d.gold" % os.getpid()
    #   print_dependency_conll(sentences, gold_srl, gold_path)
    #   gold_predicates = None
    # else:
    #   gold_path = srl_conll_eval_path
    #   gold_predicates = read_gold_predicates(gold_path)
    #
    # temp_output = "./tmp/srl_pred_%d.tmp" % os.getpid()
    # print_dependency_conll(sentences, predictions, temp_output)
    #
    # gold_data = conll09_utils.read_conll(srl_conll_eval_path)
    # temp_data = conll09_utils.read_conll(temp_output)
    # temp_data = conll09_utils.merge_data(gold_data, temp_data)
    #
    # eval_output = "./tmp/srl_pred_%d.eval" % os.getpid()
    # conll09_utils.save_conll(temp_data, eval_output)
    #
    # # Evalute twice with official script.
    # child = subprocess.Popen('perl {} -s {} -g {} -q'.format(
    #     _DEPENDENCY_SRL_CONLL_EVAL_SCRIPT, eval_output, srl_conll_eval_path), shell=True, stdout=subprocess.PIPE)
    # eval_info = child.communicate()[0]
    #
    # try:
    #     info = eval_info.split('\n')
    #     p_info = info[7]
    #     r_info = info[8]
    #     f1_info = info[9]
    #     p_info = p_info.replace('(', '').replace(')', '').split()
    #     r_info = r_info.replace('(', '').replace(')', '').split()
    #     f1_info = f1_info.split()
    #
    #     assert int(p_info[2]) == int(r_info[2])
    #     assert int(p_info[4]) == int(r_info[4])
    #     conll_ca = int(p_info[2])
    #     conll_cp = int(p_info[4])
    #     conll_pa = int(p_info[6])
    #     conll_pp = int(p_info[8])
    #     conll_ga = int(r_info[6])
    #     conll_gp = int(r_info[8])
    #     conll_precision = (conll_ca + conll_cp * disamb_rate) / (conll_pa + conll_pp)
    #     conll_recall = (conll_ca + conll_cp * disamb_rate) / (conll_ga + conll_gp)
    #     conll_f1 = 2 * conll_precision * conll_recall / (conll_precision + conll_recall)
    #     print('SRL(official) GA:', conll_ga, 'GP:', conll_gp, 'PA:', conll_pa, 'PP:', conll_pp, 'CA:', conll_ca, 'CP:',
    #           conll_cp)
    #     print("SRL(official) Precision: {}, Recall: {}, F1: {}".format(conll_precision * 100, conll_recall * 100,
    #                                                                    conll_f1 * 100))
    # except:
    #     conll_recall = 0
    #     conll_precision = 0
    #     conll_f1 = 0
    #     print("Unable to get FScore. Skipping.")

    return precision * 100, recall * 100, f1 * 100 #, conll_precision * 100, conll_recall * 100, conll_f1 * 100


def compute_srl_f1(sentences, gold_srl, predictions, srl_conll_eval_path):
    assert len(gold_srl) == len(predictions)
    total_gold = 0
    total_predicted = 0
    total_matched = 0
    total_unlabeled_matched = 0
    comp_sents = 0
    label_confusions = Counter()

    # Compute unofficial F1 of SRL relations.
    cun = 0
    for gold, prediction in zip(gold_srl, predictions):
        gold_rels = 0
        pred_rels = 0
        matched = 0
        for pred_id, gold_args in gold.items():
            #gold_args: (start, end, label)
            filtered_gold_args = [a for a in gold_args if a[2] not in ["V", "C-V"]]
            total_gold += len(filtered_gold_args)
            gold_rels += len(filtered_gold_args)
            if pred_id not in prediction:
                continue
            for a0 in filtered_gold_args:
                for a1 in prediction[pred_id]:
                    if a0[0] == a1[0] and a0[1] == a1[1]:  # span start with span end
                        total_unlabeled_matched += 1
                        label_confusions.update([(a0[2], a1[2]), ])
                        if a0[2] == a1[2]:
                            total_matched += 1
                            matched += 1
        for pred_id, args in prediction.items():
            filtered_args = [a for a in args if a[2] not in ["V", "C-V"]]
            total_predicted += len(filtered_args)
            pred_rels += len(filtered_args)

        if gold_rels == matched and pred_rels == matched:
            comp_sents += 1

    precision, recall, f1 = _print_f1(total_gold, total_predicted, total_matched, "SRL (unofficial)")

    ul_prec, ul_recall, ul_f1 = _print_f1(total_gold, total_predicted, total_unlabeled_matched,
                                          "Unlabeled SRL (unofficial)")

    # return precision, recall, f1, ul_prec, ul_recall, ul_f1,

    # Prepare to compute official F1.
    if not srl_conll_eval_path:
        # print "No gold conll_eval data provided. Recreating ..."
        gold_path = "./tmp/srl_pred_%d.gold" % os.getpid()
        print_to_span_conll(sentences, gold_srl, gold_path, None)
        gold_predicates = None
    else:
        gold_path = srl_conll_eval_path
        gold_predicates = read_gold_predicates(gold_path)

    temp_output = "./tmp/srl_pred_%d.tmp" % os.getpid()
    print_to_span_conll(sentences, predictions, temp_output, gold_predicates)

    # Evalute twice with official script.
    child = subprocess.Popen('sh {} {} {}'.format(
        _SPAN_SRL_CONLL_EVAL_SCRIPT, gold_path, temp_output), shell=True, stdout=subprocess.PIPE)
    eval_info = child.communicate()[0]
    child2 = subprocess.Popen('sh {} {} {}'.format(
        _SPAN_SRL_CONLL_EVAL_SCRIPT, temp_output, gold_path), shell=True, stdout=subprocess.PIPE)
    eval_info2 = child2.communicate()[0]
    eval_info = str(eval_info, encoding="utf-8")
    eval_info2 = str(eval_info2, encoding="utf-8")
    try:
        conll_recall = float(eval_info.strip().split("\n")[6].strip().split()[5])
        conll_precision = float(eval_info2.strip().split("\n")[6].strip().split()[5])
        if conll_recall + conll_precision > 0:
            conll_f1 = 2 * conll_recall * conll_precision / (conll_recall + conll_precision)
        else:
            conll_f1 = 0
        # print(eval_info)
        # print(eval_info2)
        print("Official CoNLL Precision={}, Recall={}, Fscore={}".format(
            conll_precision, conll_recall, conll_f1))
    except IndexError:
        conll_recall = 0
        conll_precision = 0
        conll_f1 = 0
        print("Unable to get FScore. Skipping.")

    return conll_precision, conll_recall, conll_f1, precision, recall, f1  #,label_confusions, comp_sents


def print_sentence_to_conll(fout, tokens, labels):
    """Print a labeled sentence into CoNLL format.
    """
    for label_column in labels:
        assert len(label_column) == len(tokens)
    for i in range(len(tokens)):
        fout.write(tokens[i])  # .ljust(15)
        for label_column in labels:
            fout.write('\t' + label_column[i])  # .rjust(15)
        fout.write("\n")
    fout.write("\n")


def read_gold_predicates(gold_path):
    fin = codecs.open(gold_path, "r", "utf-8")
    gold_predicates = [[], ]
    for line in fin:
        line = line.strip()
        if not line:
            gold_predicates.append([])
        else:
            info = line.split()
            gold_predicates[-1].append(info[0])
    fin.close()
    return gold_predicates


def print_dependency_conll(sentences, srl_labels, output_filename):
    fout = codecs.open(output_filename, "w", "utf-8")

    for sent_id, words in enumerate(sentences):

        pred_to_args = srl_labels[sent_id]

        props = ["_" for _ in words]

        col_labels = [["_" for _ in words] for _ in range(len(pred_to_args))]

        for i, pred_id in enumerate(sorted(pred_to_args.keys())):
            props[pred_id] = 'Y'
            for start, label in pred_to_args[pred_id]:

                col_labels[i][start] = label

        print_sentence_to_conll(fout, props, col_labels)

    fout.close()


def print_to_span_conll(sentences, srl_labels, output_filename, gold_predicates):
    fout = codecs.open(output_filename, "w", "utf-8")

    for sent_id, words in enumerate(sentences):

        if gold_predicates:
            assert len(gold_predicates[sent_id]) == len(words)

        pred_to_args = srl_labels[sent_id]

        props = ["-" for _ in words]

        col_labels = [["*" for _ in words] for _ in range(len(pred_to_args))]

        for i, pred_id in enumerate(sorted(pred_to_args.keys())):

            # To make sure CoNLL-eval script count matching predicates as correct.
            if gold_predicates and gold_predicates[sent_id][pred_id] != "-":
                props[pred_id] = gold_predicates[sent_id][pred_id]
            else:
                props[pred_id] = "P" + words[pred_id]

            flags = [False for _ in words]

            for start, end, label in pred_to_args[pred_id]:
                # print(len(flags))
                # print(start)
                # print(end)
                if not max(flags[start:end + 1]):
                    col_labels[i][start] = "(" + label + col_labels[i][start]
                    col_labels[i][end] = col_labels[i][end] + ")"
                    for j in range(start, end + 1):
                        flags[j] = True

            # Add unpredicted verb (for predicted SRL).
            if not flags[pred_id]:
                col_labels[i][pred_id] = "(V*)"

        print_sentence_to_conll(fout, props, col_labels)

    fout.close()