import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))

import argparse
import time
import math
import numpy as np
import random
import torch
import torch.nn as nn

import data
import model
import wandb

from utils import batchify, get_batch, repackage_hidden, compute_grad_norm
from polyak import *
from dowg import *


# Modified for LSTM
from dog import DoG, DoG2
from parameterfree import COCOB
from dadaptation import DAdaptSGD

parser = argparse.ArgumentParser(description='PyTorch PennTreeBank RNN/LSTM Language Model')
parser.add_argument('--data', type=str, default='data/penn/',
                    help='location of the data corpus')
parser.add_argument('--model', type=str, default='LSTM',
                    help='type of recurrent net (LSTM, QRNN, GRU)')
parser.add_argument('--emsize', type=int, default=400,
                    help='size of word embeddings')
parser.add_argument('--nhid', type=int, default=1150,
                    help='number of hidden units per layer')
parser.add_argument('--nlayers', type=int, default=3,
                    help='number of layers')
parser.add_argument('--learning_rate', type=float, default=30,
                    help='initial learning rate')
parser.add_argument('--grad_clip', type=float, default=0.0,
                    help='gradient clipping')
parser.add_argument('--epochs', type=int, default=200,
                    help='upper epoch limit')
parser.add_argument('--batch_size', type=int, default=80, metavar='N',
                    help='batch size')
parser.add_argument('--bptt', type=int, default=70,
                    help='sequence length')
parser.add_argument('--dropout', type=float, default=0.4,
                    help='dropout applied to layers (0 = no dropout)')
parser.add_argument('--dropouth', type=float, default=0.3,
                    help='dropout for rnn layers (0 = no dropout)')
parser.add_argument('--dropouti', type=float, default=0.65,
                    help='dropout for input embedding layers (0 = no dropout)')
parser.add_argument('--dropoute', type=float, default=0.1,
                    help='dropout to remove words from embedding layer (0 = no dropout)')
parser.add_argument('--wdrop', type=float, default=0.5,
                    help='amount of weight dropout to apply to the RNN hidden to hidden matrix')
parser.add_argument('--seed', type=int, default=1111,
                    help='random seed')
parser.add_argument('--nonmono', type=int, default=5,
                    help='random seed')
parser.add_argument('--cuda', action='store_false',
                    help='use CUDA')
parser.add_argument('--log-interval', type=int, default=200, metavar='N',
                    help='report interval')
randomhash = ''.join(str(time.time()).split('.'))
parser.add_argument('--save', type=str,  default=randomhash+'.pt',
                    help='path to save the final model')
parser.add_argument('--alpha', type=float, default=2,
                    help='alpha L2 regularization on RNN activation (alpha = 0 means no regularization)')
parser.add_argument('--beta', type=float, default=1,
                    help='beta slowness regularization applied on RNN activiation (beta = 0 means no regularization)')
parser.add_argument('--wdecay', type=float, default=1.2e-6,
                    help='weight decay applied to all weights')
parser.add_argument('--method', type=str,  default='sgd',
                    help='optimizer to use (sgd, adam)')
args = parser.parse_args()
args.tied = True
n_iter = 0

wandb.init(project="polyak_lstm", dir="/work/YamadaU/takezawa/polyak_debug")
#wandb.init(project="polyak_lstm", dir="/data/takezawa/polyak")
wandb.config.update(args)

# Set the random seed manually for reproducibility.
np.random.seed(wandb.config.seed)
torch.manual_seed(wandb.config.seed)
random.seed(wandb.config.seed)

if torch.cuda.is_available():
    if not wandb.config.cuda:
        print("WARNING: You have a CUDA device, so you should probably run with --cuda")
    else:
        torch.cuda.manual_seed(wandb.config.seed)

###############################################################################
# Load data
###############################################################################

def model_save(fn):
    with open(fn, 'wb') as f:
        torch.save([model, criterion, optimizer], f)

def model_load(fn):
    global model, criterion, optimizer
    with open(fn, 'rb') as f:
        model, criterion, optimizer = torch.load(f)

import os
import hashlib
fn = 'corpus.{}.data'.format(hashlib.md5(wandb.config.data.encode()).hexdigest())
if os.path.exists(fn):
    print('Loading cached dataset...')
    corpus = torch.load(fn)
