import copy

import torch
import sys
import numpy as np
from collections import OrderedDict
sys.path.append('..')
from utils import Averager, clip_perturbed_image
from loss import *


def offline_tta(args, model, adapt, train_loader, test_loader, device=None):
    # black-box model
    model.eval()
    for param in model.parameters():
        param.requires_grad_(False)

    # optimizer
    parameters = adapt.parameters()
    if args.optim == 'SGD':
        optimizer = torch.optim.SGD(parameters, lr=args.lr, momentum=0.9, weight_decay=1e-5)
    else:
        optimizer = torch.optim.Adam(parameters, args.lr, weight_decay=1e-5)

    # torch.autograd.set_detect_anomaly(True)
    accs = []
    pl = torch.zeros((len(train_loader.dataset),), dtype=torch.long).to(device)
    ce_loss = nn.CrossEntropyLoss()
    for epoch in range(0, args.steps + 1):
        average = Averager()
        ave_norm = Averager()
        if epoch > 0:

            for i_bat, data_bat in enumerate(train_loader):
                x, y, idx = (data_bat[0].to(device), data_bat[1].to(device), data_bat[2].to(device))
                optimizer.zero_grad()

                adapt.train()
                model.eval()

                # generate perturbation
                delta = adapt(x)
                x_tilda = x + args.ad_scale * delta
                x_tilda = clip_perturbed_image(x, x_tilda)

                if not args.zo:
                    logits_tilda = model(x_tilda)
                    loss_tr = ce_loss(logits_tilda, pl[idx])
                    # delta norm regularizer
                    delta_norm = torch.linalg.norm(delta, ord=1, dim=(-2, -1)).mean()
                    loss = loss_tr + args.wdelta * delta_norm
                    loss.backward()
                else:
                    # ZO gradient estimate
                    channel = x_tilda.size()[1]
                    h = x_tilda.size()[2]
                    w = x_tilda.size()[3]
                    x_temp = x_tilda.detach()

                    # delta norm regularizer
                    delta_norm = torch.linalg.norm(delta, ord=1, dim=(-2, -1)).mean()
                    del delta

                    with torch.no_grad():
                        mu = torch.tensor(args.mu).to(device)
                        q = torch.tensor(args.q).to(device)

                        # Forward Inference (Original)
                        recon_pre = model(x_temp)
                        loss_0 = ce_loss(recon_pre, pl[idx])

                        # ZO Gradient Estimation
                        loss_tmps = []
                        original_parameter = adapt.state_dict()
                        u = OrderedDict()
                        grad = OrderedDict()
                        for name in original_parameter.keys():
                            u[name] = torch.zeros(original_parameter[name].size()).to(device)
                            grad[name] = torch.zeros(original_parameter[name].size()).to(device)
                        for i in range(args.q):
                            for name in u.keys():
                                u[name] = torch.normal(0, args.sigma, size=u[name].size()).to(device)
                                original_parameter[name] += mu * u[name]
                            adapt.load_state_dict(original_parameter)

                            delta_temp = adapt(x)
                            x_tilda = x + args.ad_scale * delta_temp
                            x_tilda = clip_perturbed_image(x, x_tilda)
                            x_temp = x_tilda.detach()
                            recon_q_pre = model(x_temp)

                            # Loss Calculation and Gradient Estimation
                            loss_tmp = ce_loss(recon_q_pre, pl[idx])
                            loss_diff = torch.tensor(loss_tmp - loss_0)
                            for name in u.keys():
                                grad[name] += loss_diff / (mu * q) * u[name]
                                original_parameter[name] -= mu * u[name]
                            loss_tmps.append(loss_tmp.detach().cpu().mean())

                        # return parameter
                        adapt.load_state_dict(original_parameter)
                        for name, p in adapt.named_parameters():
                            p.grad = grad[name]

                if args.zo: loss_record = torch.mean(torch.tensor(loss_tmps))
                else: loss_record = loss
                # print(f"epoch {epoch}: loss_record {loss_record:.4f}, delta_norm {delta_norm:.4f}", flush=True)

                optimizer.step()
                average.update(loss_record)
                ave_norm.update(delta_norm)

        if epoch % args.eval_interval == 0:
            avgr = Averager()
            adapt.eval()
            model.eval()
            with torch.no_grad():
                for x, y, idx in train_loader:
                    x, y, idx = x.to(device), y.to(device), idx.to(device)
                    if epoch > 0:
                        delta = adapt(x)
                        x_tilda = x + args.ad_scale * delta
                        x_tilda = clip_perturbed_image(x, x_tilda)
                    else:
                        x_tilda = x
                    logits = model(x_tilda)
                    ypred = logits.argmax(dim=-1)
                    if epoch == 0:
                        pl[idx] = ypred
                    avgr.update((ypred == y).float().mean().item(), nrep=len(y))
            acc = avgr.avg
            accs.append(acc)
            print(f"epoch {epoch:.1f}, loss = {average.avg}, acc = {acc:3f}, delta norm = {ave_norm.avg:3f}.", flush=True)

    return accs