import numpy as np
import os
import pickle
import time
import torch
import torch.distributed as dist
from torch.utils.data.dataloader import default_collate
from torch.utils.data.sampler import Sampler, SubsetRandomSampler
from utils import print_or_log, trigger_job_requeue

try:
    from sinkornknopp import optimize_L_sk,optimize_L_sk_gpu
except:
    from .sinkornknopp import optimize_L_sk,optimize_L_sk_gpu


def round_down(num, divisor):
    return num - (num % divisor)


def collate_fn(batch):
    batch = [(d[0], d[1], d[2], d[3], d[4]) for d in batch if d is not None]
    if len(batch) == 0:
        return None
    else:
        return default_collate(batch)


class Subset_Sampler(Sampler):
    """
    Sample indices.
    """

    def __init__(self, indices):
        self.indices = indices

    def __iter__(self):
        return iter(self.indices)

    def __len__(self):
        return len(self.indices)


def get_cluster_assignments_gpu(args, dataset, model, logger=None, writer=None, group=None, iter_num=0):
    # clear cache at beginning
    torch.cuda.empty_cache()
    model.eval()
    N = len(dataset)
    # this process deals only with a subset of the dataset
    sampler = None
    if args.distributed and args.world_size > 8:
        local_nmb_data = N // args.world_size
        train_indices = torch.arange(
            args.global_rank * local_nmb_data,
            (args.global_rank + 1) * local_nmb_data
        ).int()
        # create subset sampler
        sampler = SubsetRandomSampler(train_indices)
    else:
        train_indices = torch.arange(0, N).int()
        sampler = SubsetRandomSampler(train_indices)

    ## THIS INTRODUCED A bug
    # sampler = torch.utils.data.distributed.DistributedSampler(dataset)
    # sampler.set_epoch(iter_num) # just dummy for shuffling.

    # we need a data loader
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=64 if args.world_size > 16 else 64,  # i.e larger batchsize
        sampler=sampler,
        shuffle=sampler is None,
        num_workers=args.workers,
        pin_memory=True,
        collate_fn=None
    )

    # Ensure processes reach to end of optim clusters
    if args.distributed and args.world_size > 8:
        if group is not None:
            dist.barrier(group=group)
        else:
            dist.barrier()

    assert args.ind_groups <= args.headcount, "can't have more independent head-groups than heads"

    if args.headcount > 1:
        # aggregate GAP features when using multi heads && not every head is its own group
        model.module.return_features = True
    aggregtensor = torch.cuda.DoubleTensor if args.headcount == 1 else torch.cuda.FloatTensor
    dtype = torch.float64 if args.headcount == 1 else torch.float32
    L = torch.zeros((N, args.headcount), dtype=torch.long, device='cuda')
    order_heads = list(range(args.headcount))
    np.random.shuffle(order_heads) # is inplace

    for hd_grp_idx in range(args.ind_groups):
        # 1. aggregate inputs:
        for batch_idx, batch in enumerate(dataloader):
            # Get data
            video, audio, _, _, idx = batch

            # Move to GPU
            video = video.cuda(non_blocking=True)
            audio = audio.cuda(non_blocking=True)
            idx = idx.cuda(non_blocking=True)

            # Forward pass
            feat_v, feat_a = model(video, audio)
            if args.headcount == 1:
                feat_v = torch.nn.functional.softmax(feat_v, dim=1, dtype=torch.float64)  # now float64
                feat_a = torch.nn.functional.softmax(feat_a, dim=1, dtype=torch.float64)  # now float64

            if args.global_rank == 0 and batch_idx % 10 == 0:
                print_or_log((batch_idx, video.shape, audio.shape), logger=logger)

            if args.distributed and args.world_size > 8:
                # gather the features computed by all processes
                all_feat_v_list = [aggregtensor(feat_v.size()) for src in range(args.world_size)]
                all_feat_a_list = [aggregtensor(feat_a.size()) for src in range(args.world_size)]
                all_indices_list = [torch.IntTensor(feat_v.size(0)).random_(0, N).cuda() for src in
                                    range(args.world_size)]

                dist.all_gather(all_feat_v_list, feat_v)
                dist.all_gather(all_feat_a_list, feat_a)
                dist.all_gather(all_indices_list, idx)

                # only main process stores all features
                if args.global_rank == 0:
                    all_feat_v = torch.cat(all_feat_v_list)
                    all_feat_a = torch.cat(all_feat_a_list)
                    all_indices = torch.cat(all_indices_list).cpu()
            else:
                all_feat_v = feat_v
                all_feat_a = feat_a
                all_indices = idx.cpu()


            if batch_idx == 0 and (args.global_rank == 0):
                fr = 0
                K = feat_v.size(1)
                print_or_log(f"storing features of size {K}", logger=logger)
                PS_v = torch.zeros((N, K), dtype=dtype, device='cuda')
                PS_a = torch.zeros((N, K), dtype=dtype, device='cuda')
                indices = torch.zeros(N, dtype=torch.long)

            # fill in arrays on main node
            if args.global_rank == 0:
                to = fr + all_feat_v.shape[0]
                PS_v[fr: to] = all_feat_v
                PS_a[fr: to] = all_feat_a
                indices[fr: to] = all_indices
                fr = to

            # signal received, relaunch experiment
            if os.environ['SIGNAL_RECEIVED'] == 'True':
                args.resume = 'True'
                if args.global_rank == 0:
                    print_or_log("Beginning requeue", logger=logger)
                    trigger_job_requeue(os.path.join(args.output_dir, 'checkpoints', 'checkpoint.pth'))

            if args.distributed and args.world_size > 8:
                if group is not None:
                    dist.barrier(group=group)
                else:
                    dist.barrier()

        # 2. solve label assignment via sinkhorn-knopp:
        if args.global_rank == 0:
            print_or_log("Optimizing via sinkhorn-knopp on master GPU", logger=logger)
            if os.environ['SIGNAL_RECEIVED'] == 'True':
                args.resume = 'True'
                if args.global_rank == 0:
                    print_or_log("Beginning requeue", logger=logger)
                    trigger_job_requeue(os.path.join(args.output_dir, 'checkpoints', 'checkpoint.pth'))

            _costs = [0 for i in range(args.headcount)]
            _times = [0 for i in range(args.headcount)]

            # Assert we didn't miss any
            print(PS_v.shape, PS_a.shape, PS_v.max(), PS_a.max(),flush=True)
            # assert(np.sum(~PS_v.cpu().numpy().any(1)) == 0)
            # assert(np.sum(~PS_a.cpu().numpy().any(1)) == 0)

            # optimize heads
            for head in order_heads[hd_grp_idx::args.ind_groups]:
                # optimize to get labels
                if args.headcount == 1:
                    PS_a_sk = PS_a
                    PS_v_sk = PS_v
                    head_a = model.module.mlp_a
                else:
                    head_a = getattr(model.module, f'mlp_a{head}')
                    head_v = getattr(model.module, f'mlp_v{head}')
                    PS_a_sk = torch.nn.functional.softmax(head_a.forward(PS_a),
                                                          dim=1, dtype=torch.float64)
                    PS_v_sk = torch.nn.functional.softmax(head_v.forward(PS_v),
                                                          dim=1, dtype=torch.float64)
                # align heads of audio and video:
                if args.match:
                    if iter_num == 0:
                        print_or_log("very first SK: doing matching-based permutation of audio-subnetwork's last fc")
                        match_order(PS_v_sk,
                                    PS_a_sk,
                                    list(head_a.modules())[-1] if model.module.use_mlp else head_a,
                                    steps=50000,
                                    restarts=2
                        )

                torch.mul(PS_v_sk, PS_a_sk, out=PS_v_sk)  # move activations to PS_v_sk
                sk_start = time.time()
                print_or_log(f"GROUP (== 1): {args.groups}", logger=logger)
                cost, L_head = optimize_L_sk_gpu(args, PS_v_sk)
                L[indices, head] = L_head.to('cuda')  # put it in correct order

                _costs[head] = cost
                _times[head] = time.time() - sk_start
                print_or_log(f"Head {head}, Cost: (video): {_costs[head]:.3f}; time: {_times[head]:.3f}", logger=logger)

            print_or_log(f"Final Cost: (video): {np.mean(_costs):.3f}; time: {np.mean(_times):.3f}", logger=logger)
            del PS_v
            del PS_a

            # processes wait on main process compute PS features
            # Write costs to log
            if writer:
                writer.add_scalar('train/LP-cost', np.mean(_costs), iter_num)

    if args.global_rank == 0:
        print_or_log(f"{args.global_rank} finished clustering", logger=logger)

    if args.distributed and args.world_size > 8:
        if group is not None:
            dist.barrier(group=group)
        else:
            dist.barrier()

        torch.cuda.synchronize()

        if group is not None:
            torch.distributed.broadcast(L, 0, group)
        else:
            torch.distributed.broadcast(L, 0)
        if args.global_rank == 0:
            print_or_log(f"{args.global_rank} finished broadcasting", logger=logger)

        if args.distributed and args.world_size > 1:
            if group is not None:
                dist.barrier(group=group)
            else:
                dist.barrier()
        model.module.return_features = False
        model.train()
        return L  # change this later
    else:
        model.module.return_features = False
        model.train()
        return L

