from library.plotting import utils, plot_2d

import numpy as np
import matplotlib.pyplot as plt
import torch
from slicetca.core import decompositions

'''
Contains plotting and utility functions for main.py
'''

def var_exp(true_x, x): return (1 - ((true_x-x)**2).mean()/true_x.var()).item()

def tca(x, rank, max_iter=10**4):

    device = ('cuda' if torch.cuda.is_available() else 'cpu')

    x = torch.tensor(x, device=device)

    x = x/x.std()

    x_var = x.var(dim=(1,2)).unsqueeze(-1).unsqueeze(-1)

    model = decompositions.TCA(x.shape, rank, device=device)
    optimizer = torch.optim.Adam(model.parameters(), lr=10 ** -3)

    model.fit(x, optimizer, min_std=10 ** -6, iter_std=500, max_iter=max_iter)  #

    return (1-((x-model.construct())**2/x_var).mean()).item()


def var_exp_per_pc(x, y=None, max_pc=10, var_leg=0, tol=10**-8):

    x = x/x.std()

    if y is None: y = x

    U, S, V = np.linalg.svd(y.reshape(-1, y.shape[-1]), full_matrices=False)

    participation_ratio = S.sum()**2/(S**2).sum()

    x_projected = x @ V.T[:,:max_pc]

    var_x = np.var(x, axis=tuple([i for i in range(len(x.shape)) if i!=var_leg]))
    var_x = var_x.reshape(*(1 if i!= var_leg else -1 for i in range(len(x.shape))))
    var_x = np.clip(var_x, tol, np.inf)

    var_exp = []

    for i in range(1, max_pc+1):

        temp_x = x_projected[..., :i] @ V[:i]

        R = 1 - np.mean(np.square(x-temp_x)/var_x)

        var_exp.append(R)

    return var_exp, participation_ratio


def get_axs(rows, columns):

    fig = plt.figure(figsize=(columns * 4, rows * 4+0.25), constrained_layout=True, dpi=80)
    axs_3d = []
    axs_ignored = []
    gs = fig.add_gridspec(ncols=columns, nrows=rows)
    axs = [[fig.add_subplot(gs[i, j], projection=('3d' if [j, i] in axs_3d else None)) if [j, i] not in axs_ignored else None for i in range(rows)] for j in range(columns)]
    axs = np.array(axs)

    return axs

def loss(ax_temp, Losses, cmap_eig):
    Losses_std, Losses_mean = Losses.std(axis=0), Losses.mean(axis=0)
    utils.set_bottom_axis(ax_temp)

    ax_temp.fill_between(np.arange(len(Losses_mean)) + 1, Losses_mean - Losses_std, Losses_mean + Losses_std,
                                      color=(0.3, 0.3, 0.3), alpha=0.1, zorder=3)
    ax_temp.plot(np.arange(len(Losses_mean)) + 1, Losses_mean, color=(0.3, 0.3, 0.3), markerfacecolor=(1,1,1), zorder=4)#, '-o'
    ax_temp.scatter(1, Losses_mean[0], color=cmap_eig(0.0), facecolor=(1,1,1), zorder=5)#, '-o'
    ax_temp.scatter(len(Losses_mean), Losses_mean[-1], color=cmap_eig(1.0), facecolor=(1,1,1), zorder=5)#, '-o'

    ax_temp.set_xlabel('Iteration'), ax_temp.set_ylabel('MSE')

def activation(ax_temp, net):
    utils.set_bottom_axis(ax_temp)
    xx = torch.linspace(-3, 3, 101)
    ax_temp.plot(xx.numpy(), net.activation(xx).numpy(), color=(0.4, 0.4, 0.95))
    ax_temp.set_xlabel('$x$'), ax_temp.set_ylabel('$\phi(x)$')

def activity(ax_temp, ts, Xs, cmap_eig, parameters, neuron_sample=2, max_neuron=30):
    utils.set_bottom_axis(ax_temp)
    ax_temp.plot(ts.numpy(force=True), Xs[0, -1, :, 0, neuron_sample:max_neuron + neuron_sample], color=(0.3, 0.3, 0.3),
                 alpha=0.5)
    ax_temp.plot(ts.numpy(force=True), Xs[0, -1, :, 0, :neuron_sample], color=cmap_eig(1.0))
    ax_temp.set_xlabel('Time'), ax_temp.set_ylabel('$\phi(x)$')
    if parameters['activation'] == 'tanh': ax_temp.set_ylim(-1, 1)

