import torch
import utils
from mlp import MLP
from torch import nn
from torchdiffeq import odeint

omega_feasible_range = [4.5, 15.5]


class FeatureExtractor(nn.Module):
    def __init__(self, config):
        super(FeatureExtractor, self).__init__()
        dim_t = config['dim_t'] * 2
        activation = config['activation']
        arch_feat = config['arch_feat']
        num_units_feat = config['num_units_feat']
        self.dim_t = dim_t
        self.arch_feat = arch_feat
        self.num_units_feat = num_units_feat
        if arch_feat == 'mlp':
            hidlayers_feat = config['hidlayers_feat']
            self.func = MLP([dim_t,]+hidlayers_feat+[num_units_feat,], activation, actfun_output=True)
        elif arch_feat == 'rnn':
            num_rnns_feat = config['num_rnns_feat']
            self.num_rnns_feat = num_rnns_feat
            self.func = nn.GRU(1, num_units_feat, num_layers=num_rnns_feat, bidirectional=False)
        else:
            raise ValueError('unknown feature type')

    def forward(self, x):
        x_ = x.reshape(-1, self.dim_t)
        n = x_.shape[0]
        device = x_.device
        if self.arch_feat == 'mlp':
            feat = self.func(x_)
        elif self.arch_feat == 'rnn':
            h_0 = torch.zeros(self.num_rnns_feat, n, self.num_units_feat, device=device)
            out, h_n = self.func(x_.T.unsqueeze(2), h_0)
            feat = out[-1]
        return feat


class Decoders(nn.Module):
    def __init__(self, config):
        super(Decoders, self).__init__()
        dim_t = config['dim_t']
        dim_z_aux1 = config['dim_z_aux1']
        dim_z_aux2 = config['dim_z_aux2']
        activation = config['activation']
        no_phy = config['no_phy']
        x_lnvar = config['x_lnvar']
        # x_lnvar
        self.register_buffer('param_x_lnvar', torch.ones(1)*x_lnvar)
        if dim_z_aux1 >= 0:
            hidlayers_aux1 = config['hidlayers_aux1_dec']
            # z_aux1, yy (=[y, y_dot]) & t --> time-derivative of y_dot
            self.func_aux1 = MLP([dim_z_aux1+4,]+hidlayers_aux1+[2,], activation)

        if dim_z_aux2 >= 0:
            hidlayers_aux2 = config['hidlayers_aux2_dec']
            # z_phy, z_aux2 --> x - y_seq
            dim_z_phy = 0 if no_phy else 1
            self.func_aux2_res = MLP([dim_z_phy+max(0, dim_z_aux1)+dim_z_aux2,]+hidlayers_aux2+[dim_t*2,], activation)


class Encoders(nn.Module):
    def __init__(self, config):
        super(Encoders, self).__init__()
        dim_t = config['dim_t']
        dim_z_aux1 = config['dim_z_aux1']
        dim_z_aux2 = config['dim_z_aux2']
        activation = config['activation']
        no_phy = config['no_phy']
        num_units_feat = config['num_units_feat']
        if dim_z_aux1 > 0:
            hidlayers_aux1_enc = config['hidlayers_aux1_enc']
            # x --> feature_aux1
            self.func_feat_aux1 = FeatureExtractor(config)
            # feature_aux1 --> z_aux1
            self.func_z_aux1_mean = MLP([num_units_feat,]+hidlayers_aux1_enc+[dim_z_aux1,], activation)
            self.func_z_aux1_lnvar = MLP([num_units_feat,]+hidlayers_aux1_enc+[dim_z_aux1,], activation)
        if dim_z_aux2 > 0:
            hidlayers_aux2_enc = config['hidlayers_aux2_enc']
            # x --> feature_aux2
            self.func_feat_aux2 = FeatureExtractor(config)
            # feature_aux2 --> z_aux2
            self.func_z_aux2_mean = MLP([num_units_feat,]+hidlayers_aux2_enc+[dim_z_aux2,], activation)
            self.func_z_aux2_lnvar = MLP([num_units_feat,]+hidlayers_aux2_enc+[dim_z_aux2,], activation)
        if not no_phy:
            hidlayers_unmixer = config['hidlayers_unmixer']
            hidlayers_omega = config['hidlayers_omega']
            # x, z_aux1, z_aux2 --> unmixed - x
            self.func_unmixer_res = MLP([dim_t*2+max(dim_z_aux1,0)+max(dim_z_aux2,0),]+hidlayers_unmixer+[dim_t*2,], activation)
            # unmixed --> feature_phy
            self.func_feat_phy = FeatureExtractor(config)
            # features_phy --> omega
            self.func_omega_mean = nn.Sequential(MLP([num_units_feat,]+hidlayers_omega+[1,], activation), nn.Softplus())
            self.func_omega_lnvar = MLP([num_units_feat,]+hidlayers_omega+[1,], activation)


