import pandas as pd
import numpy as np
import torch
from torch_geometric.data import Dataset, InMemoryDataset, download_url
from utils import construct_data_from_graph_gvp, construct_data_from_graph_gvp_mean
import lmdb
import pickle

class TankBindDataSet(Dataset):
    def __init__(self, root, data=None, protein_dict=None, compound_dict=None, proteinMode=0, compoundMode=1,
                add_noise_to_com=None, pocket_radius=20, contactCutoff=8.0, predDis=True, shake_nodes=None,
                use_whole_protein=False,
                transform=None, pre_transform=None, pre_filter=None):
        self.data = data
        self.protein_dict = protein_dict
        self.compound_dict = compound_dict
        # this will call the process function to save the data, protein_dict and compound_dict
        super().__init__(root, transform, pre_transform, pre_filter)
        print(self.processed_paths)
        self.data = torch.load(self.processed_paths[0])
        self.protein_dict = torch.load(self.processed_paths[1])
        self.compound_dict = torch.load(self.processed_paths[2])
        self.add_noise_to_com = add_noise_to_com
        self.proteinMode = proteinMode
        self.compoundMode = compoundMode
        self.pocket_radius = pocket_radius
        self.contactCutoff = contactCutoff
        self.predDis = predDis
        self.shake_nodes = shake_nodes
        self.use_whole_protein = use_whole_protein
    @property
    def processed_file_names(self):
        return ['data.pt', 'protein.pt', 'compound.pt']

    def process(self):
        torch.save(self.data, self.processed_paths[0])
        torch.save(self.protein_dict, self.processed_paths[1])
        torch.save(self.compound_dict, self.processed_paths[2])

    def len(self):
        return len(self.data)

    def get(self, idx):
        line = self.data.iloc[idx]
        # uid = line['uid']
        # smiles = line['smiles']
        pocket_com = line['pocket_com']
        use_compound_com = line['use_compound_com']
        # use_whole_protein = line['use_whole_protein'] if "use_whole_protein" in line.index else False
        use_whole_protein = line['use_whole_protein'] if "use_whole_protein" in line.index else self.use_whole_protein
        group = line['group'] if "group" in line.index else 'train'
        add_noise_to_com = self.add_noise_to_com if group == 'train' else None # 5.0

        protein_name = line['protein_name'] # pdb id
        if self.proteinMode == 0:
            # protein embedding follow GVP protocol.
            protein_node_xyz, protein_seq, protein_node_s, protein_node_v, protein_edge_index, protein_edge_s, protein_edge_v = self.protein_dict[protein_name]

        name = line['compound_name']
        # compound embedding from torchdrug
        coords, compound_node_features, input_atom_edge_list, input_atom_edge_attr_list, pair_dis_distribution = self.compound_dict[name]

        # node_xyz could add noise too.
        shake_nodes = self.shake_nodes if group == 'train' else None
        if shake_nodes is not None:
            protein_node_xyz = protein_node_xyz + shake_nodes * (2 * np.random.rand(*protein_node_xyz.shape) - 1)
            coords = coords  + shake_nodes * (2 * np.random.rand(*coords.shape) - 1)

        if self.proteinMode == 0:
            data, input_node_list, keepNode = construct_data_from_graph_gvp(protein_node_xyz, protein_seq, protein_node_s, 
                                  protein_node_v, protein_edge_index, protein_edge_s, protein_edge_v,
                                  coords, compound_node_features, input_atom_edge_list, input_atom_edge_attr_list, contactCutoff=self.contactCutoff, includeDisMap=self.predDis,
                                pocket_radius=self.pocket_radius, add_noise_to_com=add_noise_to_com, use_whole_protein=use_whole_protein, 
                                use_compound_com_as_pocket=use_compound_com, chosen_pocket_com=pocket_com, compoundMode=self.compoundMode)

        # affinity = affinity_to_native_pocket * min(1, float((data.y.numpy() > 0).sum()/(5*coords.shape[0])))
        affinity = float(line['affinity'])
        data.affinity = torch.tensor([affinity], dtype=torch.float)
        # N x N x 16 -> N^2 x 16
        data.compound_pair = pair_dis_distribution.reshape(-1, 16)
        data.pdb = line['pdb'] if "pdb" in line.index else f'smiles_{idx}'
        data.group = group

        data.real_affinity_mask = torch.tensor([use_compound_com], dtype=torch.bool)
        data.real_y_mask = torch.ones(data.y.shape).bool() if use_compound_com else torch.zeros(data.y.shape).bool()
        # fract_of_native_contact = float(line['fract_of_native_contact']) if "fract_of_native_contact" in line.index else 1
        # equivalent native pocket. Use fraction of native contact to decide whether the currect pocket is native pocket
        # This condition is always be true
        if "native_num_contact" in line.index:
            fract_of_native_contact = (data.y.numpy() > 0).sum() / float(line['native_num_contact'])
            is_equivalent_native_pocket = fract_of_native_contact >= 0.9
            # [True] of [False]
            data.is_equivalent_native_pocket = torch.tensor([is_equivalent_native_pocket], dtype=torch.bool)
            # all True of all False
            data.equivalent_native_y_mask = torch.ones(data.y.shape).bool() if is_equivalent_native_pocket else torch.zeros(data.y.shape).bool()
        else:
            # native_num_contact information is not available.
            # use ligand com to determine if this pocket is equivalent to native pocket.
            # This condition can never be true
            if "ligand_com" in line.index:
                ligand_com = line["ligand_com"]
                pocket_com = data.node_xyz.numpy().mean(axis=0) # data.node_xyz is the pocket node coordinates
                dis = np.sqrt(((ligand_com - pocket_com)**2).sum())
                # is equivalent native pocket if ligand com is less than 8 A from pocket com.
                is_equivalent_native_pocket = dis < 8
                data.is_equivalent_native_pocket = torch.tensor([is_equivalent_native_pocket], dtype=torch.bool)
                data.equivalent_native_y_mask = torch.ones(data.y.shape).bool() if is_equivalent_native_pocket else torch.zeros(data.y.shape).bool()
            else:
                # data.is_equivalent_native_pocket and data.equivalent_native_y_mask will not be available.
                pass
        return data

