#!/usr/bin/env python3
# from https://github.com/JunjieYang97/stocBiO

import random
import os

import numpy as np
import torch
from torch import nn
from torch import optim
import argparse
from torch.utils.tensorboard import SummaryWriter
import shutil
from torch.autograd import grad
from torch.nn import DataParallel

import learn2learn as l2l
from learn2learn.data.transforms import NWays, KShots, LoadData, RemapLabels, ConsecutiveLabels
import torchvision.transforms as transforms

import utils
from utils import log
from tqdm import tqdm
import higher
import time

from torchviz import make_dot

parser = argparse.ArgumentParser()

parser.add_argument('--gpuid', default='0', type=str)
# meta-learning configurations
parser.add_argument('--ways', type=int, default=5)
parser.add_argument('--shots', type=int, default=5)
parser.add_argument('--num_train_tasks', type=int, default=20000)
parser.add_argument('--meta_batch_size', type=int, default=16)
parser.add_argument('--dataset', default='mini', type=str)
parser.add_argument('--data_dir', default='data', type=str)
# minimax-bilevel configurations
parser.add_argument('--lr_gamma', type=float, default=0.02)
parser.add_argument('--lr_outer', type=float, default=0.02)
parser.add_argument('--lr_inner', type=float, default=0.05)
parser.add_argument('--lr_v', type=float, default=0.1)
parser.add_argument('--lambda_value', type=float, default=1)
parser.add_argument('--eps', type=float, default=1e-3)
# training configuration
parser.add_argument('--k', type=int, default=12)
parser.add_argument('--training_steps', type=int, default=200)
parser.add_argument('--inner_steps', type=int, default=5)
parser.add_argument('--model', default='cnn4', type=str)
# logging configuration
parser.add_argument('--save_dir', default='exp')
# load model configuration
parser.add_argument('--ckpt_path', default=None)

parser.add_argument('--total', type=int, default=3000)
parser.add_argument('--seed', type=int, default=10086)

# noise setting
parser.add_argument('--rand', type=int, default=0)
parser.add_argument('--rand_max', type=float, default=0.5)

parser.add_argument('--flip', type=int, default=0)
parser.add_argument('--flip_thresh', type=float, default=0.6)
parser.add_argument('--flip_ratio', type=float, default=0.8)

parser.add_argument('--scale_gamma', type=int, default=0)
parser.add_argument('--gamma_scale', type=float, default=1)

args = parser.parse_args()

os.environ["CUDA_VISIBLE_DEVICES"] = args.gpuid

'''
################################################################
Supporting functions
################################################################
'''

class Lambda(nn.Module):

    def __init__(self, fn):
        super(Lambda, self).__init__()
        self.fn = fn

    def forward(self, x):
        return self.fn(x)


def accuracy(predictions, targets):
    predictions = predictions.argmax(dim=1).view(targets.shape)
    return (predictions == targets).sum().float() / targets.size(0)


