# Copyright 2019 The MixMatch Authors.  All rights reserved.

import itertools

from absl import flags
from libml.data import DataSet, augment_cifar10, augment_svhn, augment_stl10
import tensorflow as tf

flags.DEFINE_integer('nu', 2, 'Number of augmentations for class-consistency.')
FLAGS = flags.FLAGS


def stack_augment(augment):
    def func(x):
        xl = [augment(x) for _ in range(FLAGS.nu)]

        return dict(image=tf.stack([x['image'] for x in xl]),
                    label=tf.stack([x['label'] for x in xl]))

    return func


DATASETS = {}
DATASETS.update([DataSet.creator('cifar10', seed, label, valid, [augment_cifar10, stack_augment(augment_cifar10)])
                 for seed, label, valid in
                 itertools.product(range(6), [250, 500, 1000, 2000, 4000, 8000], [1, 5000])])
DATASETS.update(
    [DataSet.creator('cifar100', seed, label, valid, [augment_cifar10, stack_augment(augment_cifar10)], nclass=100)
     for seed, label, valid in
     itertools.product(range(6), [10000], [1, 5000])])
DATASETS.update([DataSet.creator('stl10', seed, label, valid, [augment_stl10, stack_augment(augment_stl10)], height=96,
                                 width=96, do_memoize=False)
                 for seed, label, valid in
                 itertools.product(range(6), [1000, 5000], [1, 500])])
DATASETS.update([DataSet.creator('svhn', seed, label, valid, [augment_svhn, stack_augment(augment_svhn)],
                                 do_memoize=False)
                 for seed, label, valid in
                 itertools.product(range(6), [250, 500, 1000, 2000, 4000, 8000], [1, 5000])])
DATASETS.update([DataSet.creator('svhn_noextra', seed, label, valid, [augment_svhn, stack_augment(augment_svhn)],
                                 do_memoize=False)
                 for seed, label, valid in
                 itertools.product(range(6), [250, 500, 1000, 2000, 4000, 8000], [1, 5000])])
