import json
import time
import torch
import utils
import numpy as np
from model import VAE
from config import args
from torch import optim
import torch.utils.data
from dataloader import MyDataLoader


if __name__ == '__main__':
    device = 'cuda:4'
    torch.manual_seed(1234)
    args.dim_t = 20
    args.dt = 1/100
    print(torch.load('data/data_train.pt')['state'].shape)
    loader_test = MyDataLoader('data/data_train.pt', 7, 17220, False)
    model = VAE(vars(args)).to(device)
    model.load_state_dict(torch.load('output2/model.pt', map_location=device))
    model.eval()
    with torch.no_grad():
        for batch_idx, (data, context, init, init_D) in enumerate(loader_test):
            data, context, init, init_D = data.to(device), context.to(device), init.to(device), init_D.to(device)
            omega_stat, z_aux1_stat, z_aux2_stat, unmixed = model.encode(data, context)
            omega, z_aux1, z_aux2 = model.draw(omega_stat, z_aux1_stat, z_aux2_stat, hard_z=False)
            init_y = init.clone().view(-1, 4)
            x_mean, x_PA, x_PB, x_P, x_lnvar = model.decode(omega, z_aux1, z_aux2, init_y, full=True)
            x_mean_hat = torch.cat((x_mean.cos(), x_mean.sin()), dim=-1)
            data_hat = torch.cat((data.cos(), data.sin()), dim=-1)
            torch.save(data_hat.detach().cpu(), 'GT-Meta-Hybrid-Rec.pt')
            torch.save(x_mean_hat.detach().cpu(), 'XT-Meta-Hybrid-Rec.pt')
            print(torch.sum((x_mean_hat - data_hat).pow(2), dim=1).mean())
            init_y = init_D[:, -1].clone().view(-1, 4)
            x_mean, x_PA, x_PB, x_P, x_lnvar = model.decode(omega, z_aux1, z_aux2, init_y, full=True)
            x_mean_hat = torch.cat((x_mean.cos(), x_mean.sin()), dim=-1)
            data_hat = torch.cat((context[:, -1].cos(), context[:, -1].sin()), dim=-1)
            print(torch.sum((x_mean_hat - data_hat).pow(2), dim=1).mean())
            torch.save(data_hat.detach().cpu(), 'GT-Meta-Hybrid-Gen.pt')
            torch.save(x_mean_hat.detach().cpu(), 'XT-Meta-Hybrid-Gen.pt')


        
