#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license. See LICENSE.md file in the project root for full license information.

import json
import logging
import os
import stat
import sys
import time
from zipfile import ZipFile
import shutil
import click
from xdg import XDG_CONFIG_HOME, XDG_CACHE_HOME

import nlgeval
import nlgeval.utils

CODE_PATH = nlgeval.__path__[0]


def _download_file(d):
    import requests
    from tqdm import tqdm

    url, target_dir = d['url'], d['target_dir']
    filename = url[url.rfind('/') + 1:]
    target_path = os.path.join(target_dir, filename)
    if not os.path.exists(target_path):
        # Collect data 1MB at a time.
        chunk_size = 1 * 1024 * 1024

        num_attempts = 3

        for attempt_num in range(num_attempts):
            try:
                print("Downloading {} to {}.".format(url, target_dir))
                r = requests.get(url, stream=True)
                r.raise_for_status()

                total = None
                length = r.headers.get('Content-length')
                if length is not None:
                    total = int(length) // chunk_size + 1

                with open(target_path, 'wb') as f:
                    for chunk in tqdm(r.iter_content(chunk_size=chunk_size),
                                      desc="{}".format(filename),
                                      total=total,
                                      unit_scale=True, mininterval=15, unit=" chunks"):
                        sys.stdout.flush()
                        f.write(chunk)
                break
            except:
                if attempt_num < num_attempts - 1:
                    wait_s = 1 * 60
                    logging.exception("Error downloading file, will retry in %ds.", wait_s)
                    # Wait and try to download later.
                    time.sleep(wait_s)
                else:
                    raise