def activity_over_trials(ax_temp, parameters, ts, cmap_eig, Xs, neuron_sample=2, max_neuron=30):
    utils.set_bottom_axis(ax_temp)
    steps = max(1, parameters['steps'] // 10)
    for si in range(0, parameters['steps'], steps):
        for ni in range(neuron_sample):
            ax_temp.plot(ts.numpy(force=True), Xs[0, si, :, 0, ni],
                         color=cmap_eig(si / (parameters['steps'] - 1 + 10 ** -6)), alpha=0.5)
    if parameters['activation'] == 'tanh': ax_temp.set_ylim(-1, 1)

def singular_values(Ss, parameters, axs_plot, axs_insets, ax_id, colors, labels, zorders, linestyles):
    Ss_var, Ss_mean = Ss.std(axis=1), Ss.mean(axis=1)
    Ss_var[Ss_mean < parameters['tol_sv']] = np.nan
    Ss_mean[Ss_mean < parameters['tol_sv']] = np.nan

    for ax in axs_plot: utils.set_bottom_axis(ax)
    for ax in axs_insets: utils.set_bottom_axis(ax)

    for i in range(len(Ss)):
        axs_insets[ax_id[i]].fill_between(np.arange(len(Ss_mean[i])) + 1, Ss_mean[i] - Ss_var[i],
                                          Ss_mean[i] + Ss_var[i], color=colors[i], alpha=0.1, zorder=3)
        axs_insets[ax_id[i]].plot(np.arange(len(Ss_mean[i])) + 1, Ss_mean[i], color=colors[i], label=labels[i],
                                  zorder=4)

        axs_plot[ax_id[i]].fill_between(np.arange(parameters['singular_values_plotted']) + 1,
                                        (Ss_mean[i] - Ss_var[i])[:parameters['singular_values_plotted']],
                                        (Ss_mean[i] + Ss_var[i])[:parameters['singular_values_plotted']],
                                        color=colors[i], alpha=0.1, zorder=3)
        axs_plot[ax_id[i]].plot(np.arange(parameters['singular_values_plotted']) + 1,
                                Ss_mean[i][:parameters['singular_values_plotted']], linestyles[i], linewidth=2.0,
                                color=colors[i], label=labels[i], zorder=4 + zorders[i], markeredgecolor=colors[i],
                                markerfacecolor=(1, 1, 1, 1), markeredgewidth=1.5, markersize=6)

        axs_insets[ax_id[i]].set_ylim(max(parameters['tol_sv'] - 0.5, np.nanmin(Ss_mean) - 0.5),
                                      np.nanmax(Ss_mean) + 0.5)
        axs_plot[ax_id[i]].set_ylim(
            max(parameters['tol_sv'] - 0.5, np.nanmin(Ss_mean[:, :parameters['singular_values_plotted']]) - 0.5),
            np.nanmax(Ss_mean) + 0.5)

    for ax in axs_plot:
        if parameters['rank'] * 2 + parameters['decoder_dim'] + parameters['input_dim'] <= parameters['singular_values_plotted']:
            ax.axvline(parameters['rank'] + parameters['input_dim'], color=(0.8, 0.8, 1.0), linestyle='--', label='R+d')
            ax.axvline(2 * parameters['rank'] + parameters['decoder_dim'] + parameters['input_dim'],
                       color=(1.0, 0.8, 0.8), linestyle='--', label='2R+m+d')

        ax.set_ylabel(r'Log singular value $\log_{10} \left (\sigma_i\right )$')
        ax.set_xlabel(r'Singular value index $i$')
        ax.set_xlim(ax.get_xlim()[0], ax.get_xlim()[1] + 1)

        ax.legend(loc='lower left', fancybox=True, fontsize=11, ncol=2)  #

    for ax in axs_insets:
        ax.set_xlim(0, parameters['dim'])

def eig(ax_temp, Ws, cmap_eig, Ls, labels):

    plot_2d.eigenspectrum_per_trial(ax_temp, Ws[-1], cmap=cmap_eig)
    ax_temp.scatter(Ls[0,-1].real, Ls[0,-1].imag, color=cmap_eig(0.0), s=10, label=labels[0], facecolor=(1,1,1,1))
    ax_temp.scatter(Ls[1,-1].real, Ls[1,-1].imag, color=cmap_eig(1.0), s=10, label=labels[1], facecolor=(1,1,1,1))


def tensor_rank(ax_temp, Ws, parameters, cmap):
    utils.set_bottom_axis(ax_temp)
    Delta_Ws = Ws[:, 1:] - Ws[:, 0:1]
    tensor_var_exp = np.array(
        [[tca(w, r, parameters['tca_max_iter']) for r in range(1, 1 + parameters['max_tensor_rank'])] for w in
         Delta_Ws])
    column_var_exp = np.array([var_exp_per_pc(w, max_pc=parameters['max_tensor_rank'])[0] for w in Delta_Ws])
    row_var_exp = np.array(
        [var_exp_per_pc(w.transpose(0, 2, 1), max_pc=parameters['max_tensor_rank'])[0] for w in Delta_Ws])
    trial_var_exp = np.array(
        [var_exp_per_pc(w.transpose(1, 2, 0), max_pc=parameters['max_tensor_rank'], var_leg=2)[0] for w in Delta_Ws])

    var_exps = np.stack([tensor_var_exp, column_var_exp, row_var_exp, trial_var_exp], axis=1)

    var_exps_mean, var_exps_std = var_exps.mean(axis=0), var_exps.std(axis=0)

    labels = ['Tensor rank', 'Column rank', 'Row rank', 'Trial rank']

    for i in range(len(labels)):
        ax_temp.plot(np.arange(len(var_exps_mean[i])) + 1, var_exps_mean[i], '-o',
                     color=cmap(i), markerfacecolor=(1, 1, 1), label=labels[i])
        ax_temp.fill_between(np.arange(len(var_exps_mean[i])) + 1,
                             var_exps_mean[i] - var_exps_std[i],
                             var_exps_mean[i] + var_exps_std[i],
                             color=cmap(i), alpha=0.1, zorder=3)
    ax_temp.set_xlabel('Rank'), ax_temp.set_ylabel('Variance explained')