# -*- coding: utf-8 -*-
import copy
import ray
import torch
import numpy as np
from torch import optim, nn
from optimizer.mr_sgd import MR_SGD

from model.regularization import Regularization

# 张量 power 运算
def pow_tensor(tensor, p):
    singp = torch.sign(tensor)
    temp = (tensor.abs()).pow(p)
    temp.mul_(singp)
    return temp

# 模型 power 计算
def model_power(model, name, num):
    # print('P power ===========================')
    for idx, tensor in enumerate(model):
        if 'conv' in name[idx]: 
            # print(f'{name[idx]} layer, no P power')
            pass
        else:
            # print(f'{name[idx]} layer, P power')
            model[idx] = pow_tensor(tensor, num)
    return model

# 生成指定结构的空模型
def generate_zero_model(model):
    tmp_model = []
    for tensor in model:
        tmp_tensor = torch.zeros_like(tensor)
        tmp_model.append(tmp_tensor)
    return tmp_model

# 自定义模型计算
# 模型乘
def model_mul(model, weight):
    for idx, tensor in enumerate(model):
        model[idx] = tensor.mul(weight)
    return model
# 模型加
def model_add(model1, model2):
    # added_param = []
    for idx, tensor in enumerate(model2):
        model1[idx].add_(tensor.data)
        # added_param.append(param1[idx])
# 模型减
def model_sub(model1, model2):
    result_param = []
    for idx, tensor in enumerate(model2):
        model1[idx].sub_(tensor.data)
        result_param.append(model1[idx])
    return result_param
# 模型平均（层平均）
def model_mean(model, num):
    for idx, tensor in enumerate(model):
        model[idx] = tensor.mul(1/num)
# 模型范数
def model_norm(model):
    flattened_params = torch.cat([tensor.flatten() for tensor in model])
    l2_norm = torch.norm(flattened_params, p=2)
    return float(l2_norm)

# 模型平均（全拓扑平均）
def model_avg(all_models_list, weight_matrix):
    
    avged_models_list = []
    # aggregate model of each worker
    for i in range(len(all_models_list)):
        # a temp zero parameter as aggregated model
        avged_model = generate_zero_model(all_models_list[i])
        connected_num = np.count_nonzero(weight_matrix[i][:])

        for j in range(len(all_models_list)):
            weight = 0 if weight_matrix[i][j] == 0 else 1
            weighted_model = model_mul(all_models_list[j].copy(), weight)
            model_add(avged_model, weighted_model)
            
        avged_model = model_mul(avged_model, connected_num)
        avged_models_list.append(avged_model)
        
    return avged_models_list
    

@ray.remote(num_cpus=0.2)
class Client(object):
    def __init__(self, client_index, args, model, train_loader, test_loader):

        self.p_list = [15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]
        self.args = args # 超参
        self.client_index = client_index # GPU index
        self.device = args.device # 训练使用的设备名字
        
        self.model = copy.deepcopy(model).to(self.device) # 模型
        self.optimizer = optim.SGD(self.model.parameters(), lr=self.args.lr, 
                                   momentum=0.9, weight_decay=5e-4) # 本地训练的优化器
        self.mr_optimizer = MR_SGD(self.model.parameters(), lr=self.args.lr) # Gradient Step 的优化器
        self.train_loader = train_loader # 训练数据集
        self.test_loader = test_loader # 测试数据集
        self.layer_name = list(self.model.get_weights().keys()) # 模型每一层的名字
        
        if self.args.iter_method == 'iteration': # Iterator for 'iteration' method
            self.data_iteration = iter(self.train_loader) 
            
        self.criterion = nn.CrossEntropyLoss() # 交叉熵误差
        
    # 去中心化训练
    def train(self):
        
        # set to train mode
        self.model.train()
        
        # epoch method: train all mini-batch each round
        if self.args.iter_method == 'epoch':
            for batch_idx, (data, target) in enumerate(self.train_loader):
                # obtain data
                data, target = data.to(self.device), target.to(self.device)
                               
                # train
                self.optimizer.zero_grad()
                output = self.model(data)
                loss = self.criterion(output, target)
                loss.backward()
                self.optimizer.step()
                
        # iteration mode: train 1 mini-batch each round
        elif self.args.iter_method == 'iteration':
            # obtain data
            try:
                inputs, targets = next(self.data_iteration)
            except StopIteration:
                self.data_iteration = iter(self.train_loader)
                inputs, targets = next(self.data_iteration) 
                
            inputs, targets = inputs.to(self.device), targets.to(self.device)
            
            # train
            self.optimizer.zero_grad()
            outputs = self.model(inputs)
            loss = self.criterion(outputs, targets)
            
            # 正则化项
            if self.args.model == 'Lasso':
                # L1 Regularization
                loss += Regularization(self.model, 0.001, p=1)(self.model)
            elif self.args.model == 'RR':
                # L2 Regularization
                loss += Regularization(self.model, 0.001, p=0)(self.model)
            
            loss.backward()
            self.optimizer.step()
            
            # _, predicted = outputs.max(1)

        # return model and gradient
        return self.model.get_model(), self.model.get_gradients(), loss
    

    # 将各个节点的模型权重设置为聚合后的权重
    def set_model(self, model):
        self.model.set_model(model)

    # 进行模型训练
    def gradients_step(self, gradients, scale_factor):
        # 新方案：SGD替换为MR_SGD，优化器由worker类初始化，无需每次单独生成
        self.mr_optimizer.zero_grad()
        self.model.set_gradients(gradients)
        
        # gradient step
        self.mr_optimizer.step(scale_factor, self.layer_name)
        
        return self.model.get_model()

    def test(self):
        self.model.eval()
        test_loss = 0
        correct = 0
        total = 0
        criterion = nn.CrossEntropyLoss()

        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(self.test_loader):
                outputs = self.model(inputs)

                loss = criterion(outputs, targets)
                test_loss += loss
                _, predicted = outputs.max(1)
                inner_total = targets.size(0)
                inner_correct = predicted.eq(targets).sum().item()
                total += inner_total
                correct += inner_correct

        test_acc = format(correct / total * 100, '.4f')
        test_loss = format(test_loss / batch_idx + 1, '.4f')

        return float(test_acc), float(test_loss)