def get_cluster_assignments_gpu_vid_only(args, dataset, model, logger=None, writer=None, group=None, iter_num=0):
    # clear cache at beginning
    torch.cuda.empty_cache()
    model.eval()
    N = len(dataset)
    # this process deals only with a subset of the dataset
    sampler = None
    if args.distributed:
        local_nmb_data = N // args.world_size
        train_indices = torch.arange(
            args.global_rank * local_nmb_data,
            (args.global_rank + 1) * local_nmb_data
        ).int()
        # create subset sampler
        sampler = SubsetRandomSampler(train_indices)

    ## THIS INTRODUCED A bug
    # sampler = torch.utils.data.distributed.DistributedSampler(dataset)
    # sampler.set_epoch(iter_num) # just dummy for shuffling.

    # we need a data loader
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=256,  # i.e larger batchsize
        sampler=sampler,
        shuffle=sampler is None,
        num_workers=args.workers,
        pin_memory=True,
        collate_fn=None
    )

    # Ensure processes reach to end of optim clusters
    if args.distributed and args.world_size > 1:
        if group is not None:
            dist.barrier(group=group)
        else:
            dist.barrier()

    assert args.ind_groups <= args.headcount, "can't have more independent head-groups than heads"

    if args.headcount > 1:
        # aggregate GAP features when using multi heads && not every head is its own group
        model.module.return_features = True
    aggregtensor = torch.cuda.DoubleTensor if args.headcount == 1 else torch.cuda.FloatTensor
    dtype = torch.float64 if args.headcount == 1 else torch.float32
    L = torch.zeros((N, args.headcount), dtype=torch.long, device='cuda')
    order_heads = list(range(args.headcount))
    np.random.shuffle(order_heads) # is inplace

    for hd_grp_idx in range(args.ind_groups):
        # 1. aggregate inputs:
        for batch_idx, batch in enumerate(dataloader):
            # Get data
            video, audio, _, _, idx = batch

            # Move to GPU
            video = video.cuda(non_blocking=True)
            audio = audio.cuda(non_blocking=True)
            idx = idx.cuda(non_blocking=True)

            # Forward pass
            feat_v, feat_a = model(video[:, :, 0, :, :], audio)
            if args.headcount == 1:
                feat_v = torch.nn.functional.softmax(feat_v, dim=1, dtype=torch.float64)  # now float64

            if args.global_rank == 0 and batch_idx % 10 == 0:
                print_or_log((batch_idx, video.shape, audio.shape), logger=logger)

            if args.distributed:
                # gather the features computed by all processes
                all_feat_v_list = [aggregtensor(feat_v.size()) for src in range(args.world_size)]
                all_indices_list = [torch.IntTensor(feat_v.size(0)).random_(0, N).cuda() for src in
                                    range(args.world_size)]

                dist.all_gather(all_feat_v_list, feat_v)
                dist.all_gather(all_indices_list, idx)

                # only main process stores all features
                if args.global_rank == 0:
                    all_feat_v = torch.cat(all_feat_v_list)
                    all_indices = torch.cat(all_indices_list).cpu()
            else:
                all_feat_v = feat_v
                all_indices = idx.cpu()


            if batch_idx == 0 and (args.global_rank == 0):
                fr = 0
                K = feat_v.size(1)
                print_or_log(f"storing features of size {K}", logger=logger)
                PS_v = torch.zeros((N, K), dtype=dtype, device='cuda')
                indices = torch.zeros(N, dtype=torch.long)

            # fill in arrays on main node
            if args.global_rank == 0:
                to = fr + all_feat_v.shape[0]
                PS_v[fr: to] = all_feat_v
                indices[fr: to] = all_indices
                fr = to

            # signal received, relaunch experiment
            if os.environ['SIGNAL_RECEIVED'] == 'True':
                args.resume = 'True'
                if args.global_rank == 0:
                    print_or_log("Beginning requeue", logger=logger)
                    trigger_job_requeue(os.path.join(args.output_dir, 'checkpoints', 'checkpoint.pth'))

            if args.distributed and args.world_size > 1:
                if group is not None:
                    dist.barrier(group=group)
                else:
                    dist.barrier()

        # 2. solve label assignment via sinkhorn-knopp:
        if args.global_rank == 0:
            print_or_log("Optimizing via sinkhorn-knopp on master GPU", logger=logger)
            if os.environ['SIGNAL_RECEIVED'] == 'True':
                args.resume = 'True'
                if args.global_rank == 0:
                    print_or_log("Beginning requeue", logger=logger)
                    trigger_job_requeue(os.path.join(args.output_dir, 'checkpoints', 'checkpoint.pth'))

            _costs = [0 for i in range(args.headcount)]
            _times = [0 for i in range(args.headcount)]

            # Assert we didn't miss any
            print(PS_v.shape, PS_v.max(),flush=True)
            # assert(np.sum(~PS_v.cpu().numpy().any(1)) == 0)
            # assert(np.sum(~PS_a.cpu().numpy().any(1)) == 0)

            # optimize heads
            for head in order_heads[hd_grp_idx::args.ind_groups]:
                # optimize to get labels
                if args.headcount == 1:
                    PS_v_sk = PS_v
                else:
                    head_v = getattr(model.module, f'mlp_v{head}')
                    PS_v_sk = torch.nn.functional.softmax(head_v.forward(PS_v),
                                                          dim=1, dtype=torch.float64)

                torch.mul(PS_v_sk, PS_v_sk, out=PS_v_sk)  # move activations to PS_v_sk
                sk_start = time.time()
                print_or_log(f"GROUP (== 1): {args.groups}", logger=logger)
                cost, L_head = optimize_L_sk_gpu(args, PS_v_sk)
                L[indices, head] = L_head.to('cuda')  # put it in correct order

                _costs[head] = cost
                _times[head] = time.time() - sk_start
                print_or_log(f"Head {head}, Cost: (video): {_costs[head]:.3f}; time: {_times[head]:.3f}", logger=logger)

            print_or_log(f"Final Cost: (video): {np.mean(_costs):.3f}; time: {np.mean(_times):.3f}", logger=logger)
            del PS_v

            # processes wait on main process compute PS features
            # Write costs to log
            if writer:
                writer.add_scalar('train/LP-cost', np.mean(_costs), iter_num)

    if args.global_rank == 0:
        print_or_log(f"{args.global_rank} finished clustering", logger=logger)

    if args.distributed and args.world_size > 1:
        if group is not None:
            dist.barrier(group=group)
        else:
            dist.barrier()

        torch.cuda.synchronize()

        if group is not None:
            torch.distributed.broadcast(L, 0, group)
        else:
            torch.distributed.broadcast(L, 0)
        if args.global_rank == 0:
            print_or_log(f"{args.global_rank} finished broadcasting", logger=logger)

        if args.distributed and args.world_size > 1:
            if group is not None:
                dist.barrier(group=group)
            else:
                dist.barrier()
        model.module.return_features = False
        model.train()
        return L  # change this later


