import os
import sys
# sys.path.append(os.path.realpath('../..'))
sys.path.append(os.path.realpath('.'))
import toy.ops as ops
import toy.data as data
import toy.net as net
import toy.train as train
import toy.ground_truth as gt
import matplotlib.pyplot as plt
import numpy as np
import torch
import pandas as pd
import re
import toy.ground_truth as gt

# rhos = [1.0000, 0.9599, 0.800, 0.6457, 0.1451, 0.0173, 0.0020]
rhos = [1.0000, 0.9599, 0.800, 0.6457, 0.5, 0.35, 0.1451, 0.0173, 0.0020]
datanums = [4096, 2048, 1024, 512, 256, 128, 64, 32, 16384, 8192]
fold = 5

current_dir = './experiment/KD_training/Perfect_Teacher_Distillation'
generate_dir = './experiment/KD_training/Teacher_NN_Stendent_Linear_NN_Infinite_Data'
gt_pth = './experiment/KD_training/Teacher_Gaussian_Stendent_Real_NN/network/student'
device_name = 'cuda:0'

for i in range(fold):
    for datanum in datanums:
        for rho in rhos:
            name = current_dir + '/rho-{:.04f}/datanum-{:06d}/fold-{:02d}'.format(rho, datanum, i)
            if not os.path.exists(name):
                train.train_linear(
                    name,
                    target_function_path=gt_pth,
                    init_net_path=generate_dir + '/init_net',
                    train_dataset_path=generate_dir + '/train_data',
                    test_data_path=generate_dir + '/test_data',
                    model_config={
                        'rho': rho,
                        'T': 10.0,
                        'teacher_reduction': 0.3,
                        'datanum': datanum,
                        'regenerate_data': True,
                    },
                    training_strategry={
                        'batch_size': 512,
                        'lr': 0.01,
                        'epoch': 4096,
                        'test_interval': 64,
                        'display_interval': 16,
                        'save_interval': 16,
                        'record_interval': 8,
                        'test_datanum': 32768,
                    },
                    seed=i,
                    device_name=device_name,
                )
