#!/usr/bin/env python3
# encoding: utf-8

# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

"""Training/decoding definition for the speech recognition task."""

import copy
import json
import logging
import math
import os
import sys

from chainer import reporter as reporter_module
from chainer import training
from chainer.training import extensions
from chainer.training.updater import StandardUpdater
import numpy as np
import soundfile as sf
from tensorboardX import SummaryWriter
import torch
from torch.nn.parallel import data_parallel

from espnet.asr.asr_utils import adadelta_eps_decay
from espnet.asr.asr_utils import add_results_to_json
from espnet.asr.asr_utils import CompareValueTrigger
from espnet.asr.asr_utils import format_mulenc_args
from espnet.asr.asr_utils import get_model_conf
from espnet.asr.asr_utils import plot_spectrogram
from espnet.asr.asr_utils import restore_snapshot
from espnet.asr.asr_utils import snapshot_object
from espnet.asr.asr_utils import torch_load
from espnet.asr.asr_utils import torch_resume
from espnet.asr.asr_utils import torch_snapshot
from espnet.asr.pytorch_backend.asr import CustomUpdater
from espnet.asr.pytorch_backend.asr_init import load_trained_model
from espnet.asr.pytorch_backend.asr_init import load_trained_modules
import espnet.lm.pytorch_backend.extlm as extlm_pytorch
from espnet.nets.asr_interface import ASRInterface
from espnet.nets.pytorch_backend.e2e_asr import pad_list
import espnet.nets.pytorch_backend.lm.default as lm_pytorch
from espnet.nets.pytorch_backend.streaming.segment import SegmentStreamingE2E
from espnet.nets.pytorch_backend.streaming.window import WindowStreamingE2E
from espnet.transform.spectrogram import IStft
from espnet.transform.transformation import Transformation
from espnet.utils.cli_writers import file_writer_helper
from espnet.utils.dataset import ChainerDataLoader
from espnet.utils.dataset import TransformDataset
from espnet.utils.deterministic_utils import set_deterministic_pytorch
from espnet.utils.dynamic_import import dynamic_import
from espnet.utils.io_utils import LoadInputsAndTargets
from espnet.utils.training.batchfy import make_batchset
from espnet.utils.training.evaluator import BaseEvaluator
from espnet.utils.training.iterators import ShufflingEnabler
from espnet.utils.training.tensorboard_logger import TensorboardLogger
from espnet.utils.training.train_utils import check_early_stop
from espnet.utils.training.train_utils import set_early_stop
import espnet.nets.pytorch_backend.separation.bbs_eval as bbs_eval

import matplotlib

matplotlib.use("Agg")

if sys.version_info[0] == 2:
    from itertools import izip_longest as zip_longest
else:
    from itertools import zip_longest as zip_longest


