import os
import argparse
import random
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import json

import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

from models_ZSL import VGG19_Norm_Triplet_3Scale_LateFusion
from dataset import Dataset_MA_STEP3_3Scale
from utils import *

parser = argparse.ArgumentParser()
parser.add_argument('--gpu', default='0', type=str, help='index of GPU to use')
parser.add_argument('--dataset', default='CUB', type=str, help='which dataset to use: CUB,AWA1,AWA2,FLO')
# Hyper-Parameter
parser.add_argument('--cls_LAMBDA', default=1.0, type=float, help='softmax + cross_entropy loss')
parser.add_argument('--trp_LAMBDA', default=1.0, type=float, help='triplet embedding loss')
parser.add_argument('--BATCH_SIZE', default=16, type=int, help='batch size')
# parser.add_argument('--BATCH_SIZE_TEST', default=16, type=int, help='test batch size')
parser.add_argument('--margin', default=0.5, type=float, help='margin in the triplet loss')
# Optimization
parser.add_argument('--LEARNING_RATE', default=0.0005, type=float, help='base learning rate')
parser.add_argument('--MOMENTUM',      default=0.9,    type=float, help='base momentum')
parser.add_argument('--WEIGHT_DECAY',  default=0.0005, type=float, help='base weight decay')
# exp
parser.add_argument('--split',    default='PP', type=str, help='split mode; standard split or proposed split PP/SP')
parser.add_argument('--resume',      default=None, type=str, help='path of model to resume')
parser.add_argument('--manualSeed', type=int, help='if use random seed to fix result')
# display and log
parser.add_argument('--disp_interval', default=50, type=int, help='display interval')
parser.add_argument('--evl_interval',  default=1, type=int, help='Epoch zero-shot learning evl interval')
parser.add_argument('--save_interval', default=5, type=int, help='Epoch save model interval')

opt = parser.parse_args()
# set random seed
if opt.manualSeed is None:
    opt.manualSeed = random.randint(1, 10000)
print("Random Seed: ", opt.manualSeed)
np.random.seed(opt.manualSeed)
random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)
torch.cuda.manual_seed_all(opt.manualSeed)
cudnn.benchmark = True

os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu
print('Running parameters:')
print(json.dumps(vars(opt), indent=4, separators=(',', ':  ')))

# opt.resume = '/media/evl/Public/Ethan/MA_ZSL_CVPR19/models/Sisley/CUB_New_BBox_CLS1.0_TRP1.0_Margin0.5_BS32_LR0.0005_PP/ZSL_GSC_MA_MyModel_Epoch35.tar'

