import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
import os
import csv
import matplotlib.pyplot as plt
import cv2
import PIL
import pickle

def weight_init(args, poisonids, train_data):
    # randomly initialize pertubations
    r1, r2 = 0.45, 0.55
    init = (r1 - r2) * torch.rand(len(poisonids)) + r2

    return init


def pertubation_init(args, poisonids, train_data):
    # randomly initialize pertubations
    ds = torch.tensor(train_data.data_std)[None, :, None, None]
    init = torch.randn(len(poisonids), *train_data[0][0].shape)
    init *= args.eps / ds / 255
    init.data = torch.max(torch.min(init, args.eps / ds / 255), -args.eps / ds / 255)

    return init

# Carlini-Wagner loss
def cw_loss(outputs, intended_classes, clamp=-100):
    remove_one = torch.cat((outputs[:, :intended_classes], outputs[:, intended_classes+1:]), -1)
    top_logits, _ = torch.max(remove_one, 1)
    intended_logits = torch.stack([outputs[i, intended_classes] for i in range(outputs.shape[0])])
    difference = torch.clamp(top_logits - intended_logits, min=clamp)
    return torch.mean(difference)

# untargeted loss
def ut_loss(outputs, targetclass, clamp=-100):
    remove_one = torch.cat((outputs[:, :targetclass], outputs[:, targetclass + 1:]), -1)
    top_logits, _ = torch.max(remove_one, 1)
    target_logits = torch.stack([outputs[i, targetclass] for i in range(outputs.shape[0])])
    difference = torch.clamp(target_logits - top_logits, min=clamp)
    return torch.mean(difference)

# compute d/dW [loss(z_delta, model), loss(z, model)]
def get_grad_diff(args, model, poison_weight, base_loader):
    loss_func = nn.CrossEntropyLoss()
    model.eval()
    grads = []

    for i, (images, labels, _) in enumerate(base_loader):
        images, labels = images.to(args.device), labels.to(args.device)

        result_z = model(images)
        loss_z = loss_func(result_z, labels)
        loss_diff = -loss_z

        differentiable_params = [p for p in model.parameters() if p.requires_grad]
        gradients = torch.autograd.grad(loss_diff, differentiable_params)
        grads.append(gradients)

    # add all grads from batch
    grads = list(zip(*grads))
    for i in range(len(grads)):
        tmp = grads[i][0]
        for j in range(1, len(grads[i])):
            tmp = torch.add(tmp, grads[i][j])
        grads[i] = tmp

    return grads


def hvp_train(model, x, y, v):
    """ Hessian vector product. """
    grad_L = get_gradients_train(model, x, y, v)
    # v_dot_L = [v_i * grad_i for v_i, grad_i in zip(v, grad_L)]
    differentiable_params = [p for p in model.parameters() if p.requires_grad]
    v_dot_L = torch.sum(torch.stack([torch.sum(grad_i * v_i) for grad_i, v_i in zip(grad_L, v)]))

    hvp = list(torch.autograd.grad(v_dot_L, differentiable_params, retain_graph=True))
    return hvp


def get_gradients_train(model, x, y, v):
    """ Calculate dL/dW (x, y) """
    loss_func = nn.CrossEntropyLoss()
    result = model(x)
    loss = loss_func(result, y)

    differentiable_params = [p for p in model.parameters() if p.requires_grad]
    grads = torch.autograd.grad(loss, differentiable_params, retain_graph=True, create_graph=True,
                                only_inputs=True)

    return grads


def get_inv_hvp_train(args, model, data_loader, v, damping=0.1, scale=200, rounds=1):
    estimate = None
    for r in range(rounds):
        u = [torch.zeros_like(v_i) for v_i in v]
        for i, (images, labels, _) in enumerate(data_loader):
            images, labels = images.to(args.device), labels.to(args.device)
            batch_hvp = hvp_train(model, images, labels, v)

            new_estimate = [a + (1 - damping) * b - c / scale for (a, b, c) in zip(v, u, batch_hvp)]

        res_upscaled = [r / scale for r in new_estimate]
        if estimate is None:
            estimate = [r / rounds for r in res_upscaled]
        else:
            for j in range(len(estimate)):
                estimate[j] += res_upscaled[j] / rounds

    return estimate


def hvp(model, x, y, v):
    """ Hessian vector product. """
    grad_L = get_gradients(model, x, y)
    # v_dot_L = [v_i * grad_i for v_i, grad_i in zip(v, grad_L)]
    #differentiable_params = [p for p in model.parameters() if p.requires_grad]
    v_dot_L = torch.sum(torch.stack([torch.sum(grad_i * v_i) for grad_i, v_i in zip(grad_L, v)]))

    hvp = list(torch.autograd.grad(v_dot_L, model.parameters.values(), retain_graph=True))
    return hvp


