import argparse
import logging
import traceback

import math
import numpy as np
import os
import torch
import torch.nn as nn

from matplotlib import cm
import matplotlib.pyplot as plt

from datetime import datetime
import pandas as pd

import utils
from models import MLP
import json



scale = 0.2

def true_sol(pts):
    if isinstance(pts, torch.Tensor):
        return 2*torch.exp(-pts[:,0])*torch.sin(pts[:,1])
    elif isinstance(pts, np.ndarray):
        return 2*np.exp(-pts[:,0])*np.sin(pts[:,1])
    
def generate_interior_points(num_o, random=False, tensor=False, device='cpu'):
    num_ox, num_oy = utils.closest_factors(num_o)
    if random: # random sample
        ts = np.random.random(size=num_o) # [0,1]
        xs = np.pi*np.random.random(size=num_o) # [0,1]
        pts_o = np.stack([ts, xs], axis=1)
    else: # uniform sample
        ts = np.linspace(0,1, num_ox+2)[1:-1]
        xs = np.linspace(0,np.pi, num_oy+2)[1:-1]
        ts, xs = np.meshgrid(ts, xs)
        pts_o = np.stack([ts.flatten(), xs.flatten()], axis=1)

    if tensor:
        pts_o = torch.tensor(pts_o, dtype=torch.float32, requires_grad=True).to(device)

    return pts_o
    
def generate_boundary_points(num_b, random=False, tensor=False, device='cpu'):
    ## points for boundary condition
    pts_bl = np.zeros(shape=(num_b,2)) # (t,0)
    pts_br = np.pi*np.ones(shape=(num_b,2)) # (t,pi)
    
    ## points for initial condition
    pts_i = np.zeros(shape=(num_b,2)) # (0,x)
    
    if random: # random sample
        pts_bl[:,0] = np.random.random(size=num_b)  
        pts_br[:,0] = np.random.random(size=num_b)
        
        pts_i[:,1] = np.pi*np.random.random(size=num_b)
    else: # uniform sample
        pts_bl[:,0] = np.linspace(0,1, num_b+2)[1:-1] 
        pts_br[:,0] = np.linspace(0,1, num_b+2)[1:-1]
        
        pts_i[:,1] = np.linspace(0,np.pi, num_b+2)[1:-1]
        
    if tensor:
        pts_bl = torch.tensor(pts_bl, dtype=torch.float32, requires_grad=True).to(device)
        pts_br = torch.tensor(pts_br, dtype=torch.float32, requires_grad=True).to(device)
        
        pts_i = torch.tensor(pts_i, dtype=torch.float32, requires_grad=True).to(device)

    return pts_bl, pts_br, pts_i

    
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=5884, help='random seed')
    parser.add_argument('--gpu', type=str, default='0', help='GPU to use [default: 0]')
    parser.add_argument('--random', type=bool, default=False)
    parser.add_argument('--no', type=int, default=10000)
    parser.add_argument('--nb', type=int, default=100)
    
    parser.add_argument('--m', type=int, default=1000)
    parser.add_argument('--activation', type=str, default='relu')
    parser.add_argument('--l', type=int, default=2)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--optimizer', type=str, default='Adam')
    parser.add_argument('--epochs', type=int, default=400000)
    args = parser.parse_args()
    return args
    
    
def main():
    args = parse_args()
    utils.set_random_seed(args.seed)
    
    m = args.m
    D = 2
    device = f'cuda:{args.gpu}'
    if args.activation.lower()=='relu':
        activation = nn.ReLU()
        ps = [2]
    elif args.activation.lower()=='tanh':
        activation = nn.Tanh()
        ps = [1]
    
    save_dir = f'./beam/{args.optimizer}/split_all(m{args.m})/{args.activation}/{args.seed}(lr{args.lr})'
    save_dir_temp = save_dir + '(doing)'
    utils.mkdir(save_dir_temp)
    utils.save_configs(save_dir_temp, vars(args))
    
    logger = logging.getLogger(__name__)
    logger.setLevel(logging.INFO)
    stream_handler = logging.StreamHandler()
    logger.addHandler(stream_handler)
    file_handler = logging.FileHandler(filename=os.path.join(save_dir_temp,'results.log'),
                                       mode='w', encoding='utf-8')
    logger.addHandler(file_handler)
    logger.info(args)
    
    random = args.random
    num_o = args.no
    num_b = args.nb
    
    pts_o = generate_interior_points(num_o, random=False)
    pts_bl, pts_br, pts_i = generate_boundary_points(num_b, random=False)
    g_bc = 2*np.sin(pts_i[:,1])
    
    pts_o = torch.tensor(pts_o, dtype=torch.float32, requires_grad=True).to(device)
    pts_bl = torch.tensor(pts_bl, dtype=torch.float32, requires_grad=True).to(device)
    pts_br = torch.tensor(pts_br, dtype=torch.float32, requires_grad=True).to(device)
    pts_i = torch.tensor(pts_i, dtype=torch.float32, requires_grad=True).to(device)
    g = torch.tensor(g_bc, dtype=torch.float32, requires_grad=True).to(device)
    
    plot_ts = np.linspace(0,1, 101)
    plot_xs = np.linspace(0,np.pi, 101)
    plot_ts, plot_xs = np.meshgrid(plot_ts, plot_xs)
    plot_pts = np.stack([plot_ts.flatten(),plot_xs.flatten()], axis=1)
    plot_pts = torch.tensor(plot_pts, dtype=torch.float32).to(device)
    
    
    epochs = args.epochs
    val_freq = 200
    
    PINN_loss_results = dict()
    L2_loss_results = dict()

    df_m = pd.DataFrame()
    start = datetime.now()
    logger.info(f"Case: m={m}\t({start.strftime('%y.%m.%d-%H:%M:%S')})")
    for p in ps:
        logger.info(f'\tp={p}')
        try:
            del model_1, model_2, model_3, model_4
        except:
            pass
        model_1 = MLP(m, d=2, d_out=1, p=p, D=D, activation=activation) # phi
        model_3 = MLP(m, d=2, d_out=1, p=p, D=D, activation=activation) # phi_3 = (u_xx)
        model_4 = MLP(m, d=2, d_out=1, p=p, D=D, activation=activation) # phi_4 = (u_xxx)
        model_2 = MLP(m, d=2, d_out=2, p=p, D=D, activation=activation) # phi_2 = (u_t, u_x)
        
        model_1.to(device)
        model_2.to(device)
        model_3.to(device)
        model_4.to(device)
        
        model_1.train()
        model_2.train()
        model_3.train()
        model_4.train()
        
        PINN_losses = []
        L2_losses = []

        # optimizer1 = torch.optim.SGD(params=model_1.parameters(), lr=args.lr)
        # optimizer2 = torch.optim.SGD(params=model_2.parameters(), lr=args.lr)
        # optimizer3 = torch.optim.SGD(params=model_3.parameters(), lr=args.lr)
        # optimizer4 = torch.optim.SGD(params=model_4.parameters(), lr=args.lr)
        
        optimizer1 = torch.optim.Adam(params=model_1.parameters(), lr=1.e-3)
        optimizer2 = torch.optim.Adam(params=model_2.parameters(), lr=1.e-3)
        optimizer3 = torch.optim.Adam(params=model_3.parameters(), lr=1.e-3)
        optimizer4 = torch.optim.Adam(params=model_4.parameters(), lr=1.e-3)

        best_epoch = 0
        running_start = datetime.now()
        for epoch in range(epochs+1):
            if random:
                del pts_o, pts_bl, pts_br, pts_i, g
                pts_o = generate_interior_points(num_o, random=random, tensor=True, device=device)
                pts_bl, pts_br, pts_i = generate_boundary_points(num_b, random=random, tensor=True, device=device)
                g = 2*torch.sin(pts_i[:,1]).to(device)
                
            # interior PDE loss
            out_o1 = model_1(pts_o).squeeze()
            out_o2 = model_2(pts_o).squeeze()
            out_o3 = model_3(pts_o).squeeze()
            out_o4 = model_4(pts_o).squeeze()
            
            u_t = out_o2[:,0]
            u_xxxx = torch.autograd.grad(out_o4.sum(), pts_o, create_graph=True)[0][:,1]
            loss_o = torch.square(u_t + u_xxxx).sum() / len(pts_o)
            
            # interior gradient matching loss
            grad_u = torch.autograd.grad(out_o1.sum(), pts_o, create_graph=True)[0]
            loss_gm1 = torch.square(grad_u - out_o2).sum() / len(pts_o)
          
            u_xx = torch.autograd.grad(out_o2[:,1], pts_o, grad_outputs=torch.ones_like(out_o2[:,1]), create_graph=True)[0][:,1]
            loss_gm2 = (torch.square(u_xx-out_o3).sum()) / len(pts_o)

            u_xxx = torch.autograd.grad(out_o3.sum(), pts_o, create_graph=True)[0][:,1]
            loss_gm3 = (torch.square(u_xxx-out_o4).sum()) / len(pts_o)
                        
            loss_gm = loss_gm1 + loss_gm2 + loss_gm3
            
            # boundary loss
            # (t,0)
            out_bl = model_1(pts_bl).squeeze()
            uxx_bl = model_3(pts_bl).squeeze()
            loss_bl = torch.square(out_bl).sum() + torch.square(uxx_bl).sum()
            # (t,pi)
            out_br = model_1(pts_br).squeeze()
            uxx_br = model_3(pts_br).squeeze()
            loss_br = torch.square(out_br).sum() + torch.square(uxx_br).sum()
            
            ## initial loss
            # (0,x)
            out_i = model_1(pts_i).squeeze()
            loss_i = torch.square(out_i - g).sum()
            
            loss_b = (loss_bl + loss_br + loss_i) / (len(pts_bl) + len(pts_br) + len(pts_i))
            
            loss = loss_o + loss_gm + 10*loss_b
            
            if epoch % val_freq==0:
                if PINN_losses and loss.item() < PINN_losses[-1]:
                    try:
                        os.remove(os.path.join(save_dir_temp, f'VS_all_model_1({best_epoch}).pth'))
                        os.remove(os.path.join(save_dir_temp, f'VS_all_model_2({best_epoch}).pth'))
                        os.remove(os.path.join(save_dir_temp, f'VS_all_model_3({best_epoch}).pth'))
                        os.remove(os.path.join(save_dir_temp, f'VS_all_model_4({best_epoch}).pth'))
                    except:
                        pass
                    best_epoch = epoch
                    torch.save(model_1.state_dict(), os.path.join(save_dir_temp, f'VS_all_model_1({best_epoch}).pth'))
                    torch.save(model_2.state_dict(), os.path.join(save_dir_temp, f'VS_all_model_2({best_epoch}).pth'))
                    torch.save(model_3.state_dict(), os.path.join(save_dir_temp, f'VS_all_model_3({best_epoch}).pth'))
                    torch.save(model_4.state_dict(), os.path.join(save_dir_temp, f'VS_all_model_4({best_epoch}).pth'))
                    
                    with torch.no_grad():
                        fig, axs = plt.subplots(1,2, subplot_kw={"projection": "3d"}, figsize=(10*2,10))
                        X = plot_pts[:,0].cpu().numpy()
                        Y = plot_pts[:,1].cpu().numpy()
                        Z = model_1(plot_pts).squeeze().cpu().numpy()
                        Z_true = true_sol(plot_pts).cpu().numpy()
                        axs[0].plot_surface(X.reshape(101,101), Y.reshape(101,101), Z_true.reshape(101,101),
                                            cmap=cm.coolwarm, linewidth=0, antialiased=False)
                        
                        axs[1].plot_surface(X.reshape(101,101), Y.reshape(101,101), Z.reshape(101,101), 
                                            cmap=cm.coolwarm, linewidth=0, antialiased=False)
                        axs[0].set_title('True solution')
                        axs[1].set_title(f'Predicted solution\n(MSE={np.square(Z-Z_true).mean()})')
                        
                        plt.savefig(fname=os.path.join(save_dir_temp, 'result.png'))
                    
                    
                    
                PINN_losses.append(loss.item())
                logger.info(f"epoch:{epoch}\t{loss.item()}\t={loss_o.item()}\t+{loss_gm.item()}\t+10*{loss_b.item()}")
                    
            # PINN_losses.append(loss.item())

            optimizer1.zero_grad()
            optimizer2.zero_grad()
            optimizer3.zero_grad()
            optimizer4.zero_grad()
            loss.backward()
            if epoch==0:
                computation_cost = dict()
                computation_cost['num_params'] = sum([sum( param.numel() for param in model_1.parameters()),
                                                     sum( param.numel() for param in model_2.parameters()),
                                                     sum( param.numel() for param in model_3.parameters()),
                                                     sum( param.numel() for param in model_4.parameters())])
                computation_cost['memory'] = torch.cuda.memory_allocated(device=device)
            elif epoch==49:
                running_end = datetime.now()
                computation_cost['running_time'] = ( (running_end-running_start)/50 ).total_seconds()
                del running_end, running_start
                with open(os.path.join(save_dir_temp, 'computation_cost.txt'), mode='w') as fp:
                    fp.writelines( f'{key}: {item}\n' for key,item in computation_cost.items() )
                del computation_cost

            optimizer1.step()
            optimizer2.step()
            optimizer3.step()
            optimizer4.step()
            
            del out_o1, out_o2, out_o3, out_o4
            del u_xxxx,loss_o
            del grad_u, loss_gm1
            del u_xx,loss_gm2
            del u_xxx, loss_gm3
            del loss_gm
            del loss_b, loss
            g = torch.tensor(g_bc, dtype=torch.float32, requires_grad=True).to(device)
            
            if epoch % val_freq==0:
                with torch.no_grad():
                    out_o = model_1(pts_o).squeeze()
                    loss_o = torch.sqrt(torch.square(out_o-true_sol(pts_o)).sum() * (np.square(np.pi) / len(pts_o) ))
                    L2_losses.append(loss_o.item())
                    del out_o, loss_o
            
        PINN_loss_results.update({(m,p): PINN_losses})
        L2_loss_results.update({(m,p): L2_losses})
        df_m[p] = PINN_losses
    
    df_m = df_m.transpose()
    df_m.to_csv(os.path.join(save_dir_temp,f'losses.csv'), index=True)
   
    del start
    logger.info('Finished')
    os.rename(save_dir_temp, save_dir)

if __name__ == '__main__':
    try:
        main()
    except:
        logging.error(traceback.format_exc())
        args = parse_args()
        save_dir = f'./beam/{args.optimizer}/split_all(m{args.m})/{args.activation}/{args.seed}(lr{args.lr})'
        save_dir_temp = save_dir + '(doing)'
        os.rename(save_dir_temp, save_dir+'(error)')