import torch
import numpy as np
import scipy
from slicetca.core import decompositions
from import_data import import_data
from matplotlib import pyplot as plt
import matplotlib
from library.plotting import utils

import os


def tca(epoch, weights, min_rank=1, max_rank=6, subtract_init=False, tol=10**-6, extra_filename='', var=None):

    device = ('cuda' if torch.cuda.is_available() else 'cpu')

    epochs = np.sort([np.argmax(epoch==e) for e in np.unique(epoch)])

    w = weights[epochs[1]+1:]-weights[0] if subtract_init else weights

    ls = [[] for i in range(4)]

    # The variance explained is R^2 entry-wise of the matrix
    if var is None:
        var = np.clip(w.var(axis=(1,2))[:, np.newaxis, np.newaxis], tol, np.inf)

    # Compute multilinear ranks
    for j in range(3):
        ids = np.arange(3)
        ids[[j, 2]] = ids[[2, j]]
        w_transposed = w.transpose(ids)
        w_flat = w_transposed.reshape(-1, w_transposed.shape[-1])
        var_transposed = var.transpose(ids)

        U, S, V = scipy.linalg.svd(w_flat, full_matrices=False)

        for i in range(min_rank, max_rank+1):
            l_temp = 1-((w_transposed-w_transposed @ V.T[:,:i] @ V.T[:,:i].T)**2/var_transposed).mean()
            ls[j].append(l_temp)

    # Compute tensor rank
    weights_torch = torch.tensor(w, device=device)
    var = torch.tensor(var, device=device)

    for i in range(min_rank, max_rank+1):
        model = decompositions.TCA(w.shape, i, device=device)
        optimizer = torch.optim.Adam(model.parameters(), lr=5*10**-3)

        model.fit(weights_torch/w.std(), optimizer, min_std=10**-6, iter_std=500, max_iter=2*10**4)

        l_temp = (1-((model.construct()*w.std()-weights_torch)**2/var).mean()).item()
        ls[3].append(l_temp)

    ls = np.array(ls)

    model_names = ['Trial covariance', 'Row covariance', 'Column covariance', 'TCA']

    # Plot rank of row, column covariance matrix as well as tensor rank
    fig = plt.figure(figsize=(4,4), constrained_layout=True, dpi=100)
    ax = fig.add_subplot()

    for i in range(1,len(model_names)):
        ax.plot(np.arange(max_rank-min_rank+1)+min_rank, ls[i], '-o', color=matplotlib.colormaps['Set1'](i/len(model_names)),
                markerfacecolor=(1,1,1,1), markeredgewidth=1.5, label=model_names[i])

    ax.set_xlabel('Rank')
    ax.set_ylabel('Variance of W explained')
    ax.axhline(1, linestyle='--', color=(0.9,0.9,0.9))

    ax.legend()

    plt.savefig(save_directory+'/var_exp_per_rank_'+extra_filename+'_sub_init_'+str(subtract_init)+'.pdf')

    # Plot rank of trial covariance matrix
    fig = plt.figure(figsize=(4, 4), constrained_layout=True, dpi=100)
    ax = fig.add_subplot()

    ax.plot(np.arange(max_rank-min_rank+1) + min_rank, ls[0], '-o', color=matplotlib.colormaps['Set1'](0 / len(model_names)),
                markerfacecolor=(1, 1, 1, 1), markeredgewidth=1.5, label=model_names[0])

    ax.set_xlabel('Rank')
    ax.set_ylabel('Variance of W explained')
    ax.axhline(1, linestyle='--', color=(0.9,0.9,0.9))
    ax.legend()

    plt.savefig(save_directory+'/var_exp_per_rank_trial_'+extra_filename+'_sub_init_'+str(subtract_init)+'.pdf')

    plt.show()

if __name__ == '__main__':

    dir_name = 'pretrained_model'
    directory = './task_training/perturbation/task_runs/' + dir_name

    save_directory = directory+'/plots'

    if os.path.exists(directory) and not os.path.exists(save_directory):
        os.makedirs(save_directory)

    utils.set_font()

    _, data, _, _, _, _, _ = import_data(directory)

    tca(data['epoch'], data['weights'])
