import math
import numpy as np
import torch
import torch.optim as optim
from tqdm import trange

from pnn.networks import NNet4l, ProbNNet4l, trainNNet, trainPNNet, computeRiskCertificates, testStochastic
from pnn.bounds import PBBobj
from pnn.data import loadbatches, loaddataset




def trainprior(train_loader_1, train_loader_2, bound_n_size=30000, delta=.025, delta_test=.01, mc_samples=10000, kl_penalty=1, classes=2, train_method='original', rho_prior=.03, learning_rate_prior=.001, momentum_prior=.95, prior_epochs=100, prior_train_opt='det', prior_dist='gaussian', objective='invkl', dropout_prob=.2, verbose=True, device='cpu'):
    # Initialize NN for prior
    if prior_train_opt == 'det':
        net01 = NNet4l(dropout_prob=dropout_prob, device=device).to(device)
        net02 = None

    elif prior_train_opt == 'pb':
        net01 = ProbNNet4l(rho_prior, prior_dist=prior_dist, device=device, train_method=train_method).to(device)
        net02 = None
        bound0 = PBBobj(classes=classes, delta=delta, delta_test=delta_test, mc_samples=mc_samples, kl_penalty=kl_penalty, device=device, train_n=bound_n_size, bound_n=bound_n_size, objective=objective, prior_dist=prior_dist, train_method=train_method)

    # Train prior
    optimizer_01 = optim.SGD(net01.parameters(), lr=learning_rate_prior, momentum=momentum_prior)
    if prior_train_opt == 'det':
        for epoch in trange(prior_epochs):
            trainNNet(net01, optimizer_01, epoch, train_loader_1, device=device, verbose=verbose)
            
    elif prior_train_opt == 'pb':
        for epoch in trange(prior_epochs):
            trainPNNet(net01, optimizer_01, bound0, epoch, train_loader_1, verbose=verbose)
            
    return net01, net02


def trainposterior(net01, net02, train_loader_1batch, train_loader_1, train_loader_2, set_bound_1batch_1, set_bound_1batch_2, train_loader, test_loader, bound_n_size=30000, posterior_n_size=30000, delta=.025, delta_test=.01, mc_samples=10000, kl_penalty=1, classes=2, train_method='original', rho_prior=.03, learning_rate=.001, momentum=.95, train_epochs=100, prior_train_opt='det', prior_dist='gaussian', objective='invkl', verbose=True, verbose_test=False, device='cpu', toolarge=False):
    # Initialize posterior PNN
    net_1 = ProbNNet4l(rho_prior, prior_dist=prior_dist, device=device, init_net=net01, train_method=train_method, prior_train_opt=prior_train_opt).to(device)
    net_2 = None
    bound = PBBobj(classes=classes, delta=delta, delta_test=delta_test, mc_samples=mc_samples, kl_penalty=kl_penalty, device=device, train_n=posterior_n_size, bound_n=bound_n_size, objective=objective, prior_dist=prior_dist, train_method=train_method)
    
   # Run training of posterior
    optimizer_1 = optim.SGD(net_1.parameters(), lr=learning_rate, momentum=momentum)
    
    for epoch in trange(train_epochs):
        trainPNNet(net_1, optimizer_1, bound, epoch, train_loader, verbose)
            
        if verbose_test and epoch % 20 == 0: 
            avg_net_1 = None
            avg_net_2 = None
               
            ub_risk_01, kl, err_01_train = computeRiskCertificates(avg_net_1, avg_net_2, net_1, net_2, bound, toolarge=toolarge, device=device, train_loader_1=train_loader_1, train_loader_2=train_loader_2, set_bound_1batch_1=set_bound_1batch_1, set_bound_1batch_2=set_bound_1batch_2)
            
            stch_err_1 = testStochastic(net_1, test_loader, bound, device=device)
            stch_err = stch_err_1

            
    avg_net_1 = None
    avg_net_2 = None
            
    ub_risk_01, kl, err_01_train = computeRiskCertificates(avg_net_1, avg_net_2, net_1, net_2, bound, toolarge=toolarge, device=device, train_loader_1=train_loader_1, train_loader_2=train_loader_2, set_bound_1batch_1=set_bound_1batch_1, set_bound_1batch_2=set_bound_1batch_2)

    stch_err = testStochastic(net_1, test_loader, bound, device=device)

    return ub_risk_01, stch_err


