#!/usr/bin/env python

import warnings
warnings.filterwarnings("ignore")


import sys
import concurrent.futures

if len(sys.argv) !=2:
    print("Usage ", sys.argv[0]," <p> ")
    sys.exit()
else:
    p = int(sys.argv[1])


import sys
import psutil

def check_available_memory():
    mem = psutil.virtual_memory()
    return mem.available
#print("Python version:", sys.version)


import numpy as np
import pickle
import pandas as pd
import math
import scipy.stats as stats
import scipy.interpolate as interpolator
from scipy.stats import rankdata
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.init as init
import torch.nn.functional as F
from torch import distributions
from torch.utils.data import Dataset, DataLoader
from torch.distributions import Uniform, Normal, StudentT, MultivariateNormal
import xitorch.interpolate as xi
import time
import pyvinecopulib as pv
from sklearn.datasets import load_digits
from sklearn.decomposition import PCA
from sklearn.model_selection import KFold
from sklearn.model_selection import train_test_split
import gzip
from sklearn import datasets





def cdf_lomax(x, a):
  return 1 - (1+x) ** (-a)

def pdf_lomax(x, a):
  return a * (1 + x) **(-a-1)

def alpha(step):
  #alpha value derived by (Fong et al. 2021)
  i = step
  alpha = (2 - 1/i) * (1/(i+1))
  return torch.tensor(alpha, dtype = torch.float32)

def torch_ecdf(torch_data):
    data = torch_data.detach().numpy()
    data = pd.DataFrame(data)
    pobs = {}
    for i in range(data.shape[1]):
      ticker = data.columns[i]
      series = data.iloc[:,i].values
      pobs[ticker] = rankdata(data.iloc[:,i].values)/(len(data.iloc[:,i].values)+1)
    pobs = pd.DataFrame(pobs)
    pobs = np.array(pobs)
    if torch.isnan(torch.tensor(pobs).reshape(len(torch_data))).any():
      print('Error: NaN in empirical cdf')

    return torch.tensor(pobs).reshape(len(torch_data))

class TweakedUniform(torch.distributions.Uniform):
    def log_prob(self, value, context):
        return sum_except_batch(super().log_prob(value))
        # result = super().log_prob(value)
        # if len(result.shape) == 2 and result.shape[1] == 1:
        #     return result.reshape(-1)
        # else:
        #     return result

    def sample(self, num_samples, context):
        return super().sample((num_samples, ))

#some utils
def is_int(x):
    return isinstance(x, int)

def is_nonnegative_int(x):
    return is_int(x) and x >= 0

def sum_except_batch(x, num_batch_dims=1):
    """Sums all elements of `x` except for the first `num_batch_dims` dimensions."""
    if not is_nonnegative_int(num_batch_dims):
        raise TypeError('Number of batch dimensions must be a non-negative integer.')
    reduce_dims = list(range(num_batch_dims, x.ndimension()))
    return torch.sum(x, dim=reduce_dims)

def inverse_std_normal(cumulative_prob):
	'''
	Inverse of the standard normal CDF.
	'''
	cumulative_prob_doube = torch.clip(cumulative_prob.double(),1e-6,1- (1e-6))
	return torch.erfinv(2 * cumulative_prob_doube - 1) * torch.sqrt(torch.tensor(2.0))
def cdf_std_normal(input):
  return torch.clamp(torch.distributions.normal.Normal(loc = 0, scale = 1).cdf(input),1e-6,1- (1e-6))

def pdf_std_normal(input):
  return torch.distributions.normal.Normal(loc = 0, scale = 1).log_prob(input).exp()

def bvn_density(rho, u, v, shift = 0.0, scale = 1.0):

  if len(u) != len(v):
    print('Error: length of u and v should be equal')
  else:
   mean = torch.tensor([shift, shift])
   covariance_matrix = torch.tensor([[scale, rho], [rho, scale]])
   multivariate_normal = torch.distributions.MultivariateNormal(mean, covariance_matrix)

   l = len(u)
   input = torch.cat([u.reshape(l, 1), v.reshape(l, 1)], dim=1)

   return multivariate_normal.log_prob(inverse_std_normal(input)).exp()

def GC_density(rho, u, v, shift = 0.0, scale = 1.0):

  v_d = pdf_std_normal(inverse_std_normal(v)).reshape(len(v), 1)
  u_d = pdf_std_normal(inverse_std_normal(u)).reshape(len(u), 1)
  low = u_d * v_d

  up = bvn_density(rho = rho, u = u, v = v).reshape(len(u), 1)

  return up / low

