import json
import time
import torch
import utils
from model import VAE
from config import args
from torch import optim
import torch.utils.data
from dataloader import MyDataLoader



def loss_function(args, data, omega_stat, z_aux1_stat, z_aux2_stat, x_mean, context):
    n = data.shape[0]
    device = data.device
    recerr_sq = torch.sum((x_mean - data).pow(2), dim=1).mean()
    prior_omega_stat, prior_z_aux1_stat, prior_z_aux2_stat = model.priors(n, device, context)
    KL_z_aux1 = utils.kldiv_normal_normal(
        z_aux1_stat['mean'], z_aux1_stat['lnvar'],
        prior_z_aux1_stat['mean'], prior_z_aux1_stat['lnvar']) \
        if args.dim_z_aux1 > 0 else torch.zeros(1, device=device)
    KL_z_aux2 = utils.kldiv_normal_normal(
        z_aux2_stat['mean'], z_aux2_stat['lnvar'],
        prior_z_aux2_stat['mean'], prior_z_aux2_stat['lnvar']) \
        if args.dim_z_aux2 > 0 else torch.zeros(1, device=device)
    KL_omega = utils.kldiv_normal_normal(
        omega_stat['mean'], omega_stat['lnvar'],
        prior_omega_stat['mean'], prior_omega_stat['lnvar']) \
        if not args.no_phy else torch.zeros(1, device=device)
    kldiv = (KL_z_aux1 + KL_z_aux2 + KL_omega).mean()
    return recerr_sq, kldiv


def train(epoch, args, device, loader, model, optimizer):
    model.train()
    logs = {'recerr_sq':.0, 'kldiv':.0, 'unmix':.0, 'dataug':.0, 'lact_dec':.0}
    for batch_idx, (data, context, _, _) in enumerate(loader):
        data = data.to(device)
        context = context.to(device)
        batch_size = len(data)
        optimizer.zero_grad()
        omega_stat, z_aux1_stat, z_aux2_stat, unmixed = model.encode(data, context)
        omega, z_aux1, z_aux2 = model.draw(omega_stat, z_aux1_stat, z_aux2_stat, hard_z=False)
        init_y = data[:, 0].clone().view(-1, 2)
        x_PAB, x_PA, x_PB, x_P, x_lnvar = model.decode(omega, z_aux1, z_aux2, init_y, full=True)
        x_var = torch.exp(x_lnvar)
        recerr_sq, kldiv = loss_function(args, data, omega_stat, z_aux1_stat, z_aux2_stat, x_PAB, context)
        if not args.no_phy:
            reg_unmix = torch.sum((unmixed - x_P.detach()).pow(2), dim=1).mean()
        else:
            reg_unmix = torch.zeros(1, device=device).squeeze()
        if not args.no_phy:
            model.eval()
            with torch.no_grad():
                aug_omega = torch.rand((batch_size,1), device=device)*(args.range_omega[1]-args.range_omega[0])+args.range_omega[0]
                aug_x_P = model.generate_physonly(aug_omega, init_y.detach())
            model.train()
            aug_feature_phy = model.enc.func_feat_phy(aug_x_P.detach())
            reg_dataug = (model.enc.func_omega_mean(aug_feature_phy) - aug_omega).pow(2).mean()
        else:
            reg_dataug = torch.zeros(1, device=device).squeeze()
        if not args.no_phy:
            dif_PA_P = torch.sum((x_PA - x_P).pow(2), dim=1).mean()
            dif_PB_P = torch.sum((x_PB - x_P).pow(2), dim=1).mean()
            dif_PAB_PA =  torch.sum((x_PAB - x_PA).pow(2), dim=1).mean()
            dif_PAB_PB =  torch.sum((x_PAB - x_PB).pow(2), dim=1).mean()
            reg_lact_dec = 0.25*dif_PA_P + 0.25*dif_PB_P + 0.25*dif_PAB_PA + 0.25*dif_PAB_PB
        else:
            reg_lact_dec = torch.zeros(1, device=device).squeeze()
        kldiv_balanced = (args.balance_kld + args.balance_lact_enc) * x_var.detach() * kldiv
        loss = recerr_sq + kldiv_balanced + args.balance_unmix * reg_unmix + args.balance_dataug * reg_dataug + args.balance_lact_dec * reg_lact_dec
        loss.backward()
        if args.grad_clip>0.0:
            torch.nn.utils.clip_grad_value_(model.parameters(), args.grad_clip)
        optimizer.step()
        logs['recerr_sq'] += recerr_sq.detach() * batch_size
        logs['kldiv'] += kldiv.detach() * batch_size
        logs['unmix'] += reg_unmix.detach() * batch_size
        logs['dataug'] += reg_dataug.detach() * batch_size
        logs['lact_dec'] += reg_lact_dec.detach() * batch_size
    for key in logs:
        logs[key] /= len(loader.dataset)
    print('====> Epoch: {}  Training (rec. err.)^2: {:.4f}  kldiv: {:.4f}  unmix: {:4f}  dataug: {:4f}  lact_dec: {:4f}'.format(
        epoch, logs['recerr_sq'], logs['kldiv'], logs['unmix'], logs['dataug'], logs['lact_dec']))
    return logs


