import os
import sys

HERE = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, HERE)

import torch

if 'models' in sys.modules :
	sys.modules.pop('models')
from models.ActiveStereoNet import Active_StereoNet as ActiveStereoNet
from models.StereoNet_single import StereoNet

import numpy as np

sys.path.pop(0)

class Bunch(object):
  def __init__(self, adict):
    self.__dict__.update(adict)

def getModel(pretrained, mode = "active", refineModuleStereoNet = True) :
	
	model = None
	
	if mode == 'active' :
		model = ActiveStereoNet(refine = refineModuleStereoNet)
		
	elif mode == 'passive' :
		model = StereoNet(k=3, r=4)
		
	checkpoint = torch.load(pretrained)
	state_dict = checkpoint['state_dict']
	#new_state_dict = {k[7:] : v for k,v in state_dict.items()}
	
	model = torch.nn.DataParallel(model).cuda()
	
	model.load_state_dict(state_dict)
	
	model.cuda()
	
	model.eval()
	
	if mode == 'active' :
		def testFunc(imgL, imgR) :
			disparity = model(imgL[np.newaxis, np.newaxis, ...].cuda(), imgR[np.newaxis, np.newaxis, ...].cuda())[0]
			return torch.squeeze(disparity)
	elif mode == 'passive' :
		def testFunc(imgL, imgR) :
			disparity = model(imgL[np.newaxis, ...].cuda(), imgR[np.newaxis, ...].cuda())[-1 if refineModuleStereoNet else 0]
			return torch.squeeze(disparity)
	
	return testFunc