def cbvn_density(rho, u, v, shift = 0.0, scale = 1.0):

   mean = torch.tensor([shift, shift])
   covariance_matrix = torch.tensor([[scale, rho], [rho, scale]])
   multivariate_normal = torch.distributions.MultivariateNormal(mean, covariance_matrix)

   l = len(u)
   input = torch.cat([u.reshape(l, 1), v * torch.ones(l, 1)], dim=1)

   return multivariate_normal.log_prob(inverse_std_normal(input)).exp()

def cGC_density(rho, u, v, shift = 0.0, scale = 1.0):

  l = len(u)

  v_d = pdf_std_normal(inverse_std_normal(v))
  u_d = pdf_std_normal(inverse_std_normal(u)).reshape(l, 1)
  low = u_d * v_d

  up = cbvn_density(rho = rho, u = u, v = v).reshape(l, 1)

  return up / low

def cGC_distribution(rho, u, v, shift = 0.0, scale = 1.0):
  upper = inverse_std_normal(u).reshape(len(u), 1) - rho * inverse_std_normal(v)
  lower = torch.sqrt(torch.tensor(1 - rho ** 2))
  input = upper / lower

  return cdf_std_normal(input)

#Drop highly correlated variables
def drop_corr(y,threshold= 0.98):

    data = pd.DataFrame(y)
    # Create correlation matrix
    corr_matrix = data.corr().abs()

    # Select upper triangle of correlation matrix
    upper = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(bool))

    # Find index of feature columns with correlation greater than 0.95
    to_drop = [col for col in upper.columns if any(upper[col] > threshold)]
    y = data.drop(columns = to_drop).values

    return(y)

def create_permutatons(obs, perms):

  permutations = []
  L = obs.shape[0]

  for _ in range(perms):

    permutation = torch.randperm(L)
    sequence = obs[permutation, :]
    permutations.append(sequence)

  return torch.stack(permutations)

def batched_permutatons(obs, perms):

  permutations = []
  L = obs.shape[0]
  dims = obs.shape[1]

  for _ in range(perms):
    permutation = torch.randperm(L)
    for d in range(dims):
       sequence = obs[permutation, d]
       permutations.append(sequence)

  return torch.stack(permutations)

def bayestcop(input, latents, priordeg, priorcov):

  num_data = latents.shape[0]
  num_dim = latents.shape[1]

  postdegs = num_data + priordeg
  postcov = priorcov
  for n in range(num_data):
      postcov =+ torch.matmul(latents[n,:].reshape([1,num_dim]), latents[n,:].reshape([num_dim,1]))
  postcov = postcov /  (postdegs-num_dim+1)

  return

def Energy_Score_pytorch(beta, observations_y, simulations_Y):
        '''
        This version is basically the CRPS as the obs are 1D. Check that inputs are 1D if using this.
        '''
        n = len(observations_y)
        m = len(simulations_Y)

        # First part |Y-y|. Gives the L2 dist scaled by power beta. Is a vector of length n/one value per location.
        diff_Y_y = torch.pow(
            torch.abs( # Absolute value because 1D obs
                (observations_y.unsqueeze(1) -
                simulations_Y.unsqueeze(0)).float(),
                ),
            beta)

        # Second part |Y-Y'|. 2* because pdist counts only once.
        diff_Y_Y = torch.pow(
            torch.abs( # Absolute value because 1D obs
                (simulations_Y.unsqueeze(1) -
                simulations_Y.unsqueeze(0)).float(),
                ),
            beta)
        Energy = 2 * torch.mean(diff_Y_y) - torch.sum(diff_Y_Y) / (m * (m - 1))
        return Energy

#Simulate from d-dimensional diagonal GMM
def simulate_GMM(d):
    np.random.seed(100)
    K = 2
    n = 100
    n_test = 1000
    mu = np.array([[2]*d,[-1]*d])
    sigma2 = np.ones((K,d))
    z = stats.bernoulli.rvs(p = 0.5,size = n)
    y = np.random.randn(n,d)*np.sqrt(sigma2[z,:]) + mu[z,:]
    mean_norm = np.mean(y,axis = 0)
    std_norm = np.std(y,axis = 0)
    y = (y- mean_norm)/std_norm

    z_test = stats.bernoulli.rvs(p = 0.5,size = n_test)
    y_test = np.random.randn(n_test,d)*np.sqrt(sigma2[z_test,:]) + mu[z_test,:]

    y_test = (y_test - mean_norm)/std_norm #normalize test data to have 0 mean, 1 std

    y = torch.tensor(y, dtype=torch.float32)
    y_test = torch.tensor(y_test, dtype=torch.float32)

    return y,y_test

def minmax_unif(obs):
  '''
  An informative uniform prior whose support is same as data's
  '''
  min = torch.min(obs) - 0.001
  max = torch.max(obs) + 0.001
  log_pdfs = torch.distributions.uniform.Uniform(min, max).log_prob(obs)
  cdfs = torch.distributions.uniform.Uniform(min, max).cdf(obs)
  return cdfs, log_pdfs.exp()

