import matplotlib.pyplot as plt
import numpy as np

def plot_posteriors(posteriors, true_params, N, epsilon):

    data_dim = true_params['mu_x'].shape[0]

    true_params = np.hstack((true_params['theta'].flatten(), true_params['sigma_squared']))

    param_labels = [r'$\theta_%d$' % i for i in range(data_dim)]
    param_labels.extend([r'$\theta_{bias}$', r'$\sigma^2$'])

    fig, axes = plt.subplots(ncols=data_dim + 2) # 2 = bias term + sigma_squared

    for method in posteriors:
        posteriors[method] = np.hstack((posteriors[method][0], posteriors[method][1][:, None]))

    for p, param in enumerate(param_labels):

        ax = axes[p]

        alpha = 0.4

        bins= 100

        to_plot = [posteriors[method][:,p] for method in posteriors]

        np_hist, bins, _ = ax.hist(to_plot, bins=bins, alpha=alpha,
                                linewidth=1.5,
                                histtype='step', stacked=False, fill=False, label=list(posteriors.keys()))

        ax.axvline(true_params[p], color='k', linestyle='--', label='true parameter')

        ax.set_xlabel(param)
        ax.set_yticks(())

        if p == len(param_labels) - 1:
            ax.legend()

    plt.suptitle(r'data dim $= %d; N = %d; \epsilon = %.2f$' % (data_dim, N, epsilon))
    plt.show()

def mnlp(X, y, pred_y, pred_var):
    """
    Mean negative log probability
    The pred var must include sigma**2 (noise)
    """
    y = y.ravel()
    pred_y = pred_y.ravel()
    pred_var = pred_var.ravel() #+noise
    nll_per_x = 0.5*np.log(2*np.pi*pred_var) + 0.5*((y - pred_y)**2)/pred_var
    unstandardized_logloss = np.mean(nll_per_x)

    # control_var = np.mean((y - np.mean(y))**2) # + noise
    # control_nll_per_x = 0.5*np.log(2*np.pi*control_var) + 0.5*((y - np.mean(y))**2)/control_var
    # control_logloss = np.mean(control_nll_per_x)
    return unstandardized_logloss #- control_logloss

def get_pred_y_and_var(X, theta_samples, sigma_samples):
    pred_y = (X @ theta_samples.T).mean(1)
    pred_var = (sigma_samples.T + (X @ theta_samples.T)**2).mean(1) - pred_y**2
    return pred_y, pred_var

def compute_mnlp_from_dataset(dataset, posterior):
    X = dataset['X']
    y = dataset['y']
    theta_samples, sigma_samples = posterior[0], posterior[1]
    pred_y, pred_var = get_pred_y_and_var(X, theta_samples, sigma_samples)
    return mnlp(X, y, pred_y, pred_var)

def compute_mnlps_from_dataset(dataset, posteriors):
    """posteriors is a dict. we want to use all except the prior"""
    X = dataset['X']
    y = dataset['y']
    mnlps = []
    for key in posteriors.keys():
        if key == 'prior':
            continue
        posterior = posteriors[key]
        theta_samples, sigma_samples = posterior[0], posterior[1]
        pred_y, pred_var = get_pred_y_and_var(X, theta_samples, sigma_samples)
        mnlps.append(mnlp(X, y, pred_y, pred_var))
    return np.array(mnlps).mean(), mnlps
