
from __future__ import print_function

import os
import sys

HERE = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, HERE)

from math import log10
import sys
import shutil
import os
import re
from struct import unpack
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader

if 'models' in sys.modules :
	sys.modules.pop('models')
if 'models.submodules' in sys.modules :
	sys.modules.pop('models.submodules')
	
from models.GANet_deep import GANet

import numpy as np

sys.path.pop(0)

class Bunch(object):
  def __init__(self, adict):
    self.__dict__.update(adict)

def getModel(pretrained, level=-1) :
	
	model = GANet(192)
	model = torch.nn.DataParallel(model).cuda()
	
	checkpoint = torch.load(pretrained)
	model.load_state_dict(checkpoint['state_dict'], strict=False)
	
	model.eval()
	
	def testFunc(imgL, imgR) :
		disparity = model(imgL[np.newaxis,...].cuda(), imgR[np.newaxis,...].cuda())[level]
		return disparity
	
	return testFunc
