import os
from networks import LeNet5Feats, ResNetFeats18, classifier
#import resnet
import torch
import torch.nn as nn
from torch.autograd import Variable
from torchvision.datasets.mnist import MNIST
from torchvision.datasets import CIFAR10
from torchvision.datasets import CIFAR100
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import argparse
import higher
import hypergrad as hg
#from utils import save_checkpoint
import time
import matplotlib.pyplot as plt
import pickle
import numpy as np
import random
from ada_sls_bilevel import *

parser = argparse.ArgumentParser(description='Bilevel Training')

parser.add_argument('--dataset', type=str, default='MNIST', choices=['MNIST', 'cifar10'])
parser.add_argument('--data', type=str, default='./data')
parser.add_argument('--output_dir', type=str, default='Results_mnist')
parser.add_argument('--opt_out', type=str, default='SGD')
parser.add_argument('--opt_lower', type=int, default=2)
parser.add_argument('--opt_upper', type=int, default=1)
#parser.add_argument('--opt', type=int, default=1)
parser.add_argument('--eval_interval', type=int, default=10)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--alpha', type=float, default=10.0)
parser.add_argument('--beta', type=float, default=0.001)
parser.add_argument('--lamba', type=float, default=0.001)
parser.add_argument('--eta_max_upper', type=float, default=1.0)
parser.add_argument('--eta_max_lower', type=float, default=10.0)
parser.add_argument('--c', type=float, default=0.1)
parser.add_argument('--beta_b', type=float, default=0.9)
parser.add_argument('--gamma', type=float, default=10.0)
#parser.add_argument('--gamma', type=float, default=1000.0)
parser.add_argument('--delta', type=float, default=0.01)
parser.add_argument('--bs', type=int, default=256)
parser.add_argument('--n', type=int, default=10000)
parser.add_argument('--n_train', type=int, default=10000)
parser.add_argument('--n_test', type=int, default=60000)
parser.add_argument('--K', type=int, default=50)
parser.add_argument('--inner_steps', type=int, default=10)
parser.add_argument('--line_inner_steps', type=int, default=1)
parser.add_argument('--cg_steps', type=int, default=5)
parser.add_argument('--lower_search', type=int, default=1)
parser.add_argument('--upper_search', type=int, default=1)

args = parser.parse_args()

if not os.path.isdir(args.output_dir):
    os.makedirs(args.output_dir)

torch.random.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)

data_test = MNIST(args.data,
                    download=True,
                    transform=transforms.Compose([
                        transforms.Resize((32, 32)),
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ]))
data_train = MNIST(args.data,
                    train=False,                    
                    download=True,
                    transform=transforms.Compose([
                        transforms.Resize((32, 32)),
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ]))

class CustomTensorIterator:
    def __init__(self, dataset, batch_size, **loader_kwargs):
        self.loader = DataLoader(dataset, batch_size=batch_size, **loader_kwargs)
        self.iterator = iter(self.loader)

    def __next__(self, *args):
        try:
            idx = next(self.iterator)
        except StopIteration:
            self.iterator = iter(self.loader)
            idx = next(self.iterator)
        return idx

# data_test - 60000
# data_train - 10000
lamba = args.lamba
train_iterator = CustomTensorIterator(data_train, batch_size=args.bs, shuffle=True, num_workers=1)
val_iterator = CustomTensorIterator(data_test, batch_size=args.bs, shuffle=True, num_workers=1)
val_loader = DataLoader(data_test, batch_size=args.bs, shuffle=True)
nd_train = len(data_train)
nd_test = len(data_test)
data_outer = []
hypernet = LeNet5Feats().cuda()
cnet = classifier(n_features=84, n_classes=10).cuda()
fhnet = higher.monkeypatch(hypernet, copy_initial_weights=True).cuda()
hparams = list(hypernet.parameters())
hparams = [hparam.requires_grad_(True) for hparam in hparams]
fcnet = higher.monkeypatch(cnet, copy_initial_weights=True).cuda()
params = list(cnet.parameters())
params = [param.requires_grad_(True) for param in params]
lower_step_track = []
search_lower_cost_track = []

criterion = torch.nn.CrossEntropyLoss().cuda()

def evaluate(params, hparams):
    loss = 0.
    acc = 0.
    for images, labels in val_loader:
        images, labels = images.cuda(), labels.cuda()
        feats = fhnet(images, params=hparams)
        outputs = fcnet(feats, params=params)
        loss_temp = criterion(outputs, labels)
        loss = loss + loss_temp.item() * len(images)
        preds = outputs.data.max(1)[1]
        correct = preds.eq(labels.data.view_as(preds)).sum()
        acc = acc + correct
    acc = acc / nd_test
    loss  = loss / nd_test
    return loss, acc.cpu().numpy()

