import time
import argparse
import jax
import matplotlib.pyplot as plt
import optax
import matfree
import tree_math as tm
from flax import linen as nn
from jax import nn as jnn
from jax import numpy as jnp
from jax import random, jit
import pickle
from src.losses import cross_entropy_loss, accuracy_preds, nll
from src.helper import get_gvp_fun, compute_num_params, get_ggn_vector_product
from src.sampling.predictive_samplers import sample_predictive, sample_hessian_predictive
from src.models import ResNet, ResNetBlock, PreActResNetBlock
from jax import flatten_util
import matplotlib.pyplot as plt
from src.data import CIFAR10, n_classes
from src.sampling.lanczos_diffusion import lanczos_diffusion
import torch
from src.data.torch_datasets import MNIST, numpy_collate_fn
import flax

if __name__=="__main__":
    param_dict = pickle.load(open("./checkpoints/CIFAR-10/ResNet/epoch200_seed0_params.pickle", "rb"))
    params = param_dict['params']
    batch_stats = param_dict['batch_stats']
    output_dim = 10
    model = ResNet(
            num_classes = output_dim,
            c_hidden =(16, 32, 64),
            num_blocks = (3, 3, 3),
            act_fn = nn.relu,
            block_class = ResNetBlock #PreActResNetBlock #
        )
    n_samples_per_class = None
    cls=list(range(n_classes("CIFAR-10")))
    dataset = CIFAR10(path_root='/xxx/data', train=True, n_samples_per_class=n_samples_per_class, download=True, cls=cls, seed=0)
    data_array = jnp.array([data[0] for data in dataset])
    X = data_array

    # ggn_vector_product_2 = get_ggn_vector_product(
    #                                     params,
    #                                     model,
    #                                     data_array = X,
    #                                     likelihood_type = "classification",
    #                                     is_resnet=True,
    #                                     batch_stats=batch_stats
    #                                     )
    D = compute_num_params(params)
    v0 = jnp.ones(D,)
    # print(jnp.linalg.norm(ggn_vector_product_2(v0)))

    gvp_bs = 5000
    N = X.shape[0]//gvp_bs
    data_array = X[: N * gvp_bs].reshape((N, gvp_bs)+ X.shape[1:])
    ggn_vector_product_1 = get_gvp_fun(params,
                                   model,
                                   data_array,
                                   gvp_bs, 
                                   "classification",
                                   "running",
                                   is_resnet=True,
                                   batch_stats=batch_stats
    )
    print(jnp.linalg.norm(ggn_vector_product_1(v0)))

    n_steps = 5
    n_samples = 5
    alpha = 1.0
    sample_key = jax.random.PRNGKey(0)
    n_params = D
    rank = 5
    gvp_type = "batch-sum"
    start_time = time.time()
    nonker_posterior_samples = lanczos_diffusion(model, 
                                                 params,
                                                 n_steps,
                                                 n_samples,
                                                 alpha,
                                                 sample_key,
                                                 n_params,
                                                 rank,
                                                 data_array,
                                                 "classification",
                                                 1.0,
                                                 "non-kernel-eigvals",
                                                 gvp_type,
                                                 gvp_bs,
                                                 is_resnet=True,
                                                 batch_stats=batch_stats
                                                 )
    print(f"Lanczos diffusion (for a {n_params} parameter model with {n_steps - 1} steps, {n_samples} samples and {rank} iterations) took {time.time()-start_time:.5f} seconds")
    
    breakpoint()
    