class TankBindMeanDataSet(Dataset):
    def __init__(self, root, data=None, protein_dict=None, compound_dict=None, proteinMode=0, compoundMode=1,
                add_noise_to_com=None, pocket_radius=20, contactCutoff=8.0, predDis=True, args=None,
                use_whole_protein=False, compound_coords_init_mode=None, seed=42, pre=None,
                transform=None, pre_transform=None, pre_filter=None, noise_for_predicted_pocket=5.0, test_random_rotation=False, pocket_idx_no_noise=True, use_esm2_feat=False):
        self.data = data
        self.protein_dict = protein_dict
        self.compound_dict = compound_dict
        # this will call the process function to save the data, protein_dict and compound_dict
        super().__init__(root, transform, pre_transform, pre_filter)
        print(self.processed_paths)
        self.data = torch.load(self.processed_paths[0])
        # self.protein_dict = torch.load(self.processed_paths[1])
        # self.compound_dict = torch.load(self.processed_paths[2])
        self.compound_rdkit_coords = torch.load(self.processed_paths[3])
        # if use_esm2_feat:
        #     self.protein_esm2_feat = torch.load(self.processed_paths[4], map_location=torch.device('cpu'))
        self.protein_dict = lmdb.open(self.processed_paths[1], readonly=True, max_readers=1, lock=False, readahead=False, meminit=False)
        self.compound_dict = lmdb.open(self.processed_paths[2], readonly=True, max_readers=1, lock=False, readahead=False, meminit=False)
        if use_esm2_feat:
            self.protein_esm2_feat = lmdb.open(self.processed_paths[4], readonly=True, max_readers=1, lock=False, readahead=False, meminit=False)
        self.compound_coords_init_mode = compound_coords_init_mode
        self.add_noise_to_com = add_noise_to_com
        self.noise_for_predicted_pocket = noise_for_predicted_pocket
        self.proteinMode = proteinMode
        self.compoundMode = compoundMode
        self.pocket_radius = pocket_radius
        self.contactCutoff = contactCutoff
        self.predDis = predDis
        self.use_whole_protein = use_whole_protein
        self.test_random_rotation = test_random_rotation
        self.pocket_idx_no_noise = pocket_idx_no_noise
        self.use_esm2_feat = use_esm2_feat
        self.seed = seed
        self.args = args
        self.pre = pre

    # @property
    # def processed_file_names(self):
    #     return ['data.pt', 'protein.pt', 'compound_LAS_edge_index.pt', 'compound_rdkit_coords.pt', 'esm2_t33_650M_UR50D.pt']
    
    @property
    def processed_file_names(self):
        return ['data.pt', 'protein.lmdb', 'compound_LAS_edge_index.lmdb', 'compound_rdkit_coords.pt', 'esm2_t33_650M_UR50D.lmdb']

    def len(self):
        return len(self.data)

    def get(self, idx):
        line = self.data.iloc[idx]
        # uid = line['uid']
        pocket_com = line['pocket_com']
        use_compound_com = line['use_compound_com']
        use_whole_protein = line['use_whole_protein'] if "use_whole_protein" in line.index else self.use_whole_protein
        group = line['group'] if "group" in line.index else 'train'
        if group == 'train' and use_compound_com:
            add_noise_to_com = self.add_noise_to_com
        elif group == 'train' and not use_compound_com:
            add_noise_to_com = self.noise_for_predicted_pocket
        else:
            add_noise_to_com = None

        if group == 'train':
            random_rotation = True
        elif group == 'test' and self.test_random_rotation:
            random_rotation = True
        else:
            random_rotation = False

        protein_name = line['protein_name'] # pdb id
        if self.proteinMode == 0:
            # protein embedding follow GVP protocol.
            # protein_node_xyz, protein_seq, protein_node_s, protein_node_v, protein_edge_index, protein_edge_s, protein_edge_v = self.protein_dict[protein_name]
            with self.protein_dict.begin() as txn:
                protein_node_xyz, protein_seq, protein_node_s, protein_node_v, protein_edge_index, protein_edge_s, protein_edge_v = pickle.loads(txn.get(protein_name.encode()))
            if self.use_esm2_feat:
                # protein_esm2_feat = self.protein_esm2_feat[protein_name]
                with self.protein_esm2_feat.begin() as txn:
                    protein_esm2_feat = pickle.loads(txn.get(protein_name.encode()))
            else:
                protein_esm2_feat = None

        name = line['compound_name']
        rdkit_coords = self.compound_rdkit_coords[name]
        # compound embedding from torchdrug
        # coords, compound_node_features, input_atom_edge_list, input_atom_edge_attr_list, pair_dis_distribution, LAS_edge_index = self.compound_dict[name]
        with self.compound_dict.begin() as txn:
            coords, compound_node_features, input_atom_edge_list, input_atom_edge_attr_list, pair_dis_distribution, LAS_edge_index = pickle.loads(txn.get(name.encode()))

        if self.proteinMode == 0:
            data, input_node_list, keepNode = construct_data_from_graph_gvp_mean(self.args, protein_node_xyz, protein_seq, protein_node_s, 
                                  protein_node_v, protein_edge_index, protein_edge_s, protein_edge_v,
                                  coords, compound_node_features, input_atom_edge_list, input_atom_edge_attr_list, LAS_edge_index, rdkit_coords, compound_coords_init_mode=self.compound_coords_init_mode, contactCutoff=self.contactCutoff, includeDisMap=self.predDis,
                                pocket_radius=self.pocket_radius, add_noise_to_com=add_noise_to_com, use_whole_protein=use_whole_protein, pdb_id=name, group=group, seed=self.seed, data_path=self.pre, 
                                use_compound_com_as_pocket=use_compound_com, chosen_pocket_com=pocket_com, compoundMode=self.compoundMode, random_rotation=random_rotation, pocket_idx_no_noise=self.pocket_idx_no_noise,
                                protein_esm2_feat=protein_esm2_feat)


        data.pdb = line['pdb'] if "pdb" in line.index else f'smiles_{idx}'
        data.group = group

        # data.real_affinity_mask = torch.tensor([use_compound_com], dtype=torch.bool)
        # Consistent with real_affinity_mask
        # data.real_y_mask = torch.ones(data.y.shape).bool() if use_compound_com else torch.zeros(data.y.shape).bool()

        return data


