import random
from collections import defaultdict
from typing import OrderedDict
from functools import partial
import argparse

import torch
import numpy as np
import matplotlib.pyplot as plt

import sys
import math
 
from optim import SGD, NSGD, Adam, Adagrad

# from torch.optim import Adagrad


from tensorboard_logger import Logger

# Argument
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=8, help='random seed')
parser.add_argument('--n_iter', type=int, default=3000, help='number of gradient calls')
parser.add_argument('--lr_x', type=float, default=0.01, help='learning rate of x')
parser.add_argument('--init_x', type=float, default=None, help='init value of x')
parser.add_argument('--grad_noise_x', type=float, default=1e-3, help='gradient noise variance')

parser.add_argument('--func', type=str, default='quadratic', help='function name')
parser.add_argument('--L', type=float, default=1, help='parameter for the test function')

parser.add_argument('--optim', type=str, default='adam', help='optimizer')
args = parser.parse_args()

# Set precision to 64
torch.set_default_dtype(torch.float64)

# Different functions
functions = OrderedDict()

L = args.L
functions["quadratic"] = {
        "func":
            lambda x: L * 0.5 * (x ** 2), #+ 3 * (torch.sin(x) ** 2),
        }

# Reproducibility
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)

n_iter = args.n_iter

print(f"Function: {args.func}")
print(f"Optimizer: {args.optim}")
fun = functions[args.func]["func"]

dim = 1

# Tensorboard
filename = f"./logs/{args.optim}_"
if args.func == 'quadratic':
    filename += f"L{args.L}"
else:
    filename += f"{args.func}"
filename += f"_lr_{args.lr_x}"
filename += f"_noise_{args.grad_noise_x}"

logger = Logger(filename)
logger.config_summary(args)

# learning rate
lr_x = args.lr_x

if args.init_x is None:
    init_x = torch.randn(dim)
else:
    init_x = torch.Tensor([args.init_x])

x = torch.nn.parameter.Parameter(init_x.clone())

optim_name = args.optim

if args.optim == "AdaGrad":
    optim_x = Adagrad([x], lr=lr_x)
elif args.optim == "SGD":
    optim_x = SGD([x], lr=lr_x)
elif args.optim == "NSGD":
    optim_x = NSGD([x], lr=lr_x)
elif args.optim == "AMSGrad":
    optim_x = Adam([x], lr=lr_x, betas=(0, 0), amsgrad=True)

i = 0
save_gap = 10
while i < n_iter:
    optim_x.zero_grad()
    l = fun(x)
    l.backward()
    # record gradient first, since we show deterministic gradients norm
    i += 1
    x_grad_norm = torch.norm(x.grad).item()
    logger.scalar_summary('x_grad', step=i, value=x_grad_norm)
    logger.scalar_summary('x', step=i, value=x.item())
    # stocastic gradient
    with torch.no_grad():
        x.grad += args.grad_noise_x * torch.randn(dim)
    optim_x.step()

    logger.scalar_summary('x_effective_stepsize', step=i, value=optim_x.effective_stepsize)
    if x_grad_norm > 1e30:
        break
