
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter

import matplotlib.pyplot as plt

import operator
from functools import reduce
from functools import partial

from timeit import default_timer
from utilities3 import *

torch.manual_seed(0)
np.random.seed(0)

#Complex multiplication
def compl_mul1d(a, b):
    # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x)
    op = partial(torch.einsum, "bix,iox->box")
    return torch.stack([
        op(a[..., 0], b[..., 0]) - op(a[..., 1], b[..., 1]),
        op(a[..., 1], b[..., 0]) + op(a[..., 0], b[..., 1])
    ], dim=-1)



class SpectralConv1d_fast(nn.Module):
    def __init__(self, in_channels, out_channels, modes1):
        super(SpectralConv1d_fast, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1


        self.scale = (1 / (in_channels*out_channels))
        self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, 2))

    def forward(self, x):
        batchsize = x.shape[0]
        #Compute Fourier coeffcients up to factor of e^(- something constant)
        x_ft = torch.rfft(x, 1, normalized=True, onesided=True)

        # Multiply relevant Fourier modes
        out_ft = torch.zeros(batchsize, self.in_channels, x.size(-1)//2 + 1, 2, device=x.device)
        out_ft[:, :, :self.modes1] = compl_mul1d(x_ft[:, :, :self.modes1], self.weights1)

        #Return to physical space
        x = torch.irfft(out_ft, 1, normalized=True, onesided=True, signal_sizes=(x.size(-1), ))
        return x

class SimpleBlock1d(nn.Module):
    def __init__(self, modes, width):
        super(SimpleBlock1d, self).__init__()

        self.modes1 = modes
        self.width = width
        self.fc0 = nn.Linear(2, self.width)

        self.conv0 = SpectralConv1d_fast(self.width, self.width, self.modes1)
        self.conv1 = SpectralConv1d_fast(self.width, self.width, self.modes1)
        self.conv2 = SpectralConv1d_fast(self.width, self.width, self.modes1)
        self.conv3 = SpectralConv1d_fast(self.width, self.width, self.modes1)
        self.w0 = nn.Conv1d(self.width, self.width, 1)
        self.w1 = nn.Conv1d(self.width, self.width, 1)
        self.w2 = nn.Conv1d(self.width, self.width, 1)
        self.w3 = nn.Conv1d(self.width, self.width, 1)
        self.bn0 = torch.nn.BatchNorm1d(self.width)
        self.bn1 = torch.nn.BatchNorm1d(self.width)
        self.bn2 = torch.nn.BatchNorm1d(self.width)
        self.bn3 = torch.nn.BatchNorm1d(self.width)


        self.fc1 = nn.Linear(self.width, 128)
        self.fc2 = nn.Linear(128, 1)

    def forward(self, x):
        x = self.fc0(x)
        x = x.permute(0, 2, 1)

        x1 = self.conv0(x)
        x2 = self.w0(x)
        x = x1 + x2
        # x = self.bn0(x1 + x2)
        x = F.relu(x)
        x1 = self.conv1(x)
        x2 = self.w1(x)
        x = x1 + x2
        # x = self.bn1(x1 + x2)
        x = F.relu(x)
        x1 = self.conv2(x)
        x2 = self.w2(x)
        x = x1 + x2
        # x = self.bn2(x1 + x2)
        x = F.relu(x)
        x1 = self.conv3(x)
        x2 = self.w3(x)
        x = x1 + x2
        # x = self.bn3(x1 + x2)


        x = x.permute(0, 2, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x

class Net1d(nn.Module):
    def __init__(self, modes, width):
        super(Net1d, self).__init__()

        self.conv1 = SimpleBlock1d(modes, width)


    def forward(self, x):
        x = self.conv1(x)
        return x


    def count_params(self):
        c = 0
        for p in self.parameters():
            c += reduce(operator.mul, list(p.size()))

        return c

PATH_DATA = 'data/KS_L64pi_N200_s8192_T100_t100.mat'

# total_index = 0
# total_error = np.zeros((8,500,4))
# for sub in [1, 2, 4, 8, 16, 32, 64, 128]:

Ntrain = 100 # training instances
Ntest = 100 # testing instances

T_in = 50 # starting time
T = 20  # time length
t_sub = 1
t = T*1//t_sub # timesteps

ntrain = Ntrain*t
ntest = Ntest*t

sub = 16 #subsampling rate
s = 8192 // sub

batch_size = 20
learning_rate = 0.0025

epochs = 10
step_size = 100
gamma = 0.5
ep_print = 1

modes = 32
width = 32


path = 'KS_fourier_res_N'+str(ntrain)+ '_s' + str(s) + '_T' + str(T)+'_ep' + str(epochs) + '_m' + str(modes) + '_w' + str(width)
path_model = 'model/'+path
path_pred = 'pred/'+path+'.mat'
path_error = 'results/'+path+'_error.mat'
path_image = 'image/'+path

dataloader = MatReader(PATH_DATA)
x_train = dataloader.read_field('u')[:Ntrain, T_in:T_in+T:t_sub, ::sub].reshape(-1,s,1)
y_train = dataloader.read_field('u')[:Ntrain, T_in+1:T_in+T+1:t_sub, ::sub].reshape(-1,s,1)
print(x_train.shape, y_train.shape)
x_test = dataloader.read_field('u')[-Ntest:, T_in:T_in+T:t_sub, ::sub].reshape(-1,s,1)
y_test = dataloader.read_field('u')[-Ntest:, T_in+1:T_in+T+1:t_sub, ::sub].reshape(-1,s,1)
print(x_test.shape, y_test.shape)

dim = x_train.shape[-1]



grid = np.linspace(0, 1, s).reshape(1, s, 1)
grid = torch.tensor(grid, dtype=torch.float)
x_train = torch.cat([x_train.reshape(ntrain,s,dim), grid.repeat(ntrain,1,1)], dim=2)
x_test = torch.cat([x_test.reshape(ntest,s,dim), grid.repeat(ntest,1,1)], dim=2)

train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size, shuffle=False)

# for trajectory testing
dataloader = MatReader(PATH_DATA)
x_test2 = dataloader.read_field('u')[-Ntest:, T_in, ::sub].reshape(-1,s,1)
y_test2 = dataloader.read_field('u')[-Ntest:, T_in+1:T_in+T+1, ::sub].reshape(-1,t,s)
x_test2 = torch.cat([x_test2.reshape(Ntest,s,1), grid.repeat(Ntest,1,1)], dim=2)
test_loader2 = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test2, y_test2), batch_size=batch_size, shuffle=False)


