import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
import argparse
import numpy as np
import voxelmorph as vxm
import tensorflow as tf
import evalutils
from kleindataloader import KleinDatasets
from tqdm import tqdm
from torch.nn import functional as F
import torch
import os
import pickle as pkl
import time
from glob import glob
from natsort import natsorted
import nibabel as nib
from keras import backend as K

# parse commandline args
parser = argparse.ArgumentParser()
parser.add_argument('--model', required=True, help='keras model for nonlinear registration')
parser.add_argument('--warp', help='output warp deformation filename')
parser.add_argument('--dry_run', action='store_true', help='dry run')
parser.add_argument('--multichannel', action='store_true',
                    help='specify that data has multiple channels')
parser.add_argument('--gpu_id', type=int, default=0)
args = parser.parse_args()

# tensorflow device handling
device, nb_devices = vxm.tf.utils.setup_device(args.gpu_id)

def dataset(start=0):
    # create a generator
    images = natsorted(glob("/mnt/anon_data2/neurite-OASIS/*/aligned_norm.nii.gz"))[start:]
    segs = natsorted(glob("/mnt/anon_data2/neurite-OASIS/*/aligned_seg35.nii.gz"))[start:]
    N = len(images)
    # for img, seg in zip(images, segs):
    for i in range(N-1):
        fiximg, movimg = images[i], images[i+1]
        fixseg, movseg = segs[i], segs[i+1]
        fid = fiximg.split("/")[-2]
        mid = movimg.split("/")[-2]
        fiximg, movimg = nib.load(fiximg).get_fdata().squeeze(), nib.load(movimg).get_fdata().squeeze()
        fiximg, movimg = fiximg[None, ..., None], movimg[None, ..., None]
        # load segmentations
        fixseg, movseg = nib.load(fixseg).get_fdata().squeeze(), nib.load(movseg).get_fdata().squeeze()
        fixseg, movseg = fixseg[None], movseg[None]
        yield movimg, fiximg, movseg, fixseg, fid, mid

def main():
    # get dataset
    results_dict = {}
    inshape = (160, 192, 224)
    print(device)
    gen = dataset(414 - 20)
    # with tf.device(device):
    config = dict(inshape=inshape, input_model=None)
    with tf.device(device):
        model = vxm.networks.VxmDense.load(args.model, **config)# , int_steps=2)
        exp_model = tf.keras.Model(model.inputs, [model.references.preint_flow, model.references.postint_flow, model.references.pos_flow])
        transform = None
        # run results
        for i, batch in tqdm(enumerate(gen), total=413):
            # [1, H, W, D]
            moving_img, fixed_img, moving_seg, fixed_seg, fid, mid = batch
            # convert moving and fixedseg
            maxlabel = int(max(moving_seg.max(), fixed_seg.max()))
            moving_seg = tf.one_hot(moving_seg, depth=maxlabel+1)[..., 1:]
            fixed_seg = tf.one_hot(fixed_seg, depth=maxlabel+1)[..., 1:]
            nb_feats = moving_seg.shape[-1]
            if transform is None:
                transform = vxm.networks.Transform(inshape, nb_feats=nb_feats)
            # run warp
            # inshape = moving_img.shape[1:-1]
            # nb_feats = moving_img.shape[-1]
            # with tf.device(device):
            a = time.time()
            # warp = model.register(moving_img, fixed_img)
            svf, warp, warp_up = exp_model.predict([moving_img, fixed_img])
            # warp = warp_up

            # jacobian (H,W,D,3)
            jac = [[np.gradient(svf[0, ..., j]+0, axis=i) for i in range(3)] for j in range(3)]
            jac = np.stack([np.stack(x, axis=-1) for x in jac], axis=-1)
            # get norm
            norm = np.linalg.norm(jac, axis=(-2, -1), ord='nuc')
            L = norm.max()
            Mstar = np.log2(L)
            # print(L, Mstar, norm.shape, jac.shape)

            # get jacobian of warp
            meshgrid = np.meshgrid(*[np.arange(0, x) for x in warp.shape[1:-1]], indexing='ij')
            # print([x.shape for x in meshgrid])
            for i in range(3):
                warp[0, ..., i] += meshgrid[i]
            # add meshgrid first
            jacwarp = [[np.gradient(warp[0, ..., j]+0, axis=i) for i in range(3)] for j in range(3)]
            jacwarp = np.stack([np.stack(x, axis=-1) for x in jacwarp], axis=-1)[1:-1, 1:-1, 1:-1]
            detjac = np.linalg.det(jacwarp)
            # print("singularity", (detjac < 0).mean(), jacwarp.shape, detjac.shape)
            print("L = ", L, "Mstar = ", Mstar, "singularity", (detjac < 0).mean())

            # warp = tf.keras.Model(model.inputs, model.references.postint_flow).predict([moving_img, fixed_img]
            # warp_up = model.register(moving_img, fixed_img)

            # svf = K.eval(model.references.pos_flow)
            # print(warp.shape, svf.shape, warp_up.shape)
            # input()
            # warp = warp_up
            # moved_seg = transform.predict([moving_seg.numpy(), warp])
            # b = time.time()
            # # print(b - a)
            # # print shape
            # moved_seg = (torch.from_numpy(moved_seg)>=0.5).float()
            # # print(moved_seg.shape, fixed_seg.shape, moving_seg.shape) 
            # # input("hi")
            # # ret = evalutils.compute_metrics(fixed_seg)
            # # compute metrics
            # moved_seg = torch.from_numpy(moved_seg.numpy()).permute(0, 4, 1, 2, 3)
            # fixed_seg = torch.from_numpy(fixed_seg.numpy()).permute(0, 4, 1, 2, 3)
            # # print(moved_seg.shape, fixed_seg.shape)
            # ret = evalutils.compute_metrics(moved_seg, fixed_seg, warp, onlydice=False, labelmax=maxlabel, method='fireants')
            # results_dict[(fid, mid)] = ret
            # print({k: (np.mean(v), np.array(v).shape) for k, v in ret.items()})


if __name__ == "__main__":
    main()
