import torch
import numpy as np
import os,sys
sys.path.append('src/model/DeFoRec_dir')
sys.path.append('src/model/ExpConfig_dir')


from ExpConfig_f import ExpConfig

import torch
import numpy as np
import lib
from measure_defo_f import *
import os
from pandas import DataFrame



def filiting_func(cuda, N, topk, ec, tree_model_path_dict, bs=200):
    """
    save retrieve model filiting list

    para:
        N: beam size
        topk: <= N, Number of items finally selected
        t_model_path: path of tree model
        bs: batch size each predict
    """

    t_model_path = tree_model_path_dict["t_model_path"]
    kv_file = tree_model_path_dict["kv_file"]

    #parametres 
    sampling_method='uniform_multiclass'#top_down,softmax,all_negative_sampling,uniform_multiclass
    weight_decay=1e-3
    optimizer=lambda params:torch.optim.Adam(params,lr=1.0e-3,amsgrad=True,weight_decay=weight_decay)
    print(ec.para_dict["datan"])
    data_set_name = ec.para_dict["datan"]
    device=f'cuda:{cuda}'# default device is cpu if device_ids=[] or None
    tree_num=12 #12
    repeat_time=1
    runtime=1# total 10 times,each runtime correspondint to one data partition
    has_processed_data=True
    # topk=20
    # N=60#if negative_num is None, compute the negative_num by N in trainer
    train_sample_seg_cnt=10#the training data is located in the train_sample_seg_cnt datafiles
    parall=4
    seq_len=70 # se_len-1 is the number of behaviours in all the windows
    min_seq_len=15
    tree_learner_mode='jtm'
    gamma=0.0


    item_node_share_embedding=True

    data_file_prefix='data/{}/processed_dataset/'.format(data_set_name,runtime)

    if not has_processed_data:
        if os.path.exists(data_file_prefix):
            pass
        else:
            os.makedirs(data_file_prefix)
    train_instances_file=data_file_prefix+'train_instances'
    test_instances_file=data_file_prefix+'test_instances'
    validation_instances_file=data_file_prefix+'validation_instances'
    


    featrue_groups=[20,20,10,10,2,2,2,1,1,1]
    assert sum(featrue_groups)==seq_len-1

    embed_dim=24


    #sample_num=100
    training_batch_size=100 #500
    validation_batch_size=50


    if device!='cpu':
        torch.cuda.set_device(device)#the main gpu is device_ids[0]
        device='cuda'


    ids=[]
    codes=[]
    assert kv_file is not None
    with open(kv_file) as f:
        while True:
            line=f.readline()
            if line:
                id_code=line.split('::')
                ids.append(int(id_code[0]))
                codes.append(int(id_code[1]))
            else:
                break
    ids=np.array(ids,dtype=np.int32)
    codes=np.array(codes,dtype=np.int32)

    ec.print('min item id is {}, max item id is {}'.format(ids.min(),ids.max()))
    ec.print('min leaf node code is {}, max leaf node code is {}'.format(codes.min(), codes.max()))

    ids_list,codes_list=[],[]
    for _ in range(tree_num):
        ids_list.append(ids)
        codes_list.append(codes)
    item_num=len(ids_list[0])
    ec.print('item number is {}'.format(item_num))





    moving_average = lambda x, **kw: DataFrame({'x':np.asarray(x)}).x.ewm(**kw).mean().values

    from lib.generate_training_batches import Train_instance
    train_instances = Train_instance(parall=parall)
    training_batch_generator = train_instances.training_batches(train_instances_file,train_sample_seg_cnt,batchsize=training_batch_size)
    validation_batch_generator = train_instances.validation_batches(validation_instances_file,batchsize=validation_batch_size)
    test_instances = train_instances.read_test_instances_file(test_instances_file)
    training_instance_index_pair=train_instances.get_item_instance_pair_index(train_instances_file,train_sample_seg_cnt)#


    from lib.trainer import TrainModel
    train_model=TrainModel(ids,codes,
                        embed_dim = embed_dim,
                        feature_groups=featrue_groups,
                        all_training_instance=train_instances.training_data,
                        item_user_pair_dict=training_instance_index_pair,
                        parall=parall, 
                        optimizer=optimizer,
                        N=N,
                        sampling_method=sampling_method,
                        tree_learner_mode=tree_learner_mode,
                        item_node_share_embedding=item_node_share_embedding,
                        device=device,
                        gamma=gamma
                        )
    ec.print(train_model.network_model)

    train_model.network_model.load_state_dict(torch.load(t_model_path, map_location=torch.device(device)).state_dict())

    train_model.network_model.eval()
    
    bs_count = (len(train_instances.training_data)-1)//bs+1
    all_result = np.zeros((len(train_instances.training_data),topk),dtype=np.int32)

    i=0
    for i in range(bs_count):
        if i%100 == 0:
            ec.print(f"{i=}")
        
        try:
            if i in [1000, 2000, 10000]:

                ar = [set(sublist) for sublist in all_result[:i*bs]] 
                tl = train_instances.training_labels[:i*bs].numpy()
                ss = [i in s for i, s in zip(tl, ar)]

                ec.save_data2(f"training_filiting_{i}", ss)
        except:
            ec.print("some issue")
        
        bs_user = train_instances.training_data[i*bs:(i+1)*bs]
        all_result[i*bs:(i+1)*bs] = train_model.predict(bs_user, N, topk,forest=False)
    
    tl = train_instances.training_labels.numpy()

    ss = [i in set(s) for i, s in zip(tl, all_result)]
    ec.save_data2("training_filiting", ss)