def train():
    LAMBDA_CLS = opt.cls_LAMBDA
    LAMBDA_TRP = opt.trp_LAMBDA
    lr = opt.LEARNING_RATE
    output_dir = 'models_ZSL'
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)
    output_dir = os.path.join(output_dir, opt.dataset)
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)
    output_dir = os.path.join(output_dir, 'CLS{}_TRP{}_Margin{}_BS{}_LR{}_{}'.format(opt.cls_LAMBDA,
                                                                                     opt.trp_LAMBDA,
                                                                                     opt.margin,
                                                                                     opt.BATCH_SIZE,
                                                                                     opt.LEARNING_RATE,
                                                                                     opt.split))
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)
    fout = output_dir + '/log_{}.txt'.format(strftime("%a, %d %b %Y %H:%M:%S", gmtime()))

    # load dataset
    dataset = Dataset_MA_STEP3_3Scale(opt)

    # compute beta, ||A_u - B*A_c||_22 + lamdba * ||B||_22
    beta = dict()
    for lambda_beta in [0.5, 0.1, 0.05]:
        Ac = dataset.train_semantic_feat
        Au = dataset.test_semantic_feat
        B = np.matmul(Au, Ac.T)
        A = np.matmul(Ac, Ac.T) + lambda_beta * np.eye(Ac.shape[0])
        _beta = (np.linalg.solve(A.T, B.T)).T
        beta[str(lambda_beta)] = _beta

    # load net
    net = VGG19_Norm_Triplet_3Scale_LateFusion(train_ncls=dataset.train_ncls, att_ndim=dataset.att_ndim)

    if opt.resume:
        resume_epoch = resume_model(opt, net, fout)

    if torch.cuda.device_count() > 1:
        log_print(fout, "Let's use {} GPUs!".format(torch.cuda.device_count()), color='red', attrs=['bold'])
        net = nn.DataParallel(net)
    net = net.cuda()

    criterion_cls = nn.CrossEntropyLoss().cuda()
    def criterion_embed_softmax(x_embed, z_embed_all_cls, labels):
        scores = x_embed.mm(z_embed_all_cls.transpose(1, 0))
        pred_lb = np.argmax(scores.data.cpu().numpy(), axis=1)
        accuracy = (pred_lb == labels.cpu().numpy()).mean()
        return criterion_cls(scores, labels), accuracy
    t = Timer()
    t.tic()

    # optimizer and scheduler
    optimizer = set_optimizer_LayerEqual(net, opt, lr)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True, cooldown=5,
                                  min_lr=0.0005)

    train_semantic_feat_all = torch.from_numpy(dataset.train_semantic_feat).cuda()

    """ I first evaluate to check the initial results.
    """
    net.eval()
    log_print(fout, 'Evaluating...')
    t.tic()
    results = zsl_evaluation_CE(net, dataset, beta, cosine=True)
    duration = t.toc(average=False)
    log_print(fout, 'train acc {:.2f}, Time cost: {:3d} s'.format(0.0, int(duration)))

    for key in ['all', 'org', 'body', 'part']:
        result_diff_feature = results[key]
        log_print(fout, key)
        for _key in result_diff_feature.keys():
            result = result_diff_feature[_key]
            log_print(fout,
                      '{:.2f} [UA, LA_Agents, LA_Mean, CO_Agents, CO_Mean][{:.2f} || {:.2f} || {:.2f} || {:.2f} || {:.2f}] Lambda={}'. \
                      format(_key, result[0]*100, result[1]*100, result[2]*100, result[3]*100, result[4]*100, result[5]))
    t.tic()
    net.train()

    # set the mode to train
    net.train()
    for epoch in range(resume_epoch if opt.resume else 1, 200+1):  # 400
        log_print(fout, time2str())
        train_acc_epoch = []
        # training
        train_cls_loss, train_trp_loss, train_acc= [], [], []
        train_pos_dist, train_neg_dist = [], []
        step_cnt = 0
        for i_batch, sample_batched in enumerate(dataset.dataloader_train):
            # get one batch
            im_org = sample_batched['im_org']
            im_body = sample_batched['im_body']
            im_part = sample_batched['im_part1']
            labels = sample_batched['labels']

            visual_embed, x_cls_normalized_org, x_cls_normalized_body, x_cls_normalized_part, \
            class_agents_normalized_org, class_agents_normalized_body, class_agents_normalized_part, \
            dist_mat_org, dist_mat_body, dist_mat_part = net(im_org.cuda(), im_body.cuda(), im_part.cuda())

            triplet_loss_org, pos_dist_org, neg_dist_org = General_Triplet_loss(dist_mat_org, labels.cuda())
            triplet_loss_body, pos_dist_body, neg_dist_body = General_Triplet_loss(dist_mat_body, labels.cuda())
            triplet_loss_part, pos_dist_part, neg_dist_part = General_Triplet_loss(dist_mat_part, labels.cuda())


            triplet_loss = (triplet_loss_org + triplet_loss_body + triplet_loss_part)/3
            pos_dist = (pos_dist_org + pos_dist_body + pos_dist_part)/3
            neg_dist = (neg_dist_org + neg_dist_body + neg_dist_part)/3

            cls_loss, accuracy = criterion_embed_softmax(visual_embed, train_semantic_feat_all, labels.cuda())

            train_pos_dist += [pos_dist.item()]
            train_neg_dist += [neg_dist.item()]

            cls_loss = LAMBDA_CLS*cls_loss
            trp_loss = LAMBDA_TRP*triplet_loss
            train_cls_loss += [cls_loss.item()]
            train_trp_loss += [trp_loss.item()]
            train_acc += [accuracy]
            train_acc_epoch += [accuracy]
            step_cnt += 1

            # backward
            optimizer.zero_grad()
            T_loss = cls_loss + trp_loss
            T_loss.backward()
            nn.utils.clip_grad_norm_(net.parameters(), 10.)
            optimizer.step()

            if (i_batch % opt.disp_interval == 0 or (i_batch == len(dataset.dataloader_train)-1)) and i_batch:
                duration = t.toc(average=False)
                inv_fps = duration / step_cnt
                log_text = ('Epoch:{:2d} [{:3d}/{:3d}], loss: [{:.4f}|{:.4f}|{:.4f}]  acc_cls: {:.2f}'
                            ' Pos/Neg dist [{:.4f}||{:.4f}] (lr: {}, {:.2f}s per iteration)').format(
                    epoch, i_batch, len(dataset.dataloader_train),
                    np.asarray(train_cls_loss).mean() + np.asarray(train_trp_loss).mean(),
                    np.asarray(train_cls_loss).mean(), np.asarray(train_trp_loss).mean(),
                    np.asarray(train_acc).mean()*100,
                    np.asarray(train_pos_dist).mean(), np.asarray(train_neg_dist).mean(),
                    scheduler.optimizer.param_groups[0]['lr'],  inv_fps)
                log_print(fout, log_text, color='green', attrs=['bold'])

                # reset the counter
                train_cls_loss, train_trp_loss, train_acc = [], [], []
                train_pos_dist, train_neg_dist = [], []
                step_cnt = 0
                t.tic()
        
        train_acc_epoch = np.asarray(train_acc_epoch).mean()
        if (epoch % opt.evl_interval == 0) and epoch >= 0:
            net.eval()
            log_print(fout, 'Evaluating...')
            t.tic()
            results = zsl_evaluation_CE(net, dataset, beta, cosine=True)
            duration = t.toc(average=False)
            log_print(fout, 'train acc {:.2f}, Time cost: {:3d} s'.format(train_acc_epoch*100, int(duration)))

            for key in ['all', 'org', 'body', 'part']:
                result_diff_feature = results[key]
                log_print(fout, key)
                for _key in result_diff_feature.keys():
                    result = result_diff_feature[_key]
                    log_print(fout,
                      _key + ' [UA, LA_Agents, LA_Mean, CO_Agents, CO_Mean][{:.2f} || {:.2f} || {:.2f} || {:.2f} || {:.2f}] Lambda={}'. \
                      format(result[0] * 100, result[1] * 100, result[2] * 100, result[3] * 100,
                             result[4] * 100, result[5]))
            t.tic()
            net.train()

        if (epoch % opt.save_interval == 0) and epoch:
            save_name = os.path.join(output_dir, 'ZSL_Epoch{}.tar'.format(epoch))
            net2save = net.module if torch.cuda.device_count() > 1 else net
            torch.save({
                'epoch': epoch + 1,
                'state_dict': net2save.state_dict(),
                'optimizer': optimizer.state_dict(),
                'log':   log_text
            }, save_name)
            log_print(fout, 'save model: {}'.format(save_name))
        scheduler.step(T_loss)

