import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
import time
import os
os.environ["WANDB__SERVICE_WAIT"] = "1000"
import sys
sys.path.append('../')
from rnn.evaluation import eval_VAE
from rnn.saving import save_model
file_dir = os.path.dirname(__file__)
sys.path.append(file_dir)

import matplotlib.pyplot as plt

try:
    import wandb
except:
    print("wandb not installed... continuing")


opt_eps=.1

def train_VAE(
    vae,
    training_params,
    task,
    sync_wandb=False,
    out_dir=None,
    fname=None,
    training_stats = None,
    optimizer = None,
    scheduler = None,
    curr_epoch = 0,
):
    """
    Train an VAE

    Args:
        vae: initialized VAE
        training_params: dictionary of training parameters
        task, Pytorch Dataset 
        syn_wandb (optional): Bool, indicates synchronsation with WandB
        out_dir: string designating where to store model
        fname: model name
    """
    stop_training=False #not found any NANs yet

    # add losses to training_params dict, bit of a hack
    training_loss_keys = ["ll","ll_x", "ll_z", "H", "loss", "reg_loss","KL_x", "PSH", "PSC", "mean_error", "noise_z", "noise_x", "alphan"]
    for key in training_loss_keys:
        if key not in training_params.keys():
            training_params[key] = []
    
    # cuda management, gpu potentially speeds up training
    if training_params["cuda"]:
        if not torch.cuda.is_available():
            print("Warning: CUDA not available on this machine, switching to CPU")
            device = torch.device("cpu")
        else:
            device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    vae.to_device(device)
    print("Training on : " + str(device))


    # set up dataloader
    dataloader = DataLoader(
        task, batch_size=training_params['batch_size'], shuffle=True
    )
    dataloader.dataset.data = dataloader.dataset.data.to(device=device)
    dataloader.dataset.data_eval = dataloader.dataset.data_eval.to(device=device)

    # initialize wandb
    if sync_wandb:
        wandb.init(
            project="vi_rnns",
            group=task.task_params["name"],
            config={**vae.vae_params, **task.task_params, **training_params},
            # dir="$WORK/wandb",
        )
        config = wandb.config
        wandb.watch(vae,log="all")#, log_freq=wandb_log_freq)



    # set exponential decay learning rate scheduler with Adam optimizer

    #optimizer = torch.optim.Adam(vae.parameters(), lr=training_params["lr"],eps=0.1)##amsgrad=True)#, weight_decay=1e-5)
    #optimizer = torch.optim.SGD(vae.parameters(), lr=training_params["lr"])
    #optimizer = torch.optim.Adam(vae.parameters(), lr=training_params["lr"],eps=training_params['opt_eps'])
    optimizer = optimizer or torch.optim.RAdam(vae.parameters(), lr=training_params["lr"])
    if scheduler is None:
        if training_params["CosineRestarts"]>0:
            scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=training_params["CosineRestarts"], eta_min=training_params["lr_end"], 
                                                                        last_epoch=-1, verbose=False)
        else:
            gamma = np.exp(np.log(training_params["lr_end"]/training_params["lr"])/training_params["n_epochs"])
            print("Learning rate decay factor " + str(gamma))
            scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma, last_epoch=-1, verbose=False)

    alphan = training_params['alpha']

    # set rnn to training mode
    vae.train()

     # loss function
    losses = []
    reg_losses = []
     # start timer before training
    time0 = time.time()

    for i in range(curr_epoch, training_params['n_epochs']):
        with torch.no_grad():
            #print(vae.prior.std_embed_x(vae.prior.R_x))

            # DO EVALUATION
            if i%training_params['eval_epochs']==0 and training_params['run_eval']:
                vae.eval()
                with torch.no_grad():
                    klx_bin, psc, psH,mean_rate_error = eval_VAE(vae,task,cut_off=0,smoothing = training_params['smoothing'],
                                                freq_cut_off = training_params['freq_cut_off'], 
                                                sim_obs_noise=training_params['sim_obs_noise'],sim_latent_noise=training_params['sim_latent_noise'],
                                                smooth_at_eval=training_params['smooth_at_eval'])
                    training_params['KL_x'].append(klx_bin)
                    training_params['PSH'].append(psH)
                    training_params['PSC'].append(psc)
                    training_params['mean_error'].append(mean_rate_error)

                    if sync_wandb:
                        wandb.log(
                            {
                                "KL_data": klx_bin,
                                "power_spectr_correlation": psc,
                                "power_spectr_distance": psH,
                                "mean_rate_error": mean_rate_error
                            }
                        )
                        # plot latent time series and reconstructions
                        with torch.no_grad():
                            data,u= task.__getitem__(0)
                            dim_x,_ = data.shape
                            z_hat, Emean,Esigma,eps_s = vae.encoder(data.unsqueeze(0))
                            z0 = z_hat[:,:,:1].squeeze()
                            Z = vae.prior.get_latent_time_series(time_steps=1000, z0=z0,noise_scale=training_params['sim_latent_noise'])
                            data_gen = vae.prior.get_observation(Z,noise_scale=training_params['sim_obs_noise']).permute(0,2,1,3).reshape(1000, dim_x)
                        plt.figure()
                        plt.plot(Z[0,:,:,0].detach().cpu().T);
                        plt.xlim(0)
                        wandb.log({"latent"+str(i): plt})
                        plt.figure()
                        plt.plot(data_gen.detach().cpu());
                        wandb.log({"reconstruction"+str(i): plt})       

        vae.train()
                
        #if training_params['annealing']:
        #    beta_eff = min(training_params['beta'], training_params['beta']*(i/training_params['annealing_epochs']))
        #else:
        #    beta_eff = training_params['beta']
        batch_h_loss = 0
        batch_ll = 0
        batch_ll_z =0
        batch_ll_x = 0
        batch_loss = 0
        batch_reg_loss = 0

        if training_stats is not None:
            training_stats(vae)
        alpha_i=1




        for inputs, stim in dataloader:  
          
            optimizer.zero_grad()
      
            # forward pass
            if training_params['loss_f'] == "opt_VGTF":
                Loss_it,Z, Esample,ll_x,ll_z,H, log_likelihood,alphas= vae.forward_Optimal_VGTF(inputs,u=stim,k=training_params['k'],
                                                        dreg_p=training_params['dreg_p'],MC_p=training_params["MC_p"],
                                                        dreg_q=training_params['dreg_q'],MC_q=training_params["MC_q"],
                                                        resample = training_params['resample'],out_likelihood=training_params["observation_likelihood"],
                                                        bootstrap=training_params['bootstrap'])
            elif training_params['loss_f'] == "VGTF":
                Loss_it,Z, Esample,ll_x,ll_z,H, log_likelihood,alphas= vae.forward_VGTF(inputs,u=stim,k=training_params['k'],
                                                                                        resample=training_params['resample'],
                                                                                        out_likelihood=training_params['observation_likelihood'],
                                                                                        t_forward=training_params['t_forward'])


            elif training_params['loss_f'] == "VMPF":
                Loss_it,Z, Esample,ll_x,ll_z,H, log_likelihood,alphas= vae.forward_VMPF(inputs,u=stim,k=training_params['k'],
                                                                                        resample=training_params['resample'],
                                                                                        out_likelihood=training_params['observation_likelihood'],
                                                                                        t_forward=training_params['t_forward'])
       
            
            elif training_params['loss_f']== "VGTF_dreg":
                Loss_it,Z, Esample,ll_x,ll_z,H, log_likelihood,alphas= vae.forward_VGTF_dreg(inputs,u=stim,k=training_params['k'],
                                                        dreg_p=training_params['dreg_p'],MC_p=training_params["MC_p"],
                                                        dreg_q=training_params['dreg_q'],MC_q=training_params["MC_q"],
                                                        resample = training_params['resample'],out_likelihood=training_params["observation_likelihood"],
                                                        bootstrap=training_params['bootstrap'],n_ahead=training_params['n_ahead'])
            reg_loss = torch.zeros(1,device=device)
            
            if training_params['L2_reg']:
                reg_loss+=torch.mean(vae.prior.get_rates(Z)**2)*training_params['L2_reg']
            batch_reg_loss += reg_loss.item()

            batch_ll+= log_likelihood.mean().item()
            batch_ll_x+=ll_x.mean().item()
            batch_ll_z+=ll_z.mean().item()
            batch_h_loss+=H.mean().item()
            loss =-Loss_it.mean()

            batch_loss += loss.item()

            #check for nans
            if torch.isnan(loss):
                print("UH OH FOUND NAN, stopping training...")
                stop_training = True
                break
            #print(loss)

            (loss + reg_loss).backward()

            if training_params['grad_norm']:
                nn.utils.clip_grad_norm_(parameters=vae.parameters(), max_norm=training_params['grad_norm'])
            
            # Adjust learning weights
            optimizer.step()
            alpha_i+=1

            losses.append(loss)
            reg_losses.append(reg_loss)
        if stop_training:
            break

        batch_ll/=len(dataloader)
        batch_ll_z/=-len(dataloader)
        batch_h_loss/=-len(dataloader)
        batch_ll_x/=-len(dataloader)
        batch_loss/=len(dataloader)
        batch_reg_loss/=len(dataloader)

        training_params['ll_z'].append(batch_ll_z)
        training_params['ll_x'].append(batch_ll_x)
        training_params['ll'].append(batch_ll)
        training_params['H'].append(batch_h_loss)
        training_params['reg_loss'].append(batch_reg_loss)
        training_params['loss'].append(batch_loss)

        noise_z = vae.prior.std_embed_z(vae.prior.R_z).detach()
        noise_x = vae.prior.std_embed_x(vae.prior.R_x).detach()

        training_params["noise_z"].append(noise_z)
        training_params["noise_x"].append(noise_x)
        alpha=torch.mean(alphas).item()
        training_params["alphan"].append(alpha)
        #print(alphas)
        print('epoch {} loss: {:.4f}, ll: {:.4f}, ll_x: {:.4f}, ll_z: {:.4f} H: {:.4f}, alpha: {:.2f}, lr: {:.6f}, N_z: {:.4f}, N_x: {:.4f}, reg: {:4f}'.format(
            i + 1, batch_loss,batch_ll,batch_ll_x, batch_ll_z,batch_h_loss,alpha, scheduler.get_last_lr()[0],noise_z.mean().item(),noise_x.mean().item(),reg_loss.item()))
        #print("Reg: " + str(reg_loss.item()))
        if sync_wandb:
            wandb.log(
                {
                    "loss": batch_loss,
                    "ll": batch_ll,
                    "likelihood_data": batch_ll_x,
                    "likelihood_latent": batch_ll_z,
                    "entropy": batch_h_loss,
                    #"beta": beta_eff,
                    "reg": batch_reg_loss,
                    "alpha": alpha,
                    "noise_z": noise_z.mean().item(),
                    "noise_x": noise_x.mean().item(),
                    "noise_e": torch.exp(vae.encoder.logvar/2).mean().item(),
                }
            )
        #print("LR: "+str(scheduler.get_last_lr()[0]))

        if scheduler.get_last_lr()[0]>training_params["lr_end"]:
            scheduler.step()
    print("\nDone. Training took %.1f sec." % (time.time() - time0))

    # save trained network
    fname  = save_model(vae, training_params, task.task_params, directory=out_dir)
    print("Saved: " + fname)
    # upload trained models to WandB
    if sync_wandb:
        # store to wandb
        print(fname + "_state_dict_enc.pkl")
        wandb.save(fname + "_state_dict_enc.pkl")
        wandb.save(fname + "_state_dict_prior.pkl")
        wandb.save(fname + "_vae_params.pkl")
        wandb.save(fname + "_task_params.pkl")
        wandb.save(fname + "_training_params.pkl")
        wandb.finish()



    return losses, reg_losses



