import os
import time
import warnings
import numpy as np

import torch
import torch.utils.data
import torchvision
import torchvision.transforms
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data.dataloader import default_collate
from torchvision.transforms.functional import InterpolationMode
from tqdm import tqdm
from mytorch.sampler import RASampler
from mytorch.transforms import get_mixup_cutmix
import mytorch.presets as presets
import mytorch.utils as utils
from mymodels import MODEL_DICT, WEIGHTS_DICT
from clip.classes import imagenet_classes


class FeatureData(Dataset):
    def __init__(self, data_path, target_path):
        self.all_features = torch.load(data_path).numpy()
        self.all_target = torch.load(target_path).numpy()
        self.classes = list(np.unique(self.all_target))

    def __len__(self):
        return self.all_features.shape[0]

    def __getitem__(self, idx):
        return self.all_features[idx], self.all_target[idx]


def save_all_features(model, criterion, data_loader, device, args, print_freq=100, is_train=True, log_suffix=""):
    model.eval()
    metric_logger = utils.MetricLogger(delimiter="  ")
    header = f"Test: {log_suffix}"

    num_processed_samples, all_output_features, all_target = 0, [], []
    with torch.inference_mode():
        for image, target in metric_logger.log_every(data_loader, print_freq, header):
            # if num_processed_samples >= 1000:
            #     break

            batch_size = image.shape[0]
            image = image.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            if args.save_inter:
                output_features = model.forward_inter_features(image)
            else:
                output_features = model.forward_features(image)

            # all_output_features.append(output_features.cpu())
            # all_target.append(target.cpu())
            all_output_features.append(output_features.detach().cpu())
            all_target.append(target)
            num_processed_samples += batch_size
            if args.save_inter:
                break
    all_output_features = torch.cat(all_output_features, dim=0)
    all_target = torch.cat(all_target, dim=0)

    # gather the stats from all processes
    num_processed_samples = utils.reduce_across_processes(num_processed_samples)
    if (
        hasattr(data_loader.dataset, "__len__")
        and len(data_loader.dataset) != num_processed_samples
        # and torch.distributed.get_rank() == 0
    ):
        # See FIXME above
        warnings.warn(
            f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} "
            "samples were used for the validation, which might bias the results. "
            "Try adjusting the batch size and / or the world size. "
            "Setting the world size to 1 is always a safe bet."
        )

    metric_logger.synchronize_between_processes()

    tail_info = 'train' if is_train else 'test'
    save_dir = os.path.join('saved_activations', args.model)
    os.makedirs(save_dir, exist_ok=True)
    if args.save_inter:
        save_features_path = os.path.join(save_dir, 'target_model_all_{}_inter_features.pth'.format(tail_info))
        save_target_path = os.path.join(save_dir, 'target_model_all_{}_inter_target.pth'.format(tail_info))
    else:
        save_features_path = os.path.join(save_dir, 'target_model_all_{}_features.pth'.format(tail_info))
        save_target_path = os.path.join(save_dir, 'target_model_all_{}_target.pth'.format(tail_info))
    torch.save(all_output_features.cpu(), save_features_path)
    torch.save(all_target.cpu(), save_target_path)


def load_data_save(data_dir, args):
    print("Loading data")
    weights = torchvision.models.get_weight(args.weights)
    preprocessing = weights.transforms(antialias=True)
    if args.backend == "tensor":
        preprocessing = torchvision.transforms.Compose([torchvision.transforms.PILToTensor(), preprocessing])
    dataset = torchvision.datasets.ImageFolder(
        data_dir,
        preprocessing
    )

    print("Creating data loaders")
    if args.distributed:
        if hasattr(args, "ra_sampler") and args.ra_sampler:
            sampler = RASampler(dataset, shuffle=True, repetitions=args.ra_reps)
        else:
            sampler = torch.utils.data.distributed.DistributedSampler(dataset)
    else:
        sampler = torch.utils.data.SequentialSampler(dataset)

    return dataset, sampler


