import torch
import torch.nn.functional as nnf

class ShufflePatches(object):
  def __init__(self, patch_size, randperm):
    self.ps = patch_size
    self.rand_perm = randperm

  def __call__(self, x):
    # divide the batch of images into non-overlapping patches
    u = nnf.unfold(x, kernel_size=self.ps, stride=self.ps, padding=0)
    # permute the patches of each image in the batch
    #for b_ in u:
        #print("Image shape {}".format(b_.shape))
        #print("PERM shape {}".format(torch.randperm(b_.shape[-1]).shape))
    #nkfevnkvfen
     
    pu = torch.cat([b_[:, self.rand_perm][None,...] for b_ in u], dim=0)
    #pu = torch.cat([b_[:, torch.randperm(b_.shape[-1])][None,...] for b_ in u], dim=0)
    # fold the permuted patches back together
    f = nnf.fold(pu, x.shape[-2:], kernel_size=self.ps, stride=self.ps, padding=0)
    return f
