""" Main script to generate samples of FINAL data 

	Just a stack of {w2, x, y} samples from the final model 

"""
import torch
import pickle
from torch.utils.data import DataLoader
import argparse
import os
from tqdm.auto import tqdm

from cfg.dataloader_pickle import PickleDataset
from cfg.embedding import JointEmbedding2, MNISTEmbedding
from cfg.unet import Unet
from cfg.utils import get_named_beta_schedule
from cfg.diffusion import GaussianDiffusion

# =============================================
# =           Loading blocks                  =
# =============================================



def load_data(pkl_loc, batch_size):
	""" Loads the observational data from a pickle file """
	dataset = PickleDataset(pkl_file=pkl_loc)
	loader = DataLoader(dataset, num_workers=8, batch_size=batch_size,
						shuffle=False, drop_last=False)
	return loader


def load_model(diffuser_loc, device, w=0.0):
	""" Loads the architecture for the diffusion model we've already trained """
	net = Unet(
		in_ch=3,
		mod_ch=64,
		out_ch=3,
		ch_mul=[1,2,2,2],
		num_res_blocks=2,
		cdim=64, 
		use_conv=True,
		droprate=0,
		dtype=torch.float32).to(device)

	checkpoint = torch.load(diffuser_loc, map_location='cpu')
	out = net.load_state_dict(checkpoint['net'])
	assert len(out.missing_keys) + len(out.unexpected_keys) == 0

	betas = get_named_beta_schedule(num_diffusion_timesteps=1000)
	diffusion = GaussianDiffusion(
					dtype=torch.float32,
					model=net,
					betas=betas,
					w=w, # truly conditional sampling with no upweighting
					v=1.0,
					device=device)
	# cemblayer = JointEmbedding2(num_labels_0=10, num_labels_1=2,
    #                            d_model=64, channels=3,
    #                            dim=64, hw=32).to(device)
	cemblayer = MNISTEmbedding(3, dim=64, hw=32).to(device)
	out = cemblayer.load_state_dict(checkpoint['cemblayer'])
	assert len(out.missing_keys) + len(out.unexpected_keys) == 0

	diffusion.model.eval()
	cemblayer.eval()
	return {'diffusion': diffusion, 'cemblayer': cemblayer}



# ========================================================
# =           Sample 1/multiple batches                  =
# ========================================================
@torch.no_grad()
def sample_from_batch(batch, diffuser, device, ddim=True, drop_label=True):

	bsz = batch['X'].shape[0]

	# 2: Take x from data 
	x = batch['X'].to(device)

	# 3. Sample from trained P(y | x)
	cemb = diffuser['cemblayer'](x)


	lst= []
	for iter in tqdm(range(20)):

		if ddim:
			generated = diffuser['diffusion'].ddim_sample((bsz, 3, 32, 32), 50, 0, 'linear',cemb=cemb)
		else:
			generated = diffuser['diffusion'].sample((bsz, 3, 32, 32), cemb=cemb)

		lst.append(generated.unsqueeze(0))

	y= torch.cat(lst, dim=0)

	# Return (x, y)
	return (x.cpu(), y.cpu())





def sample_batches(dataloader, diffuser, n_samples, device, ddim=True, drop_label=True):

	data = {'X': [], 'Y': []}
	count = 0
	iterator = tqdm(total=n_samples)
	while count < n_samples:
		for batch in dataloader:
			x,y = sample_from_batch(batch, diffuser, device, ddim=ddim, drop_label=drop_label)
			data['X'].append(x)
			data['Y'].append(y)
			count += x.shape[0]
			iterator.update(n=x.shape[0])
			if count >= n_samples:
				break

	return {k: torch.cat(v) for k,v in data.items()}


def save_datadict(data_dict, save_dir):
	os.makedirs(save_dir, exist_ok=True)
	save_loc = os.path.join(save_dir, 'final_Y_X.pkl')
	with open(save_loc, 'wb') as f:	
		pickle.dump(data_dict, f)


# ===========================================
# =           Main block                    =
# ===========================================


def main():
	parser = argparse.ArgumentParser(description='script to generate samples for retraining purposes')
	parser.add_argument('--pkl_loc', type=str, required=True, help='location of pickle training data file') # Base data here
	parser.add_argument('--diffuser_loc', type=str, required=True, help='location of saved diffusion model')
	parser.add_argument('--n_samples', type=int, default=10_000, help='how many samples to generate')
	parser.add_argument('--batch_size', type=int, default=256, help='batch size')
	parser.add_argument('--device', type=int, required=True, help='which gpu to use')
	parser.add_argument('--ddim', type=int, default=1, help='1 if we want to use ddim sampling, 0 ow')
	parser.add_argument('--w', type=float, default=0.0)
	parser.add_argument('--drop_label', type=int, default=0)
	parser.add_argument('--save_dir', type=str, required=True, help='location of where to save synthetic_W1W2XY.pkl dataset')

	params = parser.parse_args()
	params.ddim = bool(params.ddim)
	params.device = 'cuda:%s' % params.device


	dataloader = load_data(params.pkl_loc, params.batch_size)
	diffuser = load_model(params.diffuser_loc, params.device, w=params.w)
	data_dict = sample_batches(dataloader, diffuser, params.n_samples, params.device, params.ddim)
	save_datadict(data_dict, params.save_dir)


# 	python3 gen_conditional_dataY_X.py --pkl_loc=baseline_samples/do_X.pkl --diffuser_loc=conditional_model_Y_X/ckpt_300_checkpoint.pt --n_samples=200 --batch_size=40 --device=1 --save_dir=baseline_samples/
if __name__ == '__main__':
	main()