class CustomConverter(object):
    """Custom batch converter for Pytorch.

    Args:
        subsampling_factor (int): The subsampling factor.
        dtype (torch.dtype): Data type to convert.

    """

    def __init__(self, dtype=torch.float32, num_spkrs=2):
        """Construct a CustomConverter object."""
        self.ignore_id = -1
        self.dtype = dtype
        self.num_spkrs = num_spkrs

    def __call__(self, batch, device=torch.device("cpu")):
        """Transform a batch and send it to a device.

        Args:
            batch (list): The batch to transform.
            device (torch.device): The device to send to.

        Returns:
            tuple(torch.Tensor, torch.Tensor, torch.Tensor)

        """
        # batch should be located in list
        assert len(batch) == 1
        xs, ys = batch[0][0], batch[0][1:]

        # get batch of lengths of input sequences
        ilens = np.array([x.shape[0] for x in xs])

        # perform padding and convert to tensor
        # currently only support real number
        if xs[0].dtype.kind == "c":
            xs_pad_real = pad_list(
                [torch.from_numpy(x.real).float() for x in xs], 0
            ).to(device, dtype=self.dtype)
            xs_pad_imag = pad_list(
                [torch.from_numpy(x.imag).float() for x in xs], 0
            ).to(device, dtype=self.dtype)
            # Note(kamo):
            # {'real': ..., 'imag': ...} will be changed to ComplexTensor in E2E.
            # Don't create ComplexTensor and give it E2E here
            # because torch.nn.DataParellel can't handle it.
            xs_pad = {"real": xs_pad_real, "imag": xs_pad_imag}
        else:
            xs_pad = pad_list([torch.from_numpy(x).float() for x in xs], 0).to(
                device, dtype=self.dtype
            )

        ilens = torch.from_numpy(ilens).to(device)
        # separate ref. wav unpack
        if not(type(ys[0][0]) is list):
            ys_seq = []
            ys_wav = []

            for i in range(len(ys)):  # nspeakers
                ys_seq.append([])
                ys_wav.append([])
                for n in range(len(xs)):  # nsamples
                    ys_seq[-1].append(ys[i][n]['seq'])
                    ys_wav[-1].append(ys[i][n]['wav'])
            ys = ys_seq

        if not isinstance(ys[0], np.ndarray):
            ys_pad = []
            for i in range(len(ys)):  # nspeakers
                ys_pad += [torch.from_numpy(y).long() for y in ys[i]]
            ys_pad = pad_list(ys_pad, self.ignore_id)
            ys_pad = (
                ys_pad.view(self.num_spkrs, -1, ys_pad.size(1)).transpose(0, 1).to(device)
            )  # (B, num_spkrs, Tmax)
            if not(type(ys[0][0]) is list):  # separate ref. wav
                ys_wav_pad = []
                for i in range(len(ys)):  # nspeakers
                    ys_wav_pad += [torch.from_numpy(y).long() for y in ys_wav[i]]
                ys_wav_pad = pad_list(ys_wav_pad, 0).to(device, dtype=self.dtype)
                ys_wav_pad = (
                    ys_wav_pad.view(self.num_spkrs, -1, ys_wav_pad.size(1)).transpose(0, 1).to(device)
                )  # (B, num_spkrs, Tmax)
        else:
            ys_pad = pad_list(
                [torch.from_numpy(y).long() for y in ys], self.ignore_id
            ).to(device)
            ys_wav_pad = None

        return xs_pad, ilens, ys_pad, ys_wav_pad