def update_v(batch, mp_feature, phi, mp_head, v, lr_v, inner_func, inner_steps, warmup=False, train=True, test=False):
    if test:
        step_loss = []
        step_grad = []
    data, labels = batch
    data, labels = data.to(device), labels.to(device)

    adaptation_indices = np.zeros(data.size(0), dtype=bool)
    evaluation_indices = np.zeros(data.size(0), dtype=bool)
    test_indices = np.zeros(data.size(0), dtype=bool)
    
    adaptation_indices[np.arange(args.shots*args.ways) * 3] = True
    evaluation_indices[np.arange(args.shots*args.ways) * 3 + 1] = True
    test_indices[np.arange(args.shots*args.ways) * 3 + 2] = True
    
    adaptation_indices = torch.from_numpy(adaptation_indices)
    evaluation_indices = torch.from_numpy(evaluation_indices)
    test_indices = torch.from_numpy(test_indices)
    
    adaptation_data, adaptation_labels = data[adaptation_indices], labels[adaptation_indices]
    evaluation_data, evaluation_labels = data[evaluation_indices], labels[evaluation_indices]
    test_data, test_labels = data[test_indices], labels[test_indices]
    
    for step in range(inner_steps):
        feat = mp_feature(adaptation_data, params=phi)
        output = mp_head(feat, params=v)
        inner_loss = inner_func(output, adaptation_labels)
        grads = torch.autograd.grad(inner_loss, v)

        if test:
            step_loss.append(inner_loss.detach().cpu().item())
            cur_step_grad = []
            
        for i in range(len(v)):
            if test:
                cur_step_grad.append(torch.norm(grads[i].detach().cpu()))
            v[i] = v[i] - lr_v * grads[i]
        del grads, inner_loss, output, feat
        if test:
            step_grad.append(np.mean(cur_step_grad))

    train_feat = mp_feature(adaptation_data, params=phi)
    train_output = mp_head(train_feat, params=v)
    train_v_loss = inner_func(train_output, adaptation_labels)

    eval_feat = mp_feature(evaluation_data, params=phi)
    eval_output = mp_head(eval_feat, params=v)
    eval_v_loss = inner_func(eval_output, evaluation_labels)

    if not train:
        test_feat = mp_feature(test_data, params=phi)
        test_output = mp_head(test_feat, params=v)
        test_v_loss = inner_func(test_output, test_labels)

    del v, batch, data, labels
    
    if test:
        train_acc_v = accuracy(train_output.detach(), adaptation_labels)
        eval_acc_v = accuracy(eval_output.detach(), evaluation_labels)
        test_acc_v = accuracy(test_output.detach(), test_labels)
        del phi, train_output, eval_output, test_output, train_feat, eval_feat, test_feat
        return train_v_loss, train_acc_v, eval_v_loss, eval_acc_v, test_v_loss, test_acc_v
    if warmup:
        del phi, train_v_loss, train_output, eval_output, train_feat, eval_feat
        return eval_v_loss
    if train:
        train_phi_grad = torch.autograd.grad(train_v_loss, phi)

        del phi, train_v_loss, train_feat, eval_feat, train_output, eval_output
        return train_phi_grad, eval_v_loss
    else:
        train_acc_v = accuracy(train_output.detach(), adaptation_labels)
        eval_acc_v = accuracy(eval_output.detach(), evaluation_labels)
        test_acc_v = accuracy(test_output.detach(), test_labels)
        del phi, train_output, eval_output, test_output, train_feat, eval_feat, test_feat
        return train_v_loss, train_acc_v, eval_v_loss, eval_acc_v, test_v_loss, test_acc_v


def update_inner(batch, mp_feature, phi, mp_head, w, lr_inner, _lambda,
                 inner_func, outer_func, gamma, k, n_tasks, inner_steps, train=True, test=False):
    data, labels = batch
    data, labels = data.to(device), labels.to(device)
    
    adaptation_indices = np.zeros(data.size(0), dtype=bool)
    evaluation_indices = np.zeros(data.size(0), dtype=bool)
    test_indices = np.zeros(data.size(0), dtype=bool)
    
    adaptation_indices[np.arange(args.shots*args.ways) * 3] = True
    evaluation_indices[np.arange(args.shots*args.ways) * 3 + 1] = True
    test_indices[np.arange(args.shots*args.ways) * 3 + 2] = True
    
    adaptation_indices = torch.from_numpy(adaptation_indices)
    evaluation_indices = torch.from_numpy(evaluation_indices)
    test_indices = torch.from_numpy(test_indices)
    
    adaptation_data, adaptation_labels = data[adaptation_indices], labels[adaptation_indices]
    evaluation_data, evaluation_labels = data[evaluation_indices], labels[evaluation_indices]
    test_data, test_labels = data[test_indices], labels[test_indices]
    
    for step in range(inner_steps):
        
        inner_feat = mp_feature(adaptation_data, params=phi)
        inner_output = mp_head(inner_feat, params=w)
        inner_loss_w = inner_func(inner_output, adaptation_labels)
        inner_grads_w = torch.autograd.grad(inner_loss_w, w)

        outer_feat = mp_feature(evaluation_data, params=phi)
        outer_output = mp_head(outer_feat, params=w)
        outer_g_loss = inner_func(outer_output, evaluation_labels) 
        eps = torch.Tensor(args.eps * np.random.normal(size=n_tasks)).cuda()
        
        outer_loss = outer_func(k, n_tasks, torch.stack([outer_g_loss]), gamma, eps)
        outer_grads_w = torch.autograd.grad(outer_loss, w)
        
        for i in range(len(w)):
            lagrangian_grads_w = outer_grads_w[i] + _lambda * inner_grads_w[i]
            w[i] = w[i] - lr_inner * lagrangian_grads_w
            del lagrangian_grads_w, 
        del outer_g_loss, outer_grads_w, inner_grads_w
        
    train_feat = mp_feature(adaptation_data, params=phi)
    train_output = mp_head(train_feat, params=w)
    train_w_loss = inner_func(train_output, adaptation_labels)
    eval_feat = mp_feature(evaluation_data, params=phi)
    eval_output = mp_head(eval_feat, params=w)
    eval_w_loss = inner_func(eval_output, evaluation_labels)

    if not train:
        test_feat = mp_feature(test_data, params=phi)
        test_output = mp_head(test_feat, params=w)
        test_w_loss = inner_func(test_output, test_labels)
    
    del w, data, labels
    if test:
        train_acc_w = accuracy(train_output.detach(), adaptation_labels)
        eval_acc_w = accuracy(eval_output.detach(), evaluation_labels)
        test_acc_w = accuracy(test_output.detach(), test_labels)
        del phi, train_output, eval_output, test_output, train_feat, eval_feat, test_feat
        return train_w_loss, train_acc_w, eval_w_loss, eval_acc_w, test_w_loss, test_acc_w
        
    if train:
        train_phi_grad = torch.autograd.grad(train_w_loss, phi)
        eval_phi_grad = torch.autograd.grad(eval_w_loss, phi)
        del phi, train_w_loss
        del train_feat, train_output, eval_feat, eval_output
        return eval_w_loss, train_phi_grad, eval_phi_grad, eps
    else:
        train_acc_w = accuracy(train_output.detach(), adaptation_labels)
        eval_acc_w = accuracy(eval_output.detach(), evaluation_labels)
        test_acc_w = accuracy(test_output.detach(), test_labels)
        del phi, train_output, eval_output, test_output, train_feat, eval_feat, test_feat
        return train_w_loss, train_acc_w, eval_w_loss, eval_acc_w, test_w_loss, test_acc_w
        

