
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader

from models import __models__, model_loss_train_attn_only, model_loss_train_freeze_attn, model_loss_train, model_loss_test

from datasets.exrDatasetLoader import exrImagePairDataset

import argparse as args
import collections

if __name__ == "__main__" : 
	
	parser = args.ArgumentParser(description='Finetune ActiveStereoNet on our dataset')
	
	parser.add_argument("--traindata", help="Path to the training images")
	#parser.add_argument("--validationdata", help="Path to the validation images")
	parser.add_argument('--model', default='acvnet', help='select a model structure', choices=__models__.keys())
	parser.add_argument("--numepochs", default=10, type=int, help="Number of epochs to run")
	parser.add_argument("--batchsize", default=4, type=int, help="Batch size")
	parser.add_argument("--numworkers", default=4, type=int, help="Number of workers threads used for loading dataset")
	parser.add_argument('--learningrate', default = 5e-5, type=float, help="Learning rate for the optimizer")
	parser.add_argument('--ramcache', action="store_true", help="cache the whole dataset into ram. Do this only if you are certain it can fit.")
	
	parser.add_argument('--maxdisp', type=int, default=192, help='maximum disparity')
	parser.add_argument('--attention_weights_only', default=False, type=str,  help='only train attention weights')
	parser.add_argument('--freeze_attention_weights', default=False, type=str,  help='freeze attention weights parameters')
	
	parser.add_argument('-p', '--pretrained', default='./pretrained/sceneflow.ckpt', help="Pretrained weights")
	parser.add_argument('-o', '--output', default='./pretrained/finetuned_sim_stereo.pth', help="Trained weights")
	
	args = parser.parse_args()
	
	model = __models__[args.model](args.maxdisp, args.attention_weights_only, args.freeze_attention_weights)
	model = nn.DataParallel(model).cuda()
	
	checkpoint = torch.load(args.pretrained)
	model.load_state_dict(checkpoint['model'])
	
	model = model.module.cpu()
	
	cache = False
	
	dats = exrImagePairDataset(imagedir = args.traindata,
							left_nir_channel = 'Left.SimulatedNir.A', 
							right_nir_channel = 'Right.SimulatedNir.A',
							cache = cache,
							ramcache = args.ramcache,
							direction = 'l2r')
	
	datl = DataLoader(dats, 
					   batch_size= args.batchsize, 
					   shuffle=True, 
					   num_workers=args.numworkers)
	
	def buildOptimizer(parameters) :
		return Adam(parameters, lr=args.learningrate, betas=(0.9, 0.999))
	
	def getLoss() :
		if args.attention_weights_only:
			return model_loss_train_attn_only
		elif args.freeze_attention_weights:
			return model_loss_train_freeze_attn
		else:
			return model_loss_train
	
	optimizer = buildOptimizer(model.parameters())
	loss = getLoss()
	
	for ep in range(args.numepochs) :
		
		
		for batch_id, sampl in enumerate(datl) :
			
			imgLeft = sampl['frameLeft']
			imgRight = sampl['frameRight']
			imgGtDisp = torch.squeeze(sampl['trueDisparity'])
			
			imgLeft = torch.cat((imgLeft, imgLeft, imgLeft), dim=1)
			imgRight = torch.cat((imgRight, imgRight, imgRight), dim=1)
			
			disps = model(imgLeft, imgRight)
			mask = (imgGtDisp < args.maxdisp) & (imgGtDisp > 0)
			
			l = loss(disps, imgGtDisp, mask)
				
			optimizer.zero_grad()
					
			l.backward()
			optimizer.step()
			
			lval = l.item()
			
			print(f"Epoch {ep}, batch {batch_id}: loss = {lval}")
			
	torch.save({"upt_state_dict" : model.state_dict()}, args.output)
