import os
import math
import time
import torch
import random
import pickle
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader, random_split
from models import Classifier, Predictor
from copy import deepcopy
from math import sqrt
from dataset import IKDataset, MDataset, IGLDataset

num_epochs = 100

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
 
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.enabled = True
    
    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

    torch.use_deterministic_algorithms(True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

def train_reward_predictor(train_loader):
    model = Predictor().to(device)
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0002)
    
    tf = time.time()
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        num_batches = 0
        for images, action_one_hot, y_images, label, y_label in train_loader:
            images = images.to(device)

            action_one_hot = action_one_hot.to(device)
            y_images = y_images.to(device)

            outputs = model(images, y_images)
            loss = criterion(outputs, action_one_hot)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            num_batches += 1

        average_loss = total_loss / num_batches


        if average_loss < 0.08:
            break
    

    return model

def IGW(fhat, gamma):
    # fhat [bs, 10]
    fhatahat, ahat = fhat.max(dim=1)
    # fhatahat, ahat [bs]
    A = fhat.shape[1]
    gamma *= sqrt(A)
    # p [bs, 10]
    p = 1 / (A +  gamma * (fhatahat.unsqueeze(1) - fhat))
    sump = p.sum(dim=1)
    p[range(p.shape[0]), ahat] += torch.clamp(1 - sump, min=0, max=None)
    return torch.multinomial(p, num_samples=1).squeeze(1), ahat

def test_classifier(model, test_loader):
    total_acc = 0.0
    cnt = 0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            fhat = model.predict(x)
            __, greedy = fhat.max(dim=1)
            acc = (greedy == y).float().sum().item()

            total_acc += acc
            cnt += x.size(0)

    return total_acc * 100.0 / cnt

def lipschitz_reward(prob, th_center, width):
    prob = prob.view(-1)
    bs = prob.shape[0]
    reward = torch.zeros(bs).to(device)
    for i in range(bs):
        v = prob[i].item()
        if th_center - width < v < th_center + width:
            r = (v - th_center + width) * 1.0 / (2 * width)
        elif v >= th_center + width:
            r = 1
        else:
            r = 0
        reward[i] = r
    return reward

def bandit_feedback(ik_model, train_loader, test_loader, th_center, width):
    cls_model = Classifier().to(device)
    ik_model.eval()
   
    cnt = 0
    greedy_acc = 0.0
    action_acc = 0.0
    tf = time.time()
    g_list, a_list, t_list = [], [], []
    for batch_idx, (x_cls, x_ik, c_ik, w_ik, lbl) in enumerate(train_loader):
        x_cls, x_ik, c_ik, w_ik, lbl = x_cls.to(device), x_ik.to(device), c_ik.to(device), w_ik.to(device), lbl.to(device)    
        with torch.no_grad():
            fhat = cls_model.predict(x_cls)
            # [bs] int64
            action, greedy = IGW(fhat, gamma=sqrt(batch_idx))
            # [bs]
            bs = x_cls.size(0)
            y = []
            for i in range(bs):
                a, l = action[i].item(), lbl[i].item()
                if a == l:
                    y.append(c_ik[i, :, :, :].unsqueeze(0))
                else:
                    y.append(w_ik[i, :, :, :].unsqueeze(0))
            y = torch.cat(y, dim=0)
            output = ik_model(x_ik, y)
            # [bs, 1]
            prob = torch.gather(output, 1, action.view(-1, 1))
            reward = lipschitz_reward(prob, th_center, width)

            g_acc = (action == lbl).float().sum().item()
            a_acc = torch.sum(reward).item()

        loss = cls_model.bandit_learn(x_cls, action, reward)
        greedy_acc += g_acc
        action_acc += a_acc
        cnt += x_cls.size(0)

        g_list.append(greedy_acc)
        a_list.append(action_acc)

        if (batch_idx + 1) % 5000 == 0:
            train_g = greedy_acc * 100.0 / cnt
            train_a = action_acc * 100.0 / cnt
            test_acc = test_classifier(cls_model, test_loader)
            t_list.append(test_acc)
            print("Iterations:{}\tTime:{:.2f}\tTrain G:{:.2f}\tTrain A:{:.2f}\tTest:{:.2f}".format(batch_idx+1, time.time() - tf, train_g, train_a, test_acc))
            tf = time.time()
            
    return g_list, a_list, t_list

if __name__ == "__main__":
    transform_cls = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])

    transform_ik = transforms.Compose([
        transforms.ToTensor()
        ])
    
    seed = 0
    set_seed(seed)
    explore_size = 5000
    
    train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True)
    n = len(train_dataset)
    split_sizes = [explore_size, n - explore_size]
    p1_dataset, p2_dataset = random_split(train_dataset, split_sizes)
    p1_dataset = IKDataset(p1_dataset, transform_ik, seed)
    p1_loader = DataLoader(p1_dataset, batch_size=64, shuffle=True)

    p2_dataset = IGLDataset(p2_dataset, transform_cls, transform_ik)
    p2_loader = torch.utils.data.DataLoader(p2_dataset, batch_size=1, shuffle=True)

    test_set = datasets.MNIST(root='./data', train=False, download=True)
    test_set = MDataset(test_set, transform_cls)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=False)

    th_list = [0.25, 0.33]
    width_list = [0.0, 0.1, 0.2, 0.3, 0.4]

    for th in th_list:
        for width in width_list:
            g_acc_list, a_acc_list, t_acc_list = [], [], []
            print("Th:{:.2f}\tWidth:{:.2f}".format(th, width))
            for s in range(4):
                print("Seed:{}".format(s))
                set_seed(s)
                ik_model = train_reward_predictor(p1_loader)
                g_list, a_list, t_list = bandit_feedback(ik_model, p2_loader, test_loader, th, width)
                g_acc = np.array(g_list)
                a_acc = np.array(a_list)
                t_acc = np.array(t_list)

                g_acc_list.append(g_acc.reshape(1,-1))
                a_acc_list.append(a_acc.reshape(1,-1))
                t_acc_list.append(t_acc.reshape(1,-1))

            greedy_acc = np.concatenate(g_acc_list, axis=0)
            action_acc = np.concatenate(a_acc_list, axis=0)
            test_acc = np.concatenate(t_acc_list, axis=0)

            res = {}
            res["greedy"] = greedy_acc
            res["action"] = action_acc
            res["test"] = test_acc

            path = "./res/MNIST_{:.2f}_{:.2f}.pickle".format(th, width)
            with open(path, 'wb') as f:
                pickle.dump(res, f)

    

    

