""" This file exports and saves CLIP features for CIFAR-10, -20, and -100 datasets.
When using the code for the first time, execute this file, to save the CLIP features of the input dataset
Input argument:
dataset: {dataset name} = cifar10 or cifar100coarse or cifar100 """

from torch.utils.data import DataLoader
import clip
import torch
from tqdm import tqdm
import argparse
import os
import sys
sys.path.insert(0, './')

from data.datasets import load_dataset


parser = argparse.ArgumentParser(description='export clip features')
parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100coarse', 'cifar100', 'imagenet'], help='dataset name')
parser.add_argument('--data_path', type=str, default='~/data', help='dataset path')
parser.add_argument('--save_path', type=str, default='./data/', help='feature path')
args = parser.parse_args()

# clipfeatures.pt

train_dataset = load_dataset(args.dataset, train=True, path=args.data_path)
train_loader = DataLoader(train_dataset, batch_size=150, shuffle=False, drop_last=False, num_workers=8)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
clip_model, preprocess = clip.load("ViT-L/14", device=device)
clip_model.to(device)
visual_model = torch.nn.DataParallel(clip_model)

# clip_model, preprocess = clip.load("ViT-L/14", device='cpu')
# clip_model.to(device)
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
# visual_model = torch.nn.DataParallel(clip_model.visual.to(device))



# visual_model = torch.nn.DataParallel(clip_model.visual)

# mkdir args.save_path if not exist
if not os.path.exists(args.save_path):
    os.makedirs(args.save_path)
model_dir = args.save_path + args.dataset + '-clipfeat.pt'

print("exporting features to {}".format(model_dir))

features = []
ys = []
with tqdm(total=len(train_loader)) as progress_bar:
    for step, (x, y) in enumerate(train_loader):
        x = x.to(device)
        y = y.to(device)
        with torch.no_grad():
            x_feature = visual_model.module.encode_image(x)
            # x_feature = visual_model(x.type(clip_model.dtype))
            
        features.append(x_feature.detach().cpu())
        ys.append(y.detach().cpu())
        progress_bar.update(1)

final_features = torch.cat(features, dim=0)
final_ys = torch.cat(ys, dim=0)
print(final_features.shape)
print(final_ys.shape)
dict = {'features': final_features, 'labels': final_ys}
torch.save(dict, model_dir)
print("done!")
