import os
import sys

HERE = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, HERE)

import os
import random
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.nn.functional as F
import numpy as np
import time
import math
	
from core.raft_stereo import RAFTStereo, autocast

sys.path.pop(0)

class Bunch(object):
  def __init__(self, adict):
    self.__dict__.update(adict)

def getModel(pretrained, iters=32) :
	
	arguments = {'hidden_dims' : [128,128,128],
			  'corr_implementation' : 'reg',
			  'shared_backbone' : False,
			  'corr_levels' : 4,
			  'corr_radius' : 4,
			  'n_downsample': 2,
			  'slow_fast_gru': False,
			  'n_gru_layers': 3,
			  'mixed_precision': False}
	
	args = Bunch(arguments)
	
	model = nn.DataParallel(RAFTStereo(args), device_ids=[0])
	
	checkpoint = torch.load(pretrained)
	model.load_state_dict(checkpoint, strict=True)
	
	model.cuda()
	
	model.eval()
	
	def testFunc(imgL, imgR) :
		disparity = -model(imgL[np.newaxis,...].cuda(), imgR[np.newaxis,...].cuda(), iters=iters, test_mode=True)[1][0,...]
		return torch.squeeze(disparity)
	
	return testFunc
