import os
import math
import torch
import torchvision.datasets as datasets


class DataLoader:

    def __init__(self,X,y,batch_size):
        self.X, self.y = X, y 
        self.batch_size = batch_size
        self.n_samples = len(y)
        self.idx = 0

    def __len__(self):
        length = self.n_samples // self.batch_size
        if self.n_samples > length * self.batch_size:
            length += 1
        return length

    def __iter__(self):
        return self    

    def __next__(self):
        if self.idx >= self.n_samples:
            self.idx = 0
            rnd_idx = torch.randperm(self.n_samples)
            self.X = self.X[rnd_idx]
            self.y = self.y[rnd_idx]

        idx_end = min(self.idx+self.batch_size, self.n_samples)
        batch_X = self.X[self.idx:idx_end]
        batch_y = self.y[self.idx:idx_end]
        self.idx = idx_end

        return batch_X,batch_y



def convert(dataset):
    X = dataset.data.float()/255
    y = dataset.targets
    
    indx_1 = y>=5
    indx_2 = y<5
    y[indx_1] = 1.0
    y[indx_2] = 0.0
    y = y.float()
    
    return X, y


def load_mnist(n=1000, batch_size=100):
    dataset1 = datasets.MNIST('./data', train=True, download=True)
    dataset2 = datasets.MNIST('./data', train=False)

    X_tr, y_tr = convert(dataset1)
    X_te, y_te = convert(dataset2)

    X_tr, y_tr = X_tr[0:n].view(-1,1,28,28), y_tr[0:n]
    X_te, y_te = X_te.view(-1,1,28,28), y_te[:]
    train_loader = DataLoader(X_tr, y_tr, batch_size)
    test_loader = DataLoader(X_te, y_te, batch_size)
    return train_loader, test_loader


def gen_rfm_data(n=200, d=10, batch_size=10):
    target_func = lambda X: (0.2*X[:,0] + (X[:,1]-1)**2/3 + torch.sin(X[:,2]*X[:,0]/4)) * math.sqrt(d)

    X_tr = torch.randn(n, d)
    X_tr /= X_tr.norm(dim=1, keepdim=True)
    y_tr = target_func(X_tr)

    X_te = torch.randn(10000, d)
    X_te /= X_te.norm(dim=1, keepdim=True)
    y_te = target_func(X_te)

    train_loader = DataLoader(X_tr, y_tr, batch_size)
    test_loader = DataLoader(X_te, y_te, 100)

    return train_loader, test_loader

def gen_linear_net_data(n=50, d=100, batch_size=4):
    X_tr = torch.randn(n, d)
    y_tr = X_tr.mean(dim=1)

    X_te = torch.randn(10000, d)
    y_te = X_te.mean(dim=1)

    train_loader = DataLoader(X_tr, y_tr, batch_size)
    test_loader = DataLoader(X_te, y_te, 100)

    return train_loader, test_loader


if __name__ == '__main__':
    train_loader, test_loader = load_mnist(1000, 100)
    for i in range(30):
        batch_x, batch_y = next(train_loader)
        print(i, batch_x.shape, batch_y.shape)

    for i in range(4):
        batch_x, batch_y = next(test_loader)
        print(i, batch_x.shape, batch_y.shape)