def empirical_dist(obs):
  '''
  An informative empirical distribution which put equal probability on each data point
  '''
  N = obs.shape[0]
  return torch_ecdf(obs), torch.ones_like(obs) * (1/N)

def energy_cv(data, K, up = 4., low = 2., size = 10, beta = .5):
  kfold = KFold(n_splits=K, random_state=100, shuffle=True)
  bgrids = np.linspace(low, up, size)
  in_sample = torch.zeros([size, K])
  for train, test in kfold.split(data):
    i = 0
    for epoch in range(size):
      controls = pv.FitControlsVinecop(family_set=[pv.BicopFamily.tll], selection_criterion='mbic', nonparametric_method='constant', nonparametric_mult=bgrids[epoch], num_threads = 2048)
      cop = pv.Vinecop(data[train], controls=controls)
      news = cop.simulate(100)
      in_sample[epoch, i] = Energy_Score_pytorch(beta, data[test], torch.tensor(news, dtype=torch.float32))
    i = i + 1
  in_sample_err = torch.mean(in_sample, dim=1)
  return bgrids[torch.argmin(in_sample_err)]


def get_context(observations, rhovec, init_dist = 'Normal', a = 1.):

    flt = 1e-6

    num_perm = observations.shape[0]
    num_data = observations.shape[1]
    num_dim = observations.shape[2]

    context = torch.zeros([num_perm, num_data, num_dim])

    for j in range(num_dim):

      for perm in range(num_perm):

        if init_dist == 'Normal':

          cdf = torch.distributions.normal.Normal(loc=0, scale=1).cdf(observations[perm,:,j]).reshape(num_data)

        if init_dist == 'Cauchy':

          cdf = torch.distributions.cauchy.Cauchy(loc=0.0, scale=1.0).cdf(observations[perm,:,j]).reshape(num_data)

        if init_dist == 'Lomax':

          cdf = cdf_lomax(observations[perm,:,j], a)

        if init_dist == 'Unif':

          cdf, _ = minmax_unif(observations[perm,:,j].reshape(num_data))

        cdf = torch.clip(cdf, min=flt, max=1.+flt)

        context[perm, 0, j] = cdf[0]

        for k in range(1, num_data):

          Cop = cGC_distribution(rho = rhovec[j], u = cdf[1:], v = cdf[0]).reshape(num_data-k)
          cdf = (1 - alpha(k)) * cdf[1:] + alpha(k) * Cop
          cdf = torch.clip(cdf, min=flt, max=1.+flt)
          context[perm, k, j] = cdf[0]

    return context


def get_context_single_dim_single_rho( observations_d, rho_d, init_dist = 'Normal', a = 1.):
    '''
    Gets the v values for the R-BP recusion, which are CDFs of training data.
    It does so for a single dimension and a single rho.
    
    observations_d: torch.tensor, the data for the dimension.
    rho_d: float, the rho value to use.
    init_dist: str, the initial distribution to use. Options are 'Normal', 'Cauchy', 'Lomax', 'Unif'.
    a: float, the shape parameter for the Lomax distribution.
    '''
    flt = 1e-6

    num_perm = observations_d.shape[0]
    num_data = observations_d.shape[1]

    context = torch.zeros([num_perm, num_data])

    for perm in range(num_perm):

      if init_dist == 'Normal':

        cdf = torch.distributions.normal.Normal(loc=0, scale=1).cdf(observations_d[perm,:]).reshape(num_data)

      if init_dist == 'Cauchy':

        cdf = torch.distributions.cauchy.Cauchy(loc=0.0, scale=1.0).cdf(observations_d[perm,:]).reshape(num_data)

      if init_dist == 'Lomax':

        cdf = cdf_lomax(observations_d[perm,:], a)

      if init_dist == 'Unif':

        cdf, _ = minmax_unif(observations_d[perm,:].reshape(num_data))

      cdf = torch.clip(cdf, min=flt, max=1.+flt)

      context[perm, 0] = cdf[0]

      for k in range(1, num_data):

        Cop = cGC_distribution(rho = rho_d, u = cdf[1:], v = cdf[0]).reshape(num_data-k)
        cdf = (1 - alpha(k)) * cdf[1:] + alpha(k) * Cop
        cdf = torch.clip(cdf, min=flt, max=1.+flt)
        context[perm, k] = cdf[0]

    return context