def valid(epoch, args, device, loader, model):
    model.eval()
    logs = {'recerr_sq': .0, 'kldiv': .0}
    with torch.no_grad():
        for i, (data, context, _, _) in enumerate(loader):
            data = data.to(device)
            context = context.to(device)
            batch_size = len(data)
            omega_stat, z_aux1_stat, z_aux2_stat, x_mean, _ = model(data, context)
            recerr_sq, kldiv = loss_function(args, data, omega_stat, z_aux1_stat, z_aux2_stat, x_mean, context)
            logs['recerr_sq'] += recerr_sq.detach() * batch_size
            logs['kldiv'] += kldiv.detach() * batch_size
    for key in logs:
        logs[key] /= len(loader.dataset)
    print('====> Epoch: {}  Validation (rec. err.)^2: {:.4f}  kldiv: {:.4f}'.format(
        epoch, logs['recerr_sq'], logs['kldiv']))
    return logs


if __name__ == '__main__':
    device = 'cuda:4'
    torch.manual_seed(123456)
    args.dim_t = 80
    args.dt = 0.025
    args.outdir = 'output3/'
    loader_train = MyDataLoader('../Meta-Hybrid-VAE/data/data_train.pt', 7, 500, True)
    loader_valid = MyDataLoader('../Meta-Hybrid-VAE/data/data_valid.pt', 7, 500, False)
    model = VAE(vars(args)).to(device)
    model.load_state_dict(torch.load('output3/model.pt', map_location=device))
    kwargs = {'lr': args.learning_rate, 'weight_decay': args.weight_decay, 'eps': args.adam_eps}
    optimizer = optim.Adam(model.parameters(), **kwargs)
    print('start training with device', device)
    with open('{}/args.json'.format(args.outdir), 'w') as f:
        json.dump(vars(args), f, sort_keys=True, indent=4)
    with open('{}/log.txt'.format(args.outdir), 'w') as f:
        print('# epoch recerr_sq kldiv unmix dataug lact_dec valid_recerr_sq valid_kldiv duration', file=f)
    info = {'bestvalid_epoch': 0, 'bestvalid_recerr': 1e10}
    dur_total = .0
    for epoch in range(1, args.epochs + 1):
        start_time = time.time()
        logs_train = train(epoch, args, device, loader_train, model, optimizer)
        dur_total += time.time() - start_time
        logs_valid = valid(epoch, args, device, loader_valid, model)
        with open('{}/log.txt'.format(args.outdir), 'a') as f:
            print('{} {:.7e} {:.7e} {:.7e} {:.7e} {:.7e} {:.7e} {:.7e} {:.7e}'.format(epoch,
                logs_train['recerr_sq'], logs_train['kldiv'], logs_train['unmix'], logs_train['dataug'], logs_train['lact_dec'],
                logs_valid['recerr_sq'], logs_valid['kldiv'], dur_total), file=f)
        if logs_valid['recerr_sq'] < info['bestvalid_recerr']:
            info['bestvalid_epoch'] = epoch
            info['bestvalid_recerr'] = logs_valid['recerr_sq']
            torch.save(model.state_dict(), '{}/model.pt'.format(args.outdir))
            print('best model saved')
        if epoch % 1000 == 0:
            torch.save(model.state_dict(), '{}/model_e{}.pt'.format(args.outdir, epoch))
    print('end training')