def train(args):
    """Train with the given args.

    Args:
        args (namespace): The program arguments.

    """
    set_deterministic_pytorch(args)

    # check cuda availability
    if not torch.cuda.is_available():
        logging.warning("cuda is not available")

    model_class = dynamic_import(args.model_module)
    model = model_class(args)

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + "/model.json"
    with open(model_conf, "wb") as f:
        logging.info("writing a model config file to " + model_conf)
        f.write(
            json.dumps(
                (vars(args)),
                indent=4,
                ensure_ascii=False,
                sort_keys=True,
            ).encode("utf_8")
        )
    for key in sorted(vars(args).keys()):
        logging.info("ARGS: " + key + ": " + str(vars(args)[key]))

    reporter = model.reporter

    # check the use of multi-gpu
    if args.ngpu > 1:
        if args.batch_size != 0:
            logging.warning(
                "batch size is automatically increased (%d -> %d)"
                % (args.batch_size, args.batch_size * args.ngpu)
            )
            args.batch_size *= args.ngpu

    # set torch device
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    if args.train_dtype in ("float16", "float32", "float64"):
        dtype = getattr(torch, args.train_dtype)
    else:
        dtype = torch.float32
    model = model.to(device=device, dtype=dtype)

    # Setup an optimizer
    if args.opt == "adam":
        optimizer = torch.optim.Adam(model.parameters(), weight_decay=args.weight_decay)
    elif args.opt == "noam":
        from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt

        optimizer = get_std_opt(
            model, args.adim, args.transformer_warmup_steps, args.transformer_lr
        )
    else:
        raise NotImplementedError("unknown optimizer: " + args.opt)

    # setup apex.amp
    if args.train_dtype in ("O0", "O1", "O2", "O3"):
        try:
            from apex import amp
        except ImportError as e:
            logging.error(
                f"You need to install apex for --train-dtype {args.train_dtype}. "
                "See https://github.com/NVIDIA/apex#linux"
            )
            raise e
        if args.opt == "noam":
            model, optimizer.optimizer = amp.initialize(
                model, optimizer.optimizer, opt_level=args.train_dtype
            )
        else:
            model, optimizer = amp.initialize(
                model, optimizer, opt_level=args.train_dtype
            )
        use_apex = True

        from espnet.nets.pytorch_backend.ctc import CTC

        amp.register_float_function(CTC, "loss_fn")
        amp.init()
        logging.warning("register ctc as float function")
    else:
        use_apex = False

    # FIXME: TOO DIRTY HACK
    setattr(optimizer, "target", reporter)
    setattr(optimizer, "serialize", lambda s: reporter.serialize(s))

    converter = CustomConverter(dtype=dtype)

    # read json data
    with open(args.train_json, "rb") as f:
        train_json = json.load(f)["utts"]
    with open(args.valid_json, "rb") as f:
        valid_json = json.load(f)["utts"]

    use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
    # make minibatch list (variable length)
    min_batch_size=args.ngpu if args.ngpu > 1 else 1
    print(min_batch_size)
    train = make_batchset(
        train_json,
        args.batch_size,
        args.maxlen_in,
        args.maxlen_out,
        args.drop_len,
        args.minibatches,
        min_batch_size=args.ngpu if args.ngpu > 1 else 1,
        shortest_first=use_sortagrad,
        count=args.batch_count,
        batch_bins=args.batch_bins,
        batch_frames_in=args.batch_frames_in,
        batch_frames_out=args.batch_frames_out,
        batch_frames_inout=args.batch_frames_inout,
        iaxis=0,
        oaxis=0,
    )
    valid = make_batchset(
        valid_json,
        args.batch_size,
        args.maxlen_in,
        args.maxlen_out,
        args.minibatches,
        min_batch_size=args.ngpu if args.ngpu > 1 else 1,
        count=args.batch_count,
        batch_bins=args.batch_bins,
        batch_frames_in=args.batch_frames_in,
        batch_frames_out=args.batch_frames_out,
        batch_frames_inout=args.batch_frames_inout,
        iaxis=0,
        oaxis=0,
    )

    load_tr = LoadInputsAndTargets(
        mode="asr",
        load_output=True,
        load_wav=True,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={"train": True},  # Switch the mode of preprocessing
    )
    load_cv = LoadInputsAndTargets(
        mode="asr",
        load_output=True,
        load_wav=True,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={"train": False},  # Switch the mode of preprocessing
    )
    # hack to make batchsize argument as 1
    # actual bathsize is included in a list
    # default collate function converts numpy array to pytorch tensor
    # we used an empty collate function instead which returns list
    train_iter = ChainerDataLoader(
        dataset=TransformDataset(train, lambda data: converter([load_tr(data)])),
        batch_size=1,
        num_workers=args.n_iter_processes,
        shuffle=not use_sortagrad,
        collate_fn=lambda x: x[0],
    )
    valid_iter = ChainerDataLoader(
        dataset=TransformDataset(valid, lambda data: converter([load_cv(data)])),
        batch_size=1,
        shuffle=False,
        collate_fn=lambda x: x[0],
        num_workers=args.n_iter_processes,
    )

    # Set up a trainer
    updater = CustomUpdater(
        model,
        args.grad_clip,
        {"main": train_iter},
        optimizer,
        device,
        args.ngpu,
        args.grad_noise,
        args.accum_grad,
        use_apex=use_apex,
    )
    trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir)

    if use_sortagrad:
        trainer.extend(
            ShufflingEnabler([train_iter]),
            trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, "epoch"),
        )
   
    # # Resume from a snapshot
    if args.resume:
        logging.info("resumed from %s" % args.resume)
        torch_resume(args.resume, trainer)

    # Make a plot for training and validation values
    trainer.extend(
        extensions.PlotReport(
            [
                "main/loss",
                "main/loss_tas",
                "main/loss_sil",
            ],
            "epoch",
            file_name="loss.png",
        )
    )

    # Save best models
    # trainer.extend(
    #     snapshot_object(model, "model.loss.best"),
    #     trigger=training.triggers.MinValueTrigger("main/loss"),
    # )

    # save snapshot which contains model and optimizer states
    if args.save_interval_iters > 0:
        trainer.extend(
            torch_snapshot(filename="snapshot.iter.{.updater.iteration}"),
            trigger=(args.save_interval_iters, "iteration"),
        )
    else:
        trainer.extend(torch_snapshot(), trigger=(1, "epoch"))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(
        extensions.LogReport(trigger=(args.report_interval_iters, "iteration"))
    )
    report_keys = [
        "epoch",
        "iteration",
        "main/loss",
        "main/loss_tas",
        "main/loss_sil",
        "elapsed_time",
    ]
    trainer.extend(
        extensions.PrintReport(report_keys),
        trigger=(args.report_interval_iters, "iteration"),
    )
    trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters))
    set_early_stop(trainer, args)

    if args.tensorboard_dir is not None and args.tensorboard_dir != "":
        trainer.extend(
            TensorboardLogger(SummaryWriter(args.tensorboard_dir)),
            trigger=(args.report_interval_iters, "iteration"),
        )
    # Run the training
    trainer.run()
    check_early_stop(trainer, args.epochs)


