''' Use pyelastix to test CUMC12 '''
from glob import glob
import time
import numpy as np
import torch
import SimpleITK as sitk
sitk.ProcessObject_SetGlobalWarningDisplay(False)
sitk.ProcessObject.SetGlobalDefaultNumberOfThreads(8)
import argparse
from tqdm import tqdm
import pickle
import nibabel as nib
from subprocess import call
import os
import pprint
from multiprocessing import Pool
from os import path as osp
import threading
from queue import Queue
from itertools import product

# global lock for thread safe writing
lock = threading.Lock()
# DATA_DIR = "../neurite-OASIS"

def command_iteration(filter):
    # print(f"{filter.GetElapsedIterations():3} = {filter.GetMetric():10.5f}")
    if filter.GetElapsedIterations() % 50 == 0:
        print(f"{filter.GetElapsedIterations():3} = {filter.GetMetric():10.5f}")

# def worker(queue):
#     while True:
#         arg = queue.get()
#         if arg is None:
#             queue.task_done()
#             break
#         fix, mov, out = arg
#         if os.path.exists(out):
#             print("Exists already", out)
#             continue
#         register(fix, mov, out)
#         queue.task_done()

def register(fixed_image_path, moving_image_path, moving_seg_path, output_warp, output_deformed, gaussian=1, iterations=200):
    print(f"Registering {fixed_image_path} to {moving_image_path} with seg: {moving_seg_path}")
    print(f"Output warp: {output_warp}")
    print(f"Output deformed: {output_deformed}")
    print()
    if osp.exists(output_deformed):
        return

    fixed = sitk.ReadImage(fixed_image_path)
    moving = sitk.ReadImage(moving_image_path)
    moving.SetSpacing(fixed.GetSpacing())
    moving.SetDirection(fixed.GetDirection())
    moving.SetOrigin(fixed.GetOrigin())

    # match hisotgrams first
    matcher = sitk.HistogramMatchingImageFilter()
    if fixed.GetPixelID() in (sitk.sitkUInt8, sitk.sitkInt8):
        matcher.SetNumberOfHistogramLevels(128)
    else:
        matcher.SetNumberOfHistogramLevels(1024)
    matcher.SetNumberOfMatchPoints(7)
    matcher.ThresholdAtMeanIntensityOn()
    moving = matcher.Execute(moving, fixed)

    # symmetric forces demons
    demons = sitk.FastSymmetricForcesDemonsRegistrationFilter()
    demons.SetNumberOfIterations(iterations)
    # Standard deviation for Gaussian smoothing of displacement field
    demons.SetStandardDeviations(gaussian)
    demons.AddCommand(sitk.sitkIterationEvent, lambda: command_iteration(demons))
    displacementField = demons.Execute(fixed, moving)
    outTx = sitk.DisplacementFieldTransform(displacementField)

    ### write the deformed image
    moving_seg = sitk.ReadImage(moving_seg_path)
    seg_resampler = sitk.ResampleImageFilter()
    seg_resampler.SetReferenceImage(fixed)
    seg_resampler.SetInterpolator(sitk.sitkNearestNeighbor)
    seg_resampler.SetDefaultPixelValue(0)
    seg_resampler.SetTransform(outTx)
    moved_seg = seg_resampler.Execute(moving_seg)
    sitk.WriteImage(moved_seg, output_deformed)
    #### write the transform file
    # lock.acquire()
    # sitk.WriteTransform(outTx, output_warp)
    # print("Wrote", output_warp)
    # lock.release()


if __name__ == '__main__':
    # Get images
    parser = argparse.ArgumentParser(description='Run demons registration on all datasets')
    parser.add_argument('--dataset', type=str, default='IBSR18', choices=['IBSR18', 'CUMC12', 'LPBA40', 'MGH10'], required=True)
    parser.add_argument('--num_threads', type=int, default=8)
    args = parser.parse_args()

    # create dirs
    output_dirs = f"{args.dataset}/Demons/outputs"
    os.makedirs(output_dirs, exist_ok=True)

    # populate queue
    # q = Queue()
    q = []
    dataset = args.dataset
    if dataset == 'IBSR18':
        for i, j in product(range(1, 19), range(1, 19)):
            if i == j:
                continue
            i = str(i).zfill(2)
            j = str(j).zfill(2)
            fixed = f"{dataset}/IBSR_{i}/IBSR_{i}_ana_strip.nii.gz"
            moving = f"{dataset}/IBSR_{j}/IBSR_{j}_ana_strip.nii.gz"
            moving_seg = f"{dataset}/IBSR_{j}/IBSR_{j}_seg_ana.nii.gz"
            warp_path = f"{dataset}/Demons/outputs/output_{i}_to_{j}_warp.h5"
            deformed_seg_path = f"{dataset}/Demons/outputs/deformed_{i}_{j}_seg_ana.nii.gz"
            q.append((fixed, moving, moving_seg, warp_path, deformed_seg_path))
    elif dataset == 'CUMC12':
        for i, j in product(range(1, 13), range(1, 13)):
            if i == j:
                continue
            fixed = f"{dataset}/Brains/m{i}.img"
            moving = f"{dataset}/Brains/m{j}.img"
            moving_seg = f"{dataset}/Atlases/m{j}.img"
            warp_path = f"{dataset}/Demons/outputs/output_{i}_to_{j}_warp.h5"
            deformed_seg_path = f"{dataset}/Demons/outputs/deformed_{i}_{j}_seg.img"
            q.append((fixed, moving, moving_seg, warp_path, deformed_seg_path))
    elif dataset == 'MGH10':
        for i, j in product(range(1, 11), range(1, 11)):
            if i == j: continue
            fixed = f"{dataset}/Brains/g{i}.img"
            moving = f"{dataset}/Brains/g{j}.img"
            moving_seg = f"{dataset}/Atlases/g{j}.img"
            warp_path = f"{dataset}/Demons/outputs/output_{i}_to_{j}_warp.h5"
            deformed_seg_path = f"{dataset}/Demons/outputs/deformed_{i}_{j}_seg.img"
            q.append((fixed, moving, moving_seg, warp_path, deformed_seg_path))
    elif dataset == 'LPBA40':
        for i, j in product(range(1, 41), range(1, 41)):
            if i == j: continue
            fixed = f"{dataset}/registered_pairs/l{i}_to_l{i}.img"
            moving = f"{dataset}/registered_pairs/l{j}_to_l{i}.img"
            moving_seg = f"{dataset}/registered_label_pairs/l{j}_to_l{i}.img"
            warp_path = f"{dataset}/Demons/outputs/output_{i}_to_{j}_warp.h5"
            deformed_seg_path = f"{dataset}/Demons/outputs/deformed_{i}_{j}_seg.img"
            q.append((fixed, moving, moving_seg, warp_path, deformed_seg_path))
    
    with Pool(args.num_threads) as p:
        p.starmap(register, q)

    # run threads
    # num_threads = 8
    # threads = []
    # for _ in range(num_threads):
    #     t = threading.Thread(target=worker, args=(q,))
    #     t.start()
    #     threads.append(t)
    # q.join()
    # for _ in range(num_threads):
    #     q.put(None)
    # for t in threads:
    #     t.join()
    # print("All tasks complete.")
