import numpy as np
import torch

from model import *
import argparse
import os
import shutil
from data import *
from train import *
from train_scaling_law import *
from train_DNN_averaged import *
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader 


def main(args, **kwargs):
    # 设置随机种子
    setup_seed(args.seed)

    for file in ['pic', 'loss', 'src', 'data', 'model']:
        os.makedirs(f'{args.working_dir}/{file}', exist_ok=True)

    if args.train_method == 'train_scaling_law':
        train_scaling_law(args, **kwargs)
    elif args.train_method == 'DNN_averaged':
        datas = get_data(args, **kwargs)
        train_DNN_averaged(args, datas, **kwargs)
    else:
        datas = get_data(args, **kwargs)

        # print(datas['13_xm0'][:100])

        # quit()
 
        print('prepare data done!')
        train(args, datas, **kwargs)



if __name__ == '__main__':

    parser = argparse.ArgumentParser(description="Pytorch distributed")

    # 数据集参数
    parser.add_argument('-data_size', '--data_size', type = int, default = 1000) 
    parser.add_argument('-sl', '--seq_len', type = int, default = 9, help='句子长度')
    parser.add_argument('-dmin', '--data_min', type = int, default = 20, help='数据集中数据的最小值')
    parser.add_argument('-dmax', '--data_max', type = int, default = 100, help='数据集中数据的最大值')
    parser.add_argument('-bs', '--batch_size', type = int, default = 10) 
    parser.add_argument('-seed', '--seed', type = int, default = 1)  

    parser.add_argument('-dmode', '--data_mode', nargs='*', type=str, default = [1], help='各类数据集的模式，不同任务中的数据集模式不同')
    parser.add_argument('-dp', '--data_percent', nargs='*', type=float, default = [1], help='各类数据集占比')
    parser.add_argument('-dn', '--data_name', nargs='*', type=str, default = ['full data'], help='各类数据集名称')
    parser.add_argument('-dtrain', '--data_train', nargs='*', type=int, default = [0], help='该类是否参与训练')
    parser.add_argument('-dshow', '--data_show', nargs='*', type=int, default = [0], help='画图时是否显示该类数据集，1表示显示，0表示不显示')
    parser.add_argument('-rdm', '--random_data_num', type = int, default = 1, help='随机的复合函数的种类')
    # 目标函数
    parser.add_argument('-func', '--target', type = str, default = '3x_to_x', help='任务')

    # 网络结构与超参数
    parser.add_argument('-m', '--model', type = str, default = 'GPT', help='模型') 
    parser.add_argument('-vs', '--vocab_size', type = int, default = 201) 
    parser.add_argument('-mp', '--max_pos', type = int, default = 20)
    parser.add_argument('-dm', '--d_model', type = int, default = 400)
    parser.add_argument('-d_ff', '--d_feedforward', type = int, default = 1200)
    parser.add_argument('-dk', '--d_k', type = int, default = 64)
    parser.add_argument('-dv', '--d_v', type = int, default = 64)
    parser.add_argument('-nl', '--n_layers', type = int, default = 4)
    parser.add_argument('-nh', '--n_heads', type = int, default = 4)
    parser.add_argument('-cl', '--clip', type = int, default = 1, help='梯度裁剪')
    
    # 训练超参数
    parser.add_argument('-ne', '--n_epoch', type = int, default = 3000) 
    parser.add_argument('-lr', '--lr', type = float, default = 1.e-4, help='初始学习率') 
    parser.add_argument('-op', '--optim', choices = ['Adam', 'SGD', 'AdamW'], default = 'AdamW', help='优化器')  
    parser.add_argument('-scheduler', '--scheduler', type = str, choices = ['StepLR', 'GradualWarmupScheduler_CosineAnnealingLR'], default = 'StepLR', help='调度器')
    parser.add_argument('-eps', '--eps', type = float, default = 1.e-8, help='adam epsilon') 
    parser.add_argument('-wd', '--weight_decay', type = float, default = 1.e-2, help='adam weight decay') 
    parser.add_argument('-beta1', '--beta1', type = float, default = 0.9, help='adam beta1')
    parser.add_argument('-beta2', '--beta2', type = float, default = 0.999, help='adam beta2') 



    parser.add_argument('-lds', '--lr_decay_step', type = int, default = 1000, help='使用StepLR调度器时，每隔多少epoch学习率衰减') 
    parser.add_argument('-ldr', '--lr_decay_rate', type = float, default = 1, help='使用StepLR调度器时，学习率变为原来的多少倍') 
    
    parser.add_argument('-optim_total_epoch', '--optim_total_epoch', type = int, default = 400, help='使用GradualWarmupScheduler时的预热的周期数')
    parser.add_argument('-optim_multiplier', '--optim_multiplier', type = float, default = 5, help='使用GradualWarmupScheduler时的最大学习率与初始学习率的比值')
    parser.add_argument('-optim_T_max', '--optim_T_max', type = int, default = 4000, help='使用CosineAnnealingLR时的周期长度，即从当前学习率下降到最小学习率所需的epoch，若继续训练则会按照cosine继续上升到最大学习率，然后再下降')
    parser.add_argument('-optim_eta_min', '--optim_eta_min', type = float, default = 1e-5, help='使用CosineAnnealingLR下降到的最小学习率')
    

    # 保存、输出信息和画图的间隔
    parser.add_argument('-sme', '--save_model_epoch', type = int, default = 100, help='每隔多少epoch保存一次模型') 
    parser.add_argument('-ple', '--print_loss_epoch', type = int, default = 10, help='每隔多少epoch输出一次loss')
    parser.add_argument('-pae', '--print_acc_epoch', type = int, default = 100, help='每隔多少epoch输出一次acc')
    parser.add_argument('-plae', '--plot_loss_acc_epoch', type = int, default = 500, help='每隔多少epoch画一次loss和acc')
    
    # 前缀与后缀
    parser.add_argument('-prefix', '--prefix', type = str, default = ' ', help='文件夹前缀')
    parser.add_argument('-suffix', '--suffix', type = str, default = ' ', help='文件夹后缀')
    parser.add_argument('-pname', '--proj_name', type = str, default = ' ', help='项目名称')

    # 大文件夹的后缀
    parser.add_argument('-dir_suffix', '--dir_suffix', type = str, default = ' ', help='上级文件夹的后缀')

    # scaling law
    parser.add_argument('-tm', '--train_method', type = str, default = ' ', help='训练方式，写train_scaling_law则调用train_scaling_law.py进行训练')
    parser.add_argument('-n_batch', '--n_batch', type = int, default = 10000, help='仅在train_scaling_law中使用，表示训练多少个batch') 
    parser.add_argument('-gdm', '--gen_data_mode', type = str, default = 'fix', help='仅在train_scaling_law中使用，表示生成数据的模式，可选on_the_fly或fix')

    #condense
    parser.add_argument('-sr', '--std_rate', type = float, default = 1, help='标准差的幂次') 


    # # gpu
    # parser.add_argument('-gpu', '--gpu', type = int, default = 0, help='使用的gpu编号')

    # 解析已知的参数和未知的参数
    args, remaining = parser.parse_known_args()

    # 将未知的参数转化为字典
    remaining_dict = {}
    for i in range(0, len(remaining), 2):
        key = remaining[i].lstrip('-')
        value = remaining[i+1]
        remaining_dict[key] = value

    # 生成主文件夹目录
    working_dir = f'{args.target}-N_{int(args.data_size)}'
    
    if args.prefix != ' ':
        working_dir = f'{args.prefix}-{working_dir}'
    if args.suffix != ' ':
        working_dir = f'{working_dir}-{args.suffix}'
    
    if args.dir_suffix != ' ':
        args.working_dir = f'./data/LLM/LLM_new/{args.proj_name}/{args.model}_{args.dir_suffix}/{working_dir}'
    else:
        args.working_dir = f'./data/LLM/LLM_new/{args.proj_name}/{args.model}/{working_dir}'

    print(args.working_dir)



    main(args, **remaining_dict)