def recog(args):
    """Decode with the given args.

    Args:
        args (namespace): The program arguments.

    """
    set_deterministic_pytorch(args)
    train_args = get_model_conf(
        args.model, os.path.join(os.path.dirname(args.model), "model.json"))
    if hasattr(train_args, "model_module"):
        model_module = train_args.model_module
    else:
        model_module = "espnet.nets.pytorch_backend.e2e_separate:E2E"
    mode_class = dynamic_import(model_module)
    model = mode_class(train_args)
    torch_load(args.model, model)

    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    model = model.to(device=device)

    # read json data
    with open(args.recog_json, "rb") as f:
        js = json.load(f)["utts"]
    new_js = {}

    converter = CustomConverter()
    load_inputs_and_targets = LoadInputsAndTargets(
        mode="asr",
        load_output=True,
        load_wav=True,
        sort_in_input_length=False,
        preprocess_conf=train_args.preprocess_conf
        if args.preprocess_conf is None
        else args.preprocess_conf,
        preprocess_args={"train": False},
    )

    # batch size = 1
    with torch.no_grad():
        sdr_list = np.array([])
        sdri_list = np.array([])
        si_snri_list = np.array([])
        for idx, name in enumerate(js.keys(), 1):
            logging.info("(%d/%d) decoding " + name, idx, len(js.keys()))
            batch = [(name, js[name])]
            batch = load_inputs_and_targets(batch)
            xs_pad, _, _, ys_wav_pad = converter([batch], device=device)
            assert xs_pad.size(0) == 1
            preds_wav = model.recognize(xs_pad, ys_wav_pad.size(1))

            # compute SDR, SDRi and SI-SNR
            mix = xs_pad[0].cpu().numpy()
            pred = preds_wav[0].cpu().numpy()
            truth = ys_wav_pad[0].cpu().numpy()
            sdr, _, _, popt = bbs_eval.bss_eval_sources(truth, pred)
            sdri = cal_sdri(truth, pred[popt], mix)
            si_snri = cal_si_snri(truth, pred[popt], mix)
            logging.info("SDR: %.2f; SDRi: %.2f; SI-SNRi: %.2f",
                         sdr.mean(), sdri.mean(), si_snri.mean())

            # save separation wav
            # popt: [1, 0]
            # pred: (num_spk, T)
            for i in range(ys_wav_pad.size(1)):
                save_name = name + '_spk' + str(popt[i]+1) + ".wav"
                save_path = os.path.join(args.decode_dir, save_name)
                wav_norm = pred[i]/np.max(np.abs(pred[i]))
                sf.write(save_path, wav_norm, 8000)

            sdr_list = np.append(sdr_list, sdr)
            sdri_list = np.append(sdri_list, sdri)
            si_snri_list = np.append(si_snri_list, si_snri)

        print(('SDR Avg:', sdr_list.mean()))
        print(('SDRi Avg:', sdri_list.mean()))
        print(('SI-SNRi Avg:', si_snri_list.mean()))