def evaluate_prcopula(test_points, context, rhovec, init_dist = 'Normal', a = 1.):

      flt = 1e-6

      num_evals = test_points.shape[0]
      num_perm = context.shape[0]
      num_data = context.shape[1]
      num_dim = test_points.shape[1]

      dens = torch.zeros([num_perm, num_evals, num_dim])
      cdfs = torch.zeros([num_perm, num_evals, num_dim])

      for j in range(num_dim):

        for perm in range(num_perm):

            if init_dist == 'Normal':

                cdf = torch.distributions.normal.Normal(loc=0, scale=1).cdf(test_points[:,j]).reshape(num_evals)
                pdf = torch.distributions.normal.Normal(loc=0, scale=1).log_prob(test_points[:,j]).exp().reshape(num_evals)

            if init_dist == 'Cauchy':

                cdf = torch.distributions.cauchy.Cauchy(loc=0.0, scale=1.0).cdf(test_points[:,j]).reshape(num_evals)
                pdf = torch.distributions.cauchy.Cauchy(loc=0.0, scale=1.0).log_prob(test_points[:,j]).exp().reshape(num_evals)

            if init_dist == 'Lomax':

                cdf = cdf_lomax(test_points[:,j], a)
                pdf = pdf_lomax(test_points[:,j], a)

            if init_dist == 'Unif':

                cdf, pdf = minmax_unif(test_points[:,j].reshape(num_evals))

            cdf = torch.clip(cdf, min=flt, max=1.+flt)

            for k in range(0, num_data):

                cop = cGC_density(rho = rhovec[j], u = cdf, v = context[perm, k, j]).reshape(num_evals)
                Cop = cGC_distribution(rho = rhovec[j], u = cdf, v = context[perm, k, j]).reshape(num_evals)
                cdf = (1 - alpha(k+1)) * cdf + alpha(k+1) * Cop
                cdf = torch.clip(cdf, min=flt, max=1.+flt)
                pdf = (1 - alpha(k+1)) * pdf + alpha(k+1) * cop * pdf

            dens[perm, :, j] = pdf
            cdfs[perm, :, j] = cdf

      return torch.mean(dens, dim=0), torch.mean(cdfs, dim=0)


def evaluate_prcopula_single_cdf_d(args):
    '''
    maps test_points_d to cdf values for one dimension
    '''
    dim, test_points_d, context, rho_d, init_dist , a  = args
    flt = 1e-6

    num_evals = test_points_d.shape[0]
    num_perm = context.shape[0]
    num_data = context.shape[1]
    cdfs = torch.zeros([num_perm, num_evals])



    for perm in range(num_perm):

        if init_dist == 'Normal':

            cdf = torch.distributions.normal.Normal(loc=0, scale=1).cdf(test_points_d).reshape(num_evals)
            #pdf = torch.distributions.normal.Normal(loc=0, scale=1).log_prob(test_points[:,j]).exp().reshape(num_evals)

        if init_dist == 'Cauchy':

            cdf = torch.distributions.cauchy.Cauchy(loc=0.0, scale=1.0).cdf(test_points_d).reshape(num_evals)
            #pdf = torch.distributions.cauchy.Cauchy(loc=0.0, scale=1.0).log_prob(test_points[:,j]).exp().reshape(num_evals)

        if init_dist == 'Lomax':

            cdf = cdf_lomax(test_points_d, a)
            #pdf = pdf_lomax(test_points[:,j], a)

        if init_dist == 'Unif':

            cdf, pdf = minmax_unif(test_points_d.reshape(num_evals))

        cdf = torch.clip(cdf, min=flt, max=1.+flt)

        for k in range(0, num_data):

            Cop = cGC_distribution(rho = rho_d, u = cdf, v = context[perm, k]).reshape(num_evals)
            cdf = (1 - alpha(k+1)) * cdf + alpha(k+1) * Cop
            cdf = torch.clip(cdf, min=flt, max=1.+flt)


        cdfs[perm, :] = cdf

    return (dim,torch.mean(cdfs, dim=0))
  
  
def evaluate_prcopula_single_pdfandcdf_d(args):
  '''
  maps test_points_d to cdf values for one dimension
  '''
  dim, test_points_d, context, rho_d, init_dist , a  = args
  flt = 1e-6

  num_evals = test_points_d.shape[0]
  num_perm = context.shape[0]
  num_data = context.shape[1]
  cdfs = torch.zeros([num_perm, num_evals])
  dens = torch.zeros([num_perm, num_evals])



  for perm in range(num_perm):

      if init_dist == 'Normal':

          cdf = torch.distributions.normal.Normal(loc=0, scale=1).cdf(test_points_d).reshape(num_evals)
          pdf = torch.distributions.normal.Normal(loc=0, scale=1).log_prob(test_points_d).exp().reshape(num_evals)

      if init_dist == 'Cauchy':

          cdf = torch.distributions.cauchy.Cauchy(loc=0.0, scale=1.0).cdf(test_points_d).reshape(num_evals)
          pdf = torch.distributions.cauchy.Cauchy(loc=0.0, scale=1.0).log_prob(test_points_d).exp().reshape(num_evals)

      if init_dist == 'Lomax':

          cdf = cdf_lomax(test_points_d, a)
          pdf = pdf_lomax(test_points_d, a)

      if init_dist == 'Unif':

          cdf, pdf = minmax_unif(test_points_d.reshape(num_evals))

      cdf = torch.clip(cdf, min=flt, max=1.+flt)

      for k in range(0, num_data):
          cop = cGC_density(rho = rho_d, u = cdf, v = context[perm, k]).reshape(num_evals)
          Cop = cGC_distribution(rho = rho_d, u = cdf, v = context[perm, k]).reshape(num_evals)
          cdf = (1 - alpha(k+1)) * cdf + alpha(k+1) * Cop
          cdf = torch.clip(cdf, min=flt, max=1.+flt)
          pdf = (1 - alpha(k+1)) * pdf + alpha(k+1) * cop * pdf

      dens[perm, :] = pdf
      cdfs[perm, :] = cdf

  return (dim,torch.mean(dens, dim=0),torch.mean(cdfs, dim=0))


