import gzip
import copy
import pickle
import os.path
import sys

import numpy as np

from tqdm import tqdm
import torch
from torch import optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision import datasets, transforms


from utils.cifar10_dataset import CIFAR10VFLNN
from options import args_parser

from model.ResNet_vfl import resnet18_vfl, resnet18_vfl_split

""" TESTING on CIFAR-10"""

def get_schedulers(scheduler, optimizer, milestones=[30,80], gamma=0.5, T_max=10, lr_mul=0.001, d_model=10, n_warmup_steps=5):
    if scheduler == "step":
        return torch.optim.lr_scheduler.StepLR(optimizer, 30, gamma=gamma)
    elif scheduler == "cosine":
        return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max)
    elif scheduler == "exponential":
        return torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)


def main():
    """Test SplitNN"""
    args = args_parser()
    lr = 1e-2
    epoch_max = 1000
    bs = 32
    Q = 1

    train_size = 50000

    criterion = torch.nn.CrossEntropyLoss()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("Training on device", device)

    net1 = resnet18_vfl(10).to(device)
    net2_encoder = resnet18_vfl_split(10).encoder.to(device)
    net2_clf = resnet18_vfl_split(10).clf.to(device)
    
    net2_clf_adv_sim = resnet18_vfl_split(10).clf_adv_sim.to(device)
    net2_clf_adv = resnet18_vfl_split(10).clf_adv.to(device)
    global_clf = resnet18_vfl(10).classifier(args.concat).to(device)

    #net2_clf_adv_init = copy.deepcopy(resnet18_vfl_split(10).clf_adv.to(device))
    net2_clf_adv_init = copy.deepcopy(net2_clf_adv)
    net2_clf_adv_sim_init = copy.deepcopy(net2_clf_adv_sim)

    '''
    net1 = lenet5_vfl(10).to(device)
    net2 = lenet5_vfl(10).to(device)
    clf = lenet5_vfl(10).classifier(args.concat).to(device)
    '''
    '''
    if save_models_filename:
        if os.path.exists(save_models_filename):
            print(f"Restoring models from {save_models_filename}")
            data = torch.load(save_models_filename)
            # print("net1", net1.state_dict().keys())
            # print("data['net1']", data["net1"].keys())
            # net1.load_state_dict(data["net1"])
            # net2.load_state_dict(data["net2"])
            net1_a = data["net1_a"]
            net1_b = data["net1_b"]
            net2 = data["net2"]
    '''
    optim_net1 = optim.SGD(net1.parameters(), lr=lr, momentum=0.9)
    if args.mc == "passive":
        optim_net2_encoder = optim.SGD(net2_encoder.parameters(), lr=lr, momentum=0.9)
    else:
        optim_net2_encoder = optim.SGD(net2_encoder.parameters(), lr=lr, momentum=0.9, dampening=0.1)
    optim_net2_clf = optim.SGD(net2_clf.parameters(), lr=lr, momentum=0.9)
    optim_net2_clf_adv = optim.SGD(net2_clf_adv.parameters(), lr=lr, momentum=0.9)
    optim_global_clf = optim.SGD(global_clf.parameters(), lr=lr, momentum=0.9)

    scheduler_net1 = torch.optim.lr_scheduler.StepLR(optim_net1, 30, gamma=0.5)
    scheduler_net2_encoder = torch.optim.lr_scheduler.StepLR(optim_net2_encoder, 30, gamma=0.5)
    scheduler_net2_clf = torch.optim.lr_scheduler.StepLR(optim_net2_clf, 30, gamma=0.5)
    scheduler_net2_clf_adv = torch.optim.lr_scheduler.StepLR(optim_net2_clf_adv, 30, gamma=0.5)
    scheduler_global_clf = torch.optim.lr_scheduler.StepLR(optim_global_clf, 30, gamma=0.5)

    '''
    transform_train = transforms.Compose(
        [
            transforms.ToTensor(),
            #transforms.ToPILImage(),
            #transforms.Pad(4, padding_mode="reflect"),
            #transforms.RandomCrop(32),
            #transforms.RandomHorizontalFlip(),
            #transforms.ToTensor(),
            transforms.Normalize(
                mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                std=[x / 255.0 for x in [63.0, 62.1, 66.7]],
            ),
        ]
    )
    transform_valid = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                std=[x / 255.0 for x in [63.0, 62.1, 66.7]],
            ),
        ]
    )
    '''
    transform_train = transforms.Compose(
         [transforms.ToTensor(),
          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    transform_valid = transforms.Compose(
         [transforms.ToTensor(),
          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    train_dataset = CIFAR10VFLNN(
        root="../data/cifar10_vertical",
        train=True,
        download=True,
        transform=transform_train,
        returns="all"
    ) 

    mc_train_dataset = CIFAR10VFLNN(
        root="../data/cifar10_vertical",
        train=True,
        download=True,
        transform=transform_train,
        returns="all",
        num_sample=args.num_mc_sample
    ) 


    valid_dataset = CIFAR10VFLNN(
        root="../data/cifar10_vertical",
        train=False,
        download=True,
        transform=transform_train,
        returns="all"
    )
 

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=bs, shuffle=True, num_workers=2
    )
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=bs, shuffle=False, num_workers=2
    )
    mc_train_loader = torch.utils.data.DataLoader(
        mc_train_dataset, batch_size=bs, shuffle=True, num_workers=2
    )

    '''
    apply_transform = transforms.Compose(
         [transforms.ToTensor(),
          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    cifar_dataset = datasets.CIFAR10("/tmp/cifar10_vertical", train=True, download=True,
                                    transform=apply_transform)
    train_loader = DataLoader(cifar_dataset, batch_size=32, shuffle=True)
    '''


    def mc_attack(args, net2_encoder, net2_clf_adv_init, mc_train_loader, valid_loader):
        clf_adv_ft = copy.deepcopy(net2_clf_adv_init)
        optim_clfadvft = optim.SGD(clf_adv_ft.parameters(), lr=lr, momentum=0.9)
        adv_train_acc, adv_val_acc = 0, 0
        for i in range(args.epoch_mc):
            epoch_loss_adv = 0
            for batch_idx, (_, inputs, targets) in enumerate(mc_train_loader):
                #pbar.set_description("Epoch {}".format(e+1))

                inputs, targets = inputs.to(device), targets.to(device)

                net2_encoder.train()
                clf_adv_ft.train()

                net2_encoder.zero_grad()
                optim_clfadvft.zero_grad()

                feature = net2_encoder(inputs)
                feature_sent = feature.detach().requires_grad_()
                pred_adv = clf_adv_ft(feature_sent)

                loss_adv = criterion(pred_adv, targets)

                loss_adv.backward(retain_graph=True)
                optim_clfadvft.step()

                epoch_loss_adv += loss_adv.item()
            
            if i%2 == 1:
                net2_encoder.zero_grad()
                optim_clfadvft.zero_grad()
                adv_train_acc = valid_mc(net2_encoder, clf_adv_ft, mc_train_loader, 2, device)
                adv_val_acc = valid_mc(net2_encoder, clf_adv_ft, valid_loader, 2, device)
                print(f"MC attack epoch {i}: train_acc: {adv_train_acc:.4f}, val_acc: {adv_val_acc:.4f}")
            
        net2_encoder.zero_grad()
        optim_clfadvft.zero_grad()

        return adv_train_acc, adv_val_acc
    
    def valid_mc(encoder, clf, data_loader, party_idx, device):
        encoder.eval()
        clf.eval()
        with torch.no_grad():
            correct, total = 0, 0
            for i, (inputs_a, inputs_b, labels) in enumerate(data_loader): 
                if party_idx == 1:
                    inputs, labels = inputs_a.to(device), labels.to(device)
                else:
                    inputs, labels = inputs_b.to(device), labels.to(device)
                outputs = clf(encoder(inputs))
                _, pred_label = torch.max(outputs.data, 1)
                total += inputs.data.size()[0]
                correct += (pred_label == labels.data).sum().item()
            metric = correct / float(total)
        return metric





    def valid(net1, net2_encoder, net2_clf, global_clf, data_loader, device, quit=False):
        net1.eval()
        net2_encoder.eval()
        net2_clf.eval()
        global_clf.eval()
        with torch.no_grad():
            correct, total = 0, 0
            for i, (inputs_a, inputs_b, labels) in enumerate(data_loader): 
                inputs_a, inputs_b, labels = inputs_a.to(device), inputs_b.to(device), labels.to(device)
                g_a = net1(inputs_a)
                g_b = net2_clf(net2_encoder(inputs_b))
                if quit:
                    g_b = torch.zeros_like(g_b).to(device)
                if args.concat is True:
                    outputs = torch.cat((g_a, g_b), dim=1)
                else:
                    outputs = g_a + g_b
                outputs = global_clf(outputs)
                _, pred_label = torch.max(outputs.data, 1)

                total += inputs_a.data.size()[0]
                correct += (pred_label == labels.data).sum().item()
            metric = correct / float(total)
        return metric

    best_acc = 0
    vfl_dropout_flag = 0
    for e in range(epoch_max):
        epoch_loss_clean = 0
        epoch_loss_adv = 0
        epoch_loss_encoder = 0
        epoch_dropout_num = 0
        #pbar = tqdm(enumerate(train_loader))
        for batch_idx, (inputs_a, inputs_b, targets) in enumerate(train_loader):
            #pbar.set_description("Epoch {}".format(e+1))

            inputs_a, inputs_b, targets = inputs_a.to(device), inputs_b.to(device), targets.to(device)

            
            net1.train()
            net2_encoder.train()
            net2_clf.train()
            net2_clf_adv.train()
            global_clf.train()

            optim_net1.zero_grad()
            optim_net2_encoder.zero_grad()
            optim_net2_clf.zero_grad()
            optim_net2_clf_adv.zero_grad()
            optim_global_clf.zero_grad()

            """Inference on local parties"""

            h_a = net1.forward(inputs_a) # keep on party-a
            h_b = net2_encoder.forward(inputs_b) # keep on party-b

            """Send features to the server"""
            
            h_b_sent = h_b.detach().requires_grad_() # send to party-a server

            """Inference on the server to get g_a & g_b & pred_b_adv"""

            g_a = h_a
            g_b = net2_clf(h_b_sent)
            pred_b_adv = net2_clf_adv(h_b_sent)


            """Step 1: Update net2_clf_adv"""
            loss_adv = criterion(pred_b_adv, targets)
            loss_adv.backward(retain_graph=True)
            optim_net2_clf_adv.step()
            optim_net2_clf_adv.zero_grad()
            h_b_sent.grad.zero_()




            """randomly set g_b as 0"""
            if args.vfl_dropout:
                vfl_dropout_flag = np.random.binomial(size=1, n=1, p= args.dropout_rate).item()
                if vfl_dropout_flag == 1:
                    g_b = torch.zeros_like(g_b).to(device)

            if args.concat == True:
                g = torch.cat((g_a, g_b), dim=1)
            else:
                g = (g_a+g_b)





            """Step 2: Update global_clf, net2_clf, net1"""
            pred = global_clf.forward(g)
            loss = criterion(pred, targets)
            loss.backward(retain_graph=True)
            optim_global_clf.step()
            optim_net2_clf.step()
            optim_net1.step()
            optim_global_clf.zero_grad()
            optim_net2_clf.zero_grad()
            optim_net1.zero_grad()
            h_b_sent.grad.zero_()


            """Inference on the server to get g_a & g_b & g again"""
            g_a = net1.forward(inputs_a).detach() # net1 does not need updating again
            g_b = net2_clf(h_b_sent)
            if vfl_dropout_flag == 1:
                    g_b = torch.zeros_like(g_b).to(device)
            if args.concat == True:
                g = torch.cat((g_a, g_b), dim=1)
            else:
                g = (g_a+g_b)


            """Step 3: Update net2_encoder"""
            if vfl_dropout_flag != 1:
                pred = global_clf.forward(g)
                pred_adv = net2_clf_adv.forward(h_b_sent)

                targets_random = torch.LongTensor(np.random.choice(range(10), size=len(targets))).to(device)

                loss_net2_encoder = -args.l_adv*criterion(pred_adv, targets) + (1-args.l_adv)*criterion(pred, targets) + args.l_adv*criterion(pred_adv, targets_random)
                loss_net2_encoder.backward()
                h_b.backward(gradient = h_b_sent.grad)
                optim_net2_encoder.step()
                optim_net2_encoder.zero_grad()

                epoch_loss_encoder += loss_net2_encoder.item()
            


            epoch_loss_clean += loss.item()
            epoch_loss_adv += loss_adv.item()

            epoch_dropout_num += vfl_dropout_flag
            vfl_dropout_flag = 0

        scheduler_net1.step()
        scheduler_net2_encoder.step()
        scheduler_net2_clf.step()
        scheduler_net2_clf_adv.step()
        scheduler_global_clf.step()

        if e % 5 == 0:
            print("MC attack 40 samples sim...")
            mc_train_acc_sim, mc_val_acc_sim = mc_attack(args, net2_encoder, net2_clf_adv_sim_init, mc_train_loader, valid_loader)
            print("MC attack 40 samples large...")
            mc_train_acc, mc_val_acc = mc_attack(args, net2_encoder, net2_clf_adv_init, mc_train_loader, valid_loader)




        epoch_len = int(train_size/bs)
        train_acc = valid(net1, net2_encoder, net2_clf, global_clf, train_loader, device)
        val_acc = valid(net1, net2_encoder, net2_clf, global_clf, valid_loader, device)
        val_acc_quit = valid(net1, net2_encoder, net2_clf, global_clf, valid_loader, device, quit=True)
        if val_acc > best_acc:
            torch.save({"net1": net1, "net2_encoder": net2_encoder, "net2_clf":net2_clf, 
                "net2_clf_adv":net2_clf_adv, "global_clf": global_clf, 
                "train_acc": train_acc, "val_acc": val_acc}, args.save_models_filename)
            best_acc = val_acc
        print(
            f"Epoch: {e}. clean loss: {epoch_loss_clean/epoch_len:.4f}, \
            adv loss: {epoch_loss_adv/epoch_len:.4f}, \
            encoder loss: {epoch_loss_encoder/epoch_len:.4f}")
        print(
            f"train_acc: {train_acc:.4f}, val_acc: {val_acc:.4f}, \
            mc_val_acc_sim: {mc_val_acc_sim:.4f}, mc_val_acc: {mc_val_acc:.4f}, \
            val_acc_quit: {val_acc_quit:.4f}, \
            dropout_rate: {epoch_dropout_num/epoch_len:.4f}"
        )

        '''
        train_acc = valid(net1_a, net1_b, net2, train_loader, device)
        val_acc = valid(net1_a, net1_b, net2, valid_loader, device)
        print(
            f"Epoch {e+1}/{epoch_max}. loss: {epoch_loss/epoch_len:.4f}, "
            f"train_acc: {train_acc:.4f}, val_acc: {val_acc:.4f}"
        )
        '''
        #writer.add_scalar("loss", epoch_loss, e)
        #writer.add_scalar("train_acc", train_acc, e)
        #writer.add_scalar("val_acc", val_acc, e)

    #if save_models_filename:
    #    torch.save({"net1_a": net1_a, "net1_b": net1_b, "net2": net2, "train_acc": train_acc, "val_acc": val_acc}, save_models_filename)


if __name__ == "__main__":
    main()
