import copy
# from cv2 import log
import numpy as np

import torch

from core.function import gather_flat_hyper_params
from utils.Fed import FedAvg, FedAvgGradient, FedAvgP
from core.SGDClient_hr import SGDClient
from core.SVRGClient_hr import SVRGClient
from core.Client_hr import Client
from core.ClientManage import ClientManage


class ClientManageHR(ClientManage):
    def __init__(self, args, net_glob, client_idx, dataset, dict_users, hyper_param, param, v=None) -> None:
        super().__init__(args, net_glob, client_idx, dataset, dict_users)

        self.client_idx = client_idx
        self.args = args
        self.dataset = dataset
        self.dict_users = dict_users

        self.param = [x.clone().detach() for x in param]
        self.v = [x.clone().detach() for x in v]
        self.hyper_param = [x.clone().detach() for x in hyper_param]

        self.param_old = [x.clone().detach() for x in param]
        self.v_old = [x.clone().detach() for x in v]
        self.hyper_param_old = [x.clone().detach() for x in hyper_param]


    def client_job(self, eta):
        h_y = []
        h_v = []
        h_x = []
        client_locals = []
        # k=0, 1, 2,...,tau
        for idx in self.client_idx:
            client = SVRGClient(self.args, idx, self.net_glob, self.dataset, self.dict_users, self.param,
                                self.hyper_param, self.v)
            # create client lists
            for k in range(self.args.inner_ep[idx]):
                # update y
                inner_grad = client.grad_d_in_d_y()
                count = 0

                for param in client.net.parameters():
                    if count < len(self.hyper_param):
                        count += 1
                        continue
                    else:
                        param.data = param.data - eta[0] * inner_grad[count - len(self.hyper_param)]
                        count += 1


                # update v
                v_grad = client.grad_v_R()
                for i in range(len(self.param)):
                    client.v[i] = client.v[i] - eta[1] * v_grad[i]

                # update x
                x_update = client.grad_f_bar()
                count = 0
                for param in client.net.parameters():
                    if count == len(self.hyper_param):
                        break
                    else:
                        param.data = param.data - eta[2] * x_update[count]
                    count += 1


            h_yi = []
            h_vi = []
            h_xi = []

            count = 0
            for param in client.net.parameters():
                if count < len(self.hyper_param):
                    h_xi.append((-param.data + self.hyper_param_old[count]) / (eta[2] * self.args.inner_ep[idx]))
                else:
                    h_yi.append((-param.data + self.param_old[count - len(self.hyper_param)]) / (eta[0] * self.args.inner_ep[idx]))
                    h_vi.append((-client.v[count - len(self.hyper_param)] + self.v_old[count - len(self.hyper_param)]) /
                                (eta[1] * self.args.inner_ep[idx]))
                count += 1

            client_locals.append(client)
            h_x.append(h_xi)
            h_y.append(h_yi)
            h_v.append(h_vi)
        return h_y, h_v, h_x