# model = Net1d(modes, width).cuda()
model = torch.load('model/KS_fourier_res_N2000_s512_T20_ep500_m32_w32')

print(model.count_params())

myloss = LpLoss(size_average=False)


# optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)
grid = grid.cuda()
t1 = default_timer()
error = np.zeros((500, 4))

for ep in range(epochs):

    # model.train()
    # train_mse = 0
    # train_l2 = 0
    # for x, y in train_loader:
    #     x, y = x.cuda(), y.cuda()
    #
    #     optimizer.zero_grad()
    #     out = model(x)
    #
    #     mse = F.mse_loss(out.view(-1), y.view(-1), reduction='mean')
    #     # mse.backward()
    #
    #     # y = y_normalizer.decode(y)
    #     # out = y_normalizer.decode(out)
    #     l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1))
    #     l2.backward()
    #
    #     optimizer.step()
    #     train_mse += mse.item()
    #     train_l2 += l2.item()
    #
    # scheduler.step()


    model.eval()
    test_l2 = 0.0
    test_l2_traj = 0.0
    if ep % ep_print == ep_print-1:
        with torch.no_grad():
            for x, y in test_loader:
                x, y = x.cuda(), y.cuda()

                out = model(x)
                # out = y_normalizer.decode(out)

                test_l2 += myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item()

            for xx, yy in test_loader2:
                loss = 0
                x = xx.to(device)
                yy = yy.to(device)

                for i in range(0, t):
                    y = yy[:, i:i + 1, :]
                    im = model(x)
                    loss += myloss(im.reshape(batch_size, -1), y.reshape(batch_size, -1)).item()
                    x = torch.cat((im, grid.repeat([batch_size, 1, 1])), dim=-1)

                test_l2_traj += loss

    # train_mse /= len(train_loader)
    # train_l2 /= ntrain
    test_l2 /= ntest
    test_l2_traj /= ntest


    t2 = default_timer()
    print(ep, t2-t1, test_l2, test_l2_traj)

    # print(ep, t2-t1, train_mse, train_l2, test_l2, test_l2_traj)
    t1 = default_timer()


# torch.save(model, path_model )


# test
T = 1000
dataloader = MatReader('data/KS_L64pi_s8192_T1000_t10000_test.mat')
data_test = dataloader.read_field('u')[0, 10*T_in:10*(T+T_in):10, ::sub]

x_test = data_test[:1]
y_test = data_test[1:]

print(x_test.shape)

grid = grid.cuda()
pred = torch.zeros(T, s)
pred2 = torch.zeros(T, s)

errors = torch.zeros(T,)
index = 0
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(y_test), batch_size=1, shuffle=False)


with torch.no_grad():
    test_l2 = 0
    out = x_test.cuda()

    # trajectory
    index = 0
    for y, in test_loader:
        x_in = out.view(1, s, 1)
        # x_in = x_normalizer.encode(x_in)
        x_in = torch.cat([x_in, grid], dim=2)

        y = y.cuda()

        out = model(x_in)
        # out = y_normalizer.decode(out)
        pred[index] = out.reshape(1,s)

        l2 = myloss(out.view(1, -1), y.view(1, -1)).item()
        test_l2 += l2
        errors[index] = l2
        # print(index, l2)
        index = index + 1

    # one-step
    index = 0
    for y, in test_loader:
        x_in = y.cuda().view(1, s, 1)
        # x_in = x_normalizer.encode(x_in)
        x_in = torch.cat([x_in, grid], dim=2)

        out = model(x_in)
        # out = y_normalizer.decode(out)
        pred2[index] = out.reshape(1, s)

        # print(index, l2)
        index = index + 1


print(test_l2/ntest)
scipy.io.savemat(path_pred, mdict={'pred': pred.cpu().numpy(),'pred2': pred2.cpu().numpy()})


#     total_error[total_index] = error
#     total_index = total_index + 1
#
# scipy.io.savemat(path_error, mdict={'error': total_error})