def grids_cdfs(size, context, rhovec, data, extrap_tail = .1, init_dist = 'Normal', a = 1.):

      flt = 1e-6

      num_perm = context.shape[0]
      num_data = context.shape[1]
      num_dim = context.shape[2]

      gridmat = torch.zeros([size, num_dim])

      cdfs = torch.zeros([num_perm, size, num_dim])

      for j in range(num_dim):

        min = torch.min(data[:,j]) - extrap_tail
        max = torch.max(data[:,j]) + extrap_tail
        xgrids = torch.linspace(min, max, size)
        gridmat[:,j] = xgrids

        for perm in range(num_perm):

            if init_dist == 'Normal':

                cdf = torch.distributions.normal.Normal(loc=0, scale=1).cdf(xgrids).reshape(size)

            if init_dist == 'Cauchy':

                cdf = torch.distributions.cauchy.Cauchy(loc=0.0, scale=1.0).cdf(xgrids).reshape(size)

            if init_dist == 'Lomax':

                cdf = cdf_lomax(xgrids, a)

            if init_dist == 'Unif':

                cdf, _ = minmax_unif(xgrids.reshape(size))

            cdf = torch.clip(cdf, min=flt, max=1.+flt)

            for k in range(0, num_data):

                Cop = cGC_distribution(rho = rhovec[j], u = cdf, v = context[perm, k, j]).reshape(size)
                cdf = (1 - alpha(k+1)) * cdf + alpha(k+1) * Cop
                cdf = torch.clip(cdf, min=flt, max=1.+flt)

            cdfs[perm, :, j] = cdf

      return gridmat, torch.mean(cdfs, dim=0)
    
    
def grids_cdfs_single_dim_single_rho(size, context_d_rho, current_rho, obs_d, extrap_tail = .1, init_dist = 'Normal', a = 1.):

      '''
      Gets CDF values for a single dimension and a single rho. These are used to numerically invert the CDF for the R-BP recursion to sample from it.
      
      size: int, the number of grid points to evaluate the CDF at.
      context_d_rho: a 2d list. context[perm, k]
      current_rho: float, the rho value to optimize.
      obs_d: torch.tensor, the data for the dimension.
      '''
      flt = 1e-6

      num_perm = context_d_rho.shape[0]
      num_data = context_d_rho.shape[1]

      cdfs = torch.zeros([num_perm, size])

      min = torch.min(obs_d) - extrap_tail
      max = torch.max(obs_d) + extrap_tail
      xgrids = torch.linspace(min, max, size)

      for perm in range(num_perm):
          # define p0
          if init_dist == 'Normal':

              cdf = torch.distributions.normal.Normal(loc=0, scale=1).cdf(xgrids).reshape(size)

          if init_dist == 'Cauchy':

              cdf = torch.distributions.cauchy.Cauchy(loc=0.0, scale=1.0).cdf(xgrids).reshape(size)

          if init_dist == 'Lomax':

              cdf = cdf_lomax(xgrids, a)

          if init_dist == 'Unif':

              cdf, _ = minmax_unif(xgrids.reshape(size))

          cdf = torch.clip(cdf, min=flt, max=1.+flt)

          for k in range(0, num_data):
          # recursion for p_i i>0

              Cop = cGC_distribution(rho = current_rho, u = cdf, v = context_d_rho[perm, k]).reshape(size)
              cdf = (1 - alpha(k+1)) * cdf + alpha(k+1) * Cop
              cdf = torch.clip(cdf, min=flt, max=1.+flt)

          cdfs[perm, :] = cdf

      return xgrids, torch.mean(cdfs, dim=0) # return the grid and the mean cdf values across permutations, for 1 dimension and 1 rho




