import fix_imports
import argparse

import numpy as np
import torch
import matplotlib
from sys import platform as sys_pf
if sys_pf == 'darwin':
    matplotlib.use("TkAgg")
import matplotlib.pyplot as plt
import matplotlib.collections as mc
import matplotlib.style as style
import seaborn as sns

from model import model_from_state


def sample_vectors(state, scale=1.0):
    x = np.sort(np.squeeze(state['uv']))

    lines = np.zeros([len(x), 2, 2])
    lines[:, 0, 1] = -x
    lines[:, 0, 0] = np.ones_like(x)
    lines[:, 1, 1] = x
    lines[:, 1, 0] = -np.ones_like(x)

    return lines*scale


def plot_bg(lines, ax, alpha, s=1.0):
    lx = [-s, s]
    ly1 = [-s, -s]
    ly2 = [s, s]
    ax.fill_between(lx, ly2, ly1, alpha=alpha)

    for i in range(lines.shape[0]-1):
        idx_next = (i + 1)
        lx = [-s, s]
        ly1 = [lines[i, 1, 1], lines[i, 0, 1]]
        ly2 = [lines[idx_next, 1, 1], lines[idx_next, 0, 1]]
        ax.fill_between(lx, ly1, ly2, alpha=alpha)


def main():
    print(style.available)
    style.use('seaborn')

    argparser = argparse.ArgumentParser()
    argparser.add_argument("state", type=str, help="Fitted model (out.pt) generated with fit_model.py")
    argparser.add_argument("-e", "--epoch", default=-1, type=int, help="Which epoch to plot the model at")
    argparser.add_argument("-s", "--scale", default=1.0, type=float, help="Scale factor for the figure")
    argparser.add_argument("-o", "--output", default="", help="Filename to save the figure as")
    args = argparser.parse_args()

    state = torch.load(args.state)

    model = model_from_state(state)
    model.load_state_dict(state['saved_states'][args.epoch][1])

    a, b, c = model[0].weight.data.squeeze(), model[0].bias.data.squeeze(), model[2].weight.data.squeeze()

    s = args.scale
    u = (a * torch.abs(c)).detach().cpu().numpy()
    v = (b * torch.abs(c)).detach().cpu().numpy()
    eps = torch.sign(c).detach().cpu().numpy()

    up, vp = u[eps > 0], v[eps > 0]
    un, vn = u[eps <= 0], v[eps <= 0]

    lines = sample_vectors(state, scale=s)

    if args.output:
        fig = plt.figure(figsize=(10, 10), dpi=240)
    else:
        fig = plt.figure()

    ax = plt.axes()

    cmap = sns.color_palette()
    ax.axis([-s, s, -s, s])
    plot_bg(lines, ax, alpha=0.4, s=s)
    lc = mc.LineCollection(lines, linewidths=3, label="$x_i$")
    ax.add_collection(lc)
    ax.scatter(up, vp, s=128, linewidths=2, label="$\epsilon > 0$", color=cmap[0])
    ax.scatter(un, vn, s=128, linewidths=2, label="$\epsilon < 0$", color=cmap[2])
    ax.legend(fontsize=25, frameon=True)
    # plt.xlabel("$u$")
    # plt.ylabel("$v$")

    frame1 = plt.gca()
    frame1.axes.xaxis.set_ticklabels([])
    frame1.axes.yaxis.set_ticklabels([])

    if args.output:
        plt.savefig(args.output, dpi=240)
    else:
        plt.show()


if __name__ == "__main__":
    main()
