import sys
import shutil
import os
import torch
import torchvision.datasets as dset
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
import argparse
from evaluation import  markov_scores_evaluation, mh_scores_evaluation
from DG_wrapper import *


def get_params(args=None):
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', required=True, help='folder|cifar')
    parser.add_argument('--data_f', required=True)
    parser.add_argument('-w', '--workers', type=int, default=40)
    parser.add_argument('-b', '--batch_size', type=int, default=2)
    parser.add_argument('--dim_z', type=int, default=100)
    parser.add_argument('--device', required=True)
    parser.add_argument('--gan', required=True)
    parser.add_argument('--D_pf', default=None)
    parser.add_argument('--G_pf', default=None)
    parser.add_argument('--image_size', type=int, default=64)
    parser.add_argument('--act_path', required=True)
    parser.add_argument('--mg_loops', type=int, default=5)
    parser.add_argument('--mh_loops', type=int, default=5)
    parser.add_argument('-k', required=True)
    parser.add_argument('--m_T', required=True, type=int)
    parser.add_argument('--gT', required=True, type=int)
    parser.add_argument('--chain_size', required=True, type=int)
    parser.add_argument('--time', required=True, type=float)
    parser.add_argument('--name')
    parser.add_argument('--relative', default=None)

    return parser.parse_args(args=args)


def get_init(data_f, batch_size):
    z_matched_dataset = ZDataset(data_f)
    z_matched_loader = torch.utils.data.DataLoader(z_matched_dataset, batch_size=batch_size,
                                                   shuffle=False, num_workers=1)
    x_init, z_init = next(iter(z_matched_loader))
    x_init = x_init.flatten(0, 1)
    z_init = z_init.view(batch_size*z_init.size(1), 100)

    return x_init, z_init

def main():
    sys.path.append(".")
    cudnn.benchmark = True
    params = get_params()
    print(params.D_pf)


    if params.gan == 'dcgan_bn':
        from gan.dcgan_bn import Generator
    elif params.gan == 'dcgan_isn':
        from gan.dcgan import Generator
    elif params.gan == 'wpgan':
        from gan.wgan_wp import Generator
    else:
        raise NotImplementedError

    def get_nets(G_pf, D_pf, time, device, relative):
        G = Generator()
        G.load_state_dict(torch.load(G_pf, map_location='cpu'))
        G.eval()
        trans = HmcTranstion(time)
        MG = MarkovWrap(G, trans, device)
        MG.eval()

        if relative is None:
            D = Discriminator(nc=6)
            D.apply(weights_init)
            W = DWrapper(D)
        else:
            D = LogitDiscriminator()
            W = DWrapperRelativeMCE(D)

        W.load_state_dict(torch.load(D_pf, map_location='cpu'))
        W.eval()

        return MG, W

    MG, W = get_nets(G_pf=params.G_pf, D_pf=params.D_pf, time=params.time, device=params.device,
                     relative=params.relative)

    if params.relative is None:
        d_type = 'cat'
    else:
        d_type = 'rel'

    path_f = 'gG_MG_tD_score' + params.dataset + '_' + params.gan + '_' + d_type + '_'  + params.name
    if not os.path.exists(path_f):
        os.makedirs(path_f)

    score_mg = []
    for i in range(params.mg_loops):
        x_init, z_init = get_init(params.data_f, params.batch_size)
        is_mg, fid_mg = markov_scores_evaluation(MG, params.device, params.gT,
                                                 z_init, params.act_path)
        score_mg.append((is_mg, fid_mg))

    print('-----------------')

    score_mh = []
    for i in range(params.mh_loops):
        x_init, z_init = get_init(params.data_f, params.batch_size)
        is_mh, fid_mh = mh_scores_evaluation(W, MG, params.device, params.chain_size, params.m_T,
                                             z_init, x_init, params.act_path)
        score_mh.append((is_mh, fid_mh))

    obj_save = (score_mg, score_mh, params.D_pf)
    path_save = '%s/_score_g_%s' % (path_f, params.k)
    torch.save(obj=obj_save, f=path_save)
    return 0

if __name__ == '__main__':
    main()