def inner_loss(params, hparams, data):
    images, labels = data 
    images, labels = images.cuda(), labels.cuda()
    feats = fhnet(images, params=hparams)
    outputs = fcnet(feats, params=params)
    loss = criterion(outputs, labels)
    l2_penalty  = 0.5 * lamba * sum([(p**2).sum() for p in hparams])
    loss = loss + l2_penalty
    return loss

def check_term(step_size, loss_next, loss, grad_norm,upper=False):
    break_condition = loss_next - \
        (loss - (step_size) * args.c * grad_norm**2)
    found = 0
    if upper: 
        if (break_condition <= args.delta):
            found = 1
        else:
            step_size = step_size * args.beta_b
    else:
        if (break_condition <= 0):
            found = 1
        else:
            step_size = step_size * args.beta_b
    return found, step_size

def reset_step_lower(step_size,args,n,upper=False,bs=256):
    
    if args.opt_lower == 1:
        step_size = args.eta_max_lower
    elif args.opt_lower == 2:
        # b is batch size, n is total number of points
            step_size = step_size * args.gamma
            #step_size = step_size * (args.gamma**(bs/n))
    else:
        step_size = step_size 
        #print(step_size)
    return step_size

def line_search_lower(params, hparams, step_size_old, loss, grads, args, data, step_size_inner=None,upper=False):
    
    # reset step size
    step_size = reset_step_lower(step_size_old,args,args.n_train,upper=False,bs=args.bs)
    grad_norm = compute_grad_norm(grads)
    params_temp = [p.detach().clone() for p in params]
    hparams_temp = [p.detach().clone() for p in hparams]
    e = 0
    n_search = 200
    if grad_norm >= 1e-8:
        for e in range(1,n_search+1):
            # if upper:
            #     hparams_new = [p.detach().clone() - step_size * g for p,g in zip(hparams, grads)]
            #     #params_hat, _ ,_ = inner_solver(params_temp, hparams_new,args,step_size_inner,steps=args.line_inner_steps)
            #     # Do one step sgd update to get params_hat
            #     #params_hat = steps_sgd(params_hat, hparams_new, step_size_inner, args) 
            #     params_hat = steps_sgd(params, hparams_new, step_size_inner, args) 
            #     loss_next = outer_loss([p.detach() for p in params_hat], hparams_new,data)
            #     found, step_size = check_term(step_size, loss_next, loss, grad_norm,upper=True)
            # else:
            params_new = [p - step_size * g for p,g in zip(params_temp, grads)]
            loss_next = inner_loss(params_new,hparams_temp,data)
            found, step_size = check_term(step_size, loss_next, loss, grad_norm,upper=False)
            if found == 1:
                break
        #print(found)
        if found == 0:
            print("Watch: not found after 100 eps")
            step_size = 1e-6
            e = n_search
    return step_size, e 

def inner_solver(params, hparams, args, step_size, steps=1): 
    search_cost_lower = 0
    params = [p.requires_grad_(True) for p in params]
    for i in range(steps):
        data  = next(train_iterator)
        loss = inner_loss(params, hparams,data)
        grads = torch.autograd.grad(loss, params)
        if args.lower_search == 0:
            step_size = args.alpha
            params = [p - step_size * g for p,g in zip(params, grads)]
        else:
            step_size, search_cost_lower = line_search_lower(params, hparams, step_size, loss.item(), grads, args, data, upper=False)    
            params = [p - step_size * g for p,g in zip(params, grads)]
            lower_step_track.append(step_size)
            search_lower_cost_track.append(search_cost_lower)
            #print(step_size, search_cost_lower)
    return params, step_size, search_cost_lower

# used in the cg function
def outer_func(params, hparams, more=False):
    
    images, labels = next(val_iterator)
    images, labels = images.cuda(), labels.cuda()
    feats = fhnet(images, params=hparams)
    outputs = fcnet(feats, params=params)
    loss = criterion(outputs, labels)
    preds = outputs.data.max(1)[1]
    correct = preds.eq(labels.data.view_as(preds)).sum()
    acc = float(correct) / labels.size(0)
    data_outer.append([images, labels])
    if more:
        return loss, acc
    else:
        return loss

# if args.upper_search == 0:
#     #outer_opt = torch.optim.SGD(lr=args.beta, params=hparams)
#     if args.opt_out == "Adam":
#         outer_opt = torch.optim.Adam(lr=args.beta, params=hparams)
#     elif args.opt_out == "SGD":
#         outer_opt = torch.optim.SGD(lr=args.beta, params=hparams)
#     elif args.opt_out == "SGDM":
#         outer_opt = torch.optim.SGD(lr=args.beta, momentum = 0.9, params=hparams)


