import argparse
import sys
from concurrent.futures import ProcessPoolExecutor
from pathlib import Path

import librosa
import numpy as np
from nnmnkwii.frontend import merlin as fe
from nnmnkwii.io import hts
from scipy.io import wavfile
from tqdm import tqdm
from ttslearn.dsp import world_spss_params


def get_parser():
    parser = argparse.ArgumentParser(
        description="Preprocess for acoustic models",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument("utt_list", type=str, help="utternace list")
    parser.add_argument("wav_root", type=str, help="wav directory")
    parser.add_argument("lab_root", type=str, help="label directory")
    parser.add_argument("qst_file", type=str, help="HTS style question file")
    parser.add_argument("out_dir", type=str, help="out directory")
    parser.add_argument("--n_jobs", type=int, default=1, help="Number of jobs")
    parser.add_argument("--sample_rate", type=int, default=16000, help="Sample rate")

    return parser


def preprocess(wav_file, lab_file, binary_dict, numeric_dict, sr, in_dir, out_dir):
    assert wav_file.stem == lab_file.stem
    # 言語特徴量の計算
    labels = hts.load(lab_file)
    in_feats = fe.linguistic_features(
        labels,
        binary_dict,
        numeric_dict,
        add_frame_features=True,
        subphone_features="coarse_coding",
    )
    # 音響特徴量の計算
    _sr, x = wavfile.read(wav_file)
    if x.dtype in [np.int16, np.int32]:
        x = (x / np.iinfo(x.dtype).max).astype(np.float64)
    x = librosa.resample(x, orig_sr=_sr, target_sr=sr).astype(np.float64)
    # workaround for over resampling: add a small white noise
    if sr > _sr:
        x = x + np.random.randn(len(x)) * (1 / 2 ** 15)

    out_feats = world_spss_params(x, sr)

    # フレーム数の調整
    minL = min(in_feats.shape[0], out_feats.shape[0])
    in_feats, out_feats = in_feats[:minL], out_feats[:minL]

    # 冒頭と末尾の非音声区間の長さを調整
    assert "sil" in labels.contexts[0] and "sil" in labels.contexts[-1]
    start_frame = int(labels.start_times[1] / 50000)
    end_frame = int(labels.end_times[-2] / 50000)

    # 冒頭：50 ミリ秒、末尾：100 ミリ秒
    start_frame = max(0, start_frame - int(0.050 / 0.005))
    end_frame = min(minL, end_frame + int(0.100 / 0.005))

    in_feats = in_feats[start_frame:end_frame]
    out_feats = out_feats[start_frame:end_frame]

    # NumPy 形式でファイルに保存
    utt_id = lab_file.stem
    np.save(
        in_dir / f"{utt_id}-feats.npy", in_feats.astype(np.float32), allow_pickle=False
    )
    np.save(
        out_dir / f"{utt_id}-feats.npy",
        out_feats.astype(np.float32),
        allow_pickle=False,
    )


if __name__ == "__main__":
    args = get_parser().parse_args(sys.argv[1:])

    with open(args.utt_list) as f:
        utt_ids = [utt_id.strip() for utt_id in f]
    wav_files = [Path(args.wav_root) / f"{utt_id}.wav" for utt_id in utt_ids]
    lab_files = [Path(args.lab_root) / f"{utt_id}.lab" for utt_id in utt_ids]
    binary_dict, numeric_dict = hts.load_question_set(args.qst_file)

    in_dir = Path(args.out_dir) / "in_acoustic"
    out_dir = Path(args.out_dir) / "out_acoustic"
    in_dir.mkdir(parents=True, exist_ok=True)
    out_dir.mkdir(parents=True, exist_ok=True)

    with ProcessPoolExecutor(args.n_jobs) as executor:
        futures = [
            executor.submit(
                preprocess,
                wav_file,
                lab_file,
                binary_dict,
                numeric_dict,
                args.sample_rate,
                in_dir,
                out_dir,
            )
            for wav_file, lab_file in zip(wav_files, lab_files)
        ]
        for future in tqdm(futures):
            future.result()
