import torch
import torch.optim as optim
import time
from tqdm import trange
import numpy as np
import os
import matplotlib
matplotlib.use("agg")
import matplotlib.pyplot as plt
from torch import nn
import wandb 

def train(model, train_loader, valid_loader, args):
    device = args.device
    model.to(device)
    model = model.train()
    opt = optim.Adam(model.parameters(), lr=args.lr)

    # ------------- Log ------------- #
    train_x_recon_losses = []
    train_y_recon_losses = []
    train_u_kl_losses = []
    valid_x_recon_losses = []
    valid_y_recon_losses = []
    valid_u_kl_losses = []

    loss_val_log = []
    epoch_log = []
    
    best_epoch = 0
    best_loss = 1e10
    start_time = time.time()
    # ------------- Log ------------- #


    for epoch_i in trange(args.n_epochs):
        model.train()
        loss_whole = 0
        for idx, (r, d, a, y) in enumerate(train_loader):
            loss, x_recon_loss, y_recon_loss, u_kl_loss = model.calculate_loss(r.to(device), d.to(device), a.to(device), y.to(device)) 

            opt.zero_grad()
            loss.backward()
            opt.step()

            train_x_recon_losses.append(x_recon_loss.item())
            train_y_recon_losses.append(y_recon_loss.item())
            train_u_kl_losses.append(u_kl_loss.item())
            loss_whole += loss.cpu().detach().numpy()

        epoch_log.append(epoch_i) 

        loss_reconx = np.array(train_x_recon_losses[-len(train_loader):]).mean()
        loss_kl = np.array(train_u_kl_losses[-len(train_loader):]).mean()
        loss_recony = np.array(train_y_recon_losses[-len(train_loader):]).mean()

        if args.wandb:
            wandb.log({'Loss': loss_whole},step=epoch_i)
            wandb.log({'BCE(x)': loss_reconx},step=epoch_i)
            wandb.log({'KL(u)' : loss_kl},step=epoch_i)
            wandb.log({'BCE(y)' : loss_recony},step=epoch_i)


        model.eval()
        loss_whole = 0
        _all = 0
        with torch.no_grad():
            for idx, (r, d, a, y) in enumerate(valid_loader):
                loss_val, x_recon_loss_val, y_recon_loss_val, u_kl_loss_val = model.calculate_loss(r.to(device), d.to(device), a.to(device), y.to(device))  # (*cur_batch)

                valid_x_recon_losses.append(x_recon_loss_val.item())
                valid_y_recon_losses.append(y_recon_loss_val.item())
                valid_u_kl_losses.append(u_kl_loss_val.item())
                loss_whole += loss_val.cpu().detach().numpy()
                _all += float(y.size(0))
            
            loss_val_log.append(loss_whole)
            loss_check = loss_whole.item() / _all

            # if epoch_i == 0 and loss_check > best_loss:
            #     best_loss = loss_check

            print('now best epoch is, best loss, loss_check', best_epoch, best_loss, loss_check)
            print('loss_check < best_loss', loss_check < best_loss)

            if loss_check < best_loss:
                #model_path = os.path.join(args.save_path, 'model.pth')
                torch.save(model.state_dict(), args.save_path / 'model.pth')
                best_epoch = epoch_i
                best_loss = loss_check
                print('best epoch update by loss, epoch is ', epoch_i)

            if epoch_i - best_epoch > args.break_epoch and args.early_stop == True:
                line = 'time elapsed: {:.4f}min'.format((time.time() - start_time) / 60.0)
                #logger.info(line)
                break

        if args.early_stop == False:
            torch.save(model.state_dict(), args.save_path / 'model.pth')

        line = 'time elapsed: {:.4f}min'.format((time.time() - start_time) / 60.0)
        print(line)



# def test(test_loader, args):
#     device = args.device
#     model_path = os.path.join(args.save_path, 'model.pth')
#     test_model = torch.load(model_path)
#     test_model.to(device)
#     test_model.eval()
#     correct, _all, o1s, o2s, o3s, o4s, o1s_bin, o2s_bin, o3s_bin, o4s_bin, ys, ys_bin = \
#         0, 0, None, None, None, None, None, None, None, None, None, None
#     with torch.no_grad():
#         for idx, (r, d, a, y) in enumerate(test_loader):
#             if args.use_label:
#                 loss_val, x_recon_loss_val, y_recon_loss_val, y_p_val, y_p_counter_val, u_kl_loss_val, fair_loss_val\
#                     = test_model.calculate_loss(r.to(device), d.to(device), a.to(device), y.to(device))  # (*cur_batch)
#             else:
#                 loss_val, x_recon_loss_val, u_kl_loss_val = test_model.calculate_loss(r.to(device), d.to(device), a.to(device), y.to(device))