def update_gamma(eval_losses_w, gamma, k, n_tasks, lr_gamma, eps):
    hinges = utils.hinge(eval_losses_w, gamma)>0

    grad_gamma = torch.mean(hinges.to(torch.int) - (n_tasks - k) / n_tasks)
    gamma_old = gamma
    
    gamma = gamma + lr_gamma * grad_gamma
    del eval_losses_w, eps
    # raise
    return gamma_old, gamma

def update_outer(eval_losses_w, eval_grad_phi, train_grad_phi_w, train_grad_phi_v, eps,
                  phi, gamma, lr_outer, n_tasks, _lambda):
    hinges = utils.hinge(eval_losses_w, gamma)>0

    for t in range(n_tasks):
        for i in range(len(phi)):
            temp_grad = (eval_grad_phi[t][i] - hinges[t] * eval_grad_phi[t][i] + _lambda * (train_grad_phi_w[t][i] - train_grad_phi_v[t][i])) / n_tasks
            phi[i] = phi[i] - lr_outer * temp_grad
    
    del eval_losses_w, gamma, train_grad_phi_v, train_grad_phi_w, eval_grad_phi, eps
    return phi

'''
################################################################
starts here
################################################################
'''
if __name__ == "__main__":
    seed = args.seed
    cuda = True
    
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device('cpu')
    if cuda and torch.cuda.device_count():
        torch.cuda.manual_seed(seed)
        device = torch.device('cuda')
        

    save_dir = args.save_dir
    if args.ckpt_path == None:
        if os.path.isdir(save_dir):
            shutil.rmtree(save_dir)
        os.mkdir(save_dir)
    log_file_path = os.path.join(save_dir, "train_log.txt")
    log(log_file_path, str(vars(args)))

    '''
    ################################################################
    loading data
    ################################################################
    '''
    if args.dataset == 'mini':
        train_dataset = l2l.vision.datasets.MiniImagenet(root=args.data_dir, mode='train')
        valid_dataset = l2l.vision.datasets.MiniImagenet(root=args.data_dir, mode='validation')
        test_dataset = l2l.vision.datasets.MiniImagenet(root=args.data_dir, mode='test')
    elif args.dataset == 'tiered':
        tiered_transform = transforms.ToTensor()
        train_dataset = l2l.vision.datasets.TieredImagenet(root=args.data_dir, mode='train', download=True, transform=tiered_transform)
        valid_dataset = l2l.vision.datasets.TieredImagenet(root=args.data_dir, mode='validation', download=True, transform=tiered_transform)
        test_dataset = l2l.vision.datasets.TieredImagenet(root=args.data_dir, mode='test', download=True, transform=tiered_transform)
    
    train_dataset = l2l.data.MetaDataset(train_dataset)
    valid_dataset = l2l.data.MetaDataset(valid_dataset)
    test_dataset = l2l.data.MetaDataset(test_dataset)

    train_transforms = [
            NWays(train_dataset, args.ways),
            KShots(train_dataset, 3*args.shots),
            LoadData(train_dataset),
            RemapLabels(train_dataset),
            ConsecutiveLabels(train_dataset),
        ]
    train_tasks = l2l.data.TaskDataset(train_dataset,
                                        task_transforms=train_transforms,
                                        num_tasks=args.num_train_tasks)

    valid_transforms = [
            NWays(valid_dataset, args.ways),
            KShots(valid_dataset, 3*args.shots),
            LoadData(valid_dataset),
            ConsecutiveLabels(valid_dataset),
            RemapLabels(valid_dataset),
        ]
    valid_tasks = l2l.data.TaskDataset(valid_dataset,
                                        task_transforms=valid_transforms,
                                        num_tasks=600)

    test_transforms = [
            NWays(test_dataset, args.ways),
            KShots(test_dataset, 3*args.shots),
            LoadData(test_dataset),
            RemapLabels(test_dataset),
            ConsecutiveLabels(test_dataset),
        ]
    test_tasks = l2l.data.TaskDataset(test_dataset,
                                        task_transforms=test_transforms,
                                        num_tasks=600)
    
    '''
    ################################################################
    model configuration
    ################################################################
    '''
    _lambda = args.lambda_value
    lr_gamma = args.lr_gamma
    if args.scale_gamma:
        lr_gamma = lr_gamma * args.gamma_scale
    lr_outer = args.lr_outer
    lr_inner = args.lr_inner
    lr_v = args.lr_v
    lr_v = _lambda * lr_inner if not _lambda == 0 else lr_inner

    if l2l.__version__ == '0.1.5':
        if args.model == 'resnet12':
            print('model is RN12')
            features = utils.ResNet12Backbone(output_size=32).to(device)
            head = torch.nn.Linear(640, args.ways)   
        elif args.model == 'cnn4':
            print('model is CNN4')
            features = l2l.vision.models.ConvBase(output_size=32, channels=3, max_pool=True)
            features = torch.nn.Sequential(features, Lambda(lambda x: x.view(-1, 1600))).to(device)
            head = torch.nn.Linear(1600, args.ways)
        head = l2l.algorithms.MAML(head, lr=lr_inner)
        head.to(device)
    elif l2l.__version__ == '0.1.7' or l2l.__version__ == '0.2.0':
        print('version should be 0.1.7')
        if args.model == 'resnet12':
            features = l2l.vision.models.ResNet12Backbone().to(device)
            head = torch.nn.Linear(640, args.ways)   
        elif args.model == 'cnn4':
            features = l2l.vision.models.ConvBase(hidden=64, channels=3, max_pool=True)
            features = torch.nn.Sequential(features, Lambda(lambda x: x.view(-1, 1600))).to(device)
            head = torch.nn.Linear(1600, args.ways)
        head = l2l.algorithms.MAML(head, lr=lr_inner)
        head.to(device)
    else:
        raise NotImplementedError()
    
    mp_feature = higher.monkeypatch(features, copy_initial_weights=True).to(device)
    phi_list = list(features.parameters())
    phi = [param.requires_grad_(True) for param in phi_list]

    mp_head = higher.monkeypatch(head, copy_initial_weights=True).to(device)
    w_list = list(head.parameters())

    inner_func = nn.CrossEntropyLoss(reduction='mean')
    outer_loss = utils.btk_eps_surrogate
    
    if args.ckpt_path == None:
        log(log_file_path, "Start initializing gamma...")
        eval_losses_v = []
        # there are a lot of memory leak here
        for task in range(args.meta_batch_size):
            batch = train_tasks.sample()
            v = [param.requires_grad_(True) for param in w_list]
            eval_loss_v = update_v(batch, mp_feature=mp_feature, phi=phi, mp_head=mp_head, 
                                v=v, lr_v=lr_v, inner_func=inner_func, inner_steps=args.inner_steps, warmup=True)
            eval_losses_v.append(eval_loss_v)
            del batch, v
            del eval_loss_v
            
        init_losses = torch.stack(eval_losses_v)
        sorted_losses = torch.sort(init_losses, descending=False)
        sorted_losses = torch.unbind(sorted_losses.values)
        
        gamma = sorted_losses[args.k-1].item() # gamma
        
        log(log_file_path, f"Initilization finished, gamma = {gamma}.")
    else:
        try:
            load_dict = torch.load(os.path.join(save_dir, args.ckpt_path))
        except:
            raise(FileNotFoundError("Need to provide exact filename of checkpoint!"))
        phi = load_dict['phi']
        gamma = load_dict['gamma']
        _iter = load_dict['_iter']
        lr_gamma -= args.gamma_scale / args.total * args.lr_gamma * _iter
        print(lr_gamma)

    best_acc_v = 0
    best_iter_v = -1
    best_acc_w = 0
    best_iter_w = -1
    best_test_acc_w = 0
    best_test_acc_v = 0
    best_train_test_acc_v = 0
    best_train_test_acc_w = 0
    
    _iter = 0 if args.ckpt_path == None else _iter

    while _iter < args.training_steps:
        if args.rand:
            noise_ratio_list = []
            for i in range(args.meta_batch_size):
                noise_ratio_list.append(random.choice(list(range(int(10 * args.rand_max)))) / 10.)
        elif args.flip:
            noise_ratio_list = []
            for i in range(args.meta_batch_size):
                flip = np.random.random()
                if flip < args.flip_thresh:
                    noise_ratio_list.append(args.flip_ratio) 
                else:
                    noise_ratio_list.append(0)
        else:
            noise_ratio_list = [0]*args.meta_batch_size
        torch.cuda.empty_cache()
        
        eval_losses_v = []
        eval_losses_w = []

        train_grads_phi_v = []
        train_grads_phi_w = []
        eval_grads_phi_w = []

        val_losses_w = []
        val_accs_w = []
        val_losses_v = []
        val_accs_v = []
        
        train_test_acc_v = []
        train_test_acc_w = []
        train_test_loss_v = []
        train_test_loss_w = []
        for task in range(args.meta_batch_size):
            torch.cuda.empty_cache()
            batch = train_tasks.sample()
            
            images, labels  = batch
            noisy_label = utils.flip_label(labels, noise_ratio_list[task])
            noisy_label = torch.Tensor(noisy_label).type(torch.long)
            batch = [images, noisy_label]
            
            w = [param.requires_grad_(True) for param in w_list] 
            v = [param.requires_grad_(True) for param in w_list]
            
            start = time.time()
            train_grad_phi_v, eval_loss_v = update_v(batch, 
                             mp_feature=mp_feature, phi=phi, mp_head=mp_head, 
                             v=v, lr_v=lr_v, inner_func=inner_func, inner_steps=args.inner_steps)
            end = time.time()

            train_grads_phi_v.append(train_grad_phi_v)
            eval_losses_v.append(eval_loss_v)

            eval_loss_w, train_grad_phi_w, eval_grad_phi_w, eps = update_inner(batch=batch, 
                                       mp_feature=mp_feature, phi=phi, 
                                       mp_head=mp_head, w=w, lr_inner=lr_inner, _lambda=_lambda, 
                                       inner_func=inner_func, outer_func=outer_loss, gamma=gamma, 
                                        k=args.k, 
                                       n_tasks=args.meta_batch_size, inner_steps=args.inner_steps)
            
            eval_losses_w.append(eval_loss_w)
            train_grads_phi_w.append(train_grad_phi_w)
            eval_grads_phi_w.append(eval_grad_phi_w)
            del w, v
            

            val_batch = valid_tasks.sample()
            w = [param.requires_grad_(True) for param in w_list]
            v = [param.requires_grad_(True) for param in w_list]
            
            train_v_loss, train_acc_v, eval_v_loss, eval_acc_v, test_v_loss, test_acc_v = update_v(val_batch, 
                             mp_feature=mp_feature, phi=phi, mp_head=mp_head, 
                             v=v, lr_v=lr_v, inner_func=inner_func, inner_steps=args.inner_steps, train=False)

            val_losses_v.append(eval_v_loss.item())
            val_accs_v.append(eval_acc_v.item())
            train_test_loss_v.append(test_v_loss.item())
            train_test_acc_v.append(test_acc_v.item())

            train_w_loss, train_acc_w, eval_w_loss, eval_acc_w, test_w_loss, test_acc_w = update_inner(batch=val_batch, 
                                       mp_feature=mp_feature, phi=phi, 
                                       mp_head=mp_head, w=w, lr_inner=lr_inner, _lambda=_lambda, 
                                       inner_func=inner_func, outer_func=outer_loss, gamma=gamma, 
                                       k=args.k,
                                       n_tasks=args.meta_batch_size, inner_steps=args.inner_steps, train=False)  
            
            val_losses_w.append(eval_w_loss.item())
            val_accs_w.append(eval_acc_w.item())
            train_test_loss_w.append(test_w_loss.item())
            train_test_acc_w.append(test_acc_w.item())
            
            del train_v_loss, train_w_loss, w, v, test_v_loss, test_acc_v, test_w_loss, test_acc_w
            

        train_set_eval_loss = [item.detach().cpu() for item in eval_losses_w]
        eval_losses_w = torch.stack(eval_losses_w)
        eval_losses_v = torch.stack(eval_losses_v)
        
        lr_gamma -= args.gamma_scale / args.total * args.lr_gamma
        gamma_old, gamma = update_gamma(eval_losses_w=eval_losses_v, gamma=gamma, k=args.k, eps=eps,
                     n_tasks=args.meta_batch_size, lr_gamma=lr_gamma)

        phi = update_outer(eval_losses_w=eval_losses_w, eval_grad_phi=eval_grads_phi_w, eps=eps,
                     train_grad_phi_w=train_grads_phi_w, train_grad_phi_v=train_grads_phi_v,
                     gamma=gamma_old, lr_outer=lr_outer, phi=phi, 
                     n_tasks=args.meta_batch_size, _lambda=_lambda)
        
        temp_phi = []
        
        gamma_temp = gamma.detach().requires_grad_(True)
        del gamma
        gamma = gamma_temp
        
        for item in phi:
            temp_phi.append(item.detach().requires_grad_(True))
            del item
        phi = temp_phi
        
        del eval_losses_w, gamma_old, eps
        
        iter_loss_v = np.mean(val_losses_v)
        iter_acc_v = np.mean(val_accs_v)
        iter_loss_w = np.mean(val_losses_w)
        iter_acc_w = np.mean(val_accs_w)
        
        iter_test_loss_v = np.mean(train_test_loss_v)
        iter_test_acc_v = np.mean(train_test_acc_v)
        iter_test_loss_w = np.mean(train_test_loss_w)
        iter_test_acc_w = np.mean(train_test_acc_w)
        
        if best_acc_v < iter_acc_v:
            best_iter_v = _iter
            best_acc_v = iter_acc_v
        if best_acc_w < iter_acc_w:
            best_iter_w = _iter
            best_acc_w = iter_acc_w
        if best_train_test_acc_v < iter_test_acc_v:
            best_train_test_acc_v = iter_test_acc_v
        if best_train_test_acc_w < iter_test_acc_w:
            best_train_test_acc_w = iter_test_acc_w
        
        utils.log(log_file_path, f'''Iteration {_iter} Finished!
                  Train Loss w : {torch.mean(torch.stack(train_set_eval_loss)).item()}
                  Loss using v: {iter_loss_v}
                  Accuracy using v: {iter_acc_v}
                  Best accuracy using v: {best_acc_v}
                  Loss using w: {iter_loss_w}
                  Accuracy using w: {iter_acc_w}
                  Best accuracy using w: {best_acc_w}
                  Test loss using v: {iter_test_loss_v}
                  Test acc using v: {iter_test_acc_v}
                  Best test acc using v: {best_train_test_acc_v}
                  Test loss using w: {iter_test_loss_w}
                  Test acc using w: {iter_test_acc_w}
                  Best test acc using w: {best_train_test_acc_w}
                  gamma: {gamma}
                  lr_ouuter: {lr_outer}
                  lambda: {_lambda}
                  ''')
        del train_set_eval_loss, train_grads_phi_v, train_grads_phi_w, \
            eval_grads_phi_w, val_losses_w, val_accs_w, val_losses_v, val_accs_v, \
            train_test_acc_v, train_test_acc_w, train_test_loss_v, train_test_loss_w
        
        _iter += 1
        if _iter % 50 == 0:
            torch.save({
                        'phi': phi,
                        'gamma': gamma,
                        '_iter': _iter
                }, os.path.join(save_dir, 'latest.ckpt'))

    
    utils.log(log_file_path, f'''Training Finished! 
              Best accuracy using v: {best_acc_v}
              Best accuracy using w: {best_acc_w}
              Best test accuracy using v: {best_test_acc_v}
              Best test accuracy using w: {best_test_acc_w}
              ''')
        
        