import os
import sys
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.colors import Normalize
import h5py
from tqdm import tqdm
from pdb import set_trace as bp


datapath = "datapath"

train_file = "train_o1_10_32k"
train_files = ["poisson_64_e"+e+"_train" for e in ["1_5", "1_20", "5_15", "15_20"]]


for train_file in train_files:
    f = h5py.File(os.path.join(datapath, "%s.h5"%train_file), "r")
    x_train = f['fields'] # N, 2, n_demos, H, W
    x_tensor = f['tensor']

    if len(x_train.shape) == 4:
        x_train = np.expand_dims(x_train, 2)

    source_norm = []
    sol_max = []
    tensor_max = []
    nx = ny = 64
    lx = ly = 1
    num_ten = x_tensor.shape[1]

    for i in tqdm(range(x_train.shape[0])):
        for j in range(x_train.shape[2]):
            sn = np.linalg.norm(x_train[i,0,j]) * lx/nx * ly/ny
            source_norm.append(sn)
            sol_max.append(np.max(np.abs(x_train[i,1,j])))
        tensor_max.append([np.abs(x_tensor[i,t_idx]) for t_idx in range(num_ten)])

    tensor_max = np.array(tensor_max)
    source_scale = np.median(source_norm)
    sol_scale = np.median(sol_max)
    tensor_scale = [np.median(tensor_max[:,j]) for j in range(num_ten)]

    scale = [source_scale] + tensor_scale + [sol_scale] + [lx, ly]
    print(scale)

    np.save(os.path.join(datapath, "%s_scale.npy"%train_file), scale)

    f.close()
