
import copy
import torch
import torchvision
import matplotlib.pyplot as plt


def state_dict_dif(state_dict_1, state_dict_2):
    out = copy.deepcopy(state_dict_1)
    for key, item in out.items():
        res = ~torch.isclose(state_dict_1[key], state_dict_2[key], rtol=0.0001)
        out[key] = res
    return out

def display_state_dict(state_dict, save_path):
    image = create_state_dict_image(state_dict)
    image_rgb = copy.deepcopy(image/torch.max(image))
    cmap = plt.get_cmap('magma')
    image_rgb = torch.from_numpy(cmap(image_rgb.numpy())).permute(2,0,1)
    
    image_rgb[0,:,:][image == 0.5] = 0.5
    image_rgb[1,:,:][image == 0.5] = 0.5
    image_rgb[2,:,:][image == 0.5] = 0.5

    torchvision.utils.save_image(image_rgb, save_path)


def create_state_dict_image(state_dict):
    size_x, size_y = 0, 0
    margin = 10

    for key, item in state_dict.items():
        if len(item.shape) > 1:
            size_y = max(size_y, item.shape[1])
        elif len(item.shape) == 1:
            size_y = max(size_y, item.shape[0])

        if len(item.shape) > 1:
            size_x += margin + item.shape[0]
        else:
            size_x += 1 + margin

    image = 0.5*torch.ones((size_x, size_y))

    x_pos = 0

    for key, item in state_dict.items():
        if len(item.shape) > 1:
            width = item.shape[1]
            height = item.shape[0]

            if len(item.shape) == 4:
                image[x_pos:(x_pos + height), int(size_y/2 - width/2):(int(size_y/2 - width/2) + width)] = item[:,:,0,0]
            else:
                 image[x_pos:(x_pos + height), int(size_y/2 - width/2):(int(size_y/2 - width/2) + width)] = item
            x_pos += height + margin

        elif len(item.shape) == 1:
            width = item.shape[0]
            height = 1
            image[x_pos:(x_pos + height), int(size_y/2 - width/2):int(size_y/2 + width/2)] = item
            x_pos += height + margin
        else:
            print(key)

    return image

if __name__ == '__main__':
    pass
