import logging
import os
import datetime

import torch
import torch.nn as nn
import higher
import torch.backends.cudnn as cudnn
import copy
import numpy as np

from torch.utils.data import DataLoader
from torch import optim

from utils.parameters import get_parameter
from utils.dataset import construct_datasets, targetData, poisonData
from utils.model import Linear, ResNet18, VGG16, MobileNetV2
from utils.utils import weight_init, cw_loss, ut_loss, save_results, set_random_seed, test_one_image, decay_lrate
from pretrain import train, test
from first_order_victim import victim


def main():
    args = get_parameter()

    #set_random_seed(args.seed)

    """# Create a log file"""
    if not os.path.exists(args.logdir):
        os.makedirs(args.logdir)

    log_path = args.craftproj + '-log-%s' % (datetime.datetime.now().strftime("%Y-%m-%d-%H:%M-%S"))
    log_path = log_path + '.txt'
    logging.basicConfig(
        filename=os.path.join(args.logdir, log_path),
        format="%(asctime)s - %(name)s - %(message)s",
        datefmt='%d-%b-%y %H:%M:%S', level=logging.INFO, filemode='w')

    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    logger.info(str(args))

    if torch.cuda.is_available():
        cudnn.benchmark = True
    else:
        args.device = "cpu"
    print(f'device: {args.device}')
    logger.info(f'device: {args.device}')

    print('==> loading dataset')
    logging.info('==> loading dataset')

    train_data, test_data = construct_datasets(args, args.dataset, args.datadir, load=True)
    train_loader = DataLoader(train_data, batch_size=args.batchsize, shuffle=True, num_workers=1)
    test_loader = DataLoader(test_data, batch_size=args.batchsize, shuffle=True, num_workers=1)

    print('==> building model')
    logging.info('==> building model')

    input_shape = len(train_data[0][0])
    num_classes = args.num_classes
    if args.net == 'ResNet18':
        model = ResNet18(num_classes).to(args.device)
    elif args.net == 'VGG16':
        model = VGG16().to(args.device)
    elif args.net == 'MobileNetV2':
        model = MobileNetV2().to(args.device)
    elif args.net == 'Linear':
        model = Linear(input_shape, num_classes).to(args.device)

    loss_func = nn.CrossEntropyLoss()
    if args.opt == 'Adam':
        optimizer = optim.Adam(model.parameters(), lr=args.lr)
    elif args.opt == 'SGD':
        optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)

    if args.pretrained:
        checkpoint = torch.load(
            os.path.join(args.moddir, args.dataset + '_' + args.net + '_' + str(args.modname) + '.pth'),
            map_location=args.device)
        model.load_state_dict(checkpoint['net'])
    else:
        train(args, logging, model, loss_func, optimizer, train_loader)

        # temporary save model
        if not os.path.exists(args.moddir):
            os.makedirs(args.moddir)

        state = {
            'net': model.state_dict(),
            'epoch': args.epoch,
            'batch_size': args.batchsize,
            'optimizer': optimizer,
        }
        torch.save(state, os.path.join(args.moddir, args.dataset + '_' + args.net + '_' + str(args.modname) + '.pth'))

    test(args, logging, model, test_loader)

    print('==> preparing data')
    logging.info('==> preparing data')
    # random choice target class if not specified
    if args.targetclass == -1 and args.poisonclass == -1:
        num_classes = args.num_classes
        args.targetclass = np.random.randint(num_classes)
        list_intentions = list(range(num_classes))
        list_intentions.remove(args.targetclass)
        args.poisonclass = np.random.choice(list_intentions)

        targetclass_label = test_data.classes[args.targetclass]
        print(f'Target class: {targetclass_label}, id: {args.targetclass}')
        logging.info(f'Target class: {targetclass_label}, id: {args.targetclass}')

        poisonclass_label = test_data.classes[args.poisonclass]
        print(f'Poison class: {poisonclass_label}, id: {args.poisonclass}')
        logging.info(f'Poison class: {poisonclass_label}, id: {args.poisonclass}')

    if len(args.targetids) == 0:
        # pick correct prediction test samples
        list_intentions = []
        for i in range(len(test_data)):
            if test_data[i][1] == args.targetclass:
                list_intentions.append(i)
        for i in range(args.ntargets):
            while True:
                test_id = np.random.choice(list_intentions)
                list_intentions.remove(test_id)
                test_image = test_data[test_id][0]
                test_pred, prob = test_one_image(args, test_image, model)
                if test_pred == args.targetclass and prob < args.threshold:
                    args.targetids.append(test_id)
                    break
        print(f'Target ids: {args.targetids}')
        logging.info(f'Target ids: {args.targetids}')

    target_data = torch.utils.data.Subset(test_data, args.targetids)
    target_loader = DataLoader(target_data, batch_size=args.batchsize, shuffle=False, num_workers=1)

    poisonids = np.arange(len(train_data))
    base_data = torch.utils.data.Subset(train_data, poisonids)
    base_loader = DataLoader(base_data, batch_size=args.batchsize, shuffle=False, num_workers=1)

    poison_weight = weight_init(args, poisonids, train_data).to(args.device)
    att_optimizer = torch.optim.Adam([poison_weight], lr=args.craftrate, weight_decay=0)
    poison_weight.grad = torch.zeros_like(poison_weight)

    print('==> begin crafting')
    logging.info('==> begin crafting')

    for idx in range(args.ncraftstep):
        print(f'Step {idx}')
        logging.info(f'Step {idx}')

        loss, n_batch = 0, 0

        M_unlearned = copy.deepcopy(model)
        if args.opt == 'Adam':
            optimizer_unlearned = optim.Adam(M_unlearned.parameters(), lr=args.tau)
        else:
            optimizer_unlearned = optim.SGD(M_unlearned.parameters(), lr=args.tau, momentum=0.9, weight_decay=5e-4)
        loss_func = nn.CrossEntropyLoss(reduction='none')

        for batch, example in enumerate(base_loader):
            inputs, targets, _ = example
            inputs, targets = inputs.to(args.device), targets.to(args.device)

            start, end = batch * args.batchsize, batch * args.batchsize + len(inputs)
            weight_slice = poison_weight[start: end]
            weight_slice.requires_grad_()
            h2 = 1 / (1 + torch.exp(-args.theta * (2 * weight_slice - 1)))

            M_unlearned.zero_grad()
            with torch.backends.cudnn.flags(enabled=False):
                with higher.innerloop_ctx(M_unlearned, optimizer_unlearned) as (net, opt):
                    result_z = net(inputs)
                    loss_z = loss_func(result_z, targets)
                    loss_mul = torch.mul(loss_z, h2).mean()
                    opt.step(-loss_mul)

                    net.eval()
                    # TODO targeted attack and untargeted attack
                    if args.atsetting == 'targeted':
                        for _, (images, labels, _) in enumerate(target_loader):
                            images, labels = images.to(args.device), labels.to(args.device)
                            target_outs = net(images)
                            target_loss = cw_loss(target_outs, args.poisonclass)
                    elif args.atsetting == 'untargeted':
                        for _, (images, labels, _) in enumerate(target_loader):
                            images, labels = images.to(args.device), labels.to(args.device)
                            target_outs = net(images)
                            target_loss = ut_loss(target_outs, args.targetclass)
                    else:
                        raise NotImplementedError("Not support!")

                    grads = torch.autograd.grad(target_loss, weight_slice)[0].detach()
                    poison_weight.grad[start: end] = grads


                    net.train()

            n_batch += 1
            loss += target_loss.item()

        print('Target loss: %.4f' % (loss/n_batch))
        logging.info('Target loss: %.4f ' % (loss / n_batch))
        att_optimizer.step()
        att_optimizer.zero_grad()
        with torch.no_grad():
            poison_weight.data = torch.clamp(poison_weight, 0, 1)

    save_results(args, poison_weight)
    victim(args, logging, poison_weight, train_data, test_data)

if __name__ == '__main__':
    main()