
import torch.nn.functional as F
from utils.metrics import topk_corrects
import torch
from torch.autograd import grad
import numpy as np

from core.utils import gather_flat_grad,neumann_hyperstep_preconditioner,logit_adjust_ly,loss_adjust_cross_entropy,cross_entropy,assign_hyper_gradient

def apply_policy(augment_list,p,u,pi):
    u=F.gumbel_softmax(pi)
    i=torch.argmax(u)

def train_epoch_DA(cur_epoch, model, in_loader, in_criterion , in_optimizer, in_logit_adjust=None, in_params=None,
    is_out=False, out_loader=None, out_optimizer=None, out_criterion=None, out_logit_adjust=None, out_params=None,out_posthoc=False,
    ITER_LR=None, ARCH_EPOCH=0,num_classes=10,ARCH_INTERVAL=1,ARCH_TRAIN_SAMPLE=1,ARCH_VAL_SAMPLE=1,group_size=1,agumentation_list=None):
    """Performs one epoch of bilevel optimization."""
    # Enable training mode
    model.train()
    if is_out:
        print('lr: ',in_optimizer.param_groups[0]['lr'],'  arch lr: ',out_optimizer.param_groups[0]['lr'])
        out_iter = iter(out_loader)
        in_iter_alt=iter(in_loader)
    else:
        print('lr: ',in_optimizer.param_groups[0]['lr'])
        
    total_correct=0.
    total_sample=0.
    total_loss=0.
    num_weights, num_hypers = sum(p.numel() for p in model.parameters()), 3*num_classes
    use_reg=True
    d_train_loss_d_w = torch.zeros(num_weights).cuda()

    for cur_iter, (in_data, in_targets) in enumerate(in_loader):
        #print(cur_iter)

        # Transfer the data to the current GPU device
        in_data, in_targets = in_data.cuda(non_blocking=True), in_targets.cuda(non_blocking=True)
        agumentation_list.apply_transform(in_optimizer[3],in_optimizer[4])
        # Update architecture
        if is_out and not out_posthoc:# and cur_epoch>=ARCH_EPOCH:
            model.train()
            out_optimizer.zero_grad()

            if cur_iter%ARCH_INTERVAL==0:
                for _ in range(ARCH_TRAIN_SAMPLE):
                    try:
                        in_data_alt, in_targets_alt = next(in_iter_alt)
                    except StopIteration:
                        in_iter_alt = iter(in_loader)
                        in_data_alt, in_targets_alt = next(in_iter_alt) 
                    in_data_alt, in_targets_alt = in_data_alt.cuda(non_blocking=True), in_targets_alt.cuda(non_blocking=True)
                    in_optimizer.zero_grad()
                    in_preds=model(in_data_alt)
                    in_loss=in_criterion(in_preds,in_targets_alt,in_params)
                    d_train_loss_d_w+=gather_flat_grad(grad(in_loss,model.parameters(),create_graph=True))
                    #print(cur_iter_alt)
                d_train_loss_d_w/=ARCH_TRAIN_SAMPLE
                d_val_loss_d_theta, direct_grad = torch.zeros(num_weights).cuda(), torch.zeros(num_hypers).cuda()

                for _ in range(ARCH_VAL_SAMPLE):
                    try:
                        out_data, out_targets = next(out_iter)
                    except StopIteration:
                        out_iter = iter(out_loader)
                        out_data, out_targets = next(out_iter) 
                #for _,(out_data,out_targets) in enumerate(out_loader):
                    out_data, out_targets = out_data.cuda(non_blocking=True), out_targets.cuda(non_blocking=True)
                    model.zero_grad()
                    in_optimizer.zero_grad()
                    out_preds = model(out_data)
                    out_loss = out_criterion(out_preds,out_targets,out_params)
                    d_val_loss_d_theta += gather_flat_grad(grad(out_loss, model.parameters(), retain_graph=use_reg))
                    # if use_reg:
                    #     direct_grad+=gather_flat_grad(grad(out_loss, get_trainable_hyper_params(out_params), allow_unused=True))
                    #     direct_grad[direct_grad != direct_grad] = 0
                d_val_loss_d_theta/=ARCH_VAL_SAMPLE
                direct_grad/=ARCH_VAL_SAMPLE
                preconditioner = d_val_loss_d_theta
                
                preconditioner = neumann_hyperstep_preconditioner(d_val_loss_d_theta, d_train_loss_d_w, 1.0,
                                                                5, model)
                indirect_grad = gather_flat_grad(
                    grad(d_train_loss_d_w, get_trainable_hyper_params(out_params), grad_outputs=preconditioner.view(-1),allow_unused=True))
                hyper_grad=indirect_grad#+direct_grad
                out_optimizer.zero_grad()
                assign_hyper_gradient(out_params,-hyper_grad,num_classes)
                out_optimizer.step()
                d_train_loss_d_w = torch.zeros(num_weights).cuda()
        
        if is_out and out_posthoc:
            try:
                out_data, out_targets = next(out_iter)
            except StopIteration:
                out_iter = iter(out_loader)
                out_data, out_targets = next(out_iter) 
            out_data, out_targets = out_data.cuda(non_blocking=True), out_targets.cuda(non_blocking=True)
            out_preds=model(out_data)
            out_preds=out_logit_adjust(out_preds,params=out_params)
            out_loss=out_criterion(out_preds,out_targets,out_params)
            out_optimizer.zero_grad()
            out_loss.backward()
            out_optimizer.step()


        # Perform the forward pass
        in_preds = model(in_data)
        if not in_logit_adjust is None:
            in_preds=in_logit_adjust(in_preds,in_params)
        # Compute the loss
        loss = in_criterion(in_preds, in_targets, in_params)
        # Perform the backward pass
        in_optimizer.zero_grad()
        loss.backward()
        # torch.nn.utils.clip_grad_norm(model.parameters(), 5.0)
        in_optimizer.step()

        # Compute the errors
        mb_size = in_data.size(0)
        ks = [1] 
        top1_correct = topk_corrects(in_preds, in_targets, ks)[0]
        
        # Copy the stats from GPU to CPU (sync point)
        loss = loss.item()
        top1_correct = top1_correct.item()
        total_correct+=top1_correct
        total_sample+=mb_size
        total_loss+=loss*mb_size
    # Log epoch stats
    print(f'Epoch {cur_epoch} :  Loss = {total_loss/total_sample}   ACC = {total_correct/total_sample*100.}')