#             # For saving the result:
#             if args.tSNE == True or args.u_dim == 2:
#                 u_mu, u_logvar = test_model.q_u(r.to(device), d.to(device), a.to(device), y.to(device))
#                 u_prev = test_model.reparameterize(u_mu, u_logvar)
#                 u = torch.cat((u, u_prev), 0) if idx != 0 else u_prev
#                 a_all = torch.cat((a_all, a), 0) if idx != 0 else a
            
#             if args.use_label:
#                 y_p_val = nn.Sigmoid()(y_p_val)
#                 y_p_counter_val = nn.Sigmoid()(y_p_counter_val)
#                 label_predicted = torch.eq(y_p_val.gt(0.5).byte(), y.to(device).byte())
#                 correct += torch.sum(label_predicted)
#                 _all += float(label_predicted.size(0))

#                 y_p_np = y_p_val.cpu().detach().numpy()
#                 y_cf_np = y_p_counter_val.cpu().detach().numpy()
#                 mask_a = np.where(a == 1, -1, 1)
#                 cf_effect = (y_cf_np - y_p_np) * mask_a
#                 cf_bin = (np.greater(y_cf_np, 0.5).astype(int) - np.greater(y_p_np, 0.5).astype(int)) * mask_a
            
#                 m = r.cpu().detach().numpy()[:, 1:3]
#                 mask1 = (m == [False, False]).all(axis=1)
#                 mask2 = (m == [False, True]).all(axis=1)
#                 mask3 = (m == [True, False]).all(axis=1)
#                 mask4 = (m == [True, True]).all(axis=1)

#                 o1 = cf_effect[mask1 == [True]]
#                 o2 = cf_effect[mask2 == [True]]
#                 o3 = cf_effect[mask3 == [True]]
#                 o4 = cf_effect[mask4 == [True]]

#                 o1s = np.concatenate((o1s, o1), axis=0) if idx != 0 else o1
#                 o2s = np.concatenate((o2s, o2), axis=0) if idx != 0 else o2
#                 o3s = np.concatenate((o3s, o3), axis=0) if idx != 0 else o3
#                 o4s = np.concatenate((o4s, o4), axis=0) if idx != 0 else o4

#                 o1_bin = cf_bin[mask1 == [True]]
#                 o2_bin = cf_bin[mask2 == [True]]
#                 o3_bin = cf_bin[mask3 == [True]]
#                 o4_bin = cf_bin[mask4 == [True]]

#                 o1s_bin = np.concatenate((o1s_bin, o1_bin), axis=0) if idx != 0 else o1_bin
#                 o2s_bin = np.concatenate((o2s_bin, o2_bin), axis=0) if idx != 0 else o2_bin
#                 o3s_bin = np.concatenate((o3s_bin, o3_bin), axis=0) if idx != 0 else o3_bin
#                 o4s_bin = np.concatenate((o4s_bin, o4_bin), axis=0) if idx != 0 else o4_bin

#                 ys = np.concatenate((ys, cf_effect), axis=0) if idx != 0 else cf_effect
#                 ys_bin = np.concatenate((ys_bin, cf_bin), axis=0) if idx != 0 else cf_bin


#         #if args.u_dim == 2:
#         #    draw_2dim(u, a_all, args, 'U')

#         #if args.tSNE == True:
#         #    draw_tSNE(u, a_all, args, 'U')

#         if wandb:
#             if args.use_label:
#                 wandb.log({'cf': np.sum(ys) / ys.shape[0]})
#                 wandb.log({'o1': np.sum(o1s) / o1s.shape[0]})
#                 wandb.log({'o2': np.sum(o2s) / o2s.shape[0]})
#                 wandb.log({'o3': np.sum(o3s) / o3s.shape[0]})
#                 wandb.log({'o4': np.sum(o4s) / o4s.shape[0]})

#             line = '{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\n'.format(np.sum(ys) / ys.shape[0], np.sum(o1s)/o1s.shape[0], \
#                     np.sum(o2s)/o2s.shape[0], np.sum(o3s)/o3s.shape[0], np.sum(o4s)/o4s.shape[0])
#         else:
#             line = ""
#         file_dir = os.path.abspath(os.path.join(args.save_path, os.pardir))
#         file_dir = os.path.join(file_dir, 'whole_log.txt')
#         if not os.path.exists(file_dir):
#             f = open(file_dir, 'w')
#         else:
#             f = open(file_dir, 'a')
#         f.write('a_r_{:s}_a_d_{:s}_a_y_{:s}_a_f_{:s}_u_{:d}_run_{:d}\n'\
#                           .format(str(args.a_r), str(args.a_d), str(args.a_y), str(args.a_f), args.u_dim, args.run))
#         f.write(line)
#         f.close()
