import random
import numpy as np
from scipy.linalg import eigh
from scipy.cluster.hierarchy import fcluster, linkage
import scipy.spatial.distance as ssd


def generate_cov_matrix(d, con_num):
    sigmas = np.random.uniform(low=1.0, high=con_num, size=d).reshape(d, 1)
    Q, R = np.linalg.qr(np.random.normal(size=d ** 2).reshape(d, d))
    S = Q.T @ (sigmas * Q)
    if np.all(np.linalg.eigvals(S) > 1 - 0.1) and np.all(np.linalg.eigvals(S) < con_num + 0.1) and np.allclose(S.T, S):
        return S
    return generate_cov_matrix(d, con_num)


def batches_count_each_type(m, k, kspe = 0, alpha = None):
    if k == 1:
        return [m]
    if alpha is None:
        alpha = 1 / k
    batches_spe_comp = int(m * alpha)
    batch_counts = [batches_spe_comp] * kspe
    batches_other_comp = 0
    if k >  kspe:
        batches_other_comp = int((m -batches_spe_comp * kspe) / (k - kspe))
        batch_counts += [batches_other_comp] * (k - kspe)
    remaining = m - batches_other_comp*(k - kspe) - batches_spe_comp * kspe
    for i in range(remaining):
        batch_counts[i] +=1
    return batch_counts


def generate_batches(W, Sigmas, n, m, sigma, batch_counts=None):
    k, d = np.shape(W)
    if batch_counts is None:
        batch_counts = batches_count_each_type(m, k)
    X = []
    Y = []
    covariate_mean = np.zeros(d)
    noise_mean = 0
    noise_variance = sigma * sigma
    for i in range(k):
        batch_count = batch_counts[i]
        w = W[i].reshape(d, 1)
        Sigma = Sigmas[i]
        covariates = np.random.multivariate_normal(covariate_mean, Sigma, batch_count * n)
        response = np.matmul(covariates, w) + np.random.normal(noise_mean, noise_variance, (batch_count * n, 1))
        X.append(covariates.reshape(-1, n, d))
        Y.append(response.reshape(-1, n))
    X = np.concatenate(X, axis=0)
    Y = np.concatenate(Y, axis=0)
    batch_collection = (X, Y)
    return batch_collection


def get_estimates_prior(B, P, k, T=1):
    X, Y = B
    m, n, d = np.shape(X)
    HH = np.zeros((T, m, m))
    start = 0
    for l in range(T):
        end = (l + 1) * n // T
        Xl = X[:, start:end, :]
        Yl = Y[:, start:end]
        n1 = (end - start) // 2
        start = end
        clipped_grads_a = calculate_clipped_grad(Xl[:, :n1, :], Yl[:, :n1], np.inf, np.zeros((d, 1)))
        proj_a = np.matmul(clipped_grads_a, P)
        clipped_grads_b = calculate_clipped_grad(Xl[:, n1:, :], Yl[:, n1:], np.inf, np.zeros((d, 1)))
        proj_b = np.matmul(clipped_grads_b, P)
        v = np.sum(proj_a * proj_b, axis=1)
        HH[l] -= np.dot(proj_a, proj_b.T)
        HH[l] += v
        HH[l] += HH[l].T
    if T > 1:
        H = np.median(HH, axis=0)
    else:
        H = HH[0]
    np.fill_diagonal(H, 0, wrap=False)
    Z = linkage(ssd.squareform(np.abs(H)), method="average")
    Clusters = fcluster(Z, k, criterion='maxclust') - 1
    grads = {}
    for i in range(m):
        grad = calculate_clipped_grad(X[i:i + 1, :, :], Y[i:i + 1, :], np.inf, np.zeros((d, 1)))
        proj = np.matmul(grad, P)
        if Clusters[i] not in grads:
            grads[Clusters[i]] = [proj, 1]
        else:
            grads[Clusters[i]][0] = grads[Clusters[i]][1] * grads[Clusters[i]][0] + proj
            grads[Clusters[i]][1] += 1
            grads[Clusters[i]][0] = grads[Clusters[i]][0] / grads[Clusters[i]][1]
    L = []
    for item in grads:
        L.append(-1 * grads[item][0].reshape(-1, 1))
    return L


