import torch
import argparse
import os
from tqdm import tqdm

import numpy as np

from datasets import CustomDatasetFolder
from experiments.model_builder import load_torchvision_model

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
torch.autograd.set_detect_anomaly(True)

# python code/extract_features.py --dataset_dir=/data/datasets/RSNA_ICH/original/ --weights_path=/work/work_fran/SmoothAttention/weights/ssl/simclr/rsna/resnet50/best_model.pt --save_dir=/data/data_fran/RSNA_ICH/features_resnet50simclr/ --image_size=224 --batch_size=512 --model=resnet50 --num_workers=16
# python code/extract_features.py --dataset_dir=/data/datasets/RSNA_ICH/original/ --save_dir=/data/data_fran/RSNA_ICH/features_resnet18/ --image_size=224 --batch_size=512 --model=resnet18 --num_workers=16
# CUDA_VISIBLE_DEVICES=2 nohup python code/extract_features.py --dataset_dir=/data/data_fran/Panda/patches_256/images/ --save_dir=/data/data_fran/Panda/patches_256/raw/features_resnet18/ --image_size=256 --batch_size=512 --model=resnet18 --num_workers=4 > panda_feat.out 2>&1 &
# CUDA_VISIBLE_DEVICES=2 nohup python code/extract_features.py --dataset_dir=/data/data_fran/Panda/patches_512/images/ --save_dir=/data/data_fran/Panda/patches_512/raw/features_resnet18/ --image_size=256 --batch_size=512 --model=resnet18 --num_workers=4 > panda_feat.out 2>&1 &

# CUDA_VISIBLE_DEVICES=2 nohup python code/extract_features.py --dataset_dir=/data/datasets/RSNA_ICH/original/ --save_dir=/data/data_fran/RSNA_ICH/features_resnet18/ --image_size=512 --batch_size=512 --model=resnet18 --num_workers=4 > rsna_feat.out 2>&1 &

# CUDA_VISIBLE_DEVICES=0 nohup python code/extract_features.py --dataset_dir=/data/data_fran/Panda/patches_256_new/images/ --save_dir=/data/data_fran/Panda/patches_256_new/raw/features_resnet18/ --image_size=224 --batch_size=512 --model=resnet18 --num_workers=4 > panda_feat_256.out 2>&1 &
# CUDA_VISIBLE_DEVICES=1 nohup python code/extract_features.py --dataset_dir=/data/data_fran/Panda/patches_512_new/images/ --save_dir=/data/data_fran/Panda/patches_512_new/raw/features_resnet18/ --image_size=224 --batch_size=512 --model=resnet18 --num_workers=4 > panda_feat_512.out 2>&1 &
# CUDA_VISIBLE_DEVICES=2 nohup python code/extract_features.py --dataset_dir=/data/data_fran/Panda/patches_1120/images/ --save_dir=/data/data_fran/Panda/patches_1120/raw/features_resnet18/ --image_size=224 --batch_size=512 --model=resnet18 --num_workers=4 > panda_feat_1120.out 2>&1 &


parser = argparse.ArgumentParser()
parser.add_argument('--dataset_dir', default='', type=str, help="Dataset dir")
parser.add_argument('--weights_path', default='', type=str, help="Path to load weights from")
parser.add_argument('--save_dir', default='', type=str, help="Save dir")
parser.add_argument('--image_size', default=224, type=int, help="Image size")
parser.add_argument('--batch_size', default=512, type=int, help="Batch size")
parser.add_argument('--model', default='mobilenet_v2', type=str, help="Model name")
parser.add_argument('--num_workers', default=16, type=int, help='Number of workers')

args = parser.parse_args()

print('Arguments:')
for arg in vars(args):
    print('{:<25s}: {:s}'.format(arg, str(getattr(args, arg))))

if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print('Loading model...')

model, transforms = load_torchvision_model(args.model)

if os.path.isfile(args.weights_path):
    print(f'Loading weights from { args.weights_path }')
    model.load_state_dict(torch.load(args.weights_path))

# n_feat = 2048
model.eval()
model.to(device)

print('Creating dataset...')
# Load the dataset

dataset = CustomDatasetFolder(root=args.dataset_dir, resize_size=args.image_size)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True)

print('Extracting features...')
# Extract features from each image in the dataset
pbar = tqdm(enumerate(dataloader), total=len(dataloader))
features_list = []
path_list = []
with torch.no_grad():
    for batch_idx, (X, X_path) in pbar: 
        X = X.to(device) # (batch_size, 3, H, W)

        if transforms is not None:
            X = transforms(X)

        features = model(X) # (batch_size, n_feat)
        # features = features.squeeze(dim=(2,3)) # (batch_size, n_feat)
        features = features.to('cpu')
        features_list.append(features)
        path_list = path_list + [ p for p in X_path ]

# Concatenate the features
features = torch.cat(features_list, dim=0)
features = features.numpy()

# Save the features
print('Saving features...')
pbar = tqdm(range(len(path_list)), total=len(path_list))
for i in pbar:
    feat = features[i]
    orig_path = path_list[i]
    save_path = orig_path.replace(args.dataset_dir, args.save_dir)
    save_path = save_path.split('.')[0] + '.npy'
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    if not os.path.isfile(save_path):
        np.save(save_path, feat)