
from __future__ import print_function

import os
import sys

HERE = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, HERE)

import argparse
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
from torch.autograd import Variable
import torch.nn.functional as F
import logging
from dataloader import kitti_submission_collector as ls
from dataloader import preprocess
from PIL import Image

if 'models' in sys.modules :
	sys.modules.pop('models')
if 'models.submodules' in sys.modules :
	sys.modules.pop('models.submodules')
	
from models.deeppruner import DeepPruner
from models.config import config as config_args

from setup_logging import setup_logging

import numpy as np

sys.path.pop(0)

class Bunch(object):
  def __init__(self, adict):
    self.__dict__.update(adict)

def getModel(pretrained, dataset = 'kitti', arch = 'best', withRefinement = True) :
	
	config_args.mode = 'evaluation'
	
	if arch == 'fast' :
		config_args.cost_aggregator_scale = 8
		config_args.feature_extractor_refinement_level_outplanes = 64
	else :
		config_args.cost_aggregator_scale = 4
		config_args.feature_extractor_refinement_level_outplanes = 32
	
	if dataset == 'sceneflow' :
		config_args.post_CRP_sampler_type = 'patch_match'
	
	model = DeepPruner(args = config_args, withRefinement=withRefinement)
	model = nn.DataParallel(model)
	model.cuda()
	
	state_dict = torch.load(pretrained)
	model.load_state_dict(state_dict['state_dict'], strict=True)
	
	model.eval()
	
	def testFunc(imgL, imgR) :
		disparity = model(imgL[np.newaxis,...].cuda(), imgR[np.newaxis,...].cuda())
		return disparity
	
	return testFunc
