import numpy as np


def create_subsamples():
    """ Subsample 1M, 5M, and 10M from 50M to avoid bias by image selection. """
    part1 = np.load('50m_part1.npz')
    part2 = np.load('50m_part2.npz')
    part3 = np.load('50m_part3.npz')
    part4 = np.load('50m_part4.npz')
    print('Loaded parts')
    part_images = [part1['image'], part2['image'], part3['image'], part4['image']]
    part_labels = [part1['label'], part2['label'], part3['label'], part4['label']]
    images, labels = np.concatenate(part_images), np.concatenate(part_labels)
    print('Combined parts')
    num_classes = 100  # for CIFAR-100
    for num_samples, file_name in [(10000, '1m'), (50000, '5m'), (100000, '10m')]:
        indices = np.full(labels.shape[0], False)
        step = labels.shape[0] // num_classes
        for i in range(num_classes):
            indices[i * step:i * step + num_samples] = True
        sel_images, sel_labels = images[indices], labels[indices]
        print(images.shape, labels.shape)
        np.savez('{}.npz'.format(file_name), image=sel_images, label=sel_labels)


def check_subsamples():
    """ Verify that certain assumptions about 1M, 5M, and 10M remain valid. """
    for num_samples, file_name in [(10000, '1m'), (50000, '5m'), (100000, '10m')]:
        data = np.load('{}.npz'.format(file_name))
        images, labels = data['image'], data['label']
        num_classes = 100  # for CIFAR-100
        assert images.shape == (num_classes * num_samples, 32, 32, 3)
        assert labels.shape == (num_classes * num_samples,)
        reference = []
        for i in range(num_classes):
            reference.append(np.full((num_samples,), i))
        reference = np.concatenate(reference)
        assert np.all(labels == reference)


def main():
    create_subsamples()
    check_subsamples()


if __name__ == '__main__':
    main()

