import sys

import torch.nn.functional as F
import torch
from pathlib import Path
import yaml
# import dnnlib.util as du
import json
import pickle
import numpy as np
from tqdm import tqdm
import math
from glob import glob
import io
import torch_utils
import dnnlib
from torch_utils import misc
import torch.autograd.forward_ad as fwAD
from functorch import jvp
import uuid

device = 'cuda'


image_size = 32
has_labels = True


with open('/path/to/assets/edm_nets/edm-cifar10-32x32-uncond-vp.pkl', 'rb') as f:
    data = pickle.load(f)

net = data['ema'].to(device)

batch_size = 1

c = dnnlib.EasyDict()
path_to_dataset = "/path/to/datasets/edm_style/cifar10-32x32.zip"

c.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=path_to_dataset, use_labels=has_labels, xflip=False, cache=True)

c.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, num_workers=1, prefetch_factor=2)

dataset_obj = dnnlib.util.construct_class_by_name(**c.dataset_kwargs)

dataset_sampler = misc.InfiniteSampler(dataset=dataset_obj, rank=0, num_replicas=1, seed=0)
dataset_iterator = iter(torch.utils.data.DataLoader(dataset=dataset_obj, sampler=dataset_sampler, batch_size=batch_size, **c.data_loader_kwargs))

images, labels = next(dataset_iterator)
images = images.to(device).to(torch.float32) / 127.5 - 1
labels = labels.to(device)

def display_image(image_pt, normalize=False, ax=None):
    assert image_pt.shape == (3, image_size, image_size)
    if normalize:
        image_pt /= torch.max(torch.abs(image_pt))
    image_np = ((image_pt.cpu().detach().numpy().transpose(1, 2, 0) + 1) * 127.5).astype(np.uint8)
    if ax is None:
        plt.imshow(image_np)
        plt.show()
    else:
        ax.imshow(image_np) 

min_sigma = 0.002
max_sigma = 80
log_sigmas = np.linspace(np.log(min_sigma), np.log(max_sigma), 100)
sigmas = np.exp(log_sigmas)
num_repeats = 25
score_grad_norms = np.zeros((num_repeats, len(sigmas), batch_size))
left_score_norms = np.zeros((num_repeats, len(sigmas), batch_size))
right_score_norms = np.zeros((num_repeats, len(sigmas), batch_size))
vanilla_grad_norms = np.zeros((num_repeats, len(sigmas), batch_size))
jacobian_norms = np.zeros((num_repeats, len(sigmas), batch_size))

dataset_iterator = iter(torch.utils.data.DataLoader(dataset=dataset_obj, sampler=dataset_sampler, batch_size=batch_size, **c.data_loader_kwargs))
unique_id = str(uuid.uuid4())


def compute_gradient_trace(f, x, t, batch_size=5, mini_batch_size=1, labels=None):
    x = x.clone().detach().requires_grad_()

    def grad_trace_vi(func, y, v):
        output = func(y).reshape(3*image_size*image_size)
        jvp = torch.autograd.grad(outputs=output, inputs=y, grad_outputs=v, only_inputs=True, create_graph=True)[0]
        trace_vi = torch.dot(jvp, v)
        return trace_vi

    grad_trace = torch.zeros_like(x)

    for i in range(batch_size):
        trace_v = torch.zeros(1, requires_grad=True, device=device)  
        x.grad = None

        for j in range(mini_batch_size):
            v = torch.randn_like(x)  
            if labels is None:
                trace_v = trace_v + grad_trace_vi(lambda x: f(x.reshape(1, 3, image_size, image_size), t), x, v)
            else:
                trace_v = trace_v + grad_trace_vi(lambda x: f(x.reshape(1, 3, image_size, image_size), t, labels), x, v)
                

        trace_v = trace_v / mini_batch_size  
        trace_v.backward()
        grad_trace = grad_trace + x.grad.detach() 

    grad_trace = grad_trace / batch_size  

    return grad_trace


assert batch_size == 1

for repeat_idx in tqdm(range(num_repeats)):
    images, labels = next(dataset_iterator)
    images = images.to(device).to(torch.float32) / 127.5 - 1
    labels = labels.to(device)
    for i in range(1, len(sigmas)):
        noisy_images = images + torch.randn_like(images) * sigmas[i]
        noisy_images.requires_grad_()
        sigmas_pt = torch.ones((batch_size,), device=device) * sigmas[i] 

        # h = eps_multiplier * sigmas[i]
        h = sigmas[i] - sigmas[i-1]

        def score_function(x, t, input_labels=None):

            if input_labels is not None:
                #print("calling l abel net")
                denoised = net(x, t, input_labels)
            else:
                #print("calling vanilla net")
                denoised = net(x, t)
            return ((denoised - x) / t**2)

        result = compute_gradient_trace(
            f=score_function,
            x=noisy_images.reshape(3*image_size*image_size),
            t=sigmas_pt.reshape(1,),
            batch_size=5,
            mini_batch_size=1,
            labels=labels
        ).reshape((1, 3, image_size, image_size))

        if torch.isnan(result).any():
            print("nan in result")

        if has_labels:
            cleaned_images = net(noisy_images, sigmas_pt, labels)
        else:
            cleaned_images = net(noisy_images, sigmas_pt)

        score = (cleaned_images - noisy_images) / (sigmas_pt[:, None, None, None]**2)
            
        
        euler_stepped_images = noisy_images + ((cleaned_images - noisy_images) / sigmas[i]) * h
            

        if has_labels:
            euler_stepped_clean_images = net(euler_stepped_images, sigmas_pt - h, labels)
        else:
            euler_stepped_clean_images = net(euler_stepped_images, sigmas_pt - h)

        score_eps = (euler_stepped_clean_images - euler_stepped_images) / ((sigmas_pt - h)[:, None, None, None]**2)

        score_grad_norms[repeat_idx, i, :] = torch.sum(
            (sigmas_pt[:, None, None, None] ** 2) * (score_eps - score + h * sigmas[i] * result)**2,
            dim=(1,2,3)
        ).mean().sqrt().cpu().detach().numpy()