if args.upper_search == 0:
    outer_opt = torch.optim.Adam(lr=args.beta, params=hparams)
else:
    outer_opt = AdaSLS(hparams, nd_test, args.bs, train_iterator,\
                    fhnet, fcnet, criterion, reset_option=args.opt_upper,\
                    gamma=args.gamma, eta_max_upper=args.eta_max_upper,\
                    delta=args.delta)

total_time, val_losses, running_time, hg_norms = 0,  [], [], []
val_acc = []
step_size_lower = args.eta_max_lower
step_size_upper = args.eta_max_upper

inner_opt_cg = hg.GradientDescent(inner_loss, 1., data_or_iter=train_iterator)
decay_steps = [25,40]
upper_step_track = []
search_cost_upper_track = []

for k in range(1,args.K+1):
    
    step_start_time = time.time() 
    nes = []
    params, step_size_lower, search_cost_lower = inner_solver(params, hparams, args, step_size_lower, steps=args.inner_steps)
    t1 = time.time() - step_start_time # inner loop time
    outer_opt.zero_grad()
    grads, cost = hg.CG(params, hparams, args.cg_steps, inner_opt_cg, outer_func, stochastic=True, set_grad=True)
    #step(self, loss_curr, grad, data, step_size_inner, closure=None, clip_grad=False)
    if args.upper_search == 1:
        outer_opt.step(params, cost.item(), grads, data_outer[-1],step_size_lower)
    else:
        outer_opt.step()

    # if args.upper_search == 0:
    #     outer_opt.zero_grad()
    #     grads, cost = hg.CG(params, hparams, args.cg_steps, inner_opt_cg, outer_func, stochastic=True, set_grad=True)
    #     outer_opt.step()
    # else:
    #     grads, cost = hg.CG(params, hparams, args.cg_steps, inner_opt_cg, outer_func, stochastic=True, set_grad=False)
    #     step_size_upper, search_cost_upper = line_search(params, hparams, step_size_upper, cost.item(), grads, args,data_outer[-1],step_size_lower, upper=True)
    #     hparams = [p - step_size_upper * g for p,g in zip(hparams,grads)]
    #     print(step_size_upper, search_cost_upper)
    #     upper_step_track.append(step_size_upper)
    #     search_cost_upper_track.append(search_cost_upper)
    #print(outer_opt)
    #print(len(data_outer))
    
    #opt_state = outer_opt.state_dict()
    #print(opt_state["param_groups"])
    #print(opt_state["step_size"], opt_state["search_cost"])
    print(outer_opt.state.get("step_size"),outer_opt.state.get("search_cost"))
    upper_step_track.append(outer_opt.state.get("step_size"))
    search_cost_upper_track.append(outer_opt.state.get("search_cost"))

    data_outer = []
    with torch.no_grad():
        val_loss_temp, acc_temp = evaluate(params, hparams)
        val_losses.append(val_loss_temp)
        val_acc.append(acc_temp)
    step_time = time.time()-step_start_time
    total_time +=step_time
    running_time.append(total_time)
    hg_norms.append(torch.norm(grads[0]))
    print('outer step={} | val loss={} | val acc={} |hypergrad norm = {:.3e}'.format(k, val_losses[-1], val_acc[-1],torch.norm(grads[0])))

results = {}
results["val_losses"] = val_losses
results["val_acc"] = val_acc
results["running_time"] = running_time 
results["upper_step_track"] = upper_step_track 
results["search_cost_upper_track"] = search_cost_upper_track 
results["lower_step_track"] = lower_step_track 
results["search_lower_cost_track"] = search_lower_cost_track

results["hg_norms"] = hg_norms
p_file = 'mnist_adamsls_lineInnerSteps{}_innerSteps{}_cgSteps{}_delta{}_lSearch{}_uSearch{}_alpha{}_beta{}_bs{}_opt{}_etaU{}_etaL{}_lamba{}_K{}_optU{}_optL{}_gamma{}_s{}'.format(args.line_inner_steps,\
                args.inner_steps, args.cg_steps, args.delta, args.lower_search, args.upper_search,args.alpha, args.beta,args.bs,args.opt_out,\
                args.eta_max_upper,args.eta_max_lower, args.lamba, args.K, args.opt_upper, args.opt_lower, args.gamma, args.seed)
p_file = "Results_final/" + p_file + ".pk"
pickle.dump(results, open(p_file, "wb" ))