def resume_model(opt, net, fout):
    if os.path.isfile(opt.resume):
        log_print(fout, "=> loading checkpoint '{}'".format(opt.resume), color='blue', attrs=['bold'])
        checkpoint = torch.load(opt.resume)
        net.load_state_dict(checkpoint['state_dict'])
        # optimizer_LE.load_state_dict(checkpoint['optimizer'])
        log_print(fout, "Resume Epoch log: {}".format(checkpoint['log']), color='blue', attrs=['bold'])
        log_print(fout, "=> loaded checkpoint '{}' (epoch {})".format(opt.resume, checkpoint['epoch']),
                  color='blue', attrs=['bold'])
    else:
        print("=> no checkpoint found at '{}'".format(opt.resume))
        raise ValueError('End')
    return checkpoint['epoch']

def General_Triplet_loss(dist_mat, labels):
    num, dim = dist_mat.shape
    positive_dist = torch.gather(dist_mat, 1, labels.unsqueeze(1))
    margin_mat = torch.ones((num, dim)).cuda() * opt.margin
    for i, label in enumerate(labels):
        margin_mat[i, label] = 0.0
    diff_mat = F.relu_(margin_mat + positive_dist.repeat((1, dim)) - dist_mat)
    hardest_index_data = torch.argmax(diff_mat, dim=1)

    hardest_diff = torch.gather(diff_mat, 1, hardest_index_data.unsqueeze(1))
    return torch.mean(hardest_diff), torch.mean(positive_dist), (torch.sum(dist_mat)-torch.sum(positive_dist))/(num*(dim-1))