def get_data(data_mode, logging, addNoise=None, use_whole_protein=False, pre="/PDBbind_data/pdbbind2020"):
    # pre = 
    if data_mode == "0":
        logging.info(f"re-docking, using dataset: pdbbind2020 pred distance map.")
        logging.info(f"compound feature based on torchdrug")
        add_noise_to_com = float(addNoise) if addNoise else None

        # proteinMode = 0
        # compoundMode = 1 is for GIN model.
        new_dataset = TankBindDataSet(f"{pre}/dataset", add_noise_to_com=add_noise_to_com, use_whole_protein=use_whole_protein)
        # load compound features extracted using torchdrug.
        # new_dataset.compound_dict = torch.load(f"{pre}/compound_dict.pt")
        # c_length: number of atoms in the compound
        # native_num_contact?
        # This filter may cause some samples to be filtered out. So the actual number of samples is less than that in the original papers.
        new_dataset.data = new_dataset.data.query("c_length < 100 and native_num_contact > 5").reset_index(drop=True)
        d = new_dataset.data
        only_native_train_index = d.query("use_compound_com and group =='train'").index.values
        train = new_dataset[only_native_train_index]
        train_index = d.query("group =='train'").index.values
        train_after_warm_up = new_dataset[train_index]
        # train = torch.utils.data.ConcatDataset([train1, train2])
        valid_index = d.query("use_compound_com and group =='valid'").index.values
        valid = new_dataset[valid_index]
        test_index = d.query("use_compound_com and group =='test'").index.values
        test = new_dataset[test_index]

        all_pocket_test_fileName = f"{pre}/test_dataset"
        all_pocket_test = TankBindDataSet(all_pocket_test_fileName)
        # all_pocket_test.compound_dict = torch.load(f"{pre}/compound_dict.pt")
        all_pocket_test.compound_dict = torch.load(f"../predictions/pdbbind_test_compound_dict_based_on_rdkit.pt")
        # info is used to evaluate the test set. 
        info = None
        # info = pd.read_csv(f"{pre}/apr23_testset_pdbbind_gvp_pocket_radius20_info.csv", index_col=0)

    if data_mode == "1":
        logging.info(f"self-docking, same as data mode 0 except using LAS_distance constraint masked compound pair distance")
        add_noise_to_com = float(addNoise) if addNoise else None

        # compoundMode = 1 is for GIN model.
        new_dataset = TankBindDataSet(f"{pre}/dataset", add_noise_to_com=add_noise_to_com)
        # load GIN embedding for compounds.
        # new_dataset.compound_dict = torch.load(f"{pre}/pdbbind_compound_dict_with_LAS_distance_constraint_mask.pt")
        new_dataset.compound_dict = torch.load(f"{pre}/tankbind_data/compound_torchdrug_features.pt")
        new_dataset.data = new_dataset.data.query("c_length < 100 and native_num_contact > 5").reset_index(drop=True)
        d = new_dataset.data
        only_native_train_index = d.query("use_compound_com and group =='train'").index.values
        train = new_dataset[only_native_train_index]
        # train = train1
        train_index = d.query("group =='train'").index.values
        train_after_warm_up = new_dataset[train_index]

        # train = torch.utils.data.ConcatDataset([train1, train2])
        valid_index = d.query("use_compound_com and group =='valid'").index.values
        valid = new_dataset[valid_index]
        test_index = d.query("use_compound_com and group =='test'").index.values
        test = new_dataset[test_index]

        all_pocket_test_fileName = f"{pre}/test_dataset"
        all_pocket_test = TankBindDataSet(all_pocket_test_fileName)
        # all_pocket_test.compound_dict = torch.load(f"{pre}/pdbbind_test_compound_dict_based_on_rdkit.pt")
        all_pocket_test.compound_dict = torch.load(f"../predictions/pdbbind_test_compound_dict_based_on_rdkit.pt")
        # info is used to evaluate the test set.
        info = None
        # info = pd.read_csv(f"{pre}/apr23_testset_pdbbind_gvp_pocket_radius20_info.csv", index_col=0)


    return train, train_after_warm_up, valid, test, all_pocket_test, info