def match_order(emb1, emb2_in, W2, steps=50000, restarts=2):
    with torch.no_grad():
        assert type(W2) == torch.nn.modules.linear.Linear
        K = emb1.shape[1]
        def c(a, b):
            return (torch.abs(a - b)).sum(0).sum(0)
        last_iter = 0
        cost = c(emb1, emb2_in)
        best_cost = cost
        print_or_log(f'initial cost: {cost:.1f}')

        fin_perm = torch.arange(0, K)
        for retries in range(restarts):
            cost_try = cost.item()
            perm = torch.arange(0, K)
            emb2 = emb2_in.clone().detach()
            for _iter in range(steps):
                [i, j] = np.random.choice(K, 2, replace=False) # what would happen if we switch cluster i with j
                current = c(emb1[:,i], emb2[:,i])  + c(emb1[:,j], emb2[:,j])
                future =  c(emb1[:,i], emb2[:,j])  + c(emb1[:,j], emb2[:,i])
                delta = current - future
                if delta > 0:
                    # switch i and j
                    emb1[:,j], emb2[:,i] = emb1[:,i].clone().detach(), emb2[:,j].clone().detach()
                    # embedding2 = embedding2[:,perm]
                    cost_try -= delta
                    _i = int(perm[i])
                    perm[i] = int(perm[j])
                    perm[j] = _i
                    last_iter = _iter
                    if _iter % 50 == 0:
                        print(f'cost:          {cost_try:.1f}, iter={_iter}', end='\r')
                if _iter - last_iter > 1000:
                    break

            cost_try = c(emb1, emb2_in[:, perm])
            print_or_log(f"cost of this try: {cost_try:.1f}")
            if cost_try < best_cost:
                best_cost = cost_try
                fin_perm = perm
        print_or_log(f"final cost:    {best_cost:.1f}")
        # print_or_log(f"final permutation: {fin_perm}")
        W2.bias.data = W2.bias.data[fin_perm]
        W2.weight.data = W2.weight.data[fin_perm]