def train_wpm(clients):
    # 保存所有模型
    all_models_list = []
    all_gradients_list = []
    
    process_ids = [] # 存放进程 id

    # 遍历并训练所有 client
    for i in range(len(clients)):
        weight_id = clients[i].train.remote()
        process_ids.append(weight_id)

    ray.wait(process_ids, num_returns=len(process_ids))

    for object_id in process_ids:
        model, gradient, loss = ray.get(object_id)
        all_models_list.append(model)
        all_gradients_list.append(gradient)

    return all_models_list, all_gradients_list

def agg_wpm(layer_name, all_models_list, weight_matrix, P):
    
    agged_models_list = []
    # p power for model paramrters
    for idx, model in enumerate(all_models_list):
        all_models_list[idx] = model_power(model, layer_name, P)
        
    # aggregate model of each worker
    for i in range(len(all_models_list)):
        # a temp zero parameter as aggregated model
        agged_model = generate_zero_model(all_models_list[i])

        for j in range(len(all_models_list)):
            # model parameter multiplie weight
            weighted_model = model_mul(all_models_list[j].copy(),
                                       weight_matrix[i][j])
            # sum models of all workers
            model_add(agged_model, weighted_model)
            
        agged_models_list.append(agged_model)
        
    return agged_models_list

def grad_step_wpm(layer_name, clients, agg_models_list, all_gradients_list, P):
    process_ids = []
    
    # second gradient descent with perious gradients 
    for index in range(len(clients)):
        clients[index].set_model.remote(agg_models_list[index])
        
        scale_factor = pow(10, -P) if P != 1 else 1
        model_id = clients[index].gradients_step.remote(all_gradients_list[index], scale_factor)
        process_ids.append(model_id)
        
    ray.wait(process_ids, num_returns=len(process_ids))
    
    all_models_list = []
    # 1/p power for model paramrters
    for object_id in process_ids:
        final_model = ray.get(object_id)
        final_model = model_power(final_model.copy(), layer_name, 1 / P)

        all_models_list.append(final_model)

    # set weights
    for index in range(len(clients)):
        clients[index].set_model.remote(all_models_list[index])
        
    return all_models_list

def test_wpm(clients, weight_matrix):
    total_acc = 0
    total_loss = 0
        
    process_ids = [] # 存放进程 id
    for index in range(len(clients)):
        object_id = clients[index].test.remote()
        process_ids.append(object_id)
        
    ray.wait(process_ids, num_returns=len(process_ids))
    idx = 0
    for object_id in process_ids:
        idx += 1
        acc, loss = ray.get(object_id)
        print(f'== Client [{idx}]: acc={acc}, loss={loss}, connected num={np.count_nonzero(weight_matrix[idx-1][:])}')
        total_acc += acc
        total_loss += loss
        
    avg_acc = format(total_acc/len(clients),'.4f')
    avg_loss = format(total_loss/len(clients),'.4f')

    return avg_acc, avg_loss


def consensus_distance(all_models_list, weight_matrix):
    
    avg_models = model_avg(all_models_list, weight_matrix)
    
    dis_list = []
    
    for i, avg_model in enumerate(avg_models):
        dis = 0
        for j, model in enumerate(all_models_list):
            if weight_matrix[i][j] != 0:
                dis += model_norm(model_sub(avg_model, model))
            else: pass
        dis /= np.count_nonzero(weight_matrix[i][:])
        dis_list.append(round(dis, 4))
        
    return dis_list
            
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        