import os
import sys
# sys.path.append(os.path.realpath('../..'))
sys.path.append(os.path.realpath('.'))
import toy.ops as ops
import toy.data as data
import toy.net as net
import toy.train as train
import toy.ground_truth as gt
import matplotlib.pyplot as plt
import numpy as np
import torch
import pandas as pd
import re
import toy.ground_truth as gt

# generate data
train_data = data.Input_Dataset(
    input_std=5.0, input_dim=2, device=torch.device('cuda:0'), datanum=512, online=True
)
test_data = data.Input_Dataset(input_std=5.0, input_dim=2, device=torch.device('cuda:0'))

generate_dir = './experiment/KD_training/Teacher_NN_Stendent_Linear_NN_Infinite_Data'
if not os.path.exists(generate_dir):
    os.makedirs(generate_dir)

if not os.path.exists(generate_dir + '/train_data'):
    print('*')
    torch.save(train_data, generate_dir + '/train_data')
if not os.path.exists(generate_dir + '/test_data'):
    print('*')
    torch.save(test_data, generate_dir + '/test_data')
init_net = net.Net(hidden_layer_num=3, hidden_layer_dim=2048, input_dim=2)
if not os.path.exists(generate_dir + '/init_net'):
    print('*')
    torch.save(init_net, generate_dir + '/init_net')

# rhos = [1.0, 0.99, 0.87, 0.75, 0.64, 0.53]
# rhos = [1.0]
rhos = [1.0000, 0.9599, 0.800, 0.6457, 0.5, 0.35, 0.1451, 0.0173, 0.0020]
device_name = 'cuda:0'

for rho in rhos:
    if not os.path.exists(generate_dir + '/rho-{:.04f}'.format(rho)):
        train.train_linear(
            dir=generate_dir + '/rho-{:.04f}'.format(rho),
            target_function_path=
            './experiment/KD_training/Teacher_Gaussian_Stendent_Real_NN/network/student',
            init_net_path=generate_dir + '/init_net',
            train_dataset_path=generate_dir + '/train_data',
            test_data_path=generate_dir + '/test_data',
            model_config={
                'rho': rho,
                'T': 10.0,
                'teacher_reduction': 0.3
            },
            training_strategry={
                'batch_size': 512,
                'lr': 0.01,
                'epoch': 4096 * 8,
                'test_interval': 64,
                'display_interval': 16,
                'save_interval': 64,
                'record_interval': 8,
                'test_datanum': 32768,
            },
            seed=0,
            device_name=device_name,
        )

if not os.path.exists(generate_dir + '/zero'):
    train.train_linear(
        dir=generate_dir + '/zero',
        target_function_path=
        './experiment/KD_training/Teacher_Gaussian_Stendent_Real_NN/network/student',
        init_net_path=generate_dir + '/init_net',
        train_dataset_path=generate_dir + '/train_data',
        test_data_path=generate_dir + '/test_data',
        model_config={
            'rho': 1.0,
            'T': 1.0,
            'teacher_reduction': 0.0
        },
        training_strategry={
            'batch_size': 512,
            'lr': 0.01,
            'epoch': 4096 * 8,
            'test_interval': 64,
            'display_interval': 16,
            'save_interval': 64,
            'record_interval': 8,
            'test_datanum': 32768,
        },
        seed=0,
        device_name=device_name,
    )
