import argparse
import os
parser = argparse.ArgumentParser('Training FNO')
parser.add_argument('--lr',type=float, default=1e-3)
parser.add_argument('--epochs',type=int, default=500)
parser.add_argument('--weight_decay',type=float,default=1e-4)
parser.add_argument("--n1", type=int, default=32)
parser.add_argument("--n2", type=int, default=32)
parser.add_argument("--width", type=int, default=32, help="Width")
parser.add_argument('--batch-size',type=int, default=16)
parser.add_argument("--use_tb", type=int, default=0, help="Use TensorBoard: 1 for True, 0 for False")
parser.add_argument("--gpu", type=str, default='1', help="GPU index to use")
parser.add_argument('--max_grad_norm',type=float, default=1)
parser.add_argument('--train_downsample',type=int,default=1)
parser.add_argument('--test_downsample',type=int,default=1)
parser.add_argument('--dropout',type=float, default=0.)
parser.add_argument('--dropout_type',type=str, default="GD", help="Dropout Type: MC for typical dropout, GD for Gaussian dropout")
parser.add_argument('--model',type=str,default = "ofno")
args = parser.parse_args()

os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
import torch
from timeit import default_timer
from utilities3 import *
from AM_FNO import  FNO2d, FNO2dMLP
from FNOs import UFNO2d , FNOFactorizedMesh2D, vannilaFNO2d
torch.manual_seed(42)
np.random.seed(42)



TRAIN_PATH = ''
TEST_PATH = ''

ntrain = 1000
ntest = 200


epochs = 500

sub_train = args.train_downsample
sub_test = args.test_downsample
S_train = 64 //sub_train
S_test = 64 // sub_test
T_in = 10
T = 10 
step = 1



reader = MatReader(TRAIN_PATH)
train_a = reader.read_field('u')[:ntrain,::sub_train,::sub_train,:T_in]
train_u = reader.read_field('u')[:ntrain,::sub_train,::sub_train,T_in:T+T_in]

reader = MatReader(TEST_PATH)
test_a = reader.read_field('u')[-ntest:,::sub_test,::sub_test,:T_in]
test_u = reader.read_field('u')[-ntest:,::sub_test,::sub_test,T_in:T+T_in]

print(train_u.shape)
print(test_u.shape)
#assert (S == train_u.shape[-2])
#assert (T == train_u.shape[-1])

train_a = train_a.reshape(ntrain,S_train,S_train,T_in)
test_a = test_a.reshape(ntest,S_test,S_test,T_in)

train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(train_a, train_u), batch_size=args.batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=args.batch_size, shuffle=False)


if args.model == "ofno":
    model = FNO2d( n1 = args.n1, n2 = args.n2, width = args.width, input_dim=T, output_dim = 1, mlp_dropout = args.dropout, H=S_train, W=S_train).cuda()
elif args.model == "ufno":
    model = UFNO2d(12,12,32,input_dim=T, output_dim = 1).cuda()
elif args.model == "ffno":
    model = FNOFactorizedMesh2D(modes_x=12, modes_y=12, width=32,input_dim=T, output_dim = 1).cuda()
elif args.model == "fno":
    model = vannilaFNO2d(12,12,32,input_dim=T, output_dim = 1).cuda()
elif args.model == "fnoall":
    model = vannilaFNO2d(64,33,32,input_dim=T, output_dim = 1).cuda()
elif args.model == "fnomlp":
    model = FNO2dMLP(n1 =args.n1, n2 = args.n2, width = args.width,  input_dim=T, output_dim = 1, mlp_dropout = 0, H=S_train, W=S_train).cuda()  

#model = UFNO2d(20,20,32,input_dim=T_in).cuda()
#model = FNOFactorizedMesh2D(modes_x=20, modes_y=20, width=32, input_dim=12).cuda()
print(count_params(model))
print(args)
print(model)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs*len(train_loader))

myloss = LpLoss(size_average=False)
for ep in range(args.epochs):
    model.train()
    t1 = default_timer()
    train_l2_step = 0
    train_l2_full = 0
    for xx, yy in train_loader:
        loss = 0
        bsz = xx.shape[0]
        xx = xx.to(device)
        yy = yy.to(device)

        for t in range(0, T, step):
            y = yy[..., t:t + step]
            im = model(xx)
            loss += myloss(im.reshape(bsz, -1), y.reshape(bsz, -1))

            if t == 0:
                pred = im
            else:
                pred = torch.cat((pred, im), -1)

            xx = torch.cat((xx[..., step:], im), dim=-1)

        train_l2_step += loss.item()
        l2_full = myloss(pred.reshape(bsz, -1), yy.reshape(bsz, -1))
        train_l2_full += l2_full.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

    test_l2_step = 0
    test_l2_full = 0
    with torch.no_grad():
        for xx, yy in test_loader:
            loss = 0
            xx = xx.to(device)
            yy = yy.to(device)
            bsz = xx.shape[0]
            for t in range(0, T, step):
                y = yy[..., t:t + step]
                im = model(xx)
                loss += myloss(im.reshape(bsz, -1), y.reshape(bsz, -1))

                if t == 0:
                    pred = im
                else:
                    pred = torch.cat((pred, im), -1)

                xx = torch.cat((xx[..., step:], im), dim=-1)

            test_l2_step += loss.item()
            test_l2_full += myloss(pred.reshape(bsz, -1), yy.reshape(bsz, -1)).item()

    t2 = default_timer()
    print(ep, t2 - t1, train_l2_step / ntrain / (T / step), train_l2_full / ntrain, test_l2_step / ntest / (T / step),
          test_l2_full / ntest)
    
#torch.save(model.state_dict(), "".format(args.model))