def set_optimizer_LayerEqual(net, opt, lr):
    if torch.cuda.device_count() > 1:
        net = net.module
    optimizer = torch.optim.SGD([
        {'params': net.vgg19_org.parameters()},
        {'params': net.vgg19_body.parameters()},
        {'params': net.vgg19_part.parameters()},
        {'params': net.FC_emb_org.parameters()},
        {'params': net.FC_cls_org.parameters()},
        {'params': net.FC_emb_body.parameters()},
        {'params': net.FC_cls_body.parameters()},
        {'params': net.FC_emb_part.parameters()},
        {'params': net.FC_cls_part.parameters()},
        {'params': net.class_agents_org, 'weight_decay': 0.0},
        {'params': net.class_agents_body, 'weight_decay': 0.0},
        {'params': net.class_agents_part, 'weight_decay': 0.0},
    ], lr=lr, momentum=opt.MOMENTUM, weight_decay=opt.WEIGHT_DECAY, nesterov=True)
    return optimizer


def zsl_evaluation_CE(net, dataset, beta, cosine=True):
    """"""
    """ Forward Test Feature 
    """
    att_ndim = dataset.att_ndim
    train_ncls = dataset.train_ncls
    label_test = np.zeros(dataset.num_test_sample, dtype=np.int)
    visual_embed_test   = np.zeros((dataset.num_test_sample, att_ndim))
    visual_triplet_test = np.zeros((dataset.num_test_sample, att_ndim*3))
    extract_visual_features(net, dataset.dataloader_test, embed_store=visual_embed_test,
                            triplet_store=visual_triplet_test, label_store=label_test)

    """ Forward Train Feature 
    """
    # LA prediction, first get mean of training feature
    label_train = np.zeros(dataset.num_train_sample, dtype=np.int)
    visual_triplet_train = np.zeros((dataset.num_train_sample, att_ndim*3))
    extract_visual_features(net, dataset.dataloader_train_for_test,
                            triplet_store=visual_triplet_train, label_store=label_train)
    
    train_class_mean = np.zeros((train_ncls, att_ndim * 3))
    for i in range(train_ncls):
        train_class_mean[i] = np.mean(visual_triplet_train[label_train == i], axis=0)

    # now just use one GPU
    if torch.cuda.device_count() > 1:
        net = net.module


    class_agents_org = net.L2_Normalization(net.class_agents_org).cpu().detach().numpy()
    class_agents_body = net.L2_Normalization(net.class_agents_body).cpu().detach().numpy()
    class_agents_part = net.L2_Normalization(net.class_agents_part).cpu().detach().numpy()
    class_agents_all  = np.hstack((class_agents_org, class_agents_body, class_agents_part))

    train_class_mean_all = train_class_mean
    train_class_mean_org = train_class_mean[:, :att_ndim]
    train_class_mean_body = train_class_mean[:, att_ndim:att_ndim*2]
    train_class_mean_part = train_class_mean[:, att_ndim*2:]

    visual_triplet_T_all = visual_triplet_test
    visual_triplet_T_org = visual_triplet_test[:, :att_ndim]
    visual_triplet_T_body = visual_triplet_test[:, att_ndim:att_ndim*2]
    visual_triplet_T_part = visual_triplet_test[:, att_ndim*2:]

    visual_embed_T = visual_embed_test

    results = {'org': {}, 'body': {}, 'part': {}, 'all': {}}
    def compute_MCA_diff_feature(visual_triplet_T, train_class_mean, class_agents):
        result_diff_feature = {}
        for _key in beta.keys():
            class_mean_Unseen_T = np.matmul(beta[_key], train_class_mean)
            class_agent_Unseen_T = np.matmul(beta[_key], class_agents)
            result_diff_feature[_key] = \
                get_all_MCA(dataset, label_test, visual_embed_T, visual_triplet_T, class_mean_Unseen_T,
                            class_agent_Unseen_T, cosine)
        return result_diff_feature
    results['all'] = compute_MCA_diff_feature(visual_triplet_T_all, train_class_mean_all, class_agents_all)
    results['org'] = compute_MCA_diff_feature(visual_triplet_T_org, train_class_mean_org, class_agents_org)
    results['body'] = compute_MCA_diff_feature(visual_triplet_T_body, train_class_mean_body, class_agents_body)
    results['part'] = compute_MCA_diff_feature(visual_triplet_T_part, train_class_mean_part, class_agents_part)
    return results

