from numpy.random import multivariate_normal
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
import numpy as np
import matplotlib.pyplot as plt
import torch


def save_data(data, label, usefor="train"):
    torch.save(data, f"./data_{usefor}.pt")
    torch.save(label, f"./lable_{usefor}.pt")


def gen_data(usefor, n=2500, positive_r=.5):
    mu_negative = np.zeros(200)
    mu_positive = np.array([*[1 for i in range(10)], *[0 for i in range(190)]])
    Spl = np.fromfunction(lambda i, j: np.power(0.8, abs(i-j)), shape=(200, 200))
    n_positive, n_negative = int(n*positive_r), int(n*(1-positive_r))
    data_positive = multivariate_normal(mu_positive, cov=Spl, size=n_positive)
    data_negative = multivariate_normal(mu_negative, cov=Spl, size=n_negative)
    label1 = np.ones(shape=(data_positive.shape[0], 1))
    label0 = np.zeros(shape=(data_negative.shape[0], 1))
    data, label = np.concatenate((data_positive, data_negative)), np.concatenate((label1, label0))
    save_data(data, label, usefor=usefor)
    return data, label



gen_data(usefor="train", n=5000, positive_r=0.99)
gen_data(usefor="val", n=2500, positive_r=.5)
gen_data(usefor="test", n=2500, positive_r=.5)
# label = label.astype("int")
# reducer = PCA(n_components=2)
# data_emb = reducer.fit_transform(data)

# cdict = {0: 'navy', 1: 'red'}
# ldict = {0: 'negative', 1: 'positive'}
# fig, ax = plt.subplots()
# for g in np.unique(label):
#     ix = np.where(label == g)[0]
#     if g == 1:  # positive samples
#         ax.scatter(data_emb[ix, 0], data_emb[ix, 1], c=cdict[g], label=ldict[g], s=10, alpha=0.2, marker="x")
#     else:  # negative samples
#         ax.scatter(data_emb[ix, 0], data_emb[ix, 1], c=cdict[g], label=ldict[g], s=10, alpha=0.2, marker="o")

# from matplotlib.patches import Patch
# legend_elements = [Patch(facecolor='red', edgecolor='red',
#                          label='positive', alpha=0.5),
#                    Patch(facecolor='navy', edgecolor='navy',
#                          label='negative', alpha=0.5),
#                          ]
# ax.legend(handles=legend_elements)
# plt.savefig("2d-data.png")

# mu_negative = np.zeros(200)
# mu_positive = np.array([*[1 for i in range(10)], *[0 for i in range(190)]])
# Spl = np.fromfunction(lambda i, j: np.power(0.25, 1+abs(i-j)), shape=(200, 200))
# w_star = np.dot(mu_positive, np.linalg.inv(Spl))
# w_star[w_star < 1e-4] = 0
# print(w_star)