def calculate_clipped_grad(X, Y, kappa, w):
    m, n, d = X.shape
    w = w.reshape(d,1)
    X = X.reshape(m * n, d)
    Y_hat = np.dot(X, w).reshape(m * n, 1)
    diff = Y_hat - Y.reshape(m * n, 1)
    if kappa != np.inf:
        clipped_diff = np.divide(kappa * diff, np.maximum(kappa, np.abs(diff)))
    else:
        clipped_diff = diff
    clipped_grads = np.multiply(X, clipped_diff).reshape(m, n, d)
    clipped_grads = np.sum(clipped_grads, axis=1) / n
    clipped_grads = clipped_grads.reshape(m, d)
    return clipped_grads


def estimate_mse(bspe, w):
    X, Y = bspe
    m, n, d = np.shape(X)
    X = X.reshape(n, d)
    Y = Y.reshape(n, 1)
    dif = Y - np.matmul(X, w.reshape(d, 1))
    mse = np.matmul(dif.T, dif) / n
    return mse[0][0]


def clipping_estimate(bspe, w, sigma):
    mse = estimate_mse(bspe, w)
    kappa = np.sqrt(2 * (mse + sigma * sigma))
    return kappa


def grad_est(B, bspe, kappa, w, P, weights, epsilon=0.3):
    data = [B, bspe]
    clipped_grad_proj = []
    for batches in data:
        X, Y = batches
        m, n, d = np.shape(X)
        n1 = n // 2
        clipped_grads_a = calculate_clipped_grad(X[:, :n1, :], Y[:, :n1], kappa, w)
        proj_a = np.matmul(clipped_grads_a, P)
        clipped_grads_b = calculate_clipped_grad(X[:, n1:, :], Y[:, n1:], kappa, w)
        proj_b = np.matmul(clipped_grads_b, P)

        clipped_grad_proj.append((proj_a, proj_b))

    proj_all_a, proj_all_b = clipped_grad_proj[0]
    proj_spe_a, proj_spe_b = clipped_grad_proj[1]

    diff_a = proj_all_a - proj_spe_a
    diff_b = proj_all_b - proj_spe_b
    diff_norm = np.sum(np.multiply(diff_a, diff_b), axis=1)
    spe_norm = np.sum(np.multiply(proj_spe_a, proj_spe_b), axis=1)
    kept_batches = diff_norm < epsilon * spe_norm
    m, n = np.shape(kept_batches.reshape(-1, 1))
    for i in range(m):
        if kept_batches[i] == False:
            weights[i] *= 0.1
    grad = np.matmul(weights, 0.5 * (proj_all_a + proj_all_b)) / np.sum(weights)
    grad = grad.reshape(-1, 1)
    return grad


def subspace_estimation(B, ell, kappa, w):
    X, Y = B
    m, n, d = np.shape(X)
    n1 = n // 2
    clipped_grads_a = calculate_clipped_grad(X[:, :n1, :], Y[:, :n1], kappa, w)
    clipped_grads_b = calculate_clipped_grad(X[:, n1:, :], Y[:, n1:], kappa, w)
    A = np.matmul(clipped_grads_a.T, clipped_grads_b) + np.matmul(clipped_grads_b.T, clipped_grads_a)
    e, U = eigh(A, subset_by_index=(d - ell, d - 1))
    return np.matmul(U, U.T)


def main_algo_single_comp(BS, BM, bspe, sigma, ell, con_num, R, init_est, weights):
    w = init_est
    for r in range(R):
        kappa = clipping_estimate(bspe, w, sigma)
        P = subspace_estimation(BS, ell, kappa, w)
        Delta = grad_est(BM, bspe, kappa, w, P, weights, epsilon=0.2)
        w = w - 0.8 * Delta /con_num
    return w