def main(args):
    utils.init_distributed_mode(args)
    print(args)

    device = torch.device(args.device)

    if args.use_deterministic_algorithms:
        torch.backends.cudnn.benchmark = False
        torch.use_deterministic_algorithms(True)
    else:
        torch.backends.cudnn.benchmark = True

    train_dir = os.path.join(args.data_path, "train")
    val_dir = os.path.join(args.data_path, "val")

    print("Creating model")
    model = MODEL_DICT[args.model](weights=WEIGHTS_DICT[args.model])
    model.to(device)

    criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)

    if args.test_only:
        # We disable the cudnn benchmarking because it can noticeably affect the accuracy
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

        for run_dir in [train_dir, val_dir]:
            dataset, sampler = load_data_save(run_dir, args)
            data_loader = torch.utils.data.DataLoader(
                dataset, batch_size=args.batch_size, sampler=sampler, num_workers=args.workers, pin_memory=True
            )
            is_train = 'train' in run_dir
            save_all_features(model, criterion, data_loader, device=device, args=args, is_train=is_train)
        return

def get_args_parser(add_help=True):
    import argparse

    parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)

    parser.add_argument("--data-path", default="datasets/ILSVRC2012", type=str, help="dataset path")
    parser.add_argument("--model", default="resnet18", type=str, help="model name")
    parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
    parser.add_argument(
        "-b", "--batch-size", default=512, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
    )
    parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
    parser.add_argument(
        "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
    )
    parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
    parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")
    parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
    parser.add_argument(
        "--wd",
        "--weight-decay",
        default=1e-4,
        type=float,
        metavar="W",
        help="weight decay (default: 1e-4)",
        dest="weight_decay",
    )
    parser.add_argument(
        "--norm-weight-decay",
        default=None,
        type=float,
        help="weight decay for Normalization layers (default: None, same value as --wd)",
    )
    parser.add_argument(
        "--bias-weight-decay",
        default=None,
        type=float,
        help="weight decay for bias parameters of all layers (default: None, same value as --wd)",
    )
    parser.add_argument(
        "--transformer-embedding-decay",
        default=None,
        type=float,
        help="weight decay for embedding parameters for vision transformer models (default: None, same value as --wd)",
    )
    parser.add_argument(
        "--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing"
    )
    parser.add_argument("--mixup-alpha", default=0.0, type=float, help="mixup alpha (default: 0.0)")
    parser.add_argument("--cutmix-alpha", default=0.0, type=float, help="cutmix alpha (default: 0.0)")
    parser.add_argument("--lr-scheduler", default="steplr", type=str, help="the lr scheduler (default: steplr)")
    parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)")
    parser.add_argument(
        "--lr-warmup-method", default="constant", type=str, help="the warmup method (default: constant)"
    )
    parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr")
    parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
    parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
    parser.add_argument("--lr-min", default=0.0, type=float, help="minimum lr of lr schedule (default: 0.0)")
    parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
    parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
    parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
    parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
    parser.add_argument(
        "--cache-dataset",
        dest="cache_dataset",
        help="Cache the datasets for quicker initialization. It also serializes the transforms",
        action="store_true",
    )
    parser.add_argument(
        "--sync-bn",
        dest="sync_bn",
        help="Use sync batch norm",
        action="store_true",
    )
    parser.add_argument(
        "--test-only",
        dest="test_only",
        help="Only test the model",
        action="store_true",
    )
    parser.add_argument("--auto-augment", default=None, type=str, help="auto augment policy (default: None)")
    parser.add_argument("--ra-magnitude", default=9, type=int, help="magnitude of auto augment policy")
    parser.add_argument("--augmix-severity", default=3, type=int, help="severity of augmix policy")
    parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)")

    # Mixed precision training parameters
    parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")

    # distributed training parameters
    parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
    parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
    parser.add_argument(
        "--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters"
    )
    parser.add_argument(
        "--model-ema-steps",
        type=int,
        default=32,
        help="the number of iterations that controls how often to update the EMA model (default: 32)",
    )
    parser.add_argument(
        "--model-ema-decay",
        type=float,
        default=0.99998,
        help="decay factor for Exponential Moving Average of model parameters (default: 0.99998)",
    )
    parser.add_argument(
        "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
    )
    parser.add_argument(
        "--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)"
    )
    parser.add_argument(
        "--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)"
    )
    parser.add_argument(
        "--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)"
    )
    parser.add_argument(
        "--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
    )
    parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
    parser.add_argument("--ra-sampler", action="store_true", help="whether to use Repeated Augmentation in training")
    parser.add_argument(
        "--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)"
    )
    parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
    parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive")
    parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms")
    # My added
    parser.add_argument('--save-inter', action='store_true', help='save intermediate features')

    return parser


if __name__ == "__main__":
    args = get_args_parser().parse_args()
    main(args)