import fix_imports
import argparse
import copy

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

import utils
from geometry_1d import load_geometry
from model import make_model


def make_least_squares_model(model, x, y, reg=1e-2):
    model_params = copy.deepcopy(model.state_dict())

    a = copy.deepcopy(model_params["0.weight"]).detach().squeeze()
    b = copy.deepcopy(model_params["0.bias"]).detach().squeeze()
    print(a.shape, b.shape, x.shape, b.shape)
    m = (a.unsqueeze(0) * x.squeeze().unsqueeze(1)) + b.view(1, b.shape[0])
    m = torch.max(m, torch.zeros_like(m)).detach().cpu().numpy()

    m2 = m + reg*np.eye(m.shape[0], m.shape[1])
    cond = np.linalg.cond(m2)

    # print(k.shape, yy.shape)
    # c = np.linalg.solve(k + 1e-5*np.eye(k.shape[0]), yy.shape)
    c = np.linalg.lstsq(m2, y.cpu().squeeze().numpy(), rcond=-1)[0].reshape(model_params["2.weight"].shape)
    c = torch.from_numpy(c).to(x)
    model_params["2.weight"] = c

    return model_params, cond


def main():
    argparser = argparse.ArgumentParser()
    argparser.add_argument("num_hidden_units", type=int, default=3, nargs='+',
                           help="The number of units in the hidden layer")
    argparser.add_argument("--output", "-o", type=str, default="out.pt", help="File to output the model to")
    argparser.add_argument("--geometry", "-g", type=str, default="triangle1", help="Which geometry to fit")
    argparser.add_argument("--epochs", "-ne", type=int, default=1000, help="Number of fitting iterations")
    argparser.add_argument("--learning-rate", "-lr", type=float, default=1e-3, help="Step size for gradient descent")
    argparser.add_argument("--seed", type=int, default=-1, help="Random seed")
    argparser.add_argument("--save-every", type=int, default=1, help="Save the weights ever k iterations")
    argparser.add_argument("--last-layer-bias", "-b", action="store_true", default=False,
                           help="Use a bias term in the last layer")
    argparser.add_argument("--device", type=str, default="cpu", help="Which device to store the model on")
    argparser.add_argument("--double", "-dbl", action="store_true")
    argparser.add_argument("--init", type=str, default="default",
                           help="Initialize weights using 'default', 'normal', 'one-over-n', 'one-over-sqrt-n'")
    argparser.add_argument("--num-samples", "-n", type=int, default=32, help="number of samples (s) to fit")
    argparser.add_argument("--sampling", "-s", type=str, default="uniform",
                           help="How to sample the interval. One of 'uniform' or 'random'")
    argparser.add_argument("--noise", type=float, default=0.0, help="Amount of noise to add")
    argparser.add_argument("--laziness", type=str, default="default",
                           help="Type of laziness. One of 'pure', 'none', 'default'.")
    argparser.add_argument("--lr-decay", type=float, default=0, help="exponential learning rate decay factor")
    argparser.add_argument("--min-lr", type=float, default=-np.inf, help="minimum learning rate")
    argparser.add_argument("-cc", "--clamp-c", action="store_true", help="clamp the parameter c to +/-1")
    argparser.add_argument("-rsa", "--rescale-a", type=float, default=1.0, help="rescale the a and b parameters")
    argparser.add_argument("-rsc", "--rescale-c", type=float, default=1.0, help="rescale the c parameter")
    args = argparser.parse_args()

    seed = utils.seed_everything(args.seed)
    print("Using seed = %d" % seed)

    x, u = load_geometry(args.geometry, args.num_samples,
                         dtype=np.float64 if args.double else np.float32, sampling=args.sampling)
    x_clean = copy.deepcopy(x)
    x += args.noise * torch.randn(*x.shape)
    x = x.to(args.device)
    u = u.to(args.device)
    saved_states = []

    model = make_model(1, 1, args.num_hidden_units, last_layer_bias=args.last_layer_bias,
                       device=args.device, use_double=args.double, init=args.init)

    model[0].weight.data *= args.rescale_a
    model[0].bias.data *= args.rescale_a
    model[2].weight.data *= args.rescale_c

    if args.clamp_c:
        model[2].weight.data = torch.sign(model[2].weight.data)
        
    for k, v in model.state_dict().items():
        print(k, v.shape)

    if args.laziness == "pure":
        last_layer_index = 2*len(args.num_hidden_units)
        optimizer = optim.SGD(model[last_layer_index].parameters(),
                              lr=args.learning_rate, nesterov=True, momentum=0.9)
        # optimizer = optim.Adam(model[last_layer_index].parameters(), lr=args.learning_rate)
        # optimizer = optim.SGD(model[last_layer_index].parameters(), lr=args.learning_rate)
    elif args.laziness == "default":
        # optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
        optimizer = optim.SGD(model.parameters(), lr=args.learning_rate) #, nesterov=True, momentum=0.9)
        # optimizer = optim.SGD(model.parameters(), lr=args.learning_rate)
    else:
        velocity = list()
        velocity.append(torch.zeros_like(model[0].weight, requires_grad=False))
        velocity.append(torch.zeros_like(model[0].bias, requires_grad=False))
        velocity.append(torch.zeros_like(model[2].weight, requires_grad=False))

        momentum = 0.9

    loss_fun = nn.MSELoss()

    initial_model_state = model.state_dict().copy()

    losses = []
    gradients = []

    last_weights = copy.deepcopy(model.state_dict())

    conds = []
    least_squares_models = []

    for epoch in range(args.epochs):
        if epoch % args.save_every == 0:
            saved_states.append((epoch, last_weights))

        learning_rate = max(args.learning_rate * np.power(2, -args.lr_decay * epoch), args.min_lr)
        if args.laziness != "none":
            for param_group in optimizer.param_groups:
                param_group['lr'] = learning_rate
            optimizer.zero_grad()
        else:
            model.zero_grad()

        loss = loss_fun(x, model(u))
        loss.backward()

        if args.laziness != "none":
            optimizer.step()
        else:
            m = model[2].weight.shape[1]

            last_v = copy.deepcopy(velocity)
            velocity[0] = last_v[0] * momentum - learning_rate * model[0].weight.grad * m
            velocity[1] = last_v[1] * momentum - learning_rate * model[0].bias.grad * m
            velocity[2] = last_v[2] * momentum - learning_rate * model[2].weight.grad

            # Nesterov
            # model[0].weight.data += -momentum * last_v[0] + (1 + momentum) * velocity[0]
            # model[0].bias.data += -momentum * last_v[1] + (1 + momentum)* velocity[1]
            # model[2].weight.data += -momentum * last_v[2] + (1 + momentum) * velocity[2]

            # Regular momentum
            # model[0].weight.data += velocity[0]
            # model[0].bias.data += velocity[1]
            # model[2].weight.data += velocity[2]

            # Boring ol' gradient descent
            model[0].weight.data += -learning_rate * model[0].weight.grad * m
            model[0].bias.data += -learning_rate * model[0].bias.grad * m
            # model[2].weight.data += -learning_rate * model[2].weight.grad

            # print("last_layer_index: %d, m = %d" % (last_layer_index, m))

            # model[2].weight.data -= learning_rate * model[2].weight.grad
            # model[0].weight.data -= learning_rate * model[0].weight.grad * m
            # model[0].bias.data -= learning_rate * model[0].bias.grad * m

        last_weights = copy.deepcopy(model.state_dict())
        # lsqr_mdl_params, cond = make_least_squares_model(model, u, x)
        # least_squares_models.append(lsqr_mdl_params)
        # conds.append(cond)

        print("Epoch %04d: Loss = %0.5f, LR = %.5e" %
              (epoch, loss.item(), learning_rate))

        if epoch % args.save_every == 0:
            g = {}
            for name, param in model.named_parameters():
                g[name] = copy.deepcopy(param.grad.detach().cpu().numpy())
            gradients.append(g)

        losses.append(loss.item())

    print("Final Loss: %0.9f" % losses[-1])

    output = {
        'in_dim': 1,
        'out_dim': 1,
        'last_layer_bias': args.last_layer_bias,
        'initial_state': initial_model_state,
        'final_state': model.state_dict(),
        'losses': losses,
        'seed': seed,
        'learning_rate': args.learning_rate,
        'max_num_epochs': args.epochs,
        'num_epochs': epoch,
        'num_hidden_units': args.num_hidden_units,
        'x': x.detach().cpu().numpy(),
        'uv': u.detach().cpu().numpy(),
        'device': args.device,
        'saved_states': saved_states,
        'save_every': args.save_every,
        'geometry': args.geometry,
        'use_double': args.double,
        'gradients': gradients,
        'num_samples': args.num_samples,
        'sampling_type': args.sampling,
        'least_squares_models': least_squares_models,
        'condition_numbers': conds,
        'x_clean': x_clean.squeeze().detach().cpu().numpy(),
        'init': args.init,
        'laziness': args.laziness,
        'rescale_ab': args.rescale_a,
        'rescale_c': args.rescale_c,
    }
    torch.save(output, args.output)


if __name__ == "__main__":
    main()