import jax
import haiku as hk
import jax.numpy as jnp
from jax.example_libraries import optimizers
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset
import numpy as np
import neural_tangents as nt
import functools
import operator
import optax
import copy
import models
import sys
from utils import bind, _sub, _add
import modified_resnets


@functools.partial(jax.jit, static_argnums=(3, 6, 7, 8))
def linear_forward(params, params2, state, net_fn, rng, images, is_training = False, centering = False, return_components = False):
    dparams = _sub(params2, params)
#     lambda param: forward.apply(param, *args, **kwargs)
    f_0, df_0, state = jax.jvp(lambda param: net_fn(param, state, rng, images, is_training = is_training), (params,), (dparams,), has_aux = True)#(state, rng, images, is_training = is_training)
    
    if return_components:
        if centering:
            return df_0, {'state': state, 'f': f_0, 'df': df_0}
        return _add(f_0, df_0), {'state': state, 'f': f_0, 'df': df_0}
    
    if centering:
        return df_0, state
    return _add(f_0, df_0), state

def get_resnet(n_classes):
    def _forward_resnet18(x, is_training):
        net = modified_resnets.ResNet18(n_classes, initial_conv_config = {'kernel_shape': 3, 'stride': 1})
        return net(x, is_training)
    
    net_forward = hk.transform_with_state(_forward_resnet18)

    
    return net_forward.init, net_forward.apply


def _forward_wide_mlp(x, is_training):
    mlp = hk.Sequential([
        hk.Flatten(),
        hk.Linear(2048), jax.nn.relu,
        hk.Linear(2048), jax.nn.relu,
        hk.Linear(10),
    ])
    return mlp(x)

def get_conv_net(n_classes):
#     def _forward_convnet(x, is_training):
#         net = hk.Sequential([
#             hk.Conv2D(256, [3,3]), hk.BatchNorm(True, True, 0.9), jax.nn.relu, hk.AvgPool([2,2], [2,2], 'SAME'),
#             hk.Conv2D(256, [3,3]), hk.BatchNorm(True, True, 0.9), jax.nn.relu, hk.AvgPool([2,2], [2,2], 'SAME'),
#             hk.Conv2D(256, [3,3]), hk.BatchNorm(True, True, 0.9), jax.nn.relu, hk.AvgPool([2,2], [2,2], 'SAME'),
#             hk.Flatten(),
#             hk.Linear(n_classes),
#         ])
#         return net(x, is_training)

    def _forward_convnet(x, is_training):
        net = ConvNet(n_classes)
        return net(x, is_training)

#     net = ConvNet(n_classes)
    
#     net_forward = hk.transform_with_state(lambda x, is_training: mlp(x))
    net_forward = hk.transform_with_state(_forward_convnet)

    
    return net_forward.init, net_forward.apply


class ConvNet(hk.Module):
    def __init__(self, n_classes):
        super().__init__()
        
        self.convs = [hk.Conv2D(256, [3,3]) for i in range(3)]
        self.batchnorms = [hk.BatchNorm(True, True, 0.9) for i in range(3)]
        self.pools = [hk.AvgPool([2,2], [2,2], 'SAME') for i in range(3)]
        
        self.flatten = hk.Flatten()
        self.logits = hk.Linear(n_classes)
    
    def __call__(self, x, is_training):
        for i in range(3):
            x = self.convs[i](x)
            x = self.batchnorms[i](x, is_training = is_training)
            x = jax.nn.relu(x)
            x = self.pools[i](x)
        
        
        x = self.flatten(x)
        x = self.logits(x)
        
        return x

def get_narrow_mlp(n_classes):
    def _forward_narrow_mlp(x, is_training):
        mlp = hk.Sequential([
            hk.Flatten(),
            hk.Linear(1024), jax.nn.relu,
            hk.Linear(1024), jax.nn.relu,
            hk.Linear(n_classes),
        ])
        return mlp(x)
    
#     net_forward = hk.transform_with_state(lambda x, is_training: mlp(x))
    net_forward = hk.transform_with_state(_forward_narrow_mlp)

    
    return net_forward.init, net_forward.apply

def get_wide_mlp():
    
    
#     net_forward = hk.transform_with_state(lambda x, is_training: mlp(x))
    net_forward = hk.transform_with_state(_forward_wide_mlp)

    
    return net_forward.init, net_forward.apply


def get_model(model_name, n_classes):
    if model_name == 'resnet18':
        return get_resnet(n_classes)
    elif model_name == 'mlp':
        return get_narrow_mlp(n_classes)
    elif model_name == 'conv':
        return get_conv_net(n_classes)
    else:
        print("Invalid model: {}".format(model_name))
        
        sys.exit()