#!/usr/bin/env python3
# encoding: utf-8

# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

"""RNN sequence-to-sequence speech recognition model (pytorch)."""

import argparse
from itertools import groupby
import logging
import math
import os

import chainer
from chainer import reporter
import editdistance
import numpy as np
import six
import torch
import torch.nn as nn
import torch.nn.functional as F

import espnet.nets.pytorch_backend.separation.criterion as criterion
import espnet.nets.pytorch_backend.separation.bbs_eval as bbs_eval
from espnet.nets.pytorch_backend.separation.encoder import Encoder
from espnet.nets.pytorch_backend.separation.temporal_convnet import TemporalConvNet
from espnet.nets.pytorch_backend.separation.decoder import Decoder


class Reporter(chainer.Chain):
    """A chainer reporter wrapper."""

    def report(self, loss_tas, loss_sil, mtl_loss):
        """Report at every step."""
        reporter.report({"loss_tas": loss_tas}, self)
        reporter.report({"loss_sil": loss_sil}, self)
        logging.info("mtl loss:" + str(mtl_loss))
        reporter.report({"loss": mtl_loss}, self)


class E2E(torch.nn.Module):
    """E2E module.

    :param int idim: dimension of inputs
    :param int odim: dimension of outputs
    :param Namespace args: argument Namespace containing options

    """

    @staticmethod
    def add_arguments(parser):
        """Add arguments."""
        group = parser.add_argument_group("Tasnet model paremeters.")
        group.add_argument(
            "--N",
            default=256,
            type=int,
            help="Number of filters in encoder layers"
        )
        group.add_argument(
            "--L",
            default=20,
            type=int,
            help="Length of the filters (in samples)"
        )
        group.add_argument(
            "--B",
            default=256,
            type=int,
            help="Number of channels in bottleneck 1 * 1-conv block"
        )
        group.add_argument(
            "--H",
            default=512,
            type=int,
            help="Number of channels in convolutional blocks"
        )
        group.add_argument(
            "--P",
            default=3,
            type=int,
            help="Kernel size in convolutional blocks"
        )
        group.add_argument(
            "--X",
            default=8,
            type=int,
            help="Number of convolutional blocks in each repeat"
        )
        group.add_argument(
            "--R",
            default=4,
            type=int,
            help="Number of repeats"
        )
        group.add_argument(
            "--C",
            default=2,
            type=int,
            help="Number of speakers"
        )
        group.add_argument(
            "--norm-type",
            default="gLN",
            type=str,
            choices=[
                "BN",
                "gLN",
                "cLN"
            ],
            help="BN, gLN, cLN"
        )
        group.add_argument(
            "--causal",
            default=0,
            type=int,
            help="causal or non-causal"
        )
        group.add_argument(
            "--mask-nonlinear",
            default='relu',
            type=str,
            help="use which non-linear function to generate mask"
        )
        group.add_argument(
            "--end-separation-mode",
            default=0,
            type=int,
            help="end-separation-mode"
        )
        group.add_argument(
            '--greedy-tf',
            default=0,
            type=int,
            help='greedy-tf'
        )        
        group.add_argument(
            '--add-last-silence',
            default=0,
            type=int,
            help='Add last silence'
        )
        group.add_argument(
            '--pit-without-tf',
            default=0,
            type=int,
            help='PIT without teacher force.'
        )

        return parser

    def __init__(self, args):
        """Construct an E2E object.
        """
        super(E2E, self).__init__()
        torch.nn.Module.__init__(self)

        self.add_last_silence = args.add_last_silence
        self.pit_without_tf = args.pit_without_tf
        self.greedy_tf = args.greedy_tf
        self.reporter = Reporter()

        # tasnet
        self.encoder = Encoder(args.L, args.N)
        self.separator = TemporalConvNet(args.N, args.B, args.H, args.P, args.X, args.R,
                                         args.C, args.norm_type, args.causal, args.mask_nonlinear, args.end_separation_mode)
        self.mask_conv1x1 = nn.Conv1d(args.B, args.N, 1, bias=False)
        self.decoder = Decoder(args.N, args.L)
        # init
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_normal_(p)

        self.spk_lstm = nn.LSTMCell(args.B+args.N, args.B)  # LSTM over the speakers' step.

    def choose_candidate(self, pred_wav, ilens, ys_wav_dict, BS):
        # pred_wav: BS, T , clean_batch_dict: 长度为BS的列表，每个元素是dict，里面是序号对应的向量
        # 每一步选出来一个最近的，然后将它从列表里删除去
        spk_list = []
        cand_wavs_list = []

        for idx in range(BS):
            est_wav = pred_wav[idx]  # T
            candidates_dict = ys_wav_dict[idx]  # dict topk,T
            key_max = None
            snr_max = None  # original key and dist
            for key, cand_wav in candidates_dict.items():
                # dist = F.mse_loss(torch.from_numpy(normalization_11(est_wav.data.cpu().numpy())),torch.from_numpy(normalization_11(cand_wav.data.cpu().numpy())))
                # snr = models.cal_si_snr_with_order(cand_wav.view(1,1,-1), est_wav.view(1,1,-1), torch.ones([1]).int().cuda()*cand_wav.shape[-1])
                snr = criterion.cal_sdr_with_order(
                    cand_wav.view(1, 1, -1), est_wav.view(1, 1, -1),
                    ilens[idx].view(1))
                if snr_max is None:
                    snr_max = snr
                    key_max = key
                else:
                    if snr > snr_max:
                        snr_max = snr
                        key_max = key
            spk_list.append(key_max)
            cand_wavs_list.append(ys_wav_dict[idx][key_max].unsqueeze(0)) # list of 1,T
            ys_wav_dict[idx].pop(key_max)  # 移除该元素

        return cand_wavs_list, spk_list, ys_wav_dict

    def forward(self, xs_pad, ilens, ys_pad, ys_wav_pad):
        batch_size, num_spkrs = ys_wav_pad.size(0), ys_wav_pad.size(1)
        ys_wav_dict = []
        for i in range(batch_size):
            this_dict = {idx: cand for idx, cand in enumerate(ys_wav_pad[i])}
            ys_wav_dict.append(this_dict)

        # xs_pad: (BS, T)
        enc_output = self.encoder(xs_pad)  # (BS, N, K)
        # mixture_w: (BS, B, K), where K = (T - L) / (L / 2) + 1 = 2 T / L - 1
        sep_output = self.separator(enc_output)

        # 注意这里是sep模块之后 还没加conv1x1的情况下
        # First step to use all ZEROs
        condition_last_step = torch.zeros_like(enc_output)
        N, B = self.encoder.N, self.separator.B
        BS, K = enc_output.size(0), enc_output.size(-1) # new length

        preds_wav = []
        ys_wav_resorted = []
        spks_list = []
        repeat_time = num_spkrs + 1 if self.add_last_silence else num_spkrs

        for step_idx in range(repeat_time):
            cat_condition_this_step = torch.cat((sep_output, condition_last_step), 1)    # (BS, N, K) --> (BS, B+N, K)
            if step_idx == 0:
                # (BS, B+N, K) --> (BS, K, B+N) --> (BS*K, B+N) --> (BS*K, B)
                h_0, c_0 = torch.zeros(BS*K, B).to(xs_pad.device), torch.zeros(BS*K, B).to(xs_pad.device)
                lstm_h, lstm_c = self.spk_lstm(cat_condition_this_step.transpose(1, 2).contiguous().view(-1, B+N), (h_0, c_0))
                del h_0, c_0
            else:
                # (BS, B+N, K) --> (BS, K, B+N) --> (BS*K, B+N) --> (BS*K, B)
                lstm_h, lstm_c = self.spk_lstm(cat_condition_this_step.transpose(1, 2).contiguous().view(-1, B+N), (lstm_h, lstm_c))

            # (BS*K, B) --> (BS, K, B) --> (BS, B, K) --> (BS, N, K)
            pred_wav_this_step = self.mask_conv1x1(lstm_h.view(-1, K, B).transpose(1, 2))
            pred_wav_this_step = F.relu(pred_wav_this_step).unsqueeze(1)    # (BS, 1, N, K)
            pred_wav_this_step = self.decoder(enc_output, pred_wav_this_step).squeeze(1)    # (BS, T)
            T_origin = xs_pad.size(-1)
            T_conv = pred_wav_this_step.size(-1)
            pred_wav_this_step = F.pad(pred_wav_this_step, (0, T_origin-T_conv))
            preds_wav.append(pred_wav_this_step)
            if self.add_last_silence and step_idx == repeat_time - 1:  # 如果是最后一个，后面就不用
                continue

            # update the condition
            y_wav_this_step, spk_list_this_step, ys_wav_dict = self.choose_candidate(
                pred_wav_this_step, ilens, ys_wav_dict, BS)
            logging.info("Step: {}, spk list: {}".format(step_idx, spk_list_this_step))

            y_wav_this_step = torch.cat(y_wav_this_step, 0)    # (BS, T)
            # training add some white noise
            # use a conv1d to subsample the original wav to (BS, N, K)
            condition_last_step = self.encoder(
                y_wav_this_step + 0.5*torch.randn_like(y_wav_this_step))
            ys_wav_resorted.append(y_wav_this_step)
            spks_list.append(spk_list_this_step)

        preds_wav = torch.stack(preds_wav, dim=1)
        ys_wav_resorted = torch.stack(ys_wav_resorted, dim=1)
        spks_list = torch.tensor(spks_list, device=xs_pad.device).transpose(0, 1)

        if self.greedy_tf and self.add_last_silence:
            loss_sil = 0.0

        loss_tas = criterion.cal_loss_with_sdr_order(preds_wav, ys_wav_resorted, ilens)[0] 
        
        if self.add_last_silence:
            self.loss = loss_tas + 0.1 * loss_sil
            loss_tas_data = float(loss_tas)
            loss_sil_data = float(loss_sil)
        else:
            self.loss = loss_tas
            loss_tas_data = float(loss_tas)
            loss_sil_data = None

        loss_data = float(self.loss)

        if not math.isnan(loss_data):
            self.reporter.report(loss_tas_data, loss_sil_data, loss_data)
        else:
            logging.warning("loss (=%f) is not correct", loss_data)


        return self.loss

    def recognize(self, xs_pad, num_spkrs=2):     
        # xs_pad: (BS, T)
        enc_output = self.encoder(xs_pad)  # (BS, N, K)
        # mixture_w: (BS, B, K), where K = (T - L) / (L / 2) + 1 = 2 T / L - 1
        sep_output = self.separator(enc_output)

        # 注意这里是sep模块之后 还没加conv1x1的情况下
        # First step to use all ZEROs
        condition_last_step = torch.zeros_like(enc_output)
        N, B = self.encoder.N, self.separator.B
        BS, K = enc_output.size(0), enc_output.size(-1) # new length

        preds_wav = []
        repeat_time = num_spkrs + 1 if self.add_last_silence else num_spkrs

        for step_idx in range(repeat_time):
            cat_condition_this_step = torch.cat((sep_output, condition_last_step), 1)    # (BS, N, K) --> (BS, B+N, K)
            if step_idx == 0:
                # (BS, B+N, K) --> (BS, K, B+N) --> (BS*K, B+N) --> (BS*K, B)
                h_0, c_0 = torch.zeros(BS*K, B).to(xs_pad.device), torch.zeros(BS*K, B).to(xs_pad.device)
                lstm_h, lstm_c = self.spk_lstm(cat_condition_this_step.transpose(1, 2).contiguous().view(-1, B+N), (h_0, c_0))
                del h_0, c_0
            else:
                # (BS, B+N, K) --> (BS, K, B+N) --> (BS*K, B+N) --> (BS*K, B)
                lstm_h, lstm_c = self.spk_lstm(cat_condition_this_step.transpose(1, 2).contiguous().view(-1, B+N), (lstm_h, lstm_c))

            # (BS*K, B) --> (BS, K, B) --> (BS, B, K) --> (BS, N, K)
            pred_wav_this_step = self.mask_conv1x1(lstm_h.view(-1, K, B).transpose(1, 2))
            pred_wav_this_step = F.relu(pred_wav_this_step).unsqueeze(1)    # (BS, 1, N, K)
            pred_wav_this_step = self.decoder(enc_output, pred_wav_this_step).squeeze(1)    # (BS, T)
            T_origin = xs_pad.size(-1)
            T_conv = pred_wav_this_step.size(-1)
            pred_wav_this_step = F.pad(pred_wav_this_step, (0, T_origin-T_conv))

            # update the conditon by estimated wav
            condition_last_step = self.encoder(pred_wav_this_step)
            preds_wav.append(pred_wav_this_step)

        preds_wav = torch.stack(preds_wav, dim=1)

        return preds_wav
