import torch
from torch.utils.data.sampler import Sampler
import numpy as np


###usage: 
# train_loader = DataLoader(train_dataset,batch_size= args.batch_size,
# sampler=DataSampler(train_labels,args.batch_size, multi_tasks = args.multi_tasks),
# num_workers=args.num_workers)
class DataSampler(Sampler):
    def __init__(self, labels, batchSize, posNum=-1, multi_tasks = 1):
        
        self.labels = np.array(labels)
        ### we assume that total tasks is divisible by mulit_task(how many tasks we will sample in a batch) 
        assert self.labels.shape[1]%multi_tasks == 0

        
        self.multi_tasks = multi_tasks
        self.batchSize = batchSize
        
        self.posNum = self.batchSize//multi_tasks//2  ###half positive, half negative
        self.negNum = self.batchSize//multi_tasks - self.posNum

        self.label_dict = {}
        for i in range(self.labels.shape[1]):
            task_label = self.labels[:,i]
            pos_index = np.flatnonzero(task_label==1)
            ###To avoid sampling error
            while len(pos_index) < self.posNum: 
                pos_index = np.concatenate((pos_index,pos_index))
            np.random.shuffle(pos_index)

            neg_index = np.flatnonzero(task_label==0)
            np.random.shuffle(neg_index)

            self.label_dict.update({i:(pos_index,neg_index)})

        self.posPtr, self.negPtr = np.zeros(self.labels.shape[1],dtype=np.int64), np.zeros(self.labels.shape[1],dtype=np.int64)
        self.taskPtr, self.tasks = 0, np.random.permutation(list(range(self.labels.shape[1])))

        #### to define an epoch
        self.batchNum = self.labels.shape[0]//self.batchSize 
        self.ret = np.empty(self.batchNum*self.batchSize, dtype=np.int64)



    def __iter__(self):

        for batch_id in range(self.batchNum):
            task_ids = self.tasks[self.taskPtr:self.taskPtr+self.multi_tasks]
            self.taskPtr += self.multi_tasks
            if self.taskPtr >= len(self.tasks):
                np.random.shuffle(self.tasks)
                self.taskPtr = self.taskPtr % len(self.tasks)

            beg = batch_id*self.batchSize
            for task_id in task_ids:
                if self.posPtr[task_id]+self.posNum > len(self.label_dict[task_id][0]):
                    temp = self.label_dict[task_id][0][self.posPtr[task_id]:]
                    np.random.shuffle(self.label_dict[task_id][0])
                    self.posPtr[task_id] = (self.posPtr[task_id]+self.posNum)%len(self.label_dict[task_id][0])
                    self.ret[beg:beg+self.posNum]= np.concatenate((temp,self.label_dict[task_id][0][:self.posPtr[task_id]]))
                else:
                    self.ret[beg:beg+self.posNum]= self.label_dict[task_id][0][self.posPtr[task_id]: self.posPtr[task_id]+self.posNum]
                    self.posPtr[task_id] += self.posNum
                beg += self.posNum

                if self.negPtr[task_id]+self.negNum > len(self.label_dict[task_id][1]):
                    temp = self.label_dict[task_id][1][self.negPtr[task_id]:]
                    np.random.shuffle(self.label_dict[task_id][1])
                    self.negPtr[task_id] = (self.negPtr[task_id]+self.negNum)%len(self.label_dict[task_id][1])
                    self.ret[beg:beg+self.negNum]= np.concatenate((temp,self.label_dict[task_id][1][:self.negPtr[task_id]]))
                else:
                    self.ret[beg:beg+self.negNum]= self.label_dict[task_id][1][self.negPtr[task_id]: self.negPtr[task_id]+self.negNum]
                    self.negPtr[task_id] += self.negNum
                beg += self.negNum
            
        return iter(self.ret) 


    def __len__ (self):
        return len(self.ret)

