import torch
import copy
import numpy as np
from itertools import product
from scipy.cluster.hierarchy import fcluster
from copy import deepcopy
from torch import nn
import torch.optim as optim

def Accuracy(y,y_predict):
    leng = len(y)
    miss = 0
    for i in range(leng):
        if not y[i]==y_predict[i]:
            miss +=1
    return (leng-miss)/leng


def average_weights(w):
    """
    average the weights from all local models
    """
    w_avg = copy.deepcopy(w[0])
    for key in w_avg.keys():
        for i in range(1, len(w)):
            w_avg[key] += w[i][key]
        w_avg[key] = torch.div(w_avg[key], len(w))
    return w_avg


class LocalUpdate(object):
    """
    This class is for train the local model with input global model(copied) and output the updated weight
    args: argument 
    Loader_train,Loader_val,Loaders_test: input for training and inference
    user: the index of local model
    idx: the index for data of this local model
    logger: log the loss and the process
    """
    def __init__(self, args, model,Loader_train,idx, device):
        self.args = args
        self.trainloader = Loader_train
        self.idx = idx
        self.ce = nn.CrossEntropyLoss() 
        self.device = device
        self.model  = copy.deepcopy(model)
        self.optimizer = optim.Adam(self.model.parameters(),lr=self.args.lr)
        self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, step_size=10, gamma=0.5)
        
    def update_weights_prox(self,global_round, mu):
        self.model.cuda()
        self.model.train()
        global_model = copy.deepcopy(self.model)
        global_model.eval()
        global_weight_collector = list(global_model.parameters())
        optimizer = self.optimizer
        scheduler = self.scheduler
        for iter in range(self.args.local_ep):
            for batch_idx, (X, y) in enumerate(self.trainloader):
                X = X.to(self.device)
                y = y.to(self.device)
                optimizer.zero_grad()
                p = self.model(X).double()
                loss1 = self.ce(p,y) 
                fed_prox_reg = 0.0
                for param_index, param in enumerate(self.model.parameters()):
                    fed_prox_reg += ((mu / 2) * torch.norm((param - global_weight_collector[param_index])) ** 2)
                loss = loss1 + fed_prox_reg
                loss.backward()
                optimizer.step()


        return self.model.state_dict()
    
    
    def update_weights_avg(self,global_round):
        self.model.cuda()
        self.model.train()
        optimizer = self.optimizer
        scheduler = self.scheduler
        for iter in range(self.args.local_ep):
            for batch_idx, (X, y) in enumerate(self.trainloader):
                X = X.to(self.device)
                y = y.to(self.device)
                optimizer.zero_grad()
                p = self.model(X).double()
                loss = self.ce(p,y)             
                loss.backward()
                optimizer.step()

        return self.model.state_dict()
    
    
    
    def train_loss(self):
        self.model.eval()
        batch_loss = []
        for batch_idx, (X, y) in enumerate(self.trainloader):
            X = X.to(self.device)
            y = y.to(self.device)
            p = self.model(X).double()
            loss = self.ce(p,y)  
            batch_loss.append(loss.item())
            
        return np.sum(np.array(batch_loss))/len(batch_loss)
    

    def load_model(self,global_weights):
        self.model.load_state_dict(global_weights)
        
def get_gradients(sampling, global_m, local_models):
    """return the `representative gradient` formed by the difference between
    the local work and the sent global model"""

    local_model_params = []
    for model in local_models:
        local_model_params += [
            [tens.detach().cpu().numpy() for tens in list(model.parameters())]
        ]

    global_model_params = [
        tens.detach().cpu().numpy() for tens in list(global_m.parameters())
    ]

    local_model_grads = []
    for local_params in local_model_params:
        local_model_grads += [
            [
                local_weights - global_weights
                for local_weights, global_weights in zip(
                    local_params, global_model_params
                )
            ]
        ]

    return local_model_grads


def get_gradients_fc(sampling, global_m, local_models):
    """return the `representative gradient` formed by the difference between
    the local work and the sent global model"""

    local_model_params = []
    for model in local_models:
        local_model_params +=  [
           [tens.detach().cpu().numpy() for tens in list(model.parameters())[-2:]]
        ]
            
    global_model_params = [
        tens.detach().cpu().numpy() for tens in list(global_m.parameters())[-2:]
    ]
    
    
    local_model_grads = []
    for local_params in local_model_params:
        local_model_grads += [
            [
                local_weights - global_weights
                for local_weights, global_weights in zip(
                    local_params, global_model_params
                )
            ]
        ]
    return local_model_grads