def main_algo_multiple_comp(BS, BM, Bspe, sigma, ell, con_num, R, init_est):
    L = []
    X2, Y2 = BM
    m2, n2, d = X2.shape
    Xspe, Yspe = Bspe
    kspe, _, _ = Xspe.shape
    for i in range(kspe):
        weights = np.ones(m2)
        bspe_X = Xspe[i].reshape(1, n2, d)
        bspe_Y = Yspe[i].reshape(1, n2)
        bspe = (bspe_X, bspe_Y)
        w = main_algo_single_comp(BS, BM, bspe, sigma, ell, con_num, R, init_est, weights)
        L.append(w)
    return L


def main_algo_all_comp(BS, BM, sigma, ell, con_num, init_est, R):
    L = []
    X2, Y2 = BM
    m2, n2, d = X2.shape
    batches_wo_est = list(range(m2))
    weights = np.ones(m2)
    while len(batches_wo_est) >= 0.02 * m2 and len(L) <= m2/8:
        bspe_index = random.choice(batches_wo_est)
        bspe_X = X2[bspe_index].reshape(1, n2, d)
        bspe_Y = Y2[bspe_index].reshape(1, n2)
        bspe = (bspe_X, bspe_Y)
        remaining = weights.copy()
        remaining[bspe_index] = 0
        w = main_algo_single_comp(BS, BM, bspe, sigma, ell, con_num, R, init_est, remaining)
        L.append(w)
        new_wo_estimate = []
        weights = np.zeros(m2)
        for index in batches_wo_est:
            batch_X = X2[index].reshape(1, n2, d)
            batch_Y = Y2[index].reshape(1, n2)
            batch = (batch_X, batch_Y)
            mse = estimate_mse(batch, w)
            if mse > d/5:
                new_wo_estimate.append(index)
                weights[index] = 1
        batches_wo_est = new_wo_estimate
    return L


def main_algo_prior_work(BS, BM, k, d):
    P = subspace_estimation(BS, k, np.inf, np.zeros((d, 1)))
    L = get_estimates_prior(BM, P, k, 1)
    return L


def clustering_using_list(L, W, Sigmas, sigma, m, n):
    k, d = np.shape(W)
    clusters = {}
    pred_error = []
    for i in range(m):
        index = random.randint(0, k - 1)
        batch = generate_batches(W[index].reshape(1, d), Sigmas, n, 1, sigma)
        X, Y = batch
        est_index = 0
        mse = estimate_mse(batch, L[0])
        j = 1
        while j < len(L):
            mse2 = estimate_mse(batch, L[j])
            if mse > mse2:
                mse = mse2
                est_index = j
            j += 1
        pred_batch = generate_batches(W[index].reshape(1, d), Sigmas, 100, 1, sigma)
        mse_pred = estimate_mse(pred_batch, L[est_index])
        pred_error.append(mse_pred)
        if est_index not in clusters:
            clusters[est_index] = [[index], [X.reshape(-1,d)], [Y.reshape(-1,1)]]
        else:
            clusters[est_index][0].append(index)
            clusters[est_index][1].append(X.reshape(-1,d))
            clusters[est_index][2].append(Y.reshape(-1,1))
    new_L = []
    correct_count = 0
    total = 0
    for item in clusters:
        X = np.concatenate(clusters[item][1], axis=0)
        Y = np.concatenate(clusters[item][2], axis=0)
        new_est = np.linalg.lstsq(X, Y, rcond=None)[0]
        new_L.append(new_est.T)
        (values,counts) = np.unique(clusters[item][0],return_counts=True)
        correct_count += counts[np.argmax(counts)]
        total += len(clusters[item][0])
    incorrect_count = m- correct_count
    return np.mean(pred_error), incorrect_count, new_L
