from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset, Subset
import random
import matplotlib.pyplot as plt
import numpy as np
import math
from collections import OrderedDict
import tensorflow as tf
from PIL import Image
import os
import itertools
from typing import List
from torch.cuda.amp import GradScaler, autocast

from ffcv.fields import IntField, RGBImageField
from ffcv.fields.decoders import IntDecoder, SimpleRGBImageDecoder
from ffcv.loader import Loader, OrderOption
from ffcv.pipeline.operation import Operation
from ffcv.transforms import RandomHorizontalFlip, Cutout, \
    RandomTranslate, Convert, ToDevice, ToTensor, ToTorchImage
from ffcv.transforms.common import Squeeze
from ffcv.writer import DatasetWriter
import gc

from ffcv.transforms import ToTensor, ToDevice, Squeeze, NormalizeImage, \
    RandomHorizontalFlip, ToTorchImage
from ffcv.fields.rgb_image import CenterCropRGBImageDecoder, \
    RandomResizedCropRGBImageDecoder
from ffcv.fields.basics import IntDecoder
from pathlib import Path
import wandb
from tqdm import tqdm
import heapq

def to_chunks(it, size):
  size = int(math.ceil(size))
  it = iter(it)
  return iter(lambda: tuple(itertools.islice(it, size)), ())

device = torch.device("cuda:1")

def generate_until(gen_f, pred):
  while True:
    res = gen_f()
    if pred(res):
      return res

def gen_order(*, labels, label_map, n_samples, n_negatives):
    n_points = len(labels)
    order = []
    for _ in range(n_samples):
      a = random.randrange(n_points)
      label = labels[a]
      order.append(a)
      order.append(generate_until(
          lambda: random.choice(label_map[label]),
          lambda p: p != a
          )
      )
      for _ in range(n_negatives):
        order.append(generate_until(
            lambda: random.randrange(n_points),
            lambda n: labels[n] != label
            )
        )
    return order

def get_embedding_dataloader(dataset_path, batch_size):
    CIFAR100_MEAN = [255 * x for x in [0.5071, 0.4865, 0.4409]]
    CIFAR100_STD = [255 * x for x in [0.2673, 0.2564, 0.2762]]
    loaders = {}

    label_pipeline: List[Operation] = [IntDecoder(), ToTensor(), ToDevice(device), Squeeze()]
    image_pipeline: List[Operation] = [SimpleRGBImageDecoder()]

    image_pipeline.extend([
        ToTensor(),
        ToDevice(device, non_blocking=True),
        ToTorchImage(),
        Convert(torch.float16),
        transforms.Normalize(CIFAR100_MEAN, CIFAR100_STD),
    ])

    ordering = OrderOption.SEQUENTIAL

    return Loader(dataset_path,
                  batch_size=batch_size,
                  num_workers=2,
                  order=OrderOption.SEQUENTIAL,
                  drop_last=False,
                  pipelines={'image': image_pipeline, 'label': label_pipeline})

def getembeddings(dataset_path):
  model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1).to(device)
  model.fc = nn.Identity()
  model.eval()
  res = []
  train_loader = get_embedding_dataloader(dataset_path, batch_size=1000)
  with torch.no_grad():
    for batch in train_loader:
        with autocast():
          res.append(get_outputs(model, batch).detach().cpu())
  return torch.cat(res).float()

def arboricity(g):
    degree = [len(neigh) for neigh in g]
    deleted = [False] * len(g)
    heap = [(d, i) for i, d in enumerate(degree)]
    heapq.heapify(heap)
    res = 0
    while len(heap) > 0:
        _, u = heapq.heappop(heap)
        if deleted[u]:
            continue
        deleted[u] = True
        res = max(res, degree[u])
        for v in g[u]:
            degree[v] -= 1
            heapq.heappush(heap, (degree[v], v))
    return res


def construct_graph(tuples):
    n = 1 + max(max(t) for t in tuples)
    g = [[] for _ in range(n)]
    for t in tuples:
        u = t[0]
        for v in t[1:]:
            g[u].append(v)
            g[v].append(u)
    return g

def gen_order_with_ground_truth(embeddings, *, n_samples, n_negatives):
    order = []
    n_points = embeddings.shape[0]
    tuples = []
    for _ in range(n_samples):
      cur = random.sample(range(n_points), k=n_negatives + 2)
      a, *candidates = cur
      a_emb = embeddings[a]
      similarities = a_emb @ embeddings[candidates].T
      best = torch.argmax(similarities)
      candidates[0], candidates[best] = candidates[best], candidates[0]
      tuples.append([a] + candidates)
      order += tuples[-1]
    graph = construct_graph(tuples)
    print("Arboricity: ", arboricity(graph))
    return order


def contrastive_dataloader_with_ground_truth(dataset_path, embeddings, *, n_samples, n_negatives, batch_size):
    label_map = {}
    n_points = embeddings.shape[0]

    order = gen_order_with_ground_truth(embeddings, n_samples=n_samples, n_negatives=n_negatives)

    CIFAR100_MEAN = [255 * x for x in [0.5071, 0.4865, 0.4409]]
    CIFAR100_STD = [255 * x for x in [0.2673, 0.2564, 0.2762]]
    loaders = {}

    label_pipeline: List[Operation] = [IntDecoder(), ToTensor(), ToDevice(torch.device(device)), Squeeze()]
    image_pipeline: List[Operation] = [SimpleRGBImageDecoder()]

    image_pipeline.extend([
        ToTensor(),
        ToDevice(torch.device(device), non_blocking=True),
        ToTorchImage(),
        Convert(torch.float16),
        transforms.Normalize(CIFAR100_MEAN, CIFAR100_STD),
    ])

    ordering = OrderOption.SEQUENTIAL
    batch_size = batch_size - batch_size % (n_negatives + 2)

    return Loader(dataset_path,
                  batch_size=batch_size,
                  num_workers=2,
                  order=OrderOption.SEQUENTIAL,
                  indices=order,
                  drop_last=False,
                  pipelines={'image': image_pipeline, 'label': label_pipeline})


