import torch
from torch.utils.data import DataLoader
from torchvision import models
from torchvision import transforms
from torchvision import datasets

from easydict import EasyDict
from fling.utils.registry_utils import MODEL_REGISTRY

from fling.utils.visualize_utils import plot_2d_loss_landscape

if __name__ == '__main__':
    # Step 1: prepare the dataset.
    transform = transforms.Compose([transforms.ToTensor()])
    dataset = datasets.CIFAR100('./data/cifar100', transform=transform)

    # Test dataset is for generating loss landscape.
    test_dataset = [dataset[i] for i in range(512)]
    test_dataloader = DataLoader(test_dataset, batch_size=32)

    # Step 2: prepare the model.
    model_arg=EasyDict(
        dict(
            name='resnet8',
            input_channel=3,
            class_number=100,
        )
    )
    model_name = model_arg.pop('name')
    model = MODEL_REGISTRY.build(model_name, **model_arg)
    model.load_state_dict(torch.load('./logging/cifar100_fedpart_resnet8_iid_no_warm/before_model.ckpt'))

    # Step 3: train the randomly initialized model.
    # dataloader = DataLoader(dataset, batch_size=100)
    # device = 'cuda'
    # model = model.to(device)
    # model.train()
    # optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    # criterion = torch.nn.CrossEntropyLoss()
    # for _ in range(5):
    #     for _, (data_x, data_y) in enumerate(dataloader):
    #         data_x, data_y = data_x.to(device), data_y.to(device)
    #         pred_y = model(data_x)
    #         loss = criterion(pred_y, data_y)
    #         optimizer.zero_grad()
    #         loss.backward()
    #         optimizer.step()
    # model.to('cpu')

    # Step 4: plot the loss landscape after training the model.
    # Only one line of code for visualization!
    print("visualize")
    plot_2d_loss_landscape(
        model=model,
        dataloader=test_dataloader,
        device='cuda',
        caption='Loss Landscape Trained',
        save_path='./landscape.pdf',
        noise_range=(-0.1, 0.1),
        resolution=30,
        log_scale=True
    )