from sklearn.model_selection import train_test_split

from datasets import RSNADataset, PandaDataset, CamelyonDataset

def load_train_val_dataset(config):
    name = config.dataset_name
    val_prop = config.val_prop
    n_samples = None
    seed = config.seed
    if "rsna" in name:
        
        if 'features' in name:
            # rsna-features_<model_name>
            features_dir_name = name.split('-')[1]
            data_path = f'/data/RSNA_ICH/raw/{features_dir_name}/'
            processed_data_path = f'/data/RSNA_ICH/processed/{features_dir_name}/'
        else:
            data_path = '/data/datasets/RSNA_ICH/original/'
            processed_data_path = '/data/datasets/RSNA_ICH/processed/original/'
        
        csv_path = '/data/datasets/RSNA_ICH/bags_train.csv'
        dataset = RSNADataset(data_path=data_path, processed_data_path=processed_data_path, csv_path=csv_path, n_samples=n_samples, use_slice_distances=config.use_inst_distances)
        
        bags_labels = dataset.get_bag_labels()
        len_ds = len(bags_labels)
        
        idx = list(range(len_ds))
        idx_train, idx_val = train_test_split(idx, test_size=val_prop, random_state=seed, stratify=bags_labels)

        train_dataset = dataset.subset(idx_train)
        val_dataset = dataset.subset(idx_val)

    elif "panda" in name:

        # panda-<patch_dir>-<features_dir>
        # Ex: panda-patches_512-features_resnet18

        patch_dir = name.split('-')[1]

        if 'features' in name:
            features_dir_name = name.split('-')[2]
            data_path = f'/data/Panda/{patch_dir}/raw/{features_dir_name}/'
            processed_data_path = f'/data/Panda/{patch_dir}/processed/{features_dir_name}/'
            csv_path = f'/data/Panda/{patch_dir}/train_val_patches.csv'
            # train_csv_path = f'/data/Panda/patches_512/train_patches.csv'
            # val_csv_path = f'/data/Panda/patches_512/val_patches.csv'
        else:
            data_path = f'/data/Panda/{patch_dir}/images/'
            processed_data_path = f'/data/Panda/{patch_dir}/processed/images/'
            csv_path = f'/data/Panda/{patch_dir}/train_val_patches.csv'
            # train_csv_path = '/data/Panda/patches_512/train_patches.csv'
            # val_csv_path = '/data/Panda/patches_512/val_patches.csv'
    
        dataset = PandaDataset(data_path=data_path, processed_data_path=processed_data_path, csv_path=csv_path, n_samples=n_samples, use_patch_distances=config.use_inst_distances)
        
        bags_labels = dataset.get_bag_labels()
        len_ds = len(bags_labels)

        idx = list(range(len_ds))
        idx_train, idx_val = train_test_split(idx, test_size=val_prop, random_state=seed, stratify=bags_labels)

        train_dataset = dataset.subset(idx_train)
        val_dataset = dataset.subset(idx_val)
    elif 'camelyon16' in name:

        # camelyon16-<patch_dir>-<features_dir>
        # Ex: camelyon16-patches_512_preset-features_resnet50_bt
        
        if 'features' in name:
            patches_dir_name = name.split('-')[1]
            features_dir_name = name.split('-')[2]
            main_data_path = f'/data/CAMELYON16/{patches_dir_name}/'
            csv_path = f'/data/CAMELYON16/original/train.csv'            
        else:
            raise ValueError(f"camelyon16 dataset only supports features")

        dataset = CamelyonDataset(main_data_path, csv_path, features_dir_name, use_patch_distances=config.use_inst_distances)
        bags_labels = dataset.get_bag_labels()
        len_ds = len(bags_labels)

        idx = list(range(len_ds))
        idx_train, idx_val = train_test_split(idx, test_size=val_prop, random_state=seed, stratify=bags_labels)

        train_dataset = dataset.subset(idx_train)
        val_dataset = dataset.subset(idx_val)
    else:
        raise ValueError(f"Dataset {name} not supported")
    
    return train_dataset, val_dataset

def load_test_dataset(config):
    name = config.dataset_name
    n_samples = None
    seed = config.seed
    if 'rsna' in name:
        if 'features' in name:
            # rsna-features_<model_name>
            features_dir_name = name.split('-')[1]
            data_path = f'/data/RSNA_ICH/raw/{features_dir_name}/'
            processed_data_path = f'/data/RSNA_ICH/processed/{features_dir_name}/'
        else:
            data_path = '/data/datasets/RSNA_ICH/original/'
            processed_data_path = '/data/datasets/RSNA_ICH/processed/original/'
        csv_path = '/data/datasets/RSNA_ICH/bags_test.csv'
        test_dataset = RSNADataset(data_path=data_path, processed_data_path=processed_data_path, csv_path=csv_path, n_samples=n_samples, use_slice_distances=config.use_inst_distances)  
    elif 'panda' in name:

        patch_dir = name.split('-')[1]

        if 'features' in name:
            features_dir_name = name.split('-')[2]
            data_path = f'/data/Panda/{patch_dir}/raw/{features_dir_name}/'
            processed_data_path = f'/data/Panda/{patch_dir}/processed/{features_dir_name}/'
            csv_path = f'/data/Panda/{patch_dir}/test_patches.csv'
        else:
            data_path = f'/data/Panda/{patch_dir}/images/'
            processed_data_path = f'/data/Panda/{patch_dir}/processed/images/'
            csv_path = f'/data/Panda/{patch_dir}/test_patches.csv'

        test_dataset = PandaDataset(data_path=data_path, processed_data_path=processed_data_path, csv_path=csv_path, n_samples=n_samples, use_patch_distances=config.use_inst_distances)  
    elif "camelyon16" in name:
        
        # camelyon16-<patch_dir>-<features_dir>
        # Ex: camelyon16-patches_512_preset-features_resnet50_bt
        
        if 'features' in name:
            patches_dir_name = name.split('-')[1]
            features_dir_name = name.split('-')[2]
            main_data_path = f'/data/CAMELYON16/{patches_dir_name}/'
            csv_path = f'/data/CAMELYON16/original/test.csv'            
        else:
            raise ValueError(f"camelyon16 dataset only supports features")
        test_dataset = CamelyonDataset(main_data_path, csv_path, features_dir_name, use_patch_distances=config.use_inst_distances)
    else:
        raise ValueError(f"Dataset {name} not supported")
    return test_dataset