def get_batch_length(batch):
  return len(batch[1])

def get_outputs(model, batch):
  return model(batch[0].to(device))

loss_f = nn.TripletMarginLoss()

def contrastive_loss_acc(outputs, n_negatives):
  assert n_negatives == 1
  assert outputs.shape[0] % 3 == 0
  tuples_sep = outputs.reshape([outputs.shape[0] // 3, 3] + list(outputs.shape[1:]))
  assert len(tuples_sep.shape) == 3
  assert tuples_sep.shape[1] == 3
  anchor = tuples_sep[:, 0]
  positive = tuples_sep[:, 1]
  negative = tuples_sep[:, 2]
  loss = loss_f(anchor, positive, negative)
  acc = torch.mean((torch.linalg.vector_norm(anchor - positive, dim=1)
                    < torch.linalg.vector_norm(anchor - negative, dim=1)
                    ).float())
  return loss, acc


def train(model, train_loader, optimizer, epoch, n_vals, loss_acc_f):
  model.train()
  n_batches = len(train_loader)
  # progress_bar = tqdm(train_loader, position=0, leave=True, miniters=10)
  losses = []
  # for batch in progress_bar:
  total_loss = 0
  total_acc = 0

  progress_bar = tqdm(train_loader, position=0, leave=True, miniters=10)
  n_points = 0
  for batch in progress_bar:
    optimizer.zero_grad()
    with autocast():
      outputs = get_outputs(model, batch)
    cur_points = outputs.shape[0]
    n_vals -= cur_points
    n_points += cur_points
    if n_vals < 0:
      outputs = outputs[:n_vals]
    loss, acc = loss_acc_f(outputs)
    total_loss += loss.detach().item() * cur_points
    total_acc += acc * cur_points
    loss.backward()
    optimizer.step()
    wandb.log({"train": {
        "loss": total_loss / n_points,
        "acc": 100 * total_acc / n_points,
    }})
    progress_bar.set_description(f'Train Epoch: {epoch}  Loss: {total_loss / n_points:.6f} Accuracy: {100 * total_acc / n_points}%', refresh=False)

    if n_vals < 0:
      break

def test(model, test_loader, msg, n_vals, loss_acc_f):
  model.eval()
  total_loss = 0
  total_acc = 0
  n_points = 0
  with torch.no_grad():
    for batch in test_loader:
      with autocast():
        outputs = get_outputs(model, batch)
      cur_points = outputs.shape[0]
      n_vals -= cur_points
      n_points += cur_points
      if n_vals < 0:
        outputs = outputs[:n_vals]
      loss, acc = loss_acc_f(outputs)
      total_loss += loss.detach().item() * cur_points
      total_acc += acc * cur_points
      if n_vals < 0:
        break
  total_loss /= n_points
  total_acc /= n_points
  total_acc *= 100

  print(f'{msg} Average loss: {total_loss:.4f}, Accuracy: {total_acc:.2f}%')
  return total_loss, total_acc

def get_labels(dataset):
  return torch.tensor(dataset.targets).tolist()

def class_label_list_to_sets(assignment, n_classes):
  res = [[] for _ in range(n_classes)]
  for i, c in enumerate(assignment):
    res[c].append(i)
  return res


if __name__ == "__main__":
    assert False, "We use wandb for logging. Please specify your wandb key below"
    wandb.login(key="")
    cifar100_train_dataset = datasets.CIFAR100('/tmp', train=True, download=True)
    cifar100_train_labels = get_labels(cifar100_train_dataset)
    cifar100_train_embeddings = getembeddings("/tmp/cifar100_train.beton")

    lr = 0.1
    n_runs = 10
    n_negatives = 1
    epochs = 51
    lr = 0.1
    dim = 128
    sqr = dim ** 2
    for scale in [0.5, 0.75, 1, 1.5, 2, 2.5, 3]:
        n_samples = int(sqr * scale)
        n_vals = n_samples * (2 + n_negatives)
        for _ in range(n_runs):
            wandb.init(
                project="dimension_close_sqr", entity="", reinit=True, name=f"dim={dim} lr={lr} n_samples={n_samples}",
                config = {"dim": dim, "lr": lr, "n_samples": n_samples}
            )
            train_loader = contrastive_dataloader_with_ground_truth("/tmp/cifar100_train.beton", cifar100_train_embeddings, n_samples=n_samples, n_negatives=n_negatives, batch_size=500)
            model = models.resnet18()
            model.fc = nn.Linear(512, dim, bias=False)
            model = model.to(device)
            optimizer = optim.Adadelta(model.parameters(), lr=lr)

            for epoch in range(epochs):
                if epoch % 5 == 0:
                    test_loss, test_acc = test(model, train_loader, "Train", n_vals, lambda o: contrastive_loss_acc(o, n_negatives))
                    wandb.log({"test": {"acc": test_acc, "loss": test_loss}}, commit=False)
                train(model, train_loader, optimizer, epoch, n_vals, lambda o: contrastive_loss_acc(o, n_negatives))
            final_acc = test(model, train_loader, "Final Train", n_vals, lambda o: contrastive_loss_acc(o, n_negatives))
            print(f"{dim} {n_samples} Final accuracy: {final_acc}")
            print("-----------------------------------------------")
            wandb.finish()

