import logging
import torch
import time


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.reshape(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


class SynchronousValidation(object):

    def __init__(self, worker_num, manager):
        self.worker_num = worker_num
        self._total_eval_auc_result = manager.dict()
        self._merged_auc_result = {}
        self._total_eval_accu_result = manager.dict()
        self._merged_accu_result = {}
        self._done_evaluation = manager.Value('done_evaluation', False)
        self._ready_training_workers = manager.list()

    # Not tested
    def collect_auc_result(self, pid, output, target):
        logging.info(f"Collecting evaluation result from workers {pid}")
        self._total_eval_auc_result[pid] = {"output": output, "target:": target}

    # Not tested
    def _merge_auc_result(self):
        for pid, value in self._total_eval_auc_result.items():
            for key, inner_value in value.items():
                self._merged_auc_result[key] = torch.cat((self._merged_auc_result[key], inner_value), 0)

    def collect_accu_result(self, pid, top1, top5):
        logging.info(f"Collecting evaluation result from workers {pid}")
        self._total_eval_accu_result[pid] = {"top1": top1, "top5": top5}
        logging.info(f"ready worker: {self._total_eval_accu_result.keys()}")

    def _merge_accu_result(self):
        if self._wait_collecting_evaluation_result_ready():
            logging.info("merge evaluation results from other workers...")
            for pid, value in self._total_eval_accu_result.items():
                logging.info(f"results before merged in pid {pid} is: "
                             f"count = {value['top1'].count}, sum = {value['top1'].sum}")
            for pid, value in self._total_eval_accu_result.items():
                for key, inner_value in value.items():
                    if key not in self._merged_accu_result.keys():
                        self._merged_accu_result[key] = inner_value
                    else:
                        self._merged_accu_result[key].sum += inner_value.sum
                        self._merged_accu_result[key].count += inner_value.count

    def cal_accu_metrics(self):
        self._merge_accu_result()
        top1 = self._merged_accu_result["top1"]
        top5 = self._merged_accu_result["top5"]
        logging.info(f"results have been merged. count is {top1.count}, sum is {top1.sum}")
        top1_accu = top1.sum / top1.count
        top5_accu = top5.sum / top5.count
        return top1_accu, top5_accu

    def _wait_collecting_evaluation_result_ready(self):
        ready_worker_num = 0
        while ready_worker_num < self.worker_num:
            ready_worker_indexes = list(self._total_eval_accu_result.keys())
            ready_worker_num = len(ready_worker_indexes)
            logging.info(
                "Waiting other workers' result[%s/%s], ready worker indexes are: %s",
                ready_worker_num,
                self.worker_num,
                ready_worker_indexes,
            )
            time.sleep(0.5)
        return True

    def clear(self):
        self._total_eval_auc_result.clear()
        self._merged_auc_result = {}
        self._total_eval_accu_result.clear()
        self._merged_accu_result = {}
        self._done_evaluation.value = True

    def done_evaluation(self, pid):
        while not self._done_evaluation.value:
            logging.info("waiting for the chief worker to finish evaluation...")
            time.sleep(0.5)
        self._ready_training_workers.append(pid)
        if pid == 0:
            while len(self._ready_training_workers) < self.worker_num:
                logging.info("ready training workers are: %s", self._ready_training_workers)
                time.sleep(0.5)
            logging.info("All workers are ready to train")
            self._done_evaluation.value = False
            self._ready_training_workers[:] = []