def linear_energy_grid_search(observations, rhovec, beta = 0.5, size = 1000, evalsz = 100, extrap_tail = .1, extrap_bound = .5, init_dist = 'Normal', a = 1.):
  '''
  Grid search optimization by Energy Score for [rho1, rho2, ... rhoD]. Works for a single grid cell of rhos.
  '''
  ctxtmat = get_context(observations, rhovec, init_dist, a)
  gridmatrix, gridcdf = grids_cdfs(size, ctxtmat, rhovec, observations, extrap_tail, init_dist, a)
  sams = torch.rand([evalsz, observations.shape[2]])
  scores = torch.zeros([observations.shape[2]])
  for dim in range(observations.shape[2]):
    lcb = torch.min(gridmatrix[:,dim].reshape([gridmatrix.shape[0]])) - extrap_bound
    ucb = torch.max(gridmatrix[:,dim].reshape([gridmatrix.shape[0]])) + extrap_bound
    sorted_grids = torch.cat([lcb.unsqueeze(0), gridmatrix[:,dim].reshape([gridmatrix.shape[0]]), ucb.unsqueeze(0)])
    cdf_values = torch.cat([torch.tensor(0.0).unsqueeze(0), gridcdf[:,dim].reshape([gridcdf.shape[0]]), torch.tensor(1.0).unsqueeze(0)])
    inv = xi.Interp1D(cdf_values, sorted_grids, method="linear")
    scores[dim] = Energy_Score_pytorch(beta, observations[0, :, dim].reshape([observations.shape[1], 1]), inv(sams[:,dim]).reshape([evalsz, 1]))

  return scores


def single_dand_rho_linear_energy_grid_search(dim, rho_d, observations_d, beta = 0.5, size = 1000, evalsz = 100, extrap_tail = .1, extrap_bound = .5, init_dist = 'Normal', a = 1.):
  '''
  Grid search optimization by Energy Score for a single dimension and rho.  return (dim,rho_d,e_score)
  
  dim: int, the dimension to optimize.
  rho_d: float, the rho value to optimize.
  observations_d: torch.tensor, the data for the dimension.
  
  '''
  # returns a 2d list of v values for the R-BP across permutations and obs, for a single rho and dimension. context[perm, k]
  context_d_rho = get_context_single_dim_single_rho(observations_d, rho_d, init_dist, a)  
  # Gets the input points and CDF values for a single dimension and a single rho, averaged over perms.  xgrids, torch.mean(cdfs, dim=0)
  gridmatrix, gridcdf = grids_cdfs_single_dim_single_rho(size, context_d_rho, rho_d, observations_d, extrap_tail, init_dist, a)
  sams = torch.rand([evalsz]) # sample from uniform in 1D
  
  lcb = torch.min(gridmatrix) - extrap_bound # lower bound
  ucb = torch.max(gridmatrix) + extrap_bound # upper bound
  sorted_grids = torch.cat([lcb.unsqueeze(0), gridmatrix, ucb.unsqueeze(0)]) # add the bounds to the grid
  cdf_values = torch.cat([torch.tensor(0.0).unsqueeze(0), gridcdf, torch.tensor(1.0).unsqueeze(0)]) # add the bounds to the cdf
  inv = xi.Interp1D(cdf_values.flatten(), sorted_grids.flatten(), method="linear") # get the inverse cdf function in 1D
  e_score = Energy_Score_pytorch(beta, observations_d[0], inv(sams).flatten()) # get the energy score for the dimension

  return (dim,rho_d,e_score)

def extract_grids_search(scores, lower = 0.4, upper = 0.99):
  '''
  Get the optimal theta for each marginal
  '''
  size = scores.shape[0]
  num_dim = scores.shape[1]
  theta_dic = torch.linspace(lower, upper, size)
  optimums = torch.zeros([num_dim])
  for dim in range(num_dim):
    interim = scores[:,dim].reshape([size])
    optimums[dim] = theta_dic[torch.argmin(interim)]

  return optimums


def exract_grids_search_single_dim_single_rho(results):
  '''
  Get the optimal theta for all dimensions.
  args is an array like [[(dim_out,rho_d_out,e_score_out)], [(dim_out,rho_d_out,e_score_out)],...]
  '''

  min_scores = {}
  for result in results:
      dim_out, rho_d_out, e_score_out =result
      if dim_out not in min_scores :
          min_scores[dim_out] = [(rho_d_out, e_score_out)]
      else:
          min_scores[dim_out].append((rho_d_out, e_score_out))
  min_rho = np.zeros(len(min_scores))
  for dim in min_scores:
    min_rho[int(dim)] = min(min_scores[int(dim)], key = lambda x: x[1])[0]
  

  return min_rho


