# -*- coding: utf-8 -*-
import ray
import torch
from torch import nn
from utils.Optimizer_Mudag import mudag

@ray.remote(num_cpus=1)
class Device(object):
    def __init__(self, device_index, args, model, train_loader, data_size, L, lambda_min):
        self.args = args
        self.data_size = data_size
        self.device_index = device_index
        self.device = args.device
        self.model = model.to(self.device)

        tmp_para = model.get_param()
        x_k1 = []
        x_k2 = []
        y_k2 = []
        y_k1 = []
        y_k0 = []
        g_k0 = []
        g_k1 = []
        k_ = self.args.K
        for para in tmp_para:
            x_k1.append(para)
            y_k1.append(para)
            y_k0.append(para)
            x_k2.append(torch.zeros_like(para))
            g_k0.append(torch.zeros_like(para))
            g_k1.append(torch.zeros_like(para))
            y_k2.append(torch.zeros_like(para))
        self.optimizer = mudag(self.model.parameters(), args, L, lambda_min,
                               x_k1, x_k2, y_k0, y_k1, y_k2, g_k0, g_k1, k_)

        self.train_loader = train_loader
        self.data_iteration = iter(self.train_loader) # iter为迭代器，用next执行下一迭代
        self.criterion = nn.CrossEntropyLoss()

    # 去中心化训练
    def decentralized_train(self, now_device_index, epoch, current_weights, weight_matrix):
        # 设置当前设备模型
        self.model.set_weights(current_weights)
        # 设置训练模式
        self.model.train()
        try:
            inputs, targets = next(self.data_iteration)
        except StopIteration:
            self.data_iteration = iter(self.train_loader)
            inputs, targets = next(self.data_iteration)

        # 初始化梯度为 0
        self.optimizer.zero_grad()
        outputs = self.model(inputs)
        # 返回loss
        loss = self.criterion(outputs, targets)
        # 反向传播迭代
        loss.backward()
        self.optimizer.step(torch.tensor(weight_matrix.copy(), dtype=torch.float32))
        # 返回预测标签结果
        _, predicted = outputs.max(1)

        # 返回模型参数和梯度
        return self.model.get_weights(), self.model.get_gradients(), loss

    def test(self, current_weights, test_loader):
        # 设置当前设备模型
        self.model.set_weights(current_weights)
        self.model.eval()
        test_loss = 0
        correct = 0
        total = 0
        criterion = nn.CrossEntropyLoss()

        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(test_loader):
                outputs = self.model(inputs)

                loss = criterion(outputs, targets)
                test_loss += loss
                _, predicted = outputs.max(1)
                inner_total = targets.size(0)
                inner_correct = predicted.eq(targets).sum().item()
                total += inner_total
                correct += inner_correct

        test_acc = format(correct / total * 100, '.4f')
        test_loss = format(test_loss / batch_idx + 1, '.4f')

        return float(test_acc), float(test_loss)

    # 将各个节点的模型权重设置为聚合后的权重
    def decentralized_parallel_set_weights(self, weights):
        self.model.set_weights(weights)

def train_mudag(args, devices, current_epoch, model_parameters, weight_matrix):
    # 保存权重
    all_weights_list = []
    all_gradients_list = []
    num_round = int(args.world_size / args.num_dev)
    # 遍历边缘设备
    for round in range(num_round):
        # 各个节点本地迭代返回的模型参数的id
        device_local_update_weights_id = []

        # 遍历边缘设备
        for i in range(len(devices)):
            real_idx = round * args.num_dev + i
            weight_id = devices[i].decentralized_train.remote(real_idx, current_epoch, model_parameters[real_idx],
                                                              weight_matrix)
            device_local_update_weights_id.append(weight_id)
        ray.wait(device_local_update_weights_id, num_returns=len(device_local_update_weights_id))

        for object_id in device_local_update_weights_id:
            weights, gradients, loss = ray.get(object_id)
            all_weights_list.append(weights)
            all_gradients_list.append(gradients)

    return all_weights_list, all_gradients_list

def agg_mudag(args, model, all_weights_list):
    agg_weights_list = []
    # 对每个节点遍历
    for i in range(args.world_size):
        # W_i
        curr_weight = all_weights_list[i].copy()
        for key in model.get_weights():
            curr_weight[key] = torch.mul(curr_weight[key], 1 / args.world_size)
        for j in range(args.world_size):
            if i == j:
                continue
            else:
                # W_j
                tmp_weight = all_weights_list[j].copy()
                for key in model.get_weights():
                    curr_weight[key] = torch.add(curr_weight[key],
                                                 torch.mul(tmp_weight[key],
                                                           1 / args.world_size))

        agg_weights_list.append(curr_weight)

    return agg_weights_list