def get_data_v2(data_mode, logging, addNoise=None, use_whole_protein=False, pre="/PDBbind_data/pdbbind2020"):
    # pre = 
    if data_mode == "0":
        logging.info(f"re-docking, using dataset: pdbbind2020 pred distance map.")
        logging.info(f"compound feature based on torchdrug")
        add_noise_to_com = float(addNoise) if addNoise else None

        # proteinMode = 0
        # compoundMode = 1 is for GIN model.
        new_dataset = TankBindDataSet(f"{pre}/dataset", add_noise_to_com=add_noise_to_com, use_whole_protein=use_whole_protein)
        # load compound features extracted using torchdrug.
        # new_dataset.compound_dict = torch.load(f"{pre}/compound_dict.pt")
        # c_length: number of atoms in the compound
        # native_num_contact?
        # This filter may cause some samples to be filtered out. So the actual number of samples is less than that in the original papers.
        new_dataset.data = new_dataset.data.query("c_length < 100 and native_num_contact > 5").reset_index(drop=True)
        new_dataset.data['for_update'] = False
        tmp_data = new_dataset.data.query("use_compound_com")
        tmp_data.use_compound_com = False
        new_dataset.data = pd.concat([new_dataset.data, tmp_data], axis=0).reset_index(drop=True)
        d = new_dataset.data
        only_native_train_index = d.query("use_compound_com and group =='train'").index.values
        train = new_dataset[only_native_train_index]
        train_update_index = d.query("not use_compound_com and not pocket_com == pocket_com and group =='train'").index.values
        assert len(train_update_index) == len(only_native_train_index)

        new_dataset.data.use_compound_com.iloc[train_update_index] = True
        new_dataset.data.for_update.iloc[train_update_index] = True

        train_update_index = np.concatenate((only_native_train_index, train_update_index))
        train_update = new_dataset[train_update_index]
        # train_index = d.query("group =='train'").index.values
        # train_after_warm_up = new_dataset[train_index]
        # train = torch.utils.data.ConcatDataset([train1, train2])
        valid_index = d.query("use_compound_com and group =='valid'").index.values
        valid = new_dataset[valid_index]
        valid_update_index = d.query("not use_compound_com and not pocket_com == pocket_com and group =='valid'").index.values
        assert len(valid_update_index) == len(valid_index)
        valid_update = new_dataset[valid_update_index]
        test_index = d.query("use_compound_com and group =='test'").index.values
        test = new_dataset[test_index]
        test_update_index = d.query("not use_compound_com and not pocket_com == pocket_com and group =='test'").index.values
        assert len(test_update_index) == len(test_index)
        test_update = new_dataset[test_update_index]

        new_dataset.data.use_compound_com.iloc[valid_update_index] = True
        new_dataset.data.use_compound_com.iloc[test_update_index] = True

        new_dataset.data.for_update.iloc[valid_update_index] = True
        new_dataset.data.for_update.iloc[test_update_index] = True

        all_pocket_test_fileName = f"{pre}/test_dataset"
        all_pocket_test = TankBindDataSet(all_pocket_test_fileName)
        # all_pocket_test.compound_dict = torch.load(f"{pre}/compound_dict.pt")
        all_pocket_test.compound_dict = torch.load(f"../predictions/pdbbind_test_compound_dict_based_on_rdkit.pt")
        # info is used to evaluate the test set. 
        info = None
        # info = pd.read_csv(f"{pre}/apr23_testset_pdbbind_gvp_pocket_radius20_info.csv", index_col=0)

    if data_mode == "1":
        logging.info(f"self-docking, same as data mode 0 except using LAS_distance constraint masked compound pair distance")
        add_noise_to_com = float(addNoise) if addNoise else None

        # compoundMode = 1 is for GIN model.
        new_dataset = TankBindDataSet(f"{pre}/dataset", add_noise_to_com=add_noise_to_com)
        # load GIN embedding for compounds.
        # new_dataset.compound_dict = torch.load(f"{pre}/pdbbind_compound_dict_with_LAS_distance_constraint_mask.pt")
        new_dataset.compound_dict = torch.load(f"{pre}/tankbind_data/compound_torchdrug_features.pt")
        new_dataset.data = new_dataset.data.query("c_length < 100 and native_num_contact > 5").reset_index(drop=True)
        d = new_dataset.data
        only_native_train_index = d.query("use_compound_com and group =='train'").index.values
        train = new_dataset[only_native_train_index]
        # train = train1
        train_index = d.query("group =='train'").index.values
        train_after_warm_up = new_dataset[train_index]

        # train = torch.utils.data.ConcatDataset([train1, train2])
        valid_index = d.query("use_compound_com and group =='valid'").index.values
        valid = new_dataset[valid_index]
        test_index = d.query("use_compound_com and group =='test'").index.values
        test = new_dataset[test_index]

        all_pocket_test_fileName = f"{pre}/test_dataset"
        all_pocket_test = TankBindDataSet(all_pocket_test_fileName)
        # all_pocket_test.compound_dict = torch.load(f"{pre}/pdbbind_test_compound_dict_based_on_rdkit.pt")
        all_pocket_test.compound_dict = torch.load(f"../predictions/pdbbind_test_compound_dict_based_on_rdkit.pt")
        # info is used to evaluate the test set.
        info = None
        # info = pd.read_csv(f"{pre}/apr23_testset_pdbbind_gvp_pocket_radius20_info.csv", index_col=0)


    return train, train_update, valid, valid_update, test, test_update, all_pocket_test, info