def extract_visual_features(net, dataloader, embed_store=None, triplet_store=None, label_store=None):
    cnt = 0
    for i_batch, sample_batched in enumerate(dataloader):
        # get image data, bounding box, class labels
        # print("Testing [{}/{}]".format(i_batch, len(dataset.dataloader_test)))
        im_org = sample_batched['im_org']
        im_body = sample_batched['im_body']
        im_part = sample_batched['im_part1']
        labels = sample_batched['labels']
        with torch.no_grad():
            visual_embed, visual_triplet = net(im_org.cuda(), im_body.cuda(), im_part.cuda(), extract_embed=True)
        if embed_store is not None:
            embed_store[cnt:cnt + visual_embed.shape[0], :] = visual_embed.data.cpu().numpy()
        if triplet_store is not None:
            triplet_store[cnt:cnt + visual_triplet.shape[0], :] = visual_triplet.data.cpu().numpy()
        if label_store is not None:
            label_store[cnt:cnt + labels.shape[0]] = labels.numpy()
        cnt += visual_embed.shape[0]

def get_all_MCA(dataset, label_test, visual_embed_T, visual_triplet_T, class_mean_Unseen_T, class_agent_Unseen_T, cosine):
    if isinstance(visual_embed_T, list):
        sim_UA_store = np.zeros(((visual_embed_T[0]).shape[0], dataset.test_semantic_feat.shape[0]))
        for _i in range(len(visual_embed_T)):
            sim_UA = get_sim(visual_embed_T[_i], dataset.test_semantic_feat, cosine)
            sim_UA_store += sim_UA
        sim_UA = sim_UA_store / len(visual_embed_T)
    else:
        sim_UA = get_sim(visual_embed_T, dataset.test_semantic_feat, cosine)
    MCA_UA = get_MCA(sim_UA, label_test)

    """ LA prediction with mean of training feature
    """
    sim_LA_mean = get_sim(visual_triplet_T, class_mean_Unseen_T, cosine)
    MCA_LA_Mean = get_MCA(sim_LA_mean, label_test)

    """ LA prediction with optimized training class agents 
    """
    sim_LA_agents = get_sim(visual_triplet_T, class_agent_Unseen_T, cosine)
    MCA_LA_Agents = get_MCA(sim_LA_agents, label_test)
    # combine UA, LA_agents
    MCA_CO_Agents = get_MCA(sim_LA_agents + sim_UA, label_test)

    # combine UA, LA_means
    _lambda = [1]
    MCA_CO_Means = np.zeros(len(_lambda))
    for _i, _val in enumerate(_lambda):
        MCA_CO_Means[_i] = get_MCA(_val * sim_LA_mean + sim_UA, label_test)
    max_pos = np.argmax(MCA_CO_Means)
    MCA_CO_Means = MCA_CO_Means[max_pos]
    return  MCA_UA, MCA_LA_Agents, MCA_LA_Mean, MCA_CO_Agents, MCA_CO_Means, _lambda[max_pos]

def get_MCA(sim, y):
    preds = np.argmax(sim, axis=1)
    cls_label = np.unique(y)
    acc = list()
    for i in cls_label:
        acc.append((preds[y == i] == i).mean())
    return np.asarray(acc).mean()

def get_sim(X1, X2, cosine):
    if cosine:
        return cosine_similarity(X1, X2)
    else:
        return np.dot(X1, X2.transpose())

def log_print(log_file, text, color=None, on_color=None, attrs=None):
    if cprint is not None:
        cprint(text, color=color, on_color=on_color, attrs=attrs)
    else:
        print(text)
    with open(log_file, 'a') as f:
        f.write(text + '\n')



if __name__ == "__main__":
    train()