def linvsampling(observations, context, sams, rhovec, beta = 0.5, approx = 1000, extrap_tail = .1, extrap_bound = .5, init_dist = 'Normal', a = 1.):
  gridmatrix, gridcdf = grids_cdfs(approx, context, rhovec, observations, extrap_tail, init_dist, a)
  for dim in range(observations.shape[2]):
    lcb = torch.min(gridmatrix[:,dim].reshape([gridmatrix.shape[0]])) - extrap_bound
    ucb = torch.max(gridmatrix[:,dim].reshape([gridmatrix.shape[0]])) + extrap_bound
    sorted_grids = torch.cat([lcb.unsqueeze(0), gridmatrix[:,dim].reshape([gridmatrix.shape[0]]), ucb.unsqueeze(0)])
    cdf_values = torch.cat([torch.tensor(0.0).unsqueeze(0), gridcdf[:,dim].reshape([gridcdf.shape[0]]), torch.tensor(1.0).unsqueeze(0)])
    inv = xi.Interp1D(cdf_values, sorted_grids, method="linear")
    sams[:,dim] = inv(sams[:,dim])

  return sams



def extract_future_ctxt(futures_ctxt,dim_nb):
  '''
  takes a list [(dim,ctx_dim),...] of len=dim 
  and returns a tensor of shape torch.Size([perms, train_n, dim_nb]) ordered in dims.
  '''
  # Create a list to store the tensors in the correct order
  ordered_tensors = [None] * dim_nb

  # Populate the ordered_tensors list according to the dim
  for dim, ctx_dim in futures_ctxt:
      ordered_tensors[dim] = torch.tensor(ctx_dim)
  
  # Stack the tensors along a new dimension
  result_tensor = torch.stack(ordered_tensors, dim=-1)

  return result_tensor


def extract_future_pseudos(futures_ctxt,dim_nb):
  '''
  takes a list [(dim,pseudos),...] of len=dim  with pseudos of shape torch.Size([train_n])
  and returns a tensor of shape torch.Size([train_n, dim_nb]) ordered in dims.
  '''
  # Create a list to store the tensors in the correct order
  ordered_tensors = [None] * dim_nb

  # Populate the ordered_tensors list according to the dim
  for dim, pseudo_dim in futures_ctxt:
      ordered_tensors[dim] = torch.tensor(pseudo_dim)
  
  # Stack the tensors along a new dimension
  result_tensor = torch.stack(ordered_tensors, dim=-1)

  return result_tensor


def extract_future_pseudos_train(futures_ctxt,dim_nb):
  '''
  takes a list [(dim,pseudos),...] of len=dim  with pseudos of shape torch.Size([train_n])
  and returns a tensor of shape torch.Size([train_n, dim_nb]) ordered in dims.
  '''
  # Create a list to store the tensors in the correct order
  ordered_tensors_nll = [None] * dim_nb
  ordered_tensors_cdf = [None] * dim_nb

  # Populate the ordered_tensors list according to the dim
  for dim, nll_d, pseudo_dim in futures_ctxt:
      ordered_tensors_nll[dim] = torch.tensor(nll_d)
      ordered_tensors_cdf[dim] = torch.tensor(pseudo_dim)

  # Stack the tensors along a new dimension
  result_tensor_nll = torch.stack(ordered_tensors_nll, dim=-1)
  result_tensor_cdf = torch.stack(ordered_tensors_cdf, dim=-1)

  return result_tensor_nll, result_tensor_cdf



















def sample_GMM(K,n,d,probs):
    '''
    Sample from a Gaussian Mixture Model with K components, n samples and d dimensions with probs giving the probability for each cluster.
    '''
    if np.sum([int(p*n) for p in probs])!=n:
        raise ValueError('The sum of probs*n must be equal to the number of samples.')
    for k in range(K):
        mean = np.random.uniform(-50,50,d)
        cov = stats.wishart(df=d, scale=np.eye(d)).rvs()
        if k == 0:
            samples = stats.multivariate_normal(mean=mean, cov=cov,allow_singular=True).rvs(size=int(probs[k]*n))
        else:
            samples = np.vstack((samples,stats.multivariate_normal(mean=mean, cov=cov,allow_singular=True).rvs(size=int(probs[k]*n))))
    return samples[np.random.permutation(n)]

# k=4 [0.2,0.3,0.1,0.4] 
