class Physics(nn.Module):
    def __init__(self):
        super(Physics, self).__init__()

    def forward(self, G, yy):
        th1, th2, thdot1, thdot2 = yy[:, [0]], yy[:, [1]], yy[:, [2]], yy[:, [3]]
        c, s = torch.cos(th1 - th2), torch.sin(th1 - th2)
        thdot1sq, thdot2sq = thdot1 ** 2, thdot2 ** 2
        denominator = 1 + s ** 2
        term1 = G * torch.sin(th2) * c - s * (1 * thdot1sq * c + 1 * thdot2sq) - (1 + 1) * G * torch.sin(th1)
        thdotdot1 = term1 / (1 * denominator)
        term2 = (1 + 1) * (1 * thdot1sq * s - G * torch.sin(th2) + G * torch.sin(th1) * c) + 1 * thdot2sq * s * c
        thdotdot2 = term2 / (1 * denominator)
        return torch.cat((thdot1, thdot2, thdotdot1, thdotdot2), dim=-1)


class VAE(nn.Module):
    def __init__(self, config):
        super(VAE, self).__init__()
        self.dim_t = config['dim_t']
        self.dim_z_aux1 = config['dim_z_aux1']
        self.dim_z_aux2 = config['dim_z_aux2']
        self.range_omega = config['range_omega']
        self.activation = config['activation']
        self.dt = config['dt']
        self.intg_lev = config['intg_lev']
        self.ode_solver = config['ode_solver']
        self.no_phy = config['no_phy']
        # Decoding part
        self.dec = Decoders(config)
        # Encoding part
        self.enc = Encoders(config)
        # Physics
        self.physics_model = Physics()
        # set time indices for integration
        self.dt_intg = self.dt / float(self.intg_lev)
        self.len_intg = (self.dim_t - 1) * self.intg_lev + 1
        self.register_buffer('t_intg', torch.linspace(0.0, self.dt_intg * (self.len_intg - 1), self.len_intg))

    def priors(self, n, device, context):
        prior_omega_stat = {
            'mean': torch.ones(n, 1, device=device) * 0.5 * (self.range_omega[0] + self.range_omega[1]),
            'lnvar': 2.0 * torch.log(
                torch.ones(n, 1, device=device) * max(1e-3, 0.866 * (self.range_omega[1] - self.range_omega[0])))}
        K = context.shape[1]
        context = context.reshape(-1, K, self.dim_t*2)
        if self.dim_z_aux1 > 0:
            D_feature_aux1 = []
            for k in range(K):
                D_feature_aux1.append(self.enc.func_feat_aux1(context[:, k]))
            feature_aux1 = sum(D_feature_aux1) / len(D_feature_aux1)
            prior_z_aux1_stat = {
                'mean': self.enc.func_z_aux1_mean(feature_aux1),
                'lnvar': self.enc.func_z_aux1_lnvar(feature_aux1)}
        else:
            prior_z_aux1_stat = {
                'mean': torch.zeros(n, max(0, self.dim_z_aux1), device=device),
                'lnvar': torch.zeros(n, max(0, self.dim_z_aux1), device=device)}
        if self.dim_z_aux2 > 0:
            D_feature_aux2 = []
            for k in range(K):
                D_feature_aux2.append(self.enc.func_feat_aux2(context[:, k]))
            feature_aux2 = sum(D_feature_aux2) / len(D_feature_aux2)
            prior_z_aux2_stat = {
                'mean': self.enc.func_z_aux2_mean(feature_aux2),
                'lnvar': self.enc.func_z_aux2_lnvar(feature_aux2)}
        else:
            prior_z_aux2_stat = {
                'mean': torch.zeros(n, max(0, self.dim_z_aux2), device=device),
                'lnvar': torch.zeros(n, max(0, self.dim_z_aux2), device=device)}
        return prior_omega_stat, prior_z_aux1_stat, prior_z_aux2_stat

    def generate_physonly(self, omega, init_y):
        n = omega.shape[0]
        device = omega.device
        def ODEfunc(t, yy):
            return self.physics_model(omega, yy)
        initcond = torch.cat([init_y, torch.zeros(n, 2, device=device)], dim=1)
        yy_seq = odeint(ODEfunc, initcond, self.t_intg, method=self.ode_solver)
        y_seq = yy_seq[range(0, self.len_intg, self.intg_lev), :, :2].transpose(1, 0)
        return y_seq

    def decode(self, omega, z_aux1, z_aux2, init_y, full=False):
        n = omega.shape[0]
        device = omega.device
        # omega_sq = omega.pow(2)
        omega_sq = omega
        def ODEfunc(t, _yy):
            yy_PA = _yy[:, [0, 1, 2, 3]]
            if full:
                yy_P = _yy[:, [4, 5, 6, 7]]
            if not self.no_phy:
                yy_dot_phy_PA = self.physics_model(omega_sq, yy_PA)
                if full:
                    yy_dot_phy_P = self.physics_model(omega_sq, yy_P)
            else:
                yy_dot_phy_PA = torch.zeros(n, 4, device=device)
                if full:
                    yy_dot_phy_P = torch.zeros(n, 4, device=device)
            if self.dim_z_aux1 >= 0:
                yy_dot_aux_PA = torch.cat(
                    [torch.zeros(n, 2, device=device),
                     self.dec.func_aux1(torch.cat([z_aux1, yy_PA], dim=1))], dim=1)
            else:
                yy_dot_aux_PA = torch.zeros(n, 4, device=device)
            if full:
                return torch.cat([yy_dot_phy_PA + yy_dot_aux_PA, yy_dot_phy_P], dim=1)
            else:
                return torch.cat([yy_dot_phy_PA+yy_dot_aux_PA], dim=1)
        tmp = torch.zeros(n, 2, device=device)
        if full:
            initcond = torch.cat([init_y, tmp, init_y, tmp.clone()], dim=1)
        else:
            initcond = torch.cat([init_y, tmp], dim=1)
        yy_seq = odeint(ODEfunc, initcond, self.t_intg, method=self.ode_solver)
        yy_seq = yy_seq[range(0, self.len_intg, self.intg_lev)]
        y_seq_PA = yy_seq[:, :, :2].transpose(1, 0)
        if full:
            y_seq_P = yy_seq[:, :, 4:6].transpose(1, 0)
        x_PA = y_seq_PA
        x_PAB = x_PA.clone()
        if full:
            x_P = y_seq_P
            x_PB = x_P.clone()
        if self.dim_z_aux2 >= 0:
            x_PAB += self.dec.func_aux2_res(torch.cat((omega, z_aux1, z_aux2), dim=1)).reshape(n, self.dim_t, 2)
            if full:
                x_PB += self.dec.func_aux2_res(torch.cat((omega, z_aux1, z_aux2), dim=1)).reshape(n, self.dim_t, 2)
        if full:
            return x_PAB, x_PA, x_PB, x_P, self.dec.param_x_lnvar
        else:
            return x_PAB, self.dec.param_x_lnvar

    def encode(self, x, x_context):
        x_ = x.view(-1, self.dim_t*2)
        n = x_.shape[0]
        device = x_.device
        K = x_context.shape[1]
        x_context = x_context.reshape(-1, K, self.dim_t*2)
        data = torch.cat((x_.unsqueeze(1), x_context), dim=1)
        if self.dim_z_aux1 > 0:
            D_feature_aux1 = []
            for k in range(K + 1):
                D_feature_aux1.append(self.enc.func_feat_aux1(data[:, k]))
            feature_aux1 = sum(D_feature_aux1) / len(D_feature_aux1)
            z_aux1_stat = {'mean': self.enc.func_z_aux1_mean(feature_aux1),
                           'lnvar': self.enc.func_z_aux1_lnvar(feature_aux1)}
        else:
            z_aux1_stat = {'mean': torch.empty(n, 0, device=device),
                           'lnvar': torch.empty(n, 0, device=device)}
        if self.dim_z_aux2 > 0:
            D_feature_aux2 = []
            for k in range(K + 1):
                D_feature_aux2.append(self.enc.func_feat_aux2(data[:, k]))
            feature_aux2 = sum(D_feature_aux2) / len(D_feature_aux2)
            z_aux2_stat = {'mean': self.enc.func_z_aux2_mean(feature_aux2),
                           'lnvar': self.enc.func_z_aux2_lnvar(feature_aux2)}
        else:
            z_aux2_stat = {'mean': torch.empty(n, 0, device=device),
                           'lnvar': torch.empty(n, 0, device=device)}
        if not self.no_phy:
            unmixed = x_ + self.enc.func_unmixer_res(torch.cat((x_, z_aux1_stat['mean'], z_aux2_stat['mean']), dim=1))
            feature_phy = self.enc.func_feat_phy(unmixed)
            omega_stat = {'mean': self.enc.func_omega_mean(feature_phy),
                          'lnvar': self.enc.func_omega_lnvar(feature_phy)}
        else:
            unmixed = torch.zeros(n, self.dim_t, device=device)
            omega_stat = {'mean': torch.empty(n, 0, device=device),
                          'lnvar': torch.empty(n, 0, device=device)}
        return omega_stat, z_aux1_stat, z_aux2_stat, unmixed.reshape(n, self.dim_t, 2)

    def draw(self, omega_stat, z_aux1_stat, z_aux2_stat, hard_z=False):
        if not hard_z:
            omega = utils.draw_normal(omega_stat['mean'], omega_stat['lnvar'])
            z_aux1 = utils.draw_normal(z_aux1_stat['mean'], z_aux1_stat['lnvar'])
            z_aux2 = utils.draw_normal(z_aux2_stat['mean'], z_aux2_stat['lnvar'])
        else:
            omega = omega_stat['mean'].clone()
            z_aux1 = z_aux1_stat['mean'].clone()
            z_aux2 = z_aux2_stat['mean'].clone()
        omega = torch.max(torch.ones_like(omega) * omega_feasible_range[0], omega)
        omega = torch.min(torch.ones_like(omega) * omega_feasible_range[1], omega)
        return omega, z_aux1, z_aux2

    def forward(self, x, x_context, reconstruct=True, hard_z=False):
        omega_stat, z_aux1_stat, z_aux2_stat, _ = self.encode(x, x_context)
        if not reconstruct:
            return omega_stat, z_aux1_stat, z_aux2_stat
        init_y = x[:, 0].clone().view(-1, 2)
        x_mean, x_lnvar = self.decode(*self.draw(omega_stat, z_aux1_stat, z_aux2_stat, hard_z=hard_z),
                                      init_y, full=False)
        return omega_stat, z_aux1_stat, z_aux2_stat, x_mean, x_lnvar


if __name__ == '__main__':
    device = 'cuda:0'
    from config import args
    x = torch.zeros([100, 80, 2]).to(device)
    x_context = torch.zeros([100, 7, 80, 2]).to(device)
    args.dim_t = x.shape[1]
    args.dt = 0.025
    model = VAE(vars(args)).to(device)
    model(x, x_context)

