import fix_imports
import matplotlib
from sys import platform as sys_pf
if sys_pf == 'darwin':
    matplotlib.use("TkAgg")

import matplotlib.pyplot as plt
import matplotlib.style as style

import argparse

import numpy as np
import torch


from model import model_from_state

from scipy.interpolate import CubicSpline


def lsq_abc(model, state):
    x = torch.from_numpy(state['uv'].astype(np.float32))
    y = torch.from_numpy(state['x'].astype(np.float32))

    a = model[0].weight
    b = model[0].bias

    M = a @ x.transpose(0, 1) + b.unsqueeze(1)
    M = torch.max(M, torch.zeros_like(M)).detach().cpu().numpy().transpose()
    lsq_soln = np.linalg.lstsq(M, y.cpu().squeeze().numpy(), rcond=-1)

    model[2].weight.data = torch.from_numpy(lsq_soln[0]).to(x).unsqueeze(0)

    return model


def plot_recon(model, x, u, n, plot_knots=True, use_double=False, title="Reconstruction vs. Ground Truth",
               show=True, newfig=True, legend=True, savefig=""):

    min_u, max_u = torch.min(u).item(), torch.max(u).item()
    a = model[0].weight.data.squeeze().cpu().detach().numpy()
    b = model[0].bias.data.squeeze().cpu().detach().numpy()
    c = model[2].weight.data.squeeze().cpu().detach().numpy()
    e = -b / a
    e = np.concatenate([np.array([-1.0]).astype(e.dtype), e, np.array([1.0]).astype(e.dtype)])
    e = np.sort(e[np.logical_and(min_u <= e, e <= max_u)])
    udense = torch.linspace(min_u, max_u, n).unsqueeze(1).to(x)
    with torch.no_grad():
        y = model(udense)
        fe = model(torch.from_numpy(e).to(udense).unsqueeze(1)).squeeze().cpu().numpy()

    if newfig:
        if savefig:
            plt.figure(figsize=(10, 10), dpi=500)
        else:
            plt.figure()
    mew = 3
    lw = 3
    ms = 12
    fs = 25

    # mew = 6
    # lw = 6
    # ms = 24
    # fs = 50

    # plt.title(title, fontsize=22)

    if plot_knots:
        plt.plot(e, fe, marker='o', markersize=ms, mew=mew, mfc='none', linewidth=lw, label="$(e_i, f(e_i))$")
        plt.plot(u.squeeze().cpu().numpy(), x.squeeze().cpu().numpy(), "--",
                 marker='x', markersize=ms, mew=mew, linewidth=lw, label="$(x_j, y_j)$")
    else:
        plt.plot(u.squeeze().cpu().numpy(), x.squeeze().cpu().numpy(), "--",
                 marker='o', markersize=ms, mew=mew, linewidth=lw, label="$(x_j, y_j)$")
        plt.plot(e, fe, linewidth=lw, label="$(e_i, f(e_i))$")
    # plt.grid()

    frame1 = plt.gca()
    frame1.axes.xaxis.set_ticklabels([])
    frame1.axes.yaxis.set_ticklabels([])

    if legend:
        plt.legend(fontsize=fs, frameon=True)

    if savefig:
        plt.savefig(savefig, dpi=240)

    if show:
        plt.show()


def main():
    print(style.available)
    style.use('seaborn')
    argparser = argparse.ArgumentParser()
    argparser.add_argument("state_file", type=str, help="Fitted model (out.pt) generated with fit_model.py")
    argparser.add_argument("--init", action="store_true", help="If set, plot the state at initialization")
    argparser.add_argument("--plot-knots", action="store_true", help="Plot the knots")
    argparser.add_argument("--lsq", action="store_true", 
                           help="Plot the result after doing a least squares kernel fit")
    argparser.add_argument("--output", "-o", default="", type=str, help="Filename to save the figure as")
    args = argparser.parse_args()

    state = torch.load(args.state_file)

    x, u = state['x'], state['uv']

    model = model_from_state(state)

    if args.init:
        model.load_state_dict(state['saved_states'][0][1])
    else:
        model.load_state_dict(state['final_state'])

    if args.lsq:
        model = lsq_abc(model, state)

    plot_recon(model, torch.from_numpy(x).to(state['device']),
               torch.from_numpy(u).to(state['device']),
               500, use_double=state['use_double'], plot_knots=args.plot_knots,
               savefig=args.output, show=not bool(args.output))


if __name__ == "__main__":
    main()