def get_data_mean_v2(args, logging, addNoise=None, use_whole_protein=False, compound_coords_init_mode='pocket_center_rdkit', pre="/PDBbind_data/pdbbind2020"):
    # pre = 
    if args.data == "0":
        logging.info(f"re-docking, using dataset: pdbbind2020 pred distance map.")
        logging.info(f"compound feature based on torchdrug")
        add_noise_to_com = float(addNoise) if addNoise else None

        new_dataset = TankBindMeanDataSet(f"{pre}/dataset", add_noise_to_com=add_noise_to_com, use_whole_protein=use_whole_protein, compound_coords_init_mode=compound_coords_init_mode, pocket_radius=args.pocket_radius, noise_for_predicted_pocket=args.noise_for_predicted_pocket, 
                                            test_random_rotation=args.test_random_rotation, pocket_idx_no_noise=args.pocket_idx_no_noise, pre=pre)
        # load compound features extracted using torchdrug.
        # new_dataset.compound_dict = torch.load(f"{pre}/compound_dict.pt")
        # c_length: number of atoms in the compound
        # native_num_contact?
        # This filter may cause some samples to be filtered out. So the actual number of samples is less than that in the original papers.
        # new_dataset.data = new_dataset.data.query("c_length < 100 and native_num_contact > 5").reset_index(drop=True)
        train_tmp = new_dataset.data.query("c_length < 100 and native_num_contact > 5 and group =='train'").reset_index(drop=True)
        valid_test_tmp = new_dataset.data.query("group == 'valid' or group == 'test'").reset_index(drop=True)
        new_dataset.data = pd.concat([train_tmp, valid_test_tmp], axis=0).reset_index(drop=True)
        new_dataset.data['for_update'] = False
        tmp_data = new_dataset.data.query("use_compound_com")
        tmp_data.use_compound_com = False
        new_dataset.data = pd.concat([new_dataset.data, tmp_data], axis=0).reset_index(drop=True)
        d = new_dataset.data
        only_native_train_index = d.query("use_compound_com and group =='train'").index.values
        train = new_dataset[only_native_train_index]
        train_update_index = d.query("not use_compound_com and not pocket_com == pocket_com and group =='train'").index.values
        assert len(train_update_index) == len(only_native_train_index)

        new_dataset.data.use_compound_com.iloc[train_update_index] = True
        new_dataset.data.for_update.iloc[train_update_index] = True

        train_update_index_all = np.concatenate((only_native_train_index, train_update_index))
        train_update = new_dataset[train_update_index_all]
        train_update_only = new_dataset[train_update_index]

        valid_index = d.query("use_compound_com and group =='valid'").index.values
        valid = new_dataset[valid_index]
        valid_update_index = d.query("not use_compound_com and not pocket_com == pocket_com and group =='valid'").index.values
        assert len(valid_update_index) == len(valid_index)
        valid_update = new_dataset[valid_update_index]
        test_index = d.query("use_compound_com and group =='test'").index.values
        test = new_dataset[test_index]
        test_update_index = d.query("not use_compound_com and not pocket_com == pocket_com and group =='test'").index.values
        assert len(test_update_index) == len(test_index)
        test_update = new_dataset[test_update_index]

        new_dataset.data.use_compound_com.iloc[valid_update_index] = True
        new_dataset.data.use_compound_com.iloc[test_update_index] = True

        new_dataset.data.for_update.iloc[valid_update_index] = True
        new_dataset.data.for_update.iloc[test_update_index] = True

        info = None


    if args.data == "1":
        logging.info(f"self-docking, same as data mode 0 except using LAS_distance constraint masked compound pair distance")
        add_noise_to_com = float(addNoise) if addNoise else None

        # compoundMode = 1 is for GIN model.
        new_dataset = TankBindDataSet(f"{pre}/dataset", add_noise_to_com=add_noise_to_com)
        # load GIN embedding for compounds.
        # new_dataset.compound_dict = torch.load(f"{pre}/pdbbind_compound_dict_with_LAS_distance_constraint_mask.pt")
        new_dataset.compound_dict = torch.load(f"{pre}/tankbind_data/compound_torchdrug_features.pt")
        new_dataset.data = new_dataset.data.query("c_length < 100 and native_num_contact > 5").reset_index(drop=True)
        d = new_dataset.data
        only_native_train_index = d.query("use_compound_com and group =='train'").index.values
        train = new_dataset[only_native_train_index]
        # train = train1
        train_index = d.query("group =='train'").index.values
        train_after_warm_up = new_dataset[train_index]

        # train = torch.utils.data.ConcatDataset([train1, train2])
        valid_index = d.query("use_compound_com and group =='valid'").index.values
        valid = new_dataset[valid_index]
        test_index = d.query("use_compound_com and group =='test'").index.values
        test = new_dataset[test_index]

        all_pocket_test_fileName = f"{pre}/test_dataset"
        all_pocket_test = TankBindDataSet(all_pocket_test_fileName)
        # all_pocket_test.compound_dict = torch.load(f"{pre}/pdbbind_test_compound_dict_based_on_rdkit.pt")
        all_pocket_test.compound_dict = torch.load(f"../predictions/pdbbind_test_compound_dict_based_on_rdkit.pt")
        # info is used to evaluate the test set.
        info = None
        # info = pd.read_csv(f"{pre}/apr23_testset_pdbbind_gvp_pocket_radius20_info.csv", index_col=0)


    return train, train_update, train_update_only, valid, valid_update, test, test_update, info

