from typing import Any, Mapping, Optional, Union
import os
import torch
import torch.nn as nn
import torchvision
import numpy as np
class CIFAR10Interpolation(torchvision.datasets.CIFAR10):
    offset = 100

    def __len__(self):
        self.offset = 100
        return super().__len__() 

    def __getitem__(self, idx):
        a, _ = super().__getitem__(idx)
        sub = (idx+self.offset) % 10000
        b, _ = super().__getitem__(sub)
        return (a + b) / 2, 0

class DTD(torchvision.datasets.ImageFolder):

    def __len__(self):
        return super().__len__()* 2
    def __getitem__(self, idx):
        x, y = super().__getitem__(idx // 2)
        x = x[:, :32:, :32]
        return x, y


class SVHNrotation(torchvision.datasets.SVHN):

    def __getitem__(self,idx):

        x,y = super().__getitem__(idx//2)
        x = torch.rot90(x,1,[1,2])

        return x,y

