# -----------------------------------------------------------------------------------
# SemanIR: Sharing Key Semantics in Transformer Makes Efficient Image Restoration
# -----------------------------------------------------------------------------------


from model import util
import os
from glob import glob
import argparse
import torch
from model.semanir import SemanIR


def get_data(testset, scale):
    img_list = sorted(glob(f"./dataset/{testset}_X{scale}/*.png"))
    return img_list


def load_model(upscale=4, version="small"):

    if version == "small":
        # SemanIR Small
        model = SemanIR(
            window_size=8,
            top_k=32,
            depths=[6, 6, 6, 6, 6, 6],
            embed_dim=180,
            num_heads=[6, 6, 6, 6, 6, 6],
            mlp_ratio=2,
            upscale=upscale,
            upsampler="pixelshuffle",
            conv_type="1conv",
            img_size=32,
            img_range=1.0,
            fairscale_checkpoint=False,
            offload_to_cpu=False,
            version="v2"
        )
    else:
        raise NotImplementedError(f"Model version {version} not implemented!")

    from model.common import model_analysis
    model_analysis(model)

    checkpoint_path = f"./model_zoo/sr_{version}_c_x{upscale}.ckpt"
    if os.path.isfile(checkpoint_path):
        state_dict = torch.load(checkpoint_path)
        model.load_state_dict(state_dict, strict=True)

    # print(model)
    return model

def main(args):
    # Data
    img_list = get_data(args.test_set, args.scale)
    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device = "cpu"
    # Model
    model = load_model(args.scale, args.model_size)
    model = model.to(device)

    for img_path in img_list:
        img_lr = util.imread_uint(img_path, n_channels=3)
        img_lr = util.uint2tensor4(img_lr, args.data_range)
        img_lr = img_lr.to(device)

        img_sr = model(img_lr)
        img_sr = util.tensor2uint(img_sr, args.data_range)

        save_path = "./results/sr"
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        util.imsave(img_sr, os.path.join(save_path, os.path.basename(img_path)))


if __name__ == "__main__":
    parser = argparse.ArgumentParser("SemanIR SR test code")
    parser.add_argument("--data_range", default=1.0, type=float)
    parser.add_argument("--test_set", default="Set5", type=str)
    parser.add_argument("--scale", default=4, type=int)
    parser.add_argument("--model_size", default="small", type=str)
    args = parser.parse_args()
    main(args)
