import einops
import torch
from icecream import ic
from tqdm import tqdm
import torch

from testing import get_predicted

def measure(net, loaders, mode, iters, problem, device):
    accs = []
    metric_vals = []
    for loader in loaders:
        if mode == 'measure_fgsm':
            accuracy, metric = measure_fgsm_adversarial(net, loader, iters, problem, device)

        accs.append(accuracy)
        metric_vals.append(metric)
    return accs, metric_vals

def measure_fgsm_adversarial(net, dataloader, iters, problem, device, epsilon=0.01, total_steps=100):
    max_iters = max(iters)
    corrects = torch.zeros(max_iters)
    total = 0
    net.eval()
    cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)

    for inputs, targets in tqdm(dataloader, leave=False):
        inputs, targets = inputs.to(device), targets.to(device)

        with torch.no_grad():
            init_outputs, fp_val1 = net(inputs, train_step=0, iters_to_do=max_iters, return_fp=True)
        
            predicted = get_predicted(inputs, init_outputs, problem)
            targets = targets.view(targets.size(0), -1)
            corrects += torch.amin(predicted == targets, dim=[1]).sum().item()

            total += targets.size(0)
            orig_accuracy = 100.0 * corrects / total
            print("Initial Acc", orig_accuracy.max())

        # FGSM attack
        new_fp = fp_val1.requires_grad_()
        for step in range(total_steps):
            next_outputs, fp_val2 = net(inputs, interim_thought=new_fp, train_step=0, iters_to_do=max_iters, return_fp=True)
            
            # minimizing cosine -> making more dissimilar
            loss = cos(fp_val1.view(inputs.shape[0], -1), fp_val2.view(inputs.shape[0], -1)).mean()
            net.zero_grad()
            loss.backward()

            print(f"Step {step} Loss {loss.item()}")
            # this is similar to fgsm
            sign_data_grad = new_fp.grad.data.sign()
            fp_val2 += epsilon * sign_data_grad
            new_fp = fp_val2.detach().clone().requires_grad_()

        predicted = get_predicted(inputs, next_outputs, problem)
        targets = targets.view(targets.size(0), -1)
        corrects += torch.amin(predicted == targets, dim=[1]).sum().item()

        total += targets.size(0)
        accuracy = 100.0 * corrects / total
        print(accuracy.max())

    accuracy = 100.0 * corrects / total
    return accuracy[-1]

def measure_adversarial_gd_asc(net, dataloader, iters, problem, device, epsilon=0.01, total_steps=100):
    max_iters = max(iters)
    corrects = torch.zeros(max_iters)
    total = 0
    net.eval()
    cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)

    for inputs, targets in tqdm(dataloader, leave=False):
        inputs, targets = inputs.to(device), targets.to(device)

        with torch.no_grad():
            init_outputs, fp_val1 = net(inputs, train_step=0, iters_to_do=max_iters, return_fp=True)
        
            predicted = get_predicted(inputs, init_outputs, problem)
            targets = targets.view(targets.size(0), -1)
            corrects += torch.amin(predicted == targets, dim=[1]).sum().item()

            total += targets.size(0)
            orig_accuracy = 100.0 * corrects / total
            print("Initial Acc", orig_accuracy.max())

        # FGSM attack
        new_fp = fp_val1.requires_grad_()
        for step in range(total_steps):
            next_outputs, fp_val2 = net(inputs, interim_thought=new_fp, train_step=0, iters_to_do=max_iters, return_fp=True)
            
            # minimizing cosine -> making more dissimilar
            loss = cos(fp_val1.view(inputs.shape[0], -1), fp_val2.view(inputs.shape[0], -1)).mean()
            net.zero_grad()
            loss.backward()
            fp_val2 -= epsilon * new_fp.grad.data

            print(f"[GD based] Step {step} Loss {loss.item()}")

            new_fp = fp_val2.detach().clone().requires_grad_()

        predicted = get_predicted(inputs, next_outputs, problem)
        targets = targets.view(targets.size(0), -1)
        corrects += torch.amin(predicted == targets, dim=[1]).sum().item()

        total += targets.size(0)
        accuracy = 100.0 * corrects / total
        print(accuracy.max())

    accuracy = 100.0 * corrects / total
    return accuracy[-1]