import torch
import numpy as np
import os,sys
sys.path.append('src/model/DIN_dir')
sys.path.append('src/model/ExpConfig_dir')
sys.path.append('src/script')

from load_data import DataLoad
from ExpConfig_f import ExpConfig
import numpy as np
from DIN_trainer import DINTrain
from pandas import DataFrame
import numpy as np
from measure_f import *




def din_train_func(para_dict, dls, ec):

    #parametres 
    data_set_name = para_dict["datan"]
    device = para_dict["device"]
    filiting_list_path = para_dict["filiting_list_path"]
    topk=20
    optimizer=lambda params: torch.optim.Adam(params, lr=1e-3, amsgrad=True)
 
    emb_dim = 96
    sum_pooling=False
    sample_negative_num=60
    feature_groups=[20,20,10,10,2,2,2,1,1,1]
    train_sample_seg_cnt=10#the training data is located in the train_sample_seg_cnt datafiles
    parall=10

    test_batch_size=100

    # batch_number = 10000 #
    batch_number = 20000 #
    if device!='cpu':
        torch.cuda.set_device(device)
        device='cuda'

    [user_num,item_num]=np.loadtxt(dls.item_num_node_num_file,dtype=np.int32,delimiter=',')
    print('user num is {}, item is {}'.format(user_num,item_num))



    train_model = DINTrain(item_num=item_num,
                        sample_negative_num=sample_negative_num,
                        emb_dim=emb_dim,
                        device=device,
                        sum_pooling=sum_pooling,
                        feature_groups=feature_groups,
                        optimizer=optimizer)
    print(train_model.DINModel)


    from generate_training_batches import Train_instance
    train_instances=Train_instance(parall=parall)
    #training_batch_generator=train_instances.training_batches(dls.train_instances_file,train_sample_seg_cnt,item_num,batchsize=training_batch_size)
    training_data,training_labels = train_instances.get_training_data(dls.train_instances_file, train_sample_seg_cnt, item_num)
    
    
    # filiting_list = np.load("Exp/mind/DeFoRec/f1_3/save/training_filiting.npy")
    if filiting_list_path == "-1":
        pass
    else:
        filiting_list = np.load(filiting_list_path)    
        training_data = training_data[filiting_list]
        training_labels = training_labels[filiting_list]

    #test_batch_generator=train_instances.test_batches(dls.test_instances_file,item_num,batchsize=test_batch_size)
    validation_batch_generator=train_instances.validation_batches(dls.validation_instances_file,item_num,batchsize=test_batch_size)
    test_instances=train_instances.read_test_instances_file(dls.test_instances_file,item_num)





    moving_average = lambda x, **kw: DataFrame({'x':np.asarray(x)}).x.ewm(**kw).mean().values
    loss_history,dev_precision_history,dev_recall_history,dev_f_measure_history,dev_novelty_history,policy_acc=[],[],[],[],[],[]
    total_precision_history,total_recall_history,total_f_measure_history,total_novelty_history=[],[],[],[]



    #train_model.DINModel.train()
    validation_batch_generator = train_instances.validation_batches(dls.validation_instances_file,item_num,batchsize=test_batch_size)
    for (batch_x,batch_y) in train_instances.generate_training_records(training_data, training_labels, batch_size=256):
        #print(batch_x,batch_y)
        loss=train_model.update_DIN(batch_x,batch_y)
        loss_history.append(loss.item())

        if train_model.batch_num%10==0:
            loss_value = loss.item()
            ec.print(f"{train_model.batch_num = }, {loss_value = }")
        
        if train_model.batch_num%500==0:
            # ###start to test
            train_model.DINModel.eval()
            test_batch,test_index = validation_batch_generator.__next__()
            gt_history = [train_instances.validation_labels[i.item()] for i in test_index]

            all_items = torch.arange(item_num,device=device).view(-1,1)
            preference_matrix=torch.full((len(test_batch),item_num),-1.0e9,dtype=torch.float32)
            batch_size=2000
            f_num=test_batch.shape[1]
            #print(item_num,test_batch.shape)
            for i,user in enumerate(test_batch):
                start_id=0
                while start_id<item_num:
                    part_labels=all_items[start_id:start_id+batch_size,:]
                    #print(len(part_labels),)
                    with torch.no_grad():
                        preference_matrix[i,start_id:start_id+batch_size]=train_model.calculate_preference(\
                            user.to(device).expand(len(part_labels),f_num),part_labels).view(1,-1).cpu()
                    start_id=start_id+batch_size
            resutl_history=preference_matrix.argsort(dim=-1)[:,-topk:].numpy()
            total_precision_history.append(presision(resutl_history,gt_history,topk))
            total_recall_history.append(recall(resutl_history,gt_history))
            total_f_measure_history.append(f_measure(resutl_history,gt_history,topk))
            total_novelty_history.append(novelty(resutl_history,test_batch.tolist(),topk))

            ec.print(f"{train_model.batch_num = }")
            ec.print(f"{total_precision_history = }")
            ec.print(f"{total_recall_history = }")
            ec.print(f"{total_f_measure_history = }")
            ec.print(f"{total_novelty_history = }")

            if total_precision_history[-1] == max(total_precision_history):
                ec.print(f"{train_model.batch_num = }, save DIN model")
                train_model.save_DIN(ec.save_path + '/DINModel_state_dict.pth')

            train_model.DINModel.train()
            # #######


                
        
        if train_model.batch_num > batch_number:
            break
