import random
import numpy as np
from math import log
from collections import Counter

import torch
import torch.multiprocessing as mp
import torch.distributed as dist
import torch.nn.parallel

from utils.config import parse_args
from utils.data_loader import get_data_loader

from models.wgan_gradient_penalty import WGAN_GP
from gragh import My_Graph

def main_worker(rank, args):
    # code reproducibility
    random.seed(8)
    np.random.seed(8)
    torch.manual_seed(8)
    torch.cuda.manual_seed_all(8)

    dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.nodes, rank=rank)

    model = WGAN_GP(args, rank)

    # Load datasets to train and test loaders
    train_loader, test_loader = get_data_loader(args, rank)

    # labels_for_this_process = []
    # for _, labels in train_loader:
    #     labels_for_this_process.extend(labels.tolist())

    # element_counts = Counter(labels_for_this_process)

    # print(f"Length of Rank {dist.get_rank()} labels: \n{len(labels_for_this_process)}")
    # for element, count in element_counts.items():
    #     print(f"{dist.get_rank()} label {element}: {count}")

    # Build communication topology
    if args.topo == '1' or args.topo == 'sep' or args.nodes == 1:
        matrix = np.eye(args.nodes)
    elif args.topo == '2' or args.topo == 'full':
        matrix = np.ones((args.nodes, args.nodes))/args.nodes
    elif args.topo == '3' or args.topo == 'exp':
        peer = int(log(args.nodes-1, 2) + 1)
        matrix = np.eye(args.nodes)/(peer+1)
        for i in range(args.nodes):
            for j in range(peer):
                matrix[(i+2**j)%args.nodes, i] = 1/(peer+1)
    elif args.topo == '4' or args.topo == 'ring':
        matrix = np.eye(args.nodes)/3
        for i in range(args.nodes):
            matrix[i, (i-1+args.nodes) % args.nodes] = 1/3
            matrix[i, (i+1) % args.nodes] = 1/3
    elif args.topo == '5' or args.topo == 'dense':
        peer = args.nodes // 2
        matrix = np.eye(args.nodes) / (peer + 1)
        for i in range(args.nodes):
            matrix[i, (i + 1) % args.nodes] = 1 / (peer + 1)
            for j in range(peer):
                matrix[(i + 2 * j) % args.nodes, i] = 1 / (peer + 1)

    Weight_matrix = torch.from_numpy(matrix)
    graph = My_Graph(rank=rank, world_size=dist.get_world_size(), weight_matrix=matrix)
    out_edges, in_edges = graph.get_edges()
    if rank == 0:
        print("communication topo matrix :\n", matrix)

    comm_param = { 
        'Weight_matrix' : Weight_matrix,
        'out_edges' : out_edges,
        'in_edges' : in_edges
        }

    # Start model training
    model.train(train_loader, comm_param)


if __name__ == '__main__':
    args = parse_args()
    print(args.cuda)

    mp.spawn(main_worker, nprocs=args.nodes, args=(args,))