def runexp(data_slice_idx, sigma_prior, learning_rate, momentum, objective='quad',
learning_rate_prior=0.01, momentum_prior=0.95, delta=0.025, layers=9, delta_test=0.01, mc_samples=1000, kl_penalty=1, train_epochs=100, prior_dist='gaussian', 
verbose=True, device='cpu', prior_epochs=1, dropout_prob=0.2, used_dataset_size = 1000, verbose_test=True, 
perc_prior=0.5, batch_size=250, train_method='original', model_type='fcn', prior_train_opt='det', name_data='binarymnist'):
    """Run an experiment with PAC-Bayes inspired training objectives

    Parameters
    ----------
    name_data : string
        name of the dataset to use (check data file for more info)

    objective : string
        training objective to use
    
    model : string
        could be cnn or fcn
    
    sigma_prior : float
        scale hyperparameter for the prior
    
    pmin : float
        minimum probability to clamp the output of the cross entropy loss
    
    learning_rate : float
        learning rate hyperparameter used for the optimiser

    momentum : float
        momentum hyperparameter used for the optimiser

    learning_rate_prior : float
        learning rate used in the optimiser for learning the prior (only
        applicable if prior is learnt)

    momentum_prior : float
        momentum used in the optimiser for learning the prior (only
        applicable if prior is learnt)
    
    delta : float
        confidence parameter for the risk certificate
    
    layers : int
        integer indicating the number of layers (applicable for CIFAR-10, 
        to choose between 9, 13 and 15)
    
    delta_test : float
        confidence parameter for chernoff bound

    mc_samples : int
        number of monte carlo samples for estimating the risk certificate
        (set to 1000 by default as it is more computationally efficient, 
        although larger values lead to tighter risk certificates)

    kl_penalty : float
        penalty for the kl coefficient in the training objective

    train_epochs : int
        numer of training epochs for training

    prior_dist : string
        type of prior and posterior distribution (can be gaussian or laplace)

    verbose : bool
        whether to print metrics during training

    device : string
        device the code will run in (e.g. 'cuda')

    prior_epochs : int
        number of epochs used for learning the prior (not applicable if prior is rand)

    dropout_prob : float
        probability of an element to be zeroed.

    perc_train : float
        percentage of train data to use for the entire experiment (can be used to run
        experiments with reduced datasets to test small data scenarios)
    
    verbose_test : bool
        whether to print test and risk certificate stats during training epochs

    perc_prior : float
        percentage of data to be used to learn the prior

    batch_size : int
        batch size for experiments
    """

    # this makes the initialised prior the same for all bounds
    torch.manual_seed(10)
    np.random.seed(0)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    loader_kargs = {'num_workers': 1,
                    'pin_memory': True} if torch.cuda.is_available() else {}
    
    train, test = loaddataset(name_data)
    rho_prior = math.log(math.exp(sigma_prior)-1.0)


    n_train = used_dataset_size # number of total points to be used
    perc_train = 1;             

    # shuffle the data and targets in train
    idx = torch.randperm(train.data.shape[0])     
    train.data = train.data[idx]
    train.targets = train.targets[idx]
    
    # slice it according to data_slice_idx
    data_idx_start = data_slice_idx*n_train
    data_idx_end = (data_slice_idx+1)*n_train
    train.data = train.data[data_idx_start:data_idx_end]
    train.targets = train.targets[data_idx_start:data_idx_end]
    ## this is now the new dataset to use

    # Load data
    train_loader_1batch, train_loader_1, train_loader_2, set_bound_1batch_1, set_bound_1batch_2, train_loader, test_loader = loadbatches(train, test, loader_kargs, batch_size, perc_train=perc_train, perc_prior=perc_prior)

    # Sizes of data subsets
    posterior_n_size = len(train_loader.dataset)                       # number of data points used to train the posterior (i.e. all training data)
    bound_n_size = len(train_loader.dataset) * (1-perc_prior)          # number of data points used to compute the risk certificate (i.e. all train - num used in prior)
    
    toolarge = False

    # Run prior training
    net01, net02 = trainprior(train_loader_1, train_loader_2, bound_n_size=bound_n_size, delta=delta, delta_test=delta_test, mc_samples=mc_samples, kl_penalty=kl_penalty, train_method=train_method, rho_prior=rho_prior, learning_rate_prior=learning_rate_prior, momentum_prior=momentum_prior, prior_epochs=prior_epochs, prior_train_opt=prior_train_opt, prior_dist=prior_dist, objective=objective, dropout_prob=dropout_prob, verbose=verbose, device=device)
    
    # Run posterior training
    ub_risk_01, stch_err = trainposterior(net01, net02, train_loader_1batch, train_loader_1, train_loader_2, set_bound_1batch_1, set_bound_1batch_2, train_loader, test_loader, bound_n_size=bound_n_size, posterior_n_size=posterior_n_size, delta=delta, delta_test=delta_test, mc_samples=mc_samples, kl_penalty=kl_penalty, train_method=train_method, rho_prior=rho_prior, learning_rate=learning_rate, momentum=momentum, train_epochs=train_epochs, prior_train_opt=prior_train_opt, prior_dist=prior_dist, objective=objective, verbose=verbose, verbose_test=verbose_test, device=device, toolarge=toolarge)

    return ub_risk_01, stch_err
