from scipy.stats import gmean
import numpy as np
from magnipy import Magnipy
from distances import get_dist
from vendi_score import vendi
import pandas as pd
import pickle
from sklearn.model_selection import cross_val_score
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
import numpy as np
from sklearn.model_selection import RepeatedKFold, cross_validate
from sklearn.ensemble import RandomForestRegressor
from sklearn.isotonic import IsotonicRegression
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import pairwise_distances
from n_gram_sim import *
from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import RepeatedKFold, cross_val_score
from sklearn.ensemble import RandomForestClassifier


def regression_cross_validation(df, col_true, scoring='r2', model=LinearRegression(), 
                                n_splits=5, n_repeats=20, get_scores=False):
    # Extract features (X) and target variable (y)
    X = df.drop(col_true, axis=1)
    y = df[col_true]

    scorer=scoring

    # Iterate over each column in the dataframe
    performance_metrics={}
    for col in X.columns:
        # Extract the current feature column
        non_na_mask = ~(X[col].isna() | y.isna())
        X_col = np.array(X[[col]][non_na_mask])#.reshape(-1, 1)
        y_col = y[non_na_mask]
        #X_col = X[[col]]
        cv = RepeatedKFold(n_splits=n_splits, n_repeats=n_repeats, random_state=1)
        scores = cross_val_score(model, X_col, y_col, scoring=scorer, cv=cv)
        #scores = cross_val_score(model, X_col, y_col, scoring=scoring, cv=cv)#'neg_mean_squared_error')

        if get_scores:
            performance_metrics[col]=scores
            #print(len(scores))
        else:
        # Calculate mean squared error and store in the dictionary
            mse = np.mean(scores)
            std = np.std(scores)
            performance_metrics[col] = [mse, std]
            df = pd.DataFrame(performance_metrics).T
            df.columns = [scoring+"_mean", scoring+"_std"]
            
        # Optionally, you can print or store other performance metrics as needed
        # For example, you can calculate R-squared, MAE, etc.
    if get_scores:
        return performance_metrics
    else:
        return df


def resample_correlation_mean_std(df, col1, col2, num_resamples, size_resample, corr="spearman", replace=False):
    correlations = []

    for _ in range(num_resamples):
        # Randomly sample rows with replacement
        sampled_df = df.sample(n=size_resample, replace=replace)

        # Calculate the correlation between the two columns
        correlation = sampled_df[col1].corr(sampled_df[col2], corr, min_periods=0)
        if np.isnan(correlation):
            correlation = 0
        correlations.append(correlation)

    # Calculate the mean and standard deviation of correlations
    mean_correlation = np.mean(correlations)
    std_correlation = np.std(correlations)
    lower_correlation = np.quantile(correlations, q=0.025)
    upper_correlation = np.quantile(correlations, q= 0.975)

    return mean_correlation, std_correlation, lower_correlation, upper_correlation, correlations


def resample_r2_mean_std(correlations):
    #correlations = []

    mean_correlation = np.mean(correlations)
    std_correlation = np.std(correlations)
    #lower_correlation = np.quantile(correlations, q=0.05)
    #upper_correlation = np.quantile(correlations, q= 0.95)

    return mean_correlation, std_correlation#, lower_correlation, upper_correlation

def r2_subsampling_summary(df, name="", corr="spearman",
                             cats=["mag_abs_diff", "mag_ext_diff", "mag_cut_diff", "mag_diff_scaled", "mag_diff_scaled2"]):
    mus=[]
    stds=[]
    lows=[]
    ups=[]
    for cat in cats:
        sub=df[cat].dropna()
        #mu, sig , low, up = resample_r2_mean_std(sub)
        mu, sig  = resample_r2_mean_std(sub)
        mus.append(mu)
        stds.append(sig)
        #lows.append(low)
        #ups.append(up)
    table_results = pd.DataFrame({"mean_corr_"+name: mus, "std_"+name: stds#, "lower_"+name: lows, "upper_"+name: ups
    }, index=cats)
    return table_results

def corr_subsampling_summary(df_results, name="", corr="spearman", label_value="label_value",
                             cats=["mag_abs_diff", "mag_ext_diff", "mag_cut_diff", "mag_diff_scaled", "mag_diff_scaled2"], 
                             n_resample=100, size_resample=100, replace=False):
    mus=[]
    stds=[]
    lows=[]
    ups=[]
    corrs = pd.DataFrame()
    for cat in cats:
        mu, sig , low, up, correlations = resample_correlation_mean_std(df_results, cat, 
                                                label_value, num_resamples=n_resample, size_resample=size_resample, corr= corr, replace=replace)
        mus.append(mu)
        stds.append(sig)
        lows.append(low)
        ups.append(up)
        df=pd.DataFrame(correlations, columns=["correlation"])
        df["measure"]=cat
        df["model"]=name
        corrs=pd.concat([corrs, df], axis=0)

    table_results = pd.DataFrame({"mean_corr_"+name: mus, "std_"+name: stds, "lower_"+name: lows, "upper_"+name: ups}, index=cats)
    return table_results, corrs


def classification_cross_validation_Xy(X, y, scoring='accuracy', model=KNeighborsClassifier(), n_splits=5, n_repeats=20):
    performance_metrics = {}

    cv = RepeatedKFold(n_splits=n_splits, n_repeats=n_repeats, random_state=1)
    
    scores = cross_val_score(model, X, y, scoring=scoring, cv=cv)

    mse = np.mean(scores)
    std = np.std(scores)
    performance_metrics[1] = [mse, std]
        
    df = pd.DataFrame(performance_metrics).T
    df.columns = [scoring+"_mean", scoring+"_std"]
    return df, scores

def classification_cross_validation_Xypre(X, y, scoring='accuracy', model=KNeighborsClassifier(metric="precomputed"), n_splits=5, n_repeats=20):
    
    performance_metrics = {}

    cv = RepeatedKFold(n_splits=n_splits, n_repeats=n_repeats, random_state=1)
    
    scores = cross_val_score(model, X, y, scoring=scoring, cv=cv)#'neg_mean_squared_error')

    #print(scores)
    # Calculate mean squared error and store in the dictionary
    mse = np.mean(scores)
    std = np.std(scores)
    performance_metrics[1] = [mse, std]
        
        # Optionally, you can print or store other performance metrics as needed
        # For example, you can calculate R-squared, MAE, etc.
    df = pd.DataFrame(performance_metrics).T
    df.columns = [scoring+"_mean", scoring+"_std"]
    return df, scores