from __future__ import print_function
import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.optim as optim
from torchvision import datasets, transforms

import time
import logging
import utils
import numpy as np


softmax = nn.Softmax(dim=1).cuda()
def saliency_bbox(img, lam):
    size = img.size()
    W = size[1]
    H = size[2]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    # initialize OpenCV's static fine grained saliency detector and
    # compute the saliency map
    temp_img = img.cpu().numpy().transpose(1, 2, 0)
    saliency = cv2.saliency.StaticSaliencyFineGrained_create()
    (success, saliencyMap) = saliency.computeSaliency(temp_img)
    saliencyMap = (saliencyMap * 255).astype("uint8")

    maximum_indices = np.unravel_index(np.argmax(saliencyMap, axis=None), saliencyMap.shape)
    x = maximum_indices[0]
    y = maximum_indices[1]

    bbx1 = np.clip(x - cut_w // 2, 0, W)
    bby1 = np.clip(y - cut_h // 2, 0, H)
    bbx2 = np.clip(x + cut_w // 2, 0, W)
    bby2 = np.clip(y + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

class AverageMeter(object):
    """Computes and stores the average and current value
       Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
    """
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()

        # calculate robust loss
        loss = F.cross_entropy(model(data), target)
        loss.backward()
        optimizer.step()

        # print progress
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()))
            
def train_for_one_epoch(model, loss, train_loader, optimizer, epoch_number):
    model.train()
    loss.train()

    batch_time_meter = utils.AverageMeter()
    loss_meter = utils.AverageMeter(recent=100)
    top1_meter = utils.AverageMeter(recent=100)
    top5_meter = utils.AverageMeter(recent=100)

    timestamp = time.time()
    for i, (images, labels) in enumerate(train_loader):
        batch_size = images.size(0)

        if utils.is_model_cuda(model):
            images = images.cuda()
            labels = labels.cuda()

        # Forward pass, backward pass, and update parameters.
        outputs = model(images)
        loss_output = loss(outputs, labels)

        # Sometimes loss function returns a modified version of the output,
        # which must be used to compute the model accuracy.
        if isinstance(loss_output, tuple):
            loss_value, outputs = loss_output
        else:
            loss_value = loss_output

        loss_value.backward()

        # Update parameters and reset gradients.
        optimizer.step()
        optimizer.zero_grad()

        # Record loss and model accuracy.
        loss_meter.update(loss_value.item(), batch_size)
        top1, top5 = utils.topk_accuracy(outputs, labels, recalls=(1, 5))
        top1_meter.update(top1, batch_size)
        top5_meter.update(top5, batch_size)

        # Record batch time
        batch_time_meter.update(time.time() - timestamp)
        timestamp = time.time()

    logging.info(
        'Epoch: [{epoch}] ---- TRAINING SUMMARY\t'
        'Time {batch_time.sum:.2f}   '
        'Loss {loss.average:.3f}     '
        'Top-1 {top1.average:.2f}    '
        'Top-5 {top5.average:.2f}    '.format(
            epoch=epoch_number, batch_time=batch_time_meter,
            loss=loss_meter, top1=top1_meter, top5=top5_meter))

def train_for_one_epoch_with_l1(model, loss, train_loader, optimizer, epoch_number, l1_lambda):
    model.train()
    loss.train()

    batch_time_meter = utils.AverageMeter()
    loss_meter = utils.AverageMeter(recent=100)
    top1_meter = utils.AverageMeter(recent=100)
    top5_meter = utils.AverageMeter(recent=100)

    timestamp = time.time()
    for i, (images, labels) in enumerate(train_loader):
        batch_size = images.size(0)

        if utils.is_model_cuda(model):
            images = images.cuda()
            labels = labels.cuda()

        # Forward pass, backward pass, and update parameters.
        outputs = model(images)
        loss_output = loss(outputs, labels)

        # Sometimes loss function returns a modified version of the output,
        # which must be used to compute the model accuracy.
        if isinstance(loss_output, tuple):
            loss_value, outputs = loss_output
        else:
            loss_value = loss_output

        l1_norm = sum(p.abs().sum() for p in model.parameters())
        loss_value = loss_value + l1_lambda * l1_norm

        loss_value.backward()

        # Update parameters and reset gradients.
        optimizer.step()
        optimizer.zero_grad()

        # Record loss and model accuracy.
        loss_meter.update(loss_value.item(), batch_size)
        top1, top5 = utils.topk_accuracy(outputs, labels, recalls=(1, 5))
        top1_meter.update(top1, batch_size)
        top5_meter.update(top5, batch_size)

        # Record batch time
        batch_time_meter.update(time.time() - timestamp)
        timestamp = time.time()

    logging.info(
        'Epoch: [{epoch}] ---- TRAINING SUMMARY\t'
        'Time {batch_time.sum:.2f}   '
        'Loss {loss.average:.3f}     '
        'Top-1 {top1.average:.2f}    '
        'Top-5 {top5.average:.2f}    '.format(
            epoch=epoch_number, batch_time=batch_time_meter,
            loss=loss_meter, top1=top1_meter, top5=top5_meter))
   
  
def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

 
def test_for_one_epoch(model, loss, test_loader, epoch_number):
    model.eval()
    loss.eval()

    data_time_meter = utils.AverageMeter()
    batch_time_meter = utils.AverageMeter()
    loss_meter = utils.AverageMeter(recent=100)
    top1_meter = utils.AverageMeter(recent=100)
    top5_meter = utils.AverageMeter(recent=100)

    timestamp = time.time()
    for i, (images, labels) in enumerate(test_loader):
        batch_size = images.size(0)

        if utils.is_model_cuda(model):
            images = images.cuda()
            labels = labels.cuda()

        # Record data time
        data_time_meter.update(time.time() - timestamp)

        # Forward pass without computing gradients.
        with torch.no_grad():
            outputs = model(images)
            loss_output = loss(outputs, labels)

        # Sometimes loss function returns a modified version of the output,
        # which must be used to compute the model accuracy.
        if isinstance(loss_output, tuple):
            loss_value, outputs = loss_output
        else:
            loss_value = loss_output

        # Record loss and model accuracy.
        loss_meter.update(loss_value.item(), batch_size)
        top1, top5 = utils.topk_accuracy(outputs, labels, recalls=(1, 5))
        top1_meter.update(top1, batch_size)
        top5_meter.update(top5, batch_size)

        # Record batch time
        batch_time_meter.update(time.time() - timestamp)
        timestamp = time.time()

    logging.info(
        'Epoch: [{epoch}] -- TESTING SUMMARY\t'
        'Time {batch_time.sum:.2f}   '
        'Data {data_time.sum:.2f}   '
        'Loss {loss.average:.3f}     '
        'Top-1 {top1.average:.2f}    '
        'Top-5 {top5.average:.2f}    '.format(
            epoch=epoch_number, batch_time=batch_time_meter, data_time=data_time_meter,
            loss=loss_meter, top1=top1_meter, top5=top5_meter))
    return top1_meter.average


def eval_train(model, device, train_loader):
    model.eval()
    train_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in train_loader:
            data, target = data.cuda(), target.cuda()
            output = model(data)
            train_loss += F.cross_entropy(output, target, size_average=False).item()
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
    train_loss /= len(train_loader.dataset)
    print('Training: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
        train_loss, correct, len(train_loader.dataset),
        100. * correct / len(train_loader.dataset)))
    training_accuracy = correct / len(train_loader.dataset)
    return train_loss, training_accuracy


def eval_test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, size_average=False).item()
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)
    print('Test: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    test_accuracy = correct / len(test_loader.dataset)
    return test_loss, test_accuracy



# adverasrial training
class LinfPGDAttack(object):
    def __init__(self, model):
        self.model = model
        self.epsilon = 0.0314
        self.alpha = 0.00784
        self.k=7

    def perturb(self, x_natural, y):
        epsilon, alpha, k = self.epsilon, self.alpha, self.k
        x = x_natural.detach()
        x = x + torch.zeros_like(x).uniform_(-epsilon, epsilon)
        for i in range(k):
            x.requires_grad_()
            with torch.enable_grad():
                logits = self.model(x)
                loss = F.cross_entropy(logits, y)
            grad = torch.autograd.grad(loss, [x])[0]
            x = x.detach() + alpha * torch.sign(grad.detach())
            x = torch.min(torch.max(x, x_natural - epsilon), x_natural + epsilon)
            x = torch.clamp(x, 0, 1)
        return x


def train(train_loader,model,criterion,optimizer,epoch,use_cuda,device=torch.device('cuda'),num_batchs=999999,debug_='MEDIUM',batch_size=32, uniform_reg=False):
    # switch to train mode
    model.train()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    
    
    end = time.time()

    for ind, (inputs,targets) in enumerate(train_loader):

        # measure data loading time
        data_time.update(time.time() - end)

        if use_cuda:
            inputs, targets = inputs.to(device), targets.to(device)
        inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

        # compute output
        try:
            outputs,_,_ = model(inputs)
        except:
            try:
                outputs,_ = model(inputs)
            except:
                outputs = model(inputs)
        uniform_=torch.ones(len(outputs))/len(outputs)
        
        if uniform_reg==True:
            loss = criterion(outputs, targets) + F.kl_div(uniform_,outputs)
        else:
            loss = criterion(outputs, targets)

        # measure accuracy and record loss
        prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
        losses.update(loss.item(), inputs.size(0))
        top1.update(prec1.item(), inputs.size(0))
        top5.update(prec5.item(), inputs.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # plot progress
        if False and debug_=='HIGH' and ind%100==0:
            print  ('Classifier: ({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format(
                    batch=ind + 1,
                    size=len_t,
                    data=data_time.avg,
                    bt=batch_time.avg,
                    loss=losses.avg,
                    top1=top1.avg,
                    top5=top5.avg,
                    ))

    return (losses.avg, top1.avg)


def test(test_loader,model,criterion,use_cuda,device=torch.device('cuda'), debug_='MEDIUM',batch_size=64, isAdvReg=0):
    if hasattr(model,'config'):
        batch_size=model.config.structure.bsize
        
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time() 
    

    data_time.update(time.time() - end)
 
    total = 0
    for ind,(inputs,targets) in enumerate(test_loader):
        inputs =  inputs.to(device)
        targets = targets.to(device)
        total += len(inputs) 
        # compute output
        # compute output
        outputs = model(inputs)

        if(type(outputs)==tuple):
            outputs = outputs[0]
        
        loss = criterion(outputs, targets) 
        # measure accuracy and record loss
        prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
        losses.update(loss.item(), inputs.size(0))
        top1.update(prec1.item(), inputs.size(0))
        top5.update(prec5.item(), inputs.size(0))
    return (losses.avg, top1.avg)

def advtune_defense( model, attack_model, private_dataloader, test_dataloader, ref_dataloader, private_dataloader_origin, scheduler, optimizer,criterion, attack_optimizer,attack_criterion,config=None, save_path=None, num_epochs=50, use_cuda=True,batch_size=64,alpha=0,lr=0.0005,schedule=[25,80],gamma=0.1,tr_epochs=100,at_lr=0.0001,at_schedule=[100],at_gamma=0.5,at_epochs=200,n_classes=10):

    global elapsed_time
    ############################################################ private training ############################################################
    print('Training using adversarial tuning...')
    best_acc=0
    best_test_acc=0
    for epoch in range(num_epochs):
        start_time = time.time()

        # decay the lr at certain epoches in schedule
        # adjust_learning_rate(optimizer, epoch) 
        if scheduler is not None:
            scheduler.step()

        c_batches = len(private_dataloader)
        if epoch == 0:
            print('----> NORMAL TRAINING MODE: c_batches %d '%(c_batches), flush=True)

            train_loss, train_acc = train(private_dataloader,
                                              model,criterion,optimizer,epoch,use_cuda,debug_='MEDIUM')    
            test_loss, test_acc = test(test_dataloader,model,criterion,use_cuda, batch_size=batch_size)    
            for i in range(5):
                at_loss, at_acc = train_attack(private_dataloader_origin,ref_dataloader,model,attack_model,criterion,
                                               attack_criterion,optimizer,attack_optimizer,epoch,use_cuda, batch_size=batch_size,debug_='MEDIUM')    

            print('Initial test acc {} train att acc {}'.format(test_acc, at_acc), flush=True)

        else:
            
            # for e_num in schedule:
            #     if e_num==epoch:
            #         for param_group in optimizer.param_groups:
            #             param_group['lr'] *= gamma
            #             print('Epoch %d lr %f'%(epoch,param_group['lr']))

            att_accs =[]


            rounds=(c_batches//2)

            for i in range(rounds):


                at_loss, at_acc = train_attack(private_dataloader_origin, ref_dataloader,
                                               model,attack_model,criterion,attack_criterion,optimizer,
                                               attack_optimizer,epoch,use_cuda,52//2,None,batch_size=batch_size)

                att_accs.append(at_acc)

                tr_loss, tr_acc = train_privatly(private_dataloader,model,
                                                 attack_model,criterion,optimizer,epoch,use_cuda,
                                                 2,None,alpha=alpha,batch_size=batch_size)

            train_loss,train_acc = test(private_dataloader_origin,model,criterion,use_cuda)
            val_loss, val_acc = test(test_dataloader,model,criterion,use_cuda)
            is_best = (val_acc > best_acc)

            best_acc=max(val_acc, best_acc)

            at_val_loss, at_val_acc = test_attack(private_dataloader_origin,ref_dataloader,
                                                     model,attack_model,criterion,attack_criterion,
                                                     optimizer,attack_optimizer,epoch,use_cuda,debug_='MEDIUM')
            
            att_epoch_acc = np.mean(att_accs)
            
            # save_checkpoint_global(
            #    {
            #        'epoch': epoch,
            #        'state_dict': model.model.state_dict(),
            #        'acc': val_acc,
            #        'best_acc': best_acc,
            #        'optimizer': optimizer.state_dict(),
            #    },
            #    is_best,
            #    checkpoint=checkpoint_dir,
            #    filename='protected_model-%s.pth.tar'%args.model_save_tag,
            #    best_filename=f'{args.model_save_tag}-trainSize-{args.train_size}.pth.tar',
            # )
            if epoch % config.save_freq == 0:
                print(save_path)
                torch.save(model.model.state_dict(),save_path)
          
            print('epoch %d | tr_acc %.2f | val acc %.2f | best val acc %.2f | best te acc %.2f | attack avg acc %.2f | attack val acc %.2f'%(epoch,train_acc,val_acc,best_acc,best_test_acc,att_epoch_acc,at_val_acc), flush=True)


        epoch_time = time.time() - start_time
        # elapsed_time += epoch_time
        # print('| Elapsed time : %d hr, %02d min, %02d sec'  %(get_hms(elapsed_time)))
    ############################################################ private training ############################################################
def train_privatly(train_loader,model,inference_model,criterion,optimizer,epoch,use_cuda,
                   num_batchs=10000,skip_batch=0,alpha=0.5,verbose=False,batch_size=16,loss_fun='mean'):
    model.train()
    inference_model.eval()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    end = time.time()

    attack_criterion = nn.MSELoss()

    for ind,(inputs,targets) in enumerate(train_loader):
        if ind >= num_batchs:
            break
        data_time.update(time.time() - end)

        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

        # compute output
        outputs,h_layer = model(inputs)

        one_hot_tr = torch.from_numpy((np.zeros((outputs.size(0),outputs.size(1))))).cuda().type(torch.cuda.FloatTensor)
        target_one_hot_tr = one_hot_tr.scatter_(1, targets.type(torch.cuda.LongTensor).view([-1,1]).data,1)

        infer_input_one_hot = torch.autograd.Variable(target_one_hot_tr)
        
        inference_output = inference_model(outputs,h_layer,infer_input_one_hot)
        att_labels = np.ones((inputs.size(0)))
        is_member_labels = torch.from_numpy(att_labels).type(torch.FloatTensor)

        if use_cuda:
            is_member_labels = is_member_labels.cuda()

        v_is_member_labels = torch.autograd.Variable(is_member_labels)

        loss = criterion(outputs, targets) +(alpha*(inference_output.mean()-0.5))
        
        # measure accuracy and record loss
        prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
        losses.update(loss.item(), inputs.size(0))
        top1.update(prec1.item(), inputs.size(0))
        top5.update(prec5.item(), inputs.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # plot progress
        if False and verbose and ind%100==0:
            print  (alpha, '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format(
                    batch=ind + 1,
                    size=len_t,
                    data=data_time.avg,
                    bt=batch_time.avg,
                    loss=losses.avg,
                    top1=top1.avg,
                    top5=top5.avg,
                    ))

    return (losses.avg, top1.avg)


class InferenceAttack_HZ(nn.Module):
    def __init__(self,num_classes):
        self.num_classes=num_classes
        super(InferenceAttack_HZ, self).__init__()
        self.features=nn.Sequential(
            nn.Linear(num_classes,1024),
            nn.ReLU(),
            nn.Linear(1024,512),
            nn.ReLU(),
            nn.Linear(512,64),
            nn.ReLU(),
            )

        self.labels=nn.Sequential(
           nn.Linear(num_classes,128),
            nn.ReLU(),
            nn.Linear(128,64),
            nn.ReLU(),
            )
        self.combine=nn.Sequential(
            nn.Linear(64*2,512),
            
            nn.ReLU(),
            nn.Linear(512,256),
            
            nn.ReLU(),
            nn.Linear(256,128),
            nn.ReLU(),
            nn.Linear(128,64),
            nn.ReLU(),
            nn.Linear(64,1),
            )
        for key in self.state_dict():
            # print (key)
            if key.split('.')[-1] == 'weight':    
                nn.init.normal_(self.state_dict()[key], std=0.01)
                
            elif key.split('.')[-1] == 'bias':
                self.state_dict()[key][...] = 0
        self.output= nn.Sigmoid()
    def forward(self,x1,x2,l):

        out_x1 = self.features(x1)
        
        out_l = self.labels(l)

        is_member =self.combine( torch.cat((out_x1,out_l),1))
        
        
        return self.output(is_member)

def train_attack(train_loader,ref_loader,model,attack_model,criterion,attack_criterion,optimizer,
                 attack_optimizer,epoch,use_cuda,num_batchs=100000,skip_batch=0,debug_='MEDIUM',batch_size=16):
    model.eval()
    attack_model.train()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    

    end = time.time()
    batch_size = batch_size//2

    for ind,((inputs,targets),(inputs_attack,targets_attack)) in enumerate(zip(train_loader,ref_loader)):

        if ind >= num_batchs:
            break
        # measure data loading time
        data_time.update(time.time() - end)

        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
            inputs_attack , targets_attack = inputs_attack.cuda(), targets_attack.cuda()

        inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)
        inputs_attack , targets_attack = torch.autograd.Variable(inputs_attack), torch.autograd.Variable(targets_attack)

        # compute output
        outputs, h_layer = model(inputs)
        outputs_non, h_layer_non = model(inputs_attack)

        comb_inputs_h = torch.cat((h_layer,h_layer_non))
        comb_inputs = torch.cat((outputs,outputs_non))

        attack_input = comb_inputs
        
        one_hot_tr = torch.from_numpy((np.zeros((attack_input.size(0),outputs.size(1))))).cuda().type(torch.cuda.FloatTensor)
        target_one_hot_tr=one_hot_tr.scatter_(1,torch.cat((targets,targets_attack)).type(torch.cuda.LongTensor).view([-1,1]).data,1)
        
        infer_input_one_hot = torch.autograd.Variable(target_one_hot_tr)

        attack_output = attack_model(attack_input,comb_inputs_h,infer_input_one_hot).view([-1])

        att_labels = np.zeros((inputs.size(0)+inputs_attack.size(0)))
        att_labels [:inputs.size(0)] =1.0
        att_labels [inputs.size(0):] =0.0
        is_member_labels = torch.from_numpy(att_labels).type(torch.FloatTensor)

        if use_cuda:
            is_member_labels = is_member_labels.cuda()

        v_is_member_labels = torch.autograd.Variable(is_member_labels)
        
        loss_attack = attack_criterion(attack_output, v_is_member_labels)
        
        prec1=np.mean(np.equal((attack_output.data.cpu().numpy() >0.5),(v_is_member_labels.data.cpu().numpy()> 0.5)))
        losses.update(loss_attack.item(), attack_input.size(0))
        top1.update(prec1, attack_input.size(0))
        
        #print ( attack_output.data.cpu().numpy(),v_is_member_labels.data.cpu().numpy() ,attack_input.data.cpu().numpy())

        # compute gradient and do SGD step
        attack_optimizer.zero_grad()
        loss_attack.backward()
        attack_optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # plot progress
        if False and debug_=='HIGH' and ind%100==0:
            print('Attack model: ({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | | Loss: {loss:.4f} | top1: {top1: .4f}'
                  .format(
                      batch=ind + 1,
                      size=len_t,
                      data=data_time.avg,
                      bt=batch_time.avg,
                      loss=losses.avg,
                      top1=top1.avg,
                  ))

    return (losses.avg, top1.avg)


def test_attack(private_loader,ref_loader,model,attack_model,criterion,attack_criterion,
                optimizer,attack_optimizer,epoch,use_cuda,batch_size=16,debug_='MEDIUM'):

    model.eval()
    attack_model.eval()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    
    
    end = time.time()

    for ind,((inputs,targets),(inputs_attack,targets_attack)) in enumerate(zip(private_loader,ref_loader)):
        # measure data loading time     
        data_time.update(time.time() - end)

        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
            inputs_attack , targets_attack = inputs_attack.cuda(), targets_attack.cuda()
        
        with torch.no_grad():
            inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)
            inputs_attack , targets_attack = torch.autograd.Variable(inputs_attack), torch.autograd.Variable(targets_attack)


        # compute output
        outputs,h_layer = model(inputs)
        outputs_non,h_layer_non = model(inputs_attack)
        

        comb_inputs_h = torch.cat((h_layer,h_layer_non))
        comb_inputs = torch.cat((outputs,outputs_non))

        attack_input = comb_inputs        
        
        one_hot_tr = torch.from_numpy((np.zeros((attack_input.size(0),outputs.size(1))))).cuda().type(torch.cuda.FloatTensor)
        target_one_hot_tr=one_hot_tr.scatter_(1,torch.cat((targets,targets_attack)).type(torch.cuda.LongTensor).view([-1,1]).data,1)

        infer_input_one_hot = torch.autograd.Variable(target_one_hot_tr)


        attack_output = attack_model(attack_input,comb_inputs_h,infer_input_one_hot).view([-1])

        att_labels = np.zeros((inputs.size(0)+inputs_attack.size(0)))
        att_labels [:inputs.size(0)] =1.0
        att_labels [inputs.size(0):] =0.0

        is_member_labels = torch.from_numpy(att_labels).type(torch.FloatTensor)

        if use_cuda:
            is_member_labels = is_member_labels.cuda()
        
        v_is_member_labels = torch.autograd.Variable(is_member_labels)

        loss = attack_criterion(attack_output, v_is_member_labels)

        # measure accuracy and record loss
        prec1=np.mean(np.equal((attack_output.data.cpu().numpy() >0.5),(v_is_member_labels.data.cpu().numpy()> 0.5)))
        losses.update(loss.item(), attack_input.size(0))
        top1.update(prec1, attack_input.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # plot progress
        if False and debug_=='HIGH' and ind%100==0:
            print('({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | | Loss: {loss:.4f} | top1: {top1: .4f} '
                  .format(
                      batch=ind + 1,
                      size=len_t,
                      data=data_time.avg,
                      bt=batch_time.avg,
                      loss=losses.avg,
                      top1=top1.avg,
                  ))

    return (losses.avg, top1.avg)

def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)
    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res