else:
    print('Producing dataset...')
    corpus = data.Corpus(wandb.config.data)
    torch.save(corpus, fn)

eval_batch_size = 80 #10
test_batch_size = 80 #1
train_data = batchify(corpus.train, wandb.config.batch_size, args)
val_data = batchify(corpus.valid, eval_batch_size, args)
test_data = batchify(corpus.test, test_batch_size, args)


###############################################################################
# Build the model
###############################################################################

from splitcross import SplitCrossEntropyLoss
criterion = None

ntokens = len(corpus.dictionary)
model = model.RNNModel(wandb.config.model, ntokens, wandb.config.emsize, wandb.config.nhid, wandb.config.nlayers, wandb.config.dropout, wandb.config.dropouth, wandb.config.dropouti, wandb.config.dropoute, wandb.config.wdrop, wandb.config.tied)

###
if not criterion:
    splits = []
    if ntokens > 500000:
        # One Billion
        # This produces fairly even matrix mults for the buckets:
        # 0: 11723136, 1: 10854630, 2: 11270961, 3: 11219422
        splits = [4200, 35000, 180000]
    elif ntokens > 75000:
        # WikiText-103
        splits = [2800, 20000, 76000]
    print('Using', splits)
    criterion = SplitCrossEntropyLoss(wandb.config.emsize, splits=splits, verbose=False)
###
if wandb.config.cuda:
    model = model.cuda()
    criterion = criterion.cuda()
###
params = list(model.parameters()) + list(criterion.parameters())
total_params = sum(x.size()[0] * x.size()[1] if len(x.size()) > 1 else x.size()[0] for x in params if x.size())
print('Args:', args)
print('Model total parameters:', total_params)

###############################################################################
# Training code
###############################################################################

def evaluate(data_source, batch_size=10):
    # Turn on evaluation mode which disables dropout.
    model.eval()
    if wandb.config.model == 'QRNN': model.reset()
    total_loss = 0
    ntokens = len(corpus.dictionary)
    hidden = model.init_hidden(batch_size)
    for i in range(0, data_source.size(0) - 1, wandb.config.bptt):
        data, targets = get_batch(data_source, i, args, evaluation=True)
        output, hidden = model(data, hidden)
        total_loss += len(data) * criterion(model.decoder.weight, model.decoder.bias, output, targets).data
        hidden = repackage_hidden(hidden)
    return total_loss.item() / len(data_source)



def train(epoch, n_iter):
    # Turn on training mode which enables dropout.
    if wandb.config.model == 'QRNN': model.reset()
    total_loss = 0
    start_time = time.time()
    ntokens = len(corpus.dictionary)
    hidden = model.init_hidden(wandb.config.batch_size)
    batch, i = 0, 0


    while i < train_data.size(0) - 1 - 1:
        if i > 0:
            wandb.log({"val/full_loss": None, "test/full_loss": None}, step = n_iter)
    
        bptt = wandb.config.bptt
        seq_len = bptt
        
        model.train()
        data, targets = get_batch(train_data, i, args, seq_len=seq_len)

        # Starting each batch, we detach the hidden state from how it was previously produced.
        # If we didn't, the model would try backpropagating all the way to start of the dataset.
        hidden = repackage_hidden(hidden)
        optimizer.zero_grad()

        output, hidden, rnn_hs, dropped_rnn_hs = model(data, hidden, return_h=True)
        raw_loss = criterion(model.decoder.weight, model.decoder.bias, output, targets)

        loss = raw_loss
        # Activiation Regularization
        if wandb.config.alpha: loss = loss + sum(wandb.config.alpha * dropped_rnn_h.pow(2).mean() for dropped_rnn_h in dropped_rnn_hs[-1:])
        # Temporal Activation Regularization (slowness)
        if wandb.config.beta: loss = loss + sum(wandb.config.beta * (rnn_h[1:] - rnn_h[:-1]).pow(2).mean() for rnn_h in rnn_hs[-1:])
        loss.backward()

        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
        if wandb.config.grad_clip > 0.0:
            torch.nn.utils.clip_grad_norm_(params, wandb.config.grad_clip)

        if "polyak" in wandb.config.method or "sps" in wandb.config.method:
            optimizer.step(loss)
        else:
            optimizer.step()

        total_loss += raw_loss.data

        ###
        batch += 1
        i += seq_len
        n_iter += 1
        ###
        
        # Log
        grad_norm = compute_grad_norm(model.parameters())
        if "polyak" in wandb.config.method or "sps" in wandb.config.method:
            wandb.log({"epoch": epoch, "iter": n_iter, "train/minibatch_loss": loss, "grad_norm": grad_norm, "lr": optimizer.old_lr}, step = n_iter)
        else:
            wandb.log({"epoch": epoch, "iter": n_iter, "train/minibatch_loss": loss, "grad_norm": grad_norm}, step = n_iter)
            
        if batch % wandb.config.log_interval == 0 and batch > 0:
            cur_loss = total_loss.item() / wandb.config.log_interval
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:05.5f} | ms/batch {:5.2f} | '
                    'loss {:5.2f} | ppl {:8.2f} | bpc {:8.3f}'.format(
                epoch, batch, len(train_data) // wandb.config.bptt, optimizer.param_groups[0]['lr'],
                elapsed * 1000 / wandb.config.log_interval, cur_loss, math.exp(cur_loss), cur_loss / math.log(2)))
            total_loss = 0
            start_time = time.time()
        
    return n_iter

