from rts_base import gsm_pipeline
from tfm_constants import DC40_ORIGIN_DEPS
import sys

pretrained_weights = 'dc40_weights.hdf5'

def gsm_dc40(try_arg, zero_rate, warmup_epochs, lr_values, lr_epoch_boundaries, max_epochs, data_dir,
                    init_weights=pretrained_weights,
                    batch_size=None, momentum=0.9, num_gpus=1, num_steps_per_hdf5=100000, use_dense_layer=False, specify_l2=None):
    gsm_pipeline('dc40', try_arg, init_weights=init_weights, zero_rate=zero_rate,
        lr_warmup_epochs=warmup_epochs, lr_values=lr_values, lr_epoch_boundaries=lr_epoch_boundaries, max_epochs=max_epochs,
        deps=DC40_ORIGIN_DEPS, num_gpus=num_gpus,
        use_dense_layer=use_dense_layer, batch_size=batch_size,
        momentum=momentum, num_steps_per_hdf5=num_steps_per_hdf5, specify_l2=specify_l2, data_dir=data_dir
    )

if __name__ == '__main__':

    gsm_dc40('gsm_dc40_reimpl', zero_rate=0.85,
        warmup_epochs=5, lr_values=[5e-3, 5e-4, 5e-5], lr_epoch_boundaries=[400, 500], max_epochs=600,
        batch_size=64,
        momentum=0.98,
        num_steps_per_hdf5=8000, use_dense_layer=False, data_dir=sys.argv[1])