import numpy as np
import os
import pickle
import time
import torch
import torch.distributed as dist
from torch.utils.data.dataloader import default_collate
from torch.utils.data.sampler import Sampler, SubsetRandomSampler
from utils import print_or_log #, moveddp

try:
    from sinkornknopp import optimize_L_sk
except:
    from .sinkornknopp import optimize_L_sk


def round_down(num, divisor):
    return num - (num%divisor)


def collate_fn(batch):
    batch = [(d[0], d[1], d[2], d[3], d[4]) for d in batch if d is not None]
    if len(batch) == 0:
        return None
    else:
        return default_collate(batch)


class Subset_Sampler(Sampler):
    """
    Sample indices.
    """
    def __init__(self, indices):
        self.indices = indices

    def __iter__(self):
        return iter(self.indices)

    def __len__(self):
        return len(self.indices)
    

def get_cluster_assignments(args, dataset, model, logger=None, writer=None, group=None, iter_num=0):

    # clear cache at beginning
    torch.cuda.empty_cache()

    # dtype
    dtype = np.float64
    N =  len(dataset)
    # this process deals only with a subset of the dataset
    local_nmb_data = N // args.world_size
    train_indices = torch.arange(
        args.global_rank * local_nmb_data, 
        (args.global_rank + 1) * local_nmb_data
    ).int()

    # model = moveddp(args, model.module, broadcast_buffers=True)
    # create subset sampler
    sampler = SubsetRandomSampler(train_indices)

    # we need a data loader
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=96, # 32 but 96 might work as no backward passes.
        sampler=sampler,
        num_workers=args.workers,
        pin_memory=True,
        collate_fn=None,
    )

    # Ensure processes reach to end of optim clusters
    if args.distributed and args.world_size > 1:
        if group is not None:
            dist.barrier(group=group)
        else:
            dist.barrier()
    if args.headcount > 1:
        model.module.return_features = True
    aggregtensor = torch.cuda.DoubleTensor if args.headcount == 1 else torch.cuda.FloatTensor

    # 1. aggregate inputs:
    with torch.no_grad():
        for batch_idx, batch in enumerate(dataloader):
            # Get data
            video, audio, _, _, idx = batch

            # Move to GPU
            video = video.cuda(non_blocking=True)
            audio = audio.cuda(non_blocking=True)
            idx = idx.cuda(non_blocking=True)

            # Forward pass
            feat_v, feat_a = model(video, audio)
            if args.headcount == 1:
                feat_v = torch.nn.functional.softmax(feat_v, dim=1, dtype=torch.float64) # now float64
                feat_a = torch.nn.functional.softmax(feat_a, dim=1, dtype=torch.float64) # now float64

            if args.global_rank == 0 and batch_idx % 10 == 0:
                print_or_log((batch_idx, video.shape, audio.shape), logger=logger)

            if args.world_size > 1:
                # gather the features computed by all processes
                all_feat_v_list  = [aggregtensor(feat_v.size()) for src in range(args.world_size)]
                all_feat_a_list  = [aggregtensor(feat_a.size()) for src in range(args.world_size)]
                all_indices_list = [torch.IntTensor(feat_v.size(0)).random_(0, N).cuda() for src in range(args.world_size)]

                dist.all_gather(all_feat_v_list, feat_v)
                dist.all_gather(all_feat_a_list, feat_a)
                dist.all_gather(all_indices_list, idx)

                # only main process stores all features
                if args.global_rank == 0:
                    all_feat_v = torch.cat(all_feat_v_list).cpu().numpy()
                    all_feat_a = torch.cat(all_feat_a_list).cpu().numpy()
                    all_indices = torch.cat(all_indices_list).cpu().numpy()

            elif args.world_size == 1:
                feat_v = feat_v.cpu()
                feat_a = feat_a.cpu()
                idx = idx.cpu()
                all_feat_v = feat_v.numpy()
                all_feat_a = feat_a.numpy()
                all_indices = idx.numpy().astype(np.int32)

            if batch_idx == 0:
                if args.global_rank == 0:
                    fr = 0
                    K = feat_v.size(1)
                    PS_v = np.zeros((N, K), dtype=dtype)
                    PS_a = np.zeros((N, K), dtype=dtype)
                    indices = np.zeros(N, dtype=np.int32)

            # fill in arrays on main node
            if args.global_rank == 0:
                to = fr + all_feat_v.shape[0]
                PS_v[fr: to] = all_feat_v
                PS_a[fr: to] = all_feat_a
                indices[fr: to] = all_indices
                fr = to

            if args.distributed and args.world_size > 1:
                if group is not None:
                    dist.barrier(group=group)
                else:
                    dist.barrier()
        
        # 2. solve label assignment via sinkhorn-knopp:
        L = torch.zeros((N, args.headcount), dtype=torch.long, device='cuda')
        if args.global_rank == 0:
            print_or_log("Optimizing via sinkhorn-knopp on master GPU", logger=logger)
            # ensure there are no non-zero rows in PS matrix
            assert(np.sum(~PS_v.any(1)) == 0)
            assert(np.sum(~PS_a.any(1)) == 0)
            _costs = []
            _times = []
            # optimize heads
            for head in range(args.headcount):
                # optimize to get labels
                if args.headcount == 1:
                    PS_a_sk = PS_a
                    PS_v_sk = PS_v
                else:
                    head_a = getattr(model.module, f'mlp_a{head}')
                    head_v = getattr(model.module, f'mlp_v{head}')
                    PS_a_sk = torch.nn.functional.softmax(head_a.forward(torch.tensor(PS_a).cuda()),
                                                          dim=1,dtype=torch.float64)
                    PS_v_sk = torch.nn.functional.softmax(head_v.forward(torch.tensor(PS_v).cuda()),
                                                          dim=1, dtype=torch.float64)
                if args.stoch_sk_modality == 1:
                    N = PS_a_sk.shape[0]
                    _choice = np.random.choice(N, N//2, replace=False)
                    take_aud_only = np.ones(N, dtype=bool)*False
                    take_aud_only[_choice] = True
                    take_vid_only = ~take_aud_only
                    PS_v_sk[take_aud_only] = PS_a_sk[take_aud_only]
                    PS_a_sk[take_vid_only] = PS_v_sk[take_vid_only]
                elif args.stoch_sk_modality !=0:
                    N = PS_a_sk.shape[0]
                    a_choice = np.random.choice(N, int(N*args.stoch_sk_modality), replace=False)
                    v_choice = np.random.choice(N, int(N*args.stoch_sk_modality), replace=False)

                    # nevermind the overlap of vid_only and aud_only as long as it's small
                    PS_a_sk[v_choice, :] = PS_v_sk[v_choice, :]
                    PS_v_sk[a_choice, :] = PS_a_sk[a_choice, :]

                PS_v_sk *= PS_a_sk # take mean under correct maths. note the taking of sqrt in the sk algorithms.
                sk_start = time.time()
                cost_v, L_head = optimize_L_sk(args, PS_v_sk)
                L[indices, head] = L_head.to('cuda') # put it in correct order

                _costs.append(cost_v)
                _times.append(time.time() - sk_start)
                print_or_log(f"Head {head}, Cost: (video): {_costs[head]:.3f}; time: {_times[head]:.3f}", logger=logger)
            
            print_or_log(f"Final Cost: (video): {np.mean(_costs):.3f}; time: {np.mean(_times):.3f}", logger=logger)
            del PS_v
            del PS_a

            # processes wait on main process compute PS features
            # Write costs to log
            if writer:
                writer.add_scalar('LP-cost', np.mean(_costs), iter_num)

        if args.global_rank == 0:
            print_or_log(f"{args.global_rank} finished clustering", logger=logger)

        if args.distributed and args.world_size > 1:
            if group is not None:
                dist.barrier(group=group)
            else:
                dist.barrier()
        
        torch.cuda.synchronize()

        if group is not None:
            torch.distributed.broadcast(L, 0, group)
        else:
            torch.distributed.broadcast(L, 0)
        if args.global_rank == 0:
            print_or_log(f"{args.global_rank} finished broadcasting", logger=logger)

        if args.distributed and args.world_size > 1:
            if group is not None:
                dist.barrier(group=group)
            else:
                dist.barrier()
        model.module.return_features = False
        # model = moveddp(args, model.module, broadcast_buffers=False)
        return  L # change this later