def get_data_mean_v3(args, logger, addNoise=None, use_whole_protein=False, compound_coords_init_mode='pocket_center_rdkit', pre="/PDBbind_data/pdbbind2020"):
    if args.data == "0":
        logger.log_message(f"Loading dataset")
        logger.log_message(f"compound feature based on torchdrug")
        logger.log_message(f"protein feature based on esm2")
        add_noise_to_com = float(addNoise) if addNoise else None

        new_dataset = TankBindMeanDataSet(f"{pre}/dataset", add_noise_to_com=add_noise_to_com, use_whole_protein=use_whole_protein, compound_coords_init_mode=compound_coords_init_mode, pocket_radius=args.pocket_radius, noise_for_predicted_pocket=args.noise_for_predicted_pocket, 
                                            test_random_rotation=args.test_random_rotation, pocket_idx_no_noise=args.pocket_idx_no_noise, use_esm2_feat=args.use_esm2_feat, seed=args.seed, pre=pre, args=args)
        # load compound features extracted using torchdrug.
        # new_dataset.compound_dict = torch.load(f"{pre}/compound_dict.pt")
        # c_length: number of atoms in the compound
        # native_num_contact?
        # This filter may cause some samples to be filtered out. So the actual number of samples is less than that in the original papers.
        train_tmp = new_dataset.data.query("c_length < 100 and native_num_contact > 5 and group =='train' and use_compound_com").reset_index(drop=True)
        valid_test_tmp = new_dataset.data.query("(group == 'valid' or group == 'test') and use_compound_com").reset_index(drop=True)
        new_dataset.data = pd.concat([train_tmp, valid_test_tmp], axis=0).reset_index(drop=True)
        d = new_dataset.data
        only_native_train_index = d.query("group =='train'").index.values
        train = new_dataset[only_native_train_index]
        valid_index = d.query("group =='valid'").index.values
        valid = new_dataset[valid_index]
        test_index = d.query("group =='test'").index.values
        test = new_dataset[test_index]

    return train, valid, test