def get_gradients(model, x, y):
    """ Calculate dL/dW (x, y) """
    loss_func = nn.CrossEntropyLoss()
    result = model(x)
    loss = loss_func(result, y)

    grads = torch.autograd.grad(loss, model.parameters.values(), retain_graph=True, create_graph=True,
                                only_inputs=True)

    return grads


def get_inv_hvp(args, model, data_loader, v, damping=0.1, scale=200, rounds=1):
    #print(f'damping={damping}, scale={scale}')
    estimate = None
    for r in range(rounds):
        new_estimate = [torch.zeros_like(v_i) for v_i in v]
        for i, (images, labels, _) in enumerate(data_loader):
            images, labels = images.to(args.device), labels.to(args.device)
            batch_hvp = hvp(model, images, labels, v)

            new_estimate = [a + (1 - damping) * b - c / scale for (a, b, c) in zip(v, new_estimate, batch_hvp)]

        res_upscaled = [r / scale for r in new_estimate]
        if estimate is None:
            estimate = [r / rounds for r in res_upscaled]
        else:
            for j in range(len(estimate)):
                estimate[j] += res_upscaled[j] / rounds

    return estimate


def read_results(args):
    path = os.path.join(args.outdir, args.craftproj + '.pt')
    res = torch.load(path, map_location=args.device)
    return res['targetclass'], res['poisonclass'], res['targetids'], res['poison_weight']


def save_results(args, poison_weight):
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    path = os.path.join(args.outdir, args.craftproj + '.pt')
    res = {'targetclass': args.targetclass, 'poisonclass': args.poisonclass, 'targetids':args.targetids, 'poison_weight': poison_weight}
    torch.save(res, path)


def export_results(args, trainset, testset, poisonids, targetids, poison_delta):
    sub_path = os.path.join(args.att_path, args.craftproj)
    os.makedirs(sub_path, exist_ok=True)

    ds = torch.tensor(trainset.data_std)[None, :, None, None]
    dm = torch.tensor(trainset.data_mean)[None, :, None, None]

    def _torch_to_PIL(image_tensor):
        """Torch->PIL pipeline as in torchvision.utils.save_image."""
        image_denormalized = torch.clamp(image_tensor * ds + dm, 0, 1).squeeze()
        image_torch_uint8 = image_denormalized.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8)
        image_PIL = PIL.Image.fromarray(image_torch_uint8.numpy())
        return image_PIL

    # Poisons
    benchmark_poisons = []
    for i in range(len(poisonids)):
        input, label, _ = trainset[poisonids[i]]
        input += poison_delta[i].detach().cpu()
        benchmark_poisons.append((_torch_to_PIL(input), int(label)))

    with open(os.path.join(sub_path, 'poisons.pickle'), 'wb+') as file:
        pickle.dump(benchmark_poisons, file, protocol=pickle.HIGHEST_PROTOCOL)

    # Target
    target, target_label, _ = testset[targetids[0]]
    with open(os.path.join(sub_path, 'target.pickle'), 'wb+') as file:
        pickle.dump((_torch_to_PIL(target), target_label), file, protocol=pickle.HIGHEST_PROTOCOL)

    # Indices
    with open(os.path.join(sub_path, 'base_indices.pickle'), 'wb+') as file:
        pickle.dump(poisonids, file, protocol=pickle.HIGHEST_PROTOCOL)


def test_one_image(args, image, model):
    model.eval()
    with torch.no_grad():
        image = image.unsqueeze(0).to(args.device)
        output = model(image)
        _, predicted = torch.max(output.data, 1)
        # probability
        prob = F.softmax(output, dim=1)
        top_prob, top_class = prob.topk(1, dim=1)

    return predicted.detach().cpu().numpy(), top_prob.mean().item()


def set_random_seed(seed=11):
    torch.manual_seed(seed + 1)
    torch.cuda.manual_seed(seed + 2)
    torch.cuda.manual_seed_all(seed + 3)
    np.random.seed(seed + 4)
    torch.cuda.manual_seed_all(seed + 5)
    random.seed(seed + 6)


# craftrate decay
def	decay_lrate(args, iteration):
    scheduler = [args.ncraftstep * 1/3, args.ncraftstep * 2/3]
    if iteration < scheduler[0]:
        craftrate = args.craftrate
    elif iteration < scheduler[1]:
        craftrate = args.craftrate * 0.1
    else:
        craftrate = args.craftrate * 0.01
    return craftrate