import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.autograd.functional as F
import torch.optim as optim
import torchdiffeq
import datetime
import numpy as np
import argparse
import json
import logging
import os
from matplotlib.pyplot import *
from mpl_toolkits.mplot3d import axes3d

import sys
sys.path.append('..')

from src.util import *

device = "cuda" if torch.cuda.is_available() else "cpu"

# 1. define system
dyn_sys= "lorenz"
dyn, dim = define_dyn_sys(dyn_sys)
time_step= 0.01
ind_func = 0
s = 0.2
hidden = 256
model = 'FNO'
num_trajectories = 5000
long_len_T = 3000*int(1/time_step)
init = "inside"

if init == "inside":
    true_initial_condition = torch.tensor([-9.116407, -3.381641, 33.748295]).reshape(1, dim)
    pdf_path = '../plot/dist_inside_'+str(model)+'.jpg'
else:
    true_initial_condition = torch.tensor([-15, -15, -5.]).reshape(1, dim)
    pdf_path = '../plot/dist_outside_all'+str(model)+'.jpg'
    pdf_path_2 = '../plot/dist_outside_all_Res.jpg'

model='MLP_skip'
MSE_MS_path = "../plot/Vector_field/"+str(dyn_sys)+"/best_model_MSE.pth"
JAC_MS_path = "../plot/Vector_field/"+str(dyn_sys)+"/best_model_JAC.pth"
mse_ms_model = ODE_MLP_skip(y_dim=dim, n_hidden=512, n_layers=5).to(device)
best_ms_model = ODE_MLP_skip(y_dim=dim, n_hidden=1024, n_layers=5).to(device)

# model='MLP'
# MSE_mlp_path = "../plot/Vector_field/"+str(dyn_sys)+"/"+str(model)+"_MSE_fullbatch/best_model.pth"
# JAC_mlp_path = "../plot/Vector_field/"+str(dyn_sys)+"/"+str(model)+"_Jacobian_fullbatch/best_model.pth"
# mse_mlp_model = ODE_MLP(y_dim=dim, n_hidden=512, n_layers=7).to(device)
# best_mlp_model = ODE_MLP(y_dim=dim, n_hidden=512, n_layers=7).to(device)


mse_ms_model.load_state_dict(torch.load(MSE_MS_path))
best_ms_model.load_state_dict(torch.load(JAC_MS_path))
# mse_mlp_model.load_state_dict(torch.load(MSE_mlp_path))
# best_mlp_model.load_state_dict(torch.load(JAC_mlp_path))
mse_ms_model.eval()
best_ms_model.eval()
# mse_mlp_model.eval()
# best_mlp_model.eval()

# Function to generate data
def generate_data(model, initial_condition, shortinitial_condition, is_dynamical=True):
    # Long orbit generation
    one_step = torch.linspace(0, time_step, 2).to(device)
    long_orbit = vectorized_simulate(model, initial_condition, one_step, long_len_T, device).detach().cpu().numpy()
    
    # short_orbit = vectorized_simulate(model, shortinitial_condition, one_step, short_len_T, device).detach().cpu().numpy()
    # print("short orbit length", short_orbit.shape)
    short_orbit = None
    
    return long_orbit, short_orbit

def calculate_relative_error(model, dyn, device):
    # Simulate an orbit using the true dynamics
    time_step = 0.01  # Example timestep, adjust as needed
    orbit = torchdiffeq.odeint(dyn, torch.randn(3), torch.arange(0, 100, time_step), method='rk4', rtol=1e-8)
    
    orbit = orbit.to('cuda')
    # Compute vector field from model and true dynamics
    vf_nn = model(0, orbit).detach()
    vf_true = torch.stack([dyn(0, orbit[i]) for i in range(orbit.size(0))])

    # Calculate relative error
    err = torch.linalg.norm(vf_nn - vf_true, dim=1)
    mag = torch.linalg.norm(vf_true, dim=1)
    # relative_error = torch.mean(err / mag).item() * 100  # As percentage
    print("err", err.shape, mag.shape)
    relative_error = err / mag
    print("max", torch.max(relative_error))
    print("min", torch.min(relative_error))
    print("mean", torch.mean(relative_error))
    return relative_error

rel = calculate_relative_error(mse_ms_model, dyn, 'cuda')
rel = calculate_relative_error(best_ms_model, dyn, 'cuda')

