import matplotlib.pyplot as plt
import re
import os
import sys
sys.path.append(os.path.realpath('.'))
import matplotlib.pyplot as plt
import numpy as np
import torch
from toy.ops import f_eff_GD
import matplotlib

# collect teacher save infomation
teacher_list_dir = './experiment/04a/teacher_train/network/'
cdir = './experiment/04a/'
epoch_list = []
for s in os.listdir(teacher_list_dir):
    if re.match(r'trained_net_epoch-+', s):
        epoch_list.append(int(s[18:]))
epoch_list = sorted(epoch_list)

# loading ground truth
ground_truth = torch.load(cdir + 'ground_truth')

fig = plt.figure(figsize=(8, 6))

x_range = 20.0
x = torch.linspace(-x_range, x_range, 1000).reshape(-1, 1)
zero_line = 0 * x

grid = plt.GridSpec(
    3, 10, wspace=0.5, hspace=0.2, figure=fig, left=0.10, right=0.95, bottom=0.09, top=0.95
)
#
#
#
#
#
#
#
#
#
#
#
#
#
#
#
#

# plot first figure: classification boundary
plt.subplot(grid[0, 0:9])
plt.plot(x, zero_line, color='C1', linestyle='--', linewidth='1.0')

y_gt = ground_truth(x)
# * ( > 0).to(y_gt.dtype)
y_sgn = torch.sign(y_gt)
plt.plot(x, y_sgn, color='C0', label='class')

plt.ylim(-1.2, 1.2)
plt.legend(loc=(1.02, 0.4))
plt.ylabel(r'$\mathrm{sgn}(y)$', fontdict={'size': 20})
#
#
#
#
#
#
#
#
#

# plot first figure: classification boundary

plt.subplot(grid[1, 0:10])

cmp = plt.cm.get_cmap('viridis').reversed()
# cmp = plt.cm.get_cmap('winter')

plt.plot(x, zero_line, color='C1', linestyle='--', linewidth='1.0')
for i in range(5, len(epoch_list), 15):
    teacher = torch.load(
        teacher_list_dir + 'trained_net_epoch-{:06d}'.format(epoch_list[i]),
        map_location=torch.device('cpu')
    )
    y_teacher = teacher(x).detach().numpy()
    plt.plot(
        x,
        y_teacher,
        color=cmp(i),
    # label = '{:d}'.format(epoch_list[i])
    )

plt.ylim(-15, 15)
# plt.title(r'teacher output', fontdict={'size': 20})

plt.colorbar(
    matplotlib.cm.ScalarMappable(
        norm=matplotlib.colors.Normalize(vmin=0, vmax=len(epoch_list)), cmap=cmp
    ),
    orientation='vertical',
    aspect=20,
    fraction=0.0533,
    label='epoch'
)
plt.yscale('symlog')
plt.ylabel(r'$y_{\mathrm{t}}$', fontdict={'size': 20})

#
#
#
#
#
#
#
#
#
#
#
#

plt.subplot(grid[2, 0:9])
plt.plot(x, zero_line, color='C1', linestyle='--', linewidth='1.0')
cmp = plt.cm.get_cmap('tab20')

T = 5.0
q_e = torch.linspace(0.0, 7.0, 5)
rhos = (torch.sigmoid(q_e) - 1.0) / (torch.sigmoid(q_e) - 1.0 - (torch.sigmoid(q_e / T) - 0.5) / T)
# print(rhos)

teacher = torch.load(
    teacher_list_dir + 'trained_net_epoch-{:06d}'.format(epoch_list[-1]),
    map_location=torch.device('cpu')
)
y_teacher = teacher(x).detach()

for i, (rho) in enumerate(rhos):
    y_eff = f_eff_GD(y_teacher, (y_teacher > 0).to(y_teacher.dtype), rho, T)
    plt.plot(x, y_eff, color='C{:d}'.format(i), label='{:.02f}'.format(rho))

plt.legend(loc=(1.02, 0.1), title=r'$\rho$')


plt.ylim(-10, 10)
plt.ylabel(r'$y_{\mathrm{eff,s}}$', fontdict={'size': 20})
plt.xlabel(r'$x$', fontdict={'size': 20})

fig.savefig('figure/04a.pdf', format='pdf')
