import torch
import os
import numpy as np
from fno.fno import FNO2dWithBackbone
import glob

ckpt_pths = "model.pt"
print(len(ckpt_pths))
pretrain_netsize = 4
wider_netsize = 12

bsln_dir = '/bsln_dir'

for ckpt_pth in ckpt_pths:
    ckpt = torch.load(ckpt_pth, map_location=torch.device('cpu'))
    print(ckpt_pth)

    for k, v in ckpt['model_state_dict'].items():
        if "fc0.weight" in k:
            new_tensor = []
            for _ in range(10):
                new_tensor.append(torch.clone(v[:, [0, 1]])/10)
            new_tensor.append(torch.clone(v[:, [2, 3]]))
            new_tensor = torch.cat(new_tensor, dim=-1)/10
            print(new_tensor.shape)
            ckpt['model_state_dict'][k] = new_tensor
            
    torch.save(ckpt, ckpt_pth[:-3]+'_extend.pt') 