@click.command()
@click.argument("data_path", required=False)
def setup(data_path):
    """
    Download required code and data files for nlg-eval.

    If the data_path argument is provided, install to the given location.
    Otherwise, your cache directory is used (usually ~/.cache/nlgeval).
    """
    from nltk.downloader import download
    download('punkt')

    from multiprocessing import Pool

    if data_path is None:
        data_path = os.getenv('NLGEVAL_DATA', os.path.join(XDG_CACHE_HOME, 'nlgeval'))
    click.secho("Installing to {}".format(data_path), fg='red')
    click.secho("In case of incomplete downloads, delete the directory and run `nlg-eval --setup {}' again.".format(data_path),
               fg='red')

    path = os.path.join(CODE_PATH, 'word2vec/glove2word2vec.py')
    if os.path.exists(path):
        os.remove(path)

    downloads = []

    if sys.version_info[0] == 2:
        downloads.append(dict(
            url='https://raw.githubusercontent.com/manasRK/glove-gensim/42ce46f00e83d3afa028fb6bf17ed3c90ca65fcc/glove2word2vec.py',
            target_dir=os.path.join(CODE_PATH, 'word2vec')
        ))
    else:
        downloads.append(dict(
            url='https://raw.githubusercontent.com/robmsmt/glove-gensim/4c2224bccd61627b76c50a5e1d6afd1c82699d22/glove2word2vec.py',
            target_dir=os.path.join(CODE_PATH, 'word2vec')
        ))

    setup_glove = not os.path.exists(os.path.join(data_path, 'glove.6B.300d.model.bin'))
    if setup_glove:
        downloads.append(dict(
            url='http://nlp.stanford.edu/data/glove.6B.zip',
            target_dir=data_path
        ))

    # Skip-thoughts data.
    downloads.append(dict(
        url='http://mirror.nubenum.de/www.cs.toronto.edu/~rkiros/models/dictionary.txt',
        target_dir=data_path
    ))
    downloads.append(dict(
            url='http://nlp.stanford.edu/software/stanford-corenlp-full-2015-12-09.zip',
            target_dir=CODE_PATH
    ))   
    downloads.append(dict(
        url='http://mirror.nubenum.de/www.cs.toronto.edu/~rkiros/models/utable.npy',
        target_dir=data_path
    ))
    downloads.append(dict(
        url='http://mirror.nubenum.de/www.cs.toronto.edu/~rkiros/models/btable.npy',
        target_dir=data_path
    ))
    downloads.append(dict(
        url='http://mirror.nubenum.de/www.cs.toronto.edu/~rkiros/models/uni_skip.npz',
        target_dir=data_path
    ))
    downloads.append(dict(
        url='http://mirror.nubenum.de/www.cs.toronto.edu/~rkiros/models/uni_skip.npz.pkl',
        target_dir=data_path
    ))
    downloads.append(dict(
        url='http://mirror.nubenum.de/www.cs.toronto.edu/~rkiros/models/bi_skip.npz',
        target_dir=data_path
    ))
    downloads.append(dict(
        url='http://mirror.nubenum.de/www.cs.toronto.edu/~rkiros/models/bi_skip.npz.pkl',
        target_dir=data_path
    ))

    # multi-bleu.perl
    downloads.append(dict(
        url='https://raw.githubusercontent.com/moses-smt/mosesdecoder/b199e654df2a26ea58f234cbb642e89d9c1f269d/scripts/generic/multi-bleu.perl',
        target_dir=os.path.join(CODE_PATH, 'multibleu')
    ))

    for target_dir in {d['target_dir'] for d in downloads}:
        if not os.path.exists(target_dir):
            os.makedirs(target_dir)

    # Limit the number of threads so that we don't download too much from the same source concurrently.
    pool = Pool(min(4, len(downloads)))
    pool.map(_download_file, downloads)
    pool.close()
    pool.join()

    if setup_glove:
        from nlgeval.word2vec.generate_w2v_files import generate
        with ZipFile(os.path.join(data_path, 'glove.6B.zip')) as z:
            z.extract('glove.6B.300d.txt', data_path)
        generate(data_path)
        for p in [
            os.path.join(data_path, 'glove.6B.zip'),
            os.path.join(data_path, 'glove.6B.300d.txt'),
            os.path.join(data_path, 'glove.6B.300d.model.txt'),
        ]:
            if os.path.exists(p):
                os.remove(p)

    with ZipFile(os.path.join(CODE_PATH, 'stanford-corenlp-full-2015-12-09.zip')) as z:
        z.extractall(CODE_PATH)
    temp_path = os.path.join(CODE_PATH,'stanford-corenlp-full-2015-12-09')
    os.makedirs(os.path.join(CODE_PATH,'pycocoevalcap/spice/lib/'), exist_ok=True)
    shutil.move(os.path.join(temp_path,'stanford-corenlp-3.6.0.jar'), os.path.join(CODE_PATH,'pycocoevalcap/spice/lib/stanford-corenlp-3.6.0.jar'))
    shutil.move(os.path.join(temp_path,'stanford-corenlp-3.6.0-models.jar'),os.path.join(CODE_PATH,'pycocoevalcap/spice/lib/stanford-corenlp-3.6.0-models.jar')) 
    os.remove(os.path.join(CODE_PATH, 'stanford-corenlp-full-2015-12-09.zip'))

    path = os.path.join(CODE_PATH, 'multibleu/multi-bleu.perl')
    stats = os.stat(path)
    os.chmod(path, stats.st_mode | stat.S_IEXEC)

    cfg_path = os.path.join(XDG_CONFIG_HOME, "nlgeval")
    if not os.path.exists(cfg_path):
        os.makedirs(cfg_path)
    rc = dict()
    try:
        with open(os.path.join(cfg_path, "rc.json"), 'rt') as f:
            rc = json.load(f)
    except:
        print("WARNING: could not read rc.json in %s, overwriting" % cfg_path)
    rc['data_path'] = data_path
    with open(os.path.join(cfg_path, "rc.json"), 'wt') as f:
        f.write(json.dumps(rc))


@click.command()
@click.option('--references', type=click.Path(exists=True), multiple=True, required=True, help='Path of the reference file. This option can be provided multiple times for multiple reference files.')
@click.option('--hypothesis', type=click.Path(exists=True), required=True, help='Path of the hypothesis file.')
@click.option('--no-overlap', is_flag=True, help='Flag. If provided, word overlap based metrics will not be computed.')
@click.option('--no-skipthoughts', is_flag=True, help='Flag. If provided, skip-thought cosine similarity will not be computed.')
@click.option('--no-glove', is_flag=True, help='Flag. If provided, other word embedding based metrics will not be computed.')
def compute_metrics(hypothesis, references, no_overlap, no_skipthoughts, no_glove):
    """
    Compute nlg-eval metrics.

    The --hypothesis and at least one --references parameters are required.

    To download the data and additional code files, use `nlg-eval --setup [data path]`.

    Note that nlg-eval also features an API, which may be easier to use.
    """
    try:
        data_dir = nlgeval.utils.get_data_dir()
    except nlgeval.utils.InvalidDataDirException:
        sys.exit(1)
    click.secho("Using data from {}".format(data_dir), fg='green')
    click.secho("In case of broken downloads, remove the directory and run setup again.", fg='green')
    nlgeval.compute_metrics(hypothesis, references, no_overlap, no_skipthoughts, no_glove)


if __name__ == '__main__':
    if len(sys.argv) > 1 and sys.argv[1] == '--setup':
        del sys.argv[0]
        setup()
    else:
        compute_metrics()
