#%%
from heapq import nsmallest
from tkinter import font
from turtle import pos
import matplotlib.pyplot as plt
import numpy as np
import torch
import sys

sys.path.append("../modules")

from toolbox import *
from model_def import RNN
from data_set import *
from data_set import *


#%%


# helper functions
def get_pos(vel):
    pos = torch.zeros(vel.shape,device=vel.device)
    for j in range(vel.shape[1]):
        if j == 0:
            pos[:, j] = dt * vel[:, j]
        else:
            pos[:, j] = pos[:, j - 1] + dt * vel[:, j]
    return pos


plots_path = "/plots/"

#%%

############## FIGURE 1A ################

i = 0
rot = 30
experiment_timestamp = "2024May08-205259" # EXAMPLE!

tm_savname = f"results/{experiment_timestamp}_no_delay_2k_b2b_large_batch/0003_fb_freeze_False__delay_0__go_to_peak_50__vel_10.0__error_detach_False__nlayers_1__dataset_name_Reaching__task_random_pushed__fb_density_1/{i}"
dataT = torch.load(tm_savname + "/phase0_training")
params = dataT["params"]

model_fb_ad = RNN(
    params["model"]["input_dim"],
    params["model"]["output_dim"],
    params["model"]["n"],
    torch.cuda.FloatTensor,
    params["model"]["dt"],
    params["model"]["tau"],
    fb_delay=params["model"]["fb_delay"],
    fb_density=params["model"]["fb_density"],
)
model_fb_ad = model_fb_ad.cuda()
model_fb_ad.load_state_dict(dataT["model_state_dict"])
if params["data"]["dataset_name"] == "Reaching":
    model_fb_ad.pos_err = True

np.random.seed(0)
dataset = Reaching()
params["model"]["rot_phi"] = rot / 180 * np.pi
target_30, stimulus_30, pert_30, tids_30, stim_ref_30 = dataset.prepare_pytorch(
    params, "center-out-reach_rotated", test_set=True
)

output_30, hidden_30, extras_30 = model_fb_ad(
    stimulus_30, pert_30, stim_ref_30, analysis=True,fb_in=True
)

output_30_nofb, hidden_30_nofb, extras_30_nofb = model_fb_ad(
    stimulus_30, pert_30, stim_ref_30, analysis=True,fb_in=False
)

fig = plt.figure(figsize=(4, 4))

output_30 = model_fb_ad.get_output(torch.stack(hidden_30)).cpu().detach().numpy().transpose(1, 0, 2)
output_30_nofb= model_fb_ad.get_output(torch.stack(hidden_30_nofb)).cpu().detach().numpy().transpose(1, 0, 2)
tid = tids_30
target = target_30
output = output_30_nofb
# setup
cols2 = [plt.cm.magma(i) for i in np.linspace(0.1, 0.9, 8)]
dt = 0.01
pos = get_pos(torch.tensor(output))
posT = get_pos(torch.tensor(target))
for j in range(pos.shape[0]):
    plt.plot(pos[j, :, 0], pos[j, :, 1], color=cols2[tid[j]], alpha=0.2,linestyle="--",lw=2)
# plot also targets
for j in range(8):
    tmp = posT[tid == j, -1][0]
    plt.scatter(tmp[0], tmp[1], edgecolor=cols2[j], facecolor="None", marker="s", s=200, lw=2)

output = output_30
pos = get_pos(torch.tensor(output))
posT = get_pos(torch.tensor(target))
for j in range(pos.shape[0]):
    plt.plot(pos[j, :, 0], pos[j, :, 1], color=cols2[tid[j]], alpha=0.2,lw = 2)


plt.xlim(-6, 6)
plt.ylim(-6, 6)
score = met.r2_score(output[:, :125].reshape(-1, 2), target.reshape(-1, 2))

plt.gca().spines["top"].set_visible(False)
plt.gca().spines["right"].set_visible(False)
plt.gca().spines["left"].set_visible(False)
plt.gca().spines["bottom"].set_visible(False)