for seed in range(5):


    torch.manual_seed(seed)
    np.random.seed(seed)
    # Load the digits dataset
    digits = datasets.load_digits()
    df = digits.images.reshape(1797,64)

    df = pd.DataFrame(df)


    p0_class = 'Cauchy'

    data = df
    rho_grid_size = 10
    n_perms = 10

    


    ### Data processing
    #print('Data processing....................................................')
    #data = drop_corr(data)
    data = pd.DataFrame(data).values
    frac_train = 0.5
    perms = n_perms # number of permutation

    n_tot = np.shape(data)[0]
    n = int(frac_train*n_tot)

    train_ind,test_ind = train_test_split(np.arange(n_tot),test_size = n_tot - n,train_size = n,random_state = 1)

    y = torch.tensor(data[train_ind], dtype=torch.float32)
    mean_y = torch.mean(y,axis = 0)
    std_y = torch.std(y,axis = 0)
    std_y[std_y==0] = 1
    y = (y-mean_y)/std_y

    # Dequantisation
    noise_to_digits = 1e-6*stats.norm.rvs(size = y.shape)
    y = y + noise_to_digits
    #subset = 200
    #y = y[:subset]
    y_permutations = create_permutatons(y, perms) # create permutations of the data torch.Size([perms, n, d])

    y_test = torch.tensor(data[test_ind], dtype=torch.float32)
    y_test = (y_test-mean_y)/std_y

    ### Training for M=marginal
    #print('Training for M=marginal....................................................')
    size = rho_grid_size # 50 by default
    theta_grids = torch.linspace(0.8, 0.999, size) # define the range and size of thetas you want to search



    def single_rho_d_grid_search(args): # returns [dim, rho, energy]
        dim, rho_d = args
        
        dim_out,rho_d_out,e_score_out = single_dand_rho_linear_energy_grid_search(dim,rho_d, y_permutations[:,:,dim], 
                                        beta = 0.5, size = 1000, evalsz = 100, extrap_tail = .1, extrap_bound = .5, init_dist = p0_class, a = 1.)
        return (dim, rho_d,e_score_out)


    def get_context_single_dim_single_rho_futures(args):
      observations_d, rho_d, dim = args
      return (dim,get_context_single_dim_single_rho(observations_d, rho_d, init_dist=p0_class,a=1.))

    start = time.time()

    #print('start of futures',check_available_memory())
    if __name__ == '__main__':
        #available_memory = check_available_memory()
        #print("Available Memory:", available_memory)
        # Create a list of inputs to the function f
        print('Data loaded, shape:', df.shape)
        print('p', p)
        print('perms', n_perms, 'rho_grid_size', rho_grid_size)
        #print('Subset size:', subset)


        inputs = [(a,b) for a in range(y.shape[1]) for b in theta_grids] # for each (dimension,rho_d) pair


        # Evaluate f for all inputs using a pool of processes
        with concurrent.futures.ProcessPoolExecutor(max_workers=p) as executor:
            results = executor.map(single_rho_d_grid_search, inputs)

        future_escores = np.array([np.array(result) for result in results])
        # gather the results

        opt = exract_grids_search_single_dim_single_rho(future_escores)

        #print('y_test', y_test)
        #print('y', y)
        runtime = time.time()-start
        print('end of rho',runtime)
        
        # get the context for the optimal rhos
        inputs_ctxt = [(y_permutations[:,:,dim],opt[dim],dim) for dim in range(y.shape[1])]
        with concurrent.futures.ProcessPoolExecutor(max_workers=p) as executor:
            results_ctxt = executor.map(get_context_single_dim_single_rho_futures, inputs_ctxt)

        ctxt = extract_future_ctxt(results_ctxt,y.shape[1])

        runtime = time.time()-start
        print('end of ctxt',runtime)

        # get copula pseudo observations based on training data
        inputs_pseudos = [(dim, y[:,dim], ctxt[:,:,dim], opt[dim], p0_class , 1.0 ) for dim in range(y.shape[1])]
        with concurrent.futures.ProcessPoolExecutor(max_workers=p) as executor:
            results_pseudos = executor.map(evaluate_prcopula_single_cdf_d, inputs_pseudos)
        
        pseudos = extract_future_pseudos(results_pseudos,y.shape[1])

        runtime = time.time()-start
        print('end of pseudos',runtime)
        
        # compute marginal nll and testcdf
        inputs_test = [(dim,y_test[:,dim], ctxt[:,:,dim], opt[dim], p0_class , 1.0 ) for dim in range(y.shape[1])]
        with concurrent.futures.ProcessPoolExecutor(max_workers=p) as executor:
            results_test = executor.map(evaluate_prcopula_single_pdfandcdf_d, inputs_test)
        
        test_nll,test_cdf = extract_future_pseudos_train(results_test,y.shape[1])
        #print('nll',test_nll.shape,test_nll)
        #print('test_cdf',test_cdf.shape,test_cdf)
        marginals = -torch.mean(torch.sum(torch.log(test_nll), dim=1))
        with open('output_digits_marginals'+str(seed)+'.txt', 'wb') as f:
          pickle.dump([opt,runtime,y_test,y,y_permutations,mean_y,std_y,pseudos,test_nll,test_cdf,marginals, ctxt, seed ], f)
        
        print('Seed '+str(seed)+' nll averaged: ',marginals)
        
        print('end of Seed',runtime)
      




































