import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import grad
import copy

def defense(args, rep, inputs, net, mask=None):
    device = rep.device
    if args.defense_method == "laplacian_noise":
        # add laplacian noise with scale args.laplacian_scale
        rep_send = rep+ (
            torch.distributions.laplace.Laplace(0, args.laplacian_noise)
            .sample(rep.shape)
            .to(device)
        )
    if args.defense_method == "compress":
        thresh = torch.quantile(abs(rep), args.compress)
        rep_send = (rep>thresh)*rep

    if args.defense_method == "soteria":
        raw_shape = rep.shape
        rep = rep.view(rep.shape[0], -1)
        # deviation_rep_target = torch.zeros_like(rep)
        # deviation_rep_x_norm = torch.zeros_like(rep)
        # for f in range(deviation_rep_x_norm.size(1)):
        #     deviation_rep_target[:,f] = 1
        #     rep.backward(deviation_rep_target, retain_graph=True)
        #     deviation_f1_x = inputs.grad.data
        #     deviation_rep_x_norm[:,f] = torch.norm(deviation_f1_x.view(deviation_f1_x.size(0), -1), dim=1)/(rep.data[:,f]+1e-4)
        #     net.zero_grad()
        #     inputs.grad.data.zero_()
        #     deviation_rep_target[:,f] = 0



        # # prune r_i corresponding to smallest ||dr_i/dX||/||r_i||
        # thresh = torch.quantile(abs(deviation_rep_x_norm), dim=1, keepdim=True, q=args.compress)
        # mask = abs(deviation_rep_x_norm)>thresh
        rep_send = rep*mask
        rep_send = rep_send.view(raw_shape)
        inputs.requires_grad = False
    if args.defense_method == "no":
        rep_send = rep
    return rep_send

def gradient_defense(args, feature):
    grad_w_defense = feature.grad.clone()
    device = feature.device
    if args.defense_method == "laplacian_noise":
        # print("added lap noise")
        # exit()
        # add laplacian noise with scale args.laplacian_scale
        grad_w_defense += (
            torch.distributions.laplace.Laplace(0, args.laplacian_noise)
            .sample(feature.shape)
            .to(device)
        )
        
    if args.defense_method == "compress":
        #send only proportion of gradients with largest absolute values (keep compression rate %)
        grad_w_defense = grad_w_defense.flatten()
        # sort gradients by absolute values to find threshold value
        sorted_absol, _ = torch.sort(torch.abs(grad_w_defense), descending=True)
        threshold = sorted_absol[int(args.compress * len(sorted_absol))]
        # set gradients with absolute values smaller than threshold to zero
        grad_w_defense[torch.abs(grad_w_defense) <= threshold] = 0
        grad_w_defense = grad_w_defense.reshape(feature.shape)
    
    if args.defense_method == "ppdl":
        """
        In each iteration, the server does the following steps to protect the gradients: 
        (1) randomly selects one gradient value, generates noise, and adds the noise to the gradient value; 
        (2) if the gradient value after adding noise is larger than a threshold value τ, keeps it, otherwise sets it to zero; 
        (3) loops the ﬁrst two steps until θu fraction of gradient values are gathered (ratio of θ over the total number of parameters as the parameter selection rate). 
        Both θu and τ are hyperparameters to balance the trade-off between model performance and defense performance.
        """
        grad_w_defense = grad_w_defense.flatten()
        # add noise
        grad_w_defense += (
            torch.distributions.laplace.Laplace(0, args.laplacian_noise)
            .sample(grad_w_defense.shape)
            .to(device)
        )
        # randomly zero out 1-theta % of elements in the grad_w_defense 
        inds = np.random.choice(grad_w_defense.shape[0], int(grad_w_defense.shape[0]*(1-args.ppdl_theta)), replace=False)
        grad_w_defense[inds] = 0
        # further zero out the elements if abs val is less than tau with noise
        grad_w_defense = torch.where(torch.abs(grad_w_defense) > args.ppdl_tau, grad_w_defense, torch.zeros_like(grad_w_defense))

        grad_w_defense = grad_w_defense.reshape(feature.shape)

    if args.defense_method == "discrete_sgd_wnoise":
        grad_w_defense = grad_w_defense.flatten()
        # get the mean and the standard deviation of the distribution
        mu = torch.mean(grad_w_defense)
        sigma = torch.std(grad_w_defense)
        # set an interval as [µ−2σ,µ+2σ]
        interval = [mu-2*sigma, mu+2*sigma]
        # slice the interval into N sub intervals (= N+1 interval endpoints)
        endpoints = torch.linspace(interval[0], interval[1], args.n_discretesgd+1)
        # create mask to keep only values inside the interval
        mask_keep = (grad_w_defense > interval[0]) & (grad_w_defense < interval[1])
        # NOTE: previous loop method WAY too slow
        # create matrix using repeated endpoints of length of gradient.shape[0] to allow broadcasting
        endpoints_mat = (
            endpoints.repeat(len(grad_w_defense))
            .reshape(len(grad_w_defense), len(endpoints))
            .to(device)
        )
        # find the closest subinterval endpoint to each gradient value
        inds = (
            torch.argmin(
                torch.abs(grad_w_defense.reshape(-1, 1) - endpoints_mat),
                dim=1,
            )
            .reshape(-1, 1)
            .to(device)
        )
        grad_w_defense = torch.gather(endpoints_mat, 1, inds).squeeze()
        # keep only the values inside the interval
        grad_w_defense = torch.mul(grad_w_defense, mask_keep).reshape(feature.shape) 

        # add laplacian noise with scale args.laplacian_scale
        grad_w_defense += (
            torch.distributions.laplace.Laplace(0, args.laplacian_noise)
            .sample(grad_w_defense.shape)
            .to(device)
        )


    if args.defense_method == "discrete_sgd":
        grad_w_defense = grad_w_defense.flatten()
        # get the mean and the standard deviation of the distribution
        mu = torch.mean(grad_w_defense)
        sigma = torch.std(grad_w_defense)
        # set an interval as [µ−2σ,µ+2σ]
        interval = [mu-2*sigma, mu+2*sigma]
        # slice the interval into N sub intervals (= N+1 interval endpoints)
        endpoints = torch.linspace(interval[0], interval[1], args.n_discretesgd+1)
        # create mask to keep only values inside the interval
        mask_keep = (grad_w_defense > interval[0]) & (grad_w_defense < interval[1])
        # NOTE: previous loop method WAY too slow
        # create matrix using repeated endpoints of length of gradient.shape[0] to allow broadcasting
        endpoints_mat = (
            endpoints.repeat(len(grad_w_defense))
            .reshape(len(grad_w_defense), len(endpoints))
            .to(device)
        )
        # find the closest subinterval endpoint to each gradient value
        inds = (
            torch.argmin(
                torch.abs(grad_w_defense.reshape(-1, 1) - endpoints_mat),
                dim=1,
            )
            .reshape(-1, 1)
            .to(device)
        )
        grad_w_defense = torch.gather(endpoints_mat, 1, inds).squeeze()
        # keep only the values inside the interval
        grad_w_defense = torch.mul(grad_w_defense, mask_keep).reshape(feature.shape)  