
import os, json, random
from config.options import *
from config.utils import *
from config.learning_rates import get_learning_rate_scheduler
os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id
opts.BatchSize = opts.batch_size * opts.accumulation_steps * opts.gpu_num
torch.set_printoptions(threshold=10_000)

from InpaintRewardDataset import InpaintRewardDataset
from ImageReward import ImageReward
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
import torch.nn.functional as F
from torch.backends import cudnn

import sys

mse_loss = nn.MSELoss()


def std_log():
    if get_rank() == 0:
        save_path = make_path()
        makedir(config['log_base'])
        sys.stdout = open(os.path.join(config['log_base'], "{}.txt".format(save_path)), "w")


def init_seeds(seed, cuda_deterministic=True):
    torch.manual_seed(seed)
    if cuda_deterministic:  # slower, more reproducible
       cudnn.deterministic = True
       cudnn.benchmark = False
    else:  # faster, less reproducible
       cudnn.deterministic = False
       cudnn.benchmark = True


def loss_func(reward, target, rank):
    
    # target = torch.zeros(reward.shape[0], dtype=torch.long).to(reward.device)
    target = target.to(reward.device)
    
    squared_diffs = (reward - target) ** 2
    l2_loss = squared_diffs.mean()
    loss_l1 = torch.abs(reward - target)
    
    # reward_diff = reward[:, 0] - reward[:, 1]
    # acc = torch.mean((reward_diff > 0).clone().detach().float())
    accuracy = 0
    if rank[0] == 1:
        if reward > 0:
            accuracy = 1
    elif rank[0] == 3:
        if reward < 0:
            accuracy = 1
    else:
        accuracy = 0.5
    # tolerance_values = 0.2 * target
    # correct_predictions = torch.abs(target - reward) <= tolerance_values
    # accuracy = correct_predictions.float().mean()
    
    return l2_loss, squared_diffs, accuracy, loss_l1


def comput_V(fea):
    fea = torch.matmul(fea.t(), fea)
    I = torch.eye(fea.size(0)).to('cuda')
    result = fea + 0.01 * I

    return result



if __name__ == "__main__":
    
    if opts.std_log:
        std_log()
    
    if opts.distributed:
        torch.distributed.init_process_group(backend="nccl")
        local_rank = torch.distributed.get_rank()
        torch.cuda.set_device(local_rank)
        device = torch.device("cuda", local_rank)
        init_seeds(opts.seed + local_rank)
        
    else:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        init_seeds(opts.seed)
        
    
    writer = visualizer()

    test_dataset = InpaintRewardDataset("test", config['root_path'])
   
    test_loader = DataLoader(test_dataset, batch_size=opts.batch_size, shuffle=True, collate_fn=collate_fn if not opts.rank_pair else None)
   
    # Set the training iterations.



    model = ImageReward(device).to(device)
    
    if opts.preload_path:
        model = preload_model(model)
    
   
    if opts.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model)

 

    # test model
    if get_rank() == 0:
        
        print("test: ")
        model = load_model(model, './checkpoint/ade20k5688_exp_all_gpu7_bs2_fix=0.7_lr=1e-05cosine/best_lr=49.pt')
        model.eval()


        scores_by_img_id = {}
        cnt = 0
        sum_features = torch.zeros((64, 64)).to('cuda')
        with torch.no_grad():
            for step, batch_data_package in tqdm(enumerate(test_loader)):
                # probability = random.random()
                # if probability > 0.01:
                #     continue
                if cnt >= 10:
                    break
                cnt += 1
                img_id = batch_data_package[0]['img_id_v123']
                rewards, last_feature = model(batch_data_package)
                sum_features += comput_V(last_feature[0])
                # scores = torch.stack([batch_data_package[k]['score'] for k in range(len(batch_data_package))], dim=0)
                # ranks = torch.stack([batch_data_package[k]['rank'] for k in range(len(batch_data_package))], dim=0)
                # loss, loss_list, acc, loss_l1 = loss_func(rewards, scores, ranks)
                # scores_by_img_id[img_id] = loss_l1.item() 
        
        sum_features /= 10
        print(sum_features)
        print(sum_features.shape)
        V_inverse = torch.inverse(sum_features)
        print(V_inverse)
        V_values = {}
        with torch.no_grad():
            for step, batch_data_package in tqdm(enumerate(test_loader)):
            
                img_id = batch_data_package[0]['img_id_v123']
                rewards, last_feature = model(batch_data_package)
                value = torch.sum(torch.pow(torch.matmul(last_feature[0], V_inverse), 2))
                # value = torch.sum(torch.pow(last_feature[0], 2))
                V_values[img_id] = value.item()
                # sum_features = comput_V(sum_features, last_feature)
                # scores = torch.stack([batch_data_package[k]['score'] for k in range(len(batch_data_package))], dim=0)
                # ranks = torch.stack([batch_data_package[k]['rank'] for k in range(len(batch_data_package))], dim=0)
                # loss, loss_list, acc, loss_l1 = loss_func(rewards, scores, ranks)
                # scores_by_img_id[img_id] = loss_l1.item() 

                
        json_data = json.dumps(V_values, indent=4)

        with open('V_values_img_id_sample_10.json', 'w') as file:
                file.write(json_data)
        # test_loss = torch.cat(test_loss, 0)
        # print('Test Loss %6.5f | Acc %6.4f' % (torch.mean(test_loss), sum(acc_list) / len(acc_list)))

