import numpy as np
import torch


mu_negative = np.zeros(200)
mu_positive = np.array([*[1 for i in range(10)], *[0 for i in range(190)]])
Spl = np.fromfunction(lambda i, j: np.power(0.25, abs(i-j)+0.5), shape=(200, 200))
W_star = np.dot(mu_positive, np.linalg.inv(Spl))
W_star[W_star < 1e-4] = 0


def f1_score(W_star, aggregate_loss_type):
    W_hat = np.loadtxt("./W/cross_entropy_%s_synthetic_data.csv"%(aggregate_loss_type))
    supp_W_hat = set(np.nonzero(W_hat)[0].tolist())
    supp_W_star = set(np.nonzero(W_star)[0].tolist())
    precision = len(supp_W_hat&supp_W_star) / len(supp_W_hat)
    recall = len(supp_W_hat&supp_W_star) / len(supp_W_star)
    return 2*precision*recall/(precision+recall)

    
    

for aggregate_loss_type in ["average", "atk", "matk", "smooth_matk", "sgd_matk", "smooth_sgd_matk"]:
    print(f1_score(W_star, aggregate_loss_type))