plt.xticks([])
plt.yticks([])

plt.tight_layout()
plt.savefig(f"{plots_path}/final_preadaptation_traj_{rot}.pdf", dpi=300)
plt.show()


#%%

############## FIGURE 1B ################

i = 0

rot = 30
learn_alg = 'fed'
fb_in = 1

tm_savname = f"results/2024May08-205259_no_delay_2k_b2b_adaptation_lrs/0017_rot_phi_{rot}__learning_algorithm_{learn_alg}__fb_in_{fb_in}__fb_density_1/{i}"
dataT = torch.load(tm_savname + f"/AD_{learn_alg}")
params = dataT["params"]

model_fb_ad = RNN(
    params["model"]["input_dim"],
    params["model"]["output_dim"],
    params["model"]["n"],
    torch.cuda.FloatTensor,
    params["model"]["dt"],
    params["model"]["tau"],
    fb_delay=params["model"]["fb_delay"],
    fb_density=params["model"]["fb_density"],
)
model_fb_ad = model_fb_ad.cuda()
model_fb_ad.load_state_dict(dataT["model_state_dict"])
if params["data"]["dataset_name"] == "Reaching":
    model_fb_ad.pos_err = True
model_fb_ad.error_detach = True

np.random.seed(0)
dataset = Reaching()
params["model"]["rot_phi"] = rot / 180 * np.pi
target_30, stimulus_30, pert_30, tids_30, stim_ref_30 = dataset.prepare_pytorch(
    params, "center-out-reach_rotated", test_set=True
)

output_30, hidden_30, extras_30 = model_fb_ad(
    stimulus_30, pert_30, stim_ref_30, analysis=True,fb_in=True
)

output_30_nofb, hidden_30_nofb, extras_30_nofb = model_fb_ad(
    stimulus_30, pert_30, stim_ref_30, analysis=True,fb_in=False
)

output_30_nofb, hidden_30_nofb, extras_30_nofb = model_fb_ad(
    stimulus_30, pert_30, stim_ref_30, analysis=True,fb_in=False
)

fig = plt.figure(figsize=(4, 4))


output_30 = model_fb_ad.get_output(torch.stack(hidden_30)).cpu().detach().numpy().transpose(1, 0, 2)
output_30_nofb= model_fb_ad.get_output(torch.stack(hidden_30_nofb)).cpu().detach().numpy().transpose(1, 0, 2)
tid = tids_30
target = torch.tensor(target_30)
output = torch.tensor(output_30_nofb)
# setup
cols2 = [plt.cm.magma(i) for i in np.linspace(0.1, 0.9, 8)]
dt = 0.01
pos = get_pos(output)
posT = get_pos(target)
for j in range(pos.shape[0]):
    plt.plot(pos[j, :, 0], pos[j, :, 1], color=cols2[tid[j]], alpha=0.2,linestyle="--",lw =2)
# plot also targets
for j in range(8):
    tmp = posT[tid == j, -1][0]
    plt.scatter(tmp[0], tmp[1], edgecolor=cols2[j], facecolor="None", marker="s", s=200, lw=2)

output = torch.tensor(output_30)
target = torch.tensor(target)
pos = get_pos(output)
posT = get_pos(target)
for j in range(pos.shape[0]):
    plt.plot(pos[j, :, 0], pos[j, :, 1], color=cols2[tid[j]], alpha=0.2,lw=2)


plt.xlim(-6, 6)
plt.ylim(-6, 6)
score = met.r2_score(output[:, :125].reshape(-1, 2), target.reshape(-1, 2))

plt.gca().spines["top"].set_visible(False)
plt.gca().spines["right"].set_visible(False)
plt.gca().spines["left"].set_visible(False)
plt.gca().spines["bottom"].set_visible(False)

plt.xticks([])
plt.yticks([])

plt.tight_layout()
plt.savefig(f"{plots_path}final_adaptation_traj_{rot}_{learn_alg}_{fb_in}.pdf", dpi=300)
plt.show()
