import json
import time
import torch
import utils
import numpy as np
from model import VAE
from config import args
from torch import optim
import torch.utils.data
from dataloader import MyDataLoader
from metrics import mean_corr_coef
from sklearn.cross_decomposition import CCA


if __name__ == '__main__':
    device = 'cuda:4'
    torch.manual_seed(1234)
    args.dim_t = 80
    args.dt = 0.025
    loader_test = MyDataLoader('data/data_valid.pt', 8, 5000, False)
    model = VAE(vars(args)).to(device)
    model.load_state_dict(torch.load('output5/model.pt', map_location=device))
    model.eval()
    with torch.no_grad():
        for batch_idx, (data, context, param, _) in enumerate(loader_test):
            data = data.to(device)
            context = context.to(device)
            param = param.to(device)
            omega_stat, z_aux1_stat, z_aux2_stat, unmixed = model.encode(data, context[:, :-1])
            omega, z_aux1, z_aux2 = model.draw(omega_stat, z_aux1_stat, z_aux2_stat, hard_z=False)
            init_y = data[:, 0].clone().view(-1, 2)
            x_mean, x_PA, x_PB, x_P, x_lnvar = model.decode(omega, z_aux1, z_aux2, init_y, full=True)
            x_mean_hat = torch.cat((x_mean.cos(), x_mean.sin()), dim=-1)
            data_hat = torch.cat((data.cos(), data.sin()), dim=-1)
            torch.save(data_hat.detach().cpu(), 'GT-Meta-Hybrid-Rec.pt')
            torch.save(x_mean_hat.detach().cpu(), 'XT-Meta-Hybrid-Rec.pt')
            print(torch.sum((x_mean_hat - data_hat).pow(2), dim=1).mean())
            print(torch.mean((omega_stat['mean'] - param[:, [0]]).pow(2)))
            z_aux = torch.cat((omega_stat['mean'], z_aux1_stat['mean']), dim=-1)
            cca = CCA(n_components=2, max_iter=5000)
            cca.fit(param.cpu().detach().numpy(), z_aux.cpu().detach().numpy())
            res_in = cca.transform(param.cpu().detach().numpy(), z_aux.cpu().detach().numpy())
            mcc_weak_in = mean_corr_coef(res_in[0], res_in[1])
            print(mcc_weak_in)
            init_y = context[:, -1, 0].clone().view(-1, 2)
            x_mean, x_PA, x_PB, x_P, x_lnvar = model.decode(omega, z_aux1, z_aux2, init_y, full=True)
            x_mean_hat = torch.cat((x_mean.cos(), x_mean.sin()), dim=-1)
            data_hat = torch.cat((context[:, -1].cos(), context[:, -1].sin()), dim=-1)
            print(torch.sum((x_mean_hat - data_hat).pow(2), dim=1).mean())
            torch.save(data_hat.detach().cpu(), 'GT-Meta-Hybrid-Gen.pt')
            torch.save(x_mean_hat.detach().cpu(), 'XT-Meta-Hybrid-Gen.pt')

