from experiment3_doc_experiment import *
import seaborn as sns
from matplotlib import pyplot as plt
import pandas as pd
from magnipy import Magnipy

if __name__ == "__main__":
    plt.rcParams.update({
        "text.usetex": True,
        "font.family": "Helvetica"
    })
    #plt.rcParams.update({'font.size': 14})
    #plt.rcParams['text.usetex'] = True

    datasets=[
        "cnn_dailymail___3.0.0_16384",
        "big_patent___a_16384",
        "EdinburghNLP_-_xsum_16384",
        "gfissore_-_arxiv-abstracts-2021_16384"#,#,
    ]
    names_nice=[
        "cnn", "patent", 
        "bbc", 
        "arxiv"
    ]

    scores={}
    for i, d in enumerate(datasets):
        mag_results=read_files(d, n_samples = 200, n_size=400, n_dims=384)
        
        mat = mag_results['magnitude_differences']["cosine"]
        pd.DataFrame(mat).to_csv("./results/doc/"+d+"_doc.csv")
        fig, ax = plt.subplots(figsize=(5, 4.5))
        if True:
            sns.heatmap(mat, annot=False, cmap="Reds", ax=ax)
            
            ax.tick_params(axis='both', which='both', length=0, labelsize=0)

            fig.savefig("./results/doc/doc"+names_nice[i]+"heat.png")
            plt.show()
            #scores[d] = pd.
        scores[d] = mag_results["prediction_results"]
        #prediction_task_documents(d, results=mag_results, n_samples = 200, n_size=300, n_dims=384)
        pd.DataFrame(mag_results["prediction_results"]).to_csv("./results/doc/"+d+"_pred_doc.csv")

    all_scores = get_prediction_results(scores)
    all_scores.to_csv("./results/doc/doc_results.csv")