# Loop over epochs.
lr = wandb.config.learning_rate
best_val_loss = []
stored_loss = 100000000

# At any point you can hit Ctrl + C to break out of training early.
try:
    optimizer = None
    
    # Ensure the optimizer is optimizing params, which includes both the model's weights as well as the criterion's weight (i.e. Adaptive Softmax)
    if wandb.config.method == 'sgd':
        optimizer = torch.optim.SGD(params, lr=wandb.config.learning_rate, weight_decay=wandb.config.wdecay)
    elif wandb.config.method == 'adamw':
        optimizer = torch.optim.AdamW(params, lr=wandb.config.learning_rate, weight_decay=wandb.config.wdecay)
    elif wandb.config.method == 'inexact_polyak':
        optimizer = InexactPolyakOptimizer(params, total_iteration=(train_data.size(0) - 1) / wandb.config.bptt * wandb.config.epochs)
        print(math.ceil((train_data.size(0) - 1) / wandb.config.bptt) * wandb.config.epochs)
    elif wandb.config.method == 'polyak':
        optimizer = PolyakOptimizer(params)
    elif wandb.config.method == "decsps":
        optimizer = DecSPSOptimizer(params)
    elif wandb.config.method == "adasps":
        optimizer = AdaSPSOptimizer(params)
    elif wandb.config.method == "dog":
        optimizer = DoG(params)
    elif wandb.config.method == "dowg":
        optimizer = DoWG(params)
    elif wandb.config.method == "cocob":
        optimizer = COCOB(params)
    elif wandb.config.method == "dadaptation":
        optimizer = DAdaptSGD(params)
    elif wandb.config.method == 'layer_inexact_polyak':
        optimizer = LInexactPolyakOptimizer(params, total_iteration=(train_data.size(0) - 1) / wandb.config.bptt * wandb.config.epochs)
        print(math.ceil((train_data.size(0) - 1) / wandb.config.bptt) * wandb.config.epochs)
        
        
    for epoch in range(1, wandb.config.epochs+1):
        epoch_start_time = time.time()
        val_loss = evaluate(val_data, eval_batch_size)
        test_loss = evaluate(test_data, test_batch_size)
        wandb.log({"val/full_loss": val_loss, "test/full_loss": test_loss}, step=n_iter)

        n_iter = train(epoch, n_iter)

        print('-' * 89)
        print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
              'valid ppl {:8.2f} | valid bpc {:8.3f}'.format(
                  epoch, (time.time() - epoch_start_time), val_loss, math.exp(val_loss), val_loss / math.log(2)))
        print('-' * 89)

        if val_loss < stored_loss:
            stored_loss = val_loss
                        
        best_val_loss.append(val_loss)

except KeyboardInterrupt:
    print('-' * 89)
    print('Exiting from training early')

# Run on test data.
test_loss = evaluate(test_data, test_batch_size)
print('=' * 89)
print('| End of training | test loss {:5.2f} | test ppl {:8.2f} | test bpc {:8.3f}'.format(
    test_loss, math.exp(test_loss), test_loss / math.log(2)))
print('=' * 89)

#wandb.finish()