def cal_sdri(src_ref, src_est, mix):
    """Calculate Source-to-Distortion Ratio improvement (SDRi).
    NOTE: bss_eval_sources is very very slow.
    Args:
        src_ref: numpy.ndarray, [C, T]
        src_est: numpy.ndarray, [C, T], reordered by best PIT permutation
        mix: numpy.ndarray, [T]
    Returns:
        average_SDRi
    """
    src_anchor = np.stack([mix for iii in range(src_ref.shape[0])], axis=0)
    if src_ref.shape[0] == 1:
        src_anchor = src_anchor[0]
    sdr, sir, sar, popt = bbs_eval.bss_eval_sources(src_ref, src_est)
    sdr0, sir0, sar0, popt0 = bbs_eval.bss_eval_sources(src_ref, src_anchor)
    avg_SDRi = ((sdr[0]-sdr0[0]) + (sdr[1]-sdr0[1])) / 2
    return avg_SDRi


def cal_si_snri(src_ref, src_est, mix):
    """Calculate Scale-Invariant Source-to-Noise Ratio improvement (SI-SNRi)
    Args:
        src_ref: numpy.ndarray, [C, T]
        src_est: numpy.ndarray, [C, T], reordered by best PIT permutation
        mix: numpy.ndarray, [T]
    Returns:
        average_SISNRi
    """
    sisnr1 = cal_si_snr(src_ref[0], src_est[0])
    sisnr2 = cal_si_snr(src_ref[1], src_est[1])
    sisnr1b = cal_si_snr(src_ref[0], mix)
    sisnr2b = cal_si_snr(src_ref[1], mix)
    avg_SISNRi = ((sisnr1 - sisnr1b) + (sisnr2 - sisnr2b)) / 2
    return avg_SISNRi


def cal_si_snr(ref_sig, out_sig, eps=1e-8):
    """Calcuate Scale-Invariant Source-to-Noise Ratio (SI-SNR)
    Args:
        ref_sig: numpy.ndarray, [T]
        out_sig: numpy.ndarray, [T]
    Returns:
        SISNR
    """
    assert len(ref_sig) == len(out_sig)
    ref_sig = ref_sig - np.mean(ref_sig)
    out_sig = out_sig - np.mean(out_sig)
    ref_energy = np.sum(ref_sig ** 2) + eps
    proj = np.sum(ref_sig * out_sig) * ref_sig / ref_energy
    noise = out_sig - proj
    ratio = np.sum(proj ** 2) / (np.sum(noise ** 2) + eps)
    sisnr = 10 * np.log(ratio + eps) / np.log(10.0)
    return sisnr


def extract_feats(args):
    """Extract feat from torchaudio to vertify.

    Args:
        args (namespace): The program arguments.

    """
    set_deterministic_pytorch(args)

    # read json data
    with open(args.recog_json, "rb") as f:
        js = json.load(f)["utts"]

    converter = CustomConverter()
    load_inputs_and_targets = LoadInputsAndTargets(
        mode="asr",
        load_output=True,
        load_wav=True,
        sort_in_input_length=False,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={"train": True},
    )

    params = {
        "dither": 1.0,
        "sample_frequency": 8000,
        "frame_length": 25,
        "low_freq": 20,
        "num_mel_bins": 80
    }

    global_mean = torch.tensor(args.global_mean)
    global_std = torch.tensor(args.global_std)

    for idx, name in enumerate(js.keys(), 1):
        if idx > args.num_utt: break
        logging.info("(%d/%d) generate feats for: " + name, idx, len(js.keys()))
        batch = [(name, js[name])]
        batch = load_inputs_and_targets(batch)
        _, _, _, ys_wav_pad = converter([batch])

        import torchaudio
        for spk in range(args.num_spkrs):
            waveform = ys_wav_pad[:, spk, :]
            feat = torchaudio.compliance.kaldi.fbank(waveform, **params)
            feat_norm = (feat - global_mean) / global_std
            save_name1 = name + '_spk' + str(spk+1) + '.npy'
            save_name2 = name + '_spk' + str(spk+1) + '_norm.npy'
            np.save(os.path.join(args.output_dir, save_name1), feat)
            np.save(os.path.join(args.output_dir, save_name2), feat_norm)
