import os
import argparse
import numpy as np
import matplotlib.pyplot as plt

from halo import Halo

from itertools import product
from collections import defaultdict

from sklearn.metrics import r2_score, mean_squared_error

import torch
import torch.nn as nn
from torch import Tensor
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, default_collate
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim import Adam

from einops import reduce
import itertools

from torch.utils.tensorboard import SummaryWriter

from core import VecRecogNetBaseline

import torch.nn.parallel
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Sequential, Linear, ReLU

import torch_geometric.transforms as T
from torch_geometric.transforms import KNNGraph
from torch_geometric.data import Data, Batch
from torch_geometric.nn import global_max_pool
from torch_geometric.nn import MessagePassing


class PHDataset(Dataset):
    def __init__(self, vec_file, aux_file):   

        assert os.path.exists(vec_file)
        assert os.path.exists(aux_file)

        obs = torch.load(vec_file).float()   
        aux = torch.load(aux_file).float()
        N,T,D = obs.shape

        self.T = T
        self.N = N
        self.D = D

        self.aux_obs = aux
        self.evd_obs = torch.load(vec_file).float()   
        self.evd_msk = torch.ones_like(self.evd_obs).long()
        self.evd_tid = torch.arange(T).view(1,self.T).repeat(self.N,1)
        
        self.inp_tid = torch.clone(self.evd_tid)
        self.inp_obs = torch.clone(self.evd_obs)
        self.inp_msk = torch.clone(self.evd_msk)

        self.indices = self.evd_tid # default
        self.is_subsampled = False
    
    def __getitem__(self, idx):
        inp_and_evd = {
            'inp_obs' : self.inp_obs[idx],
            'inp_msk' : self.inp_msk[idx],
            'inp_tid' : self.inp_tid[idx],
            'inp_tps' : self.inp_tid[idx]/self.T,
            'evd_obs' : self.evd_obs[idx],
            'evd_msk' : self.evd_msk[idx],
            'evd_tid' : self.evd_tid[idx],
            'aux_obs' : self.aux_obs[idx],
            'raw_tid' : self.indices[idx]
        }
        return inp_and_evd

    @property
    def num_timepts(self):
        return self.T

    @property
    def num_aux_dim(self):
        return self.aux_obs.shape[1]
    
    @property
    def num_vec_dim(self):
        return self.D 
    
    @property
    def num_samples(self):
        return self.N   

    def get_collate(self):
        def collate(batch):
            return default_collate(batch)
        return collate
        
    def subsample(self, indices):
        if self.is_subsampled:
            return 
        assert indices.shape[0] == self.N
        self.indices = indices
        
        for i in range(self.N):
            inp_msk = torch.zeros_like(self.inp_msk[i])
            inp_obs = torch.zeros_like(self.inp_obs[i])
            inp_tid = torch.zeros_like(self.inp_tid[i])
        
            idx_set = indices[i]
            inp_msk[0:len(idx_set)] = self.inp_msk[i][idx_set]
            inp_obs[0:len(idx_set)] = self.inp_obs[i][idx_set]
            inp_tid[0:len(idx_set)] = self.inp_tid[i][idx_set]

            bool_keep = torch.zeros(self.T, dtype=torch.long)
            bool_keep[idx_set] = 1

            self.evd_obs[i][bool_keep==0]=0
            self.evd_tid[i][bool_keep==0]=0
            self.evd_msk[i][bool_keep==0]=0
            
            self.inp_msk[i] = inp_msk
            self.inp_obs[i] = inp_obs
            self.inp_tid[i] = inp_tid
            
        self.is_subsampled = True

    def __len__(self):
        return self.num_samples
    

class PointCloudDataset(Dataset):
    def __init__(self, pts_file, aux_file):   

        assert os.path.exists(pts_file)
        assert os.path.exists(aux_file)

        obs = torch.load(pts_file).float()   
        aux = torch.load(aux_file).float()
        N,T,M,D = obs.shape

        self.M = M
        self.T = T
        self.N = N
        self.D = D

        tf = KNNGraph(k=6)
        self.aux_obs = aux
        self.pts_obs = {i: [tf(Data(pos=y)) for y in list(obs[i])] for i in range(self.N)}
        self.pts_msk = torch.ones(self.N,self.T,1, dtype=torch.long)
        self.indices = torch.arange(T).view(1,self.T).repeat(self.N,1)
        self.is_subsampled = False
        
    @property
    def num_aux_dim(self):
        return self.aux_obs.shape[1]

    @property
    def num_vec_dim(self):
        return 0

    @property
    def num_samples(self):
        return self.N

    @property
    def num_timepts(self):
        return self.T

    def subsample(self, indices):
        if self.is_subsampled:
            return
        assert indices.shape[0] == self.N
        self.indices = indices
        self.pts_msk.fill_(0)
        for i in range(self.N):    
            idx_set = indices[i]
            self.pts_obs[i] = [self.pts_obs[i][t] for t in idx_set]
            self.pts_msk[i][idx_set] = 1
        self.is_subsampled = True

    def get_collate(self):
        def collate(batch):
            pts_obs_batch = [b[0] for b in batch]
            pts_aux_batch = [b[1] for b in batch]
            pts_tid_batch = [b[2] for b in batch]
            pts_msk_batch = [b[3] for b in batch]
            return {
                'pts_obs_batch': list(itertools.chain(*pts_obs_batch)),
                'pts_aux_batch': default_collate(pts_aux_batch),
                'pts_tid_batch': default_collate(pts_tid_batch),
                'pts_msk_batch': default_collate(pts_msk_batch),
                'pts_cut_batch': torch.tensor([len(b) for b in pts_obs_batch][:-1]).cumsum(0),
            }
        return collate    
    
    def __getitem__(self, idx):
        return self.pts_obs[idx], self.aux_obs[idx], self.indices[idx], self.pts_msk[idx]
    
    def __len__(self):
        return self.N


class JointDataset(Dataset):    
    def __init__(self, datasets): # first dataset needs to be the TDA one
        self.datasets = datasets
        assert len(np.unique([ds.num_samples for ds in self.datasets])) == 1
        assert len(np.unique([ds.num_timepts for ds in self.datasets])) == 1
        assert len(np.unique([ds.num_aux_dim for ds in self.datasets])) == 1
            
        self.T = self.datasets[0].num_timepts
        self.N = self.datasets[0].num_samples
    
    @property
    def num_aux_dim(self):
        return self.datasets[0].num_aux_dim
    
    @property
    def num_timepts(self):
        return self.T
    
    @property
    def num_samples(self):
        return self.N
    
    @property 
    def num_vec_dim(self):
        assert hasattr(self.datasets[0], 'num_vec_dim')
        return self.datasets[0].num_vec_dim
    
    def __getitem__(self, idx):
        return [ds[idx] for ds in self.datasets]

    def subsample(self, indices):
        [ds.subsample(indices) for ds in self.datasets]

    def get_collate(self):
        def collate(batch):
            tda_obs_batch = [b[0] for b in batch]
            pts_obs_batch = [b[1][0] for b in batch]
            pts_aux_batch = [b[1][1] for b in batch]
            pts_tid_batch = [b[1][2] for b in batch]
            pts_msk_batch = [b[1][3] for b in batch]
            return {
                'tda_obs_batch': default_collate(tda_obs_batch),
                'pts_obs_batch': list(itertools.chain(*pts_obs_batch)),
                'pts_aux_batch': default_collate(pts_aux_batch),
                'pts_tid_batch': default_collate(pts_tid_batch),
                'pts_cut_batch': torch.tensor([len(b) for b in pts_obs_batch][:-1]).cumsum(0),
                'pts_msk_batch': default_collate(pts_msk_batch)
            }
        return collate
    
    def __len__(self):
        return len(self.datasets[0])


def create_sampling_indices(num_samples, num_timepts, N):
    return torch.stack([torch.randperm(num_timepts)[0:N].sort().values for _ in range(num_samples)])


"""The following implementation is taken from 

https://pytorch-geometric.readthedocs.io/en/stable/tutorial/point_cloud.html

with minor modifications (to implement an encoder for point clouds).
"""

class PointNetLayer(MessagePassing):
    def __init__(self, in_channels: int, out_channels: int):
        # Message passing with "max" aggregation.
        super().__init__(aggr='max')

        # Initialization of the MLP:
        # Here, the number of input features correspond to the hidden
        # node dimensionality plus point dimensionality (=3).
        self.mlp = Sequential(
            Linear(in_channels + 3, out_channels),
            ReLU(),
            Linear(out_channels, out_channels),
        )

    def forward(self,
        h: Tensor,
        pos: Tensor,
        edge_index: Tensor,
    ) -> Tensor:
        # Start propagating messages.
        return self.propagate(edge_index, h=h, pos=pos)

    def message(self,
        h_j: Tensor,
        pos_j: Tensor,
        pos_i: Tensor,
    ) -> Tensor:
        # h_j: The features of neighbors as shape [num_edges, in_channels]
        # pos_j: The position of neighbors as shape [num_edges, 3]
        # pos_i: The central node position as shape [num_edges, 3]

        edge_feat = torch.cat([h_j, pos_j - pos_i], dim=-1)
        return self.mlp(edge_feat)


class PointNet(torch.nn.Module):
    def __init__(self, h_dim:int=32):
        super().__init__()

        self.conv1 = PointNetLayer(    3, h_dim)
        self.conv2 = PointNetLayer(h_dim, h_dim)

    def forward(self,
        pos: Tensor,
        edge_index: Tensor,
        batch: Tensor,
    ) -> Tensor:

        # Perform two-layers of message passing:
        h = self.conv1(h=pos, pos=pos, edge_index=edge_index)
        h = h.relu()
        h = self.conv2(h=h, pos=pos, edge_index=edge_index)
        h = h.relu()

        # Global Pooling:
        h = global_max_pool(h, batch)  # [num_examples, hidden_channels]
        return h
        
        
def setup_cmdline_parsing():
    generic_parser = argparse.ArgumentParser()
    group0 = generic_parser.add_argument_group('Data loading/saving arguments')
    group0.add_argument("--log-out-file", type=str, default=None)
    group0.add_argument("--vec-inp-file", type=str, default=None)
    group0.add_argument("--aux-inp-file", type=str, default=None)
    group0.add_argument("--pts-inp-file", type=str, default=None)
    group0.add_argument("--run-dir", type=str, default='runs/')
    group0.add_argument("--experiment-id",type=str, default="42")
    
    group1 = generic_parser.add_argument_group('Training arguments')
    group1.add_argument("--batch-size", type=int, default=64)
    group1.add_argument("--lr", type=float, default=1e-3)
    group1.add_argument("--n-epochs", type=int, default=990)
    group1.add_argument("--seed", type=int, default=-1)
    group1.add_argument("--restart", type=int, default=30)
    group1.add_argument("--device", type=str, default="cuda:0")
    group1.add_argument("--weight-decay", type=float, default=0.0001)

    group2 = generic_parser.add_argument_group('Model configuration arguments')
    group2.add_argument("--mtan-h-dim", type=int, default=128, help="Hidden dim. of mTAN module.")
    group2.add_argument("--mtan-embed-time", type=int, default=128, help="Dim. of time embedding.")
    group2.add_argument("--mtan-num-queries", type=int, default=128, help="Number of queries.")    
    group2.add_argument("--pointnet-dim", type=int, default=32, help="PointNet++ hidden dim. ")
    group2.add_argument("--backbone", choices=[
        'topdyn_only', 
        'ptsdyn_only', 
        'joint'], default="topdyn_only")
    group3 = generic_parser.add_argument_group('Data preprocessing arguments')
    group3.add_argument("--tps-frac", type=float, default=0.5)
    return generic_parser
    
    
def run_epoch(args, dl, modules, optimizer, aux_loss_fn=None, tracker=None, mode='train'):
    epoch_loss = epoch_instances = 0. 
    
    if mode=='train':
        modules.train()        
    else:
        modules.eval()    
    
    aux_p = [] # predictions 
    aux_t = [] # ground truth
    
    for batch in dl:
        aux_enc, _, _, aux_obs = modules['recog_net'](batch, args.device)
        aux_out = modules['regressor'](aux_enc)
        loss = aux_loss_fn(aux_out.flatten(), aux_obs.flatten())
        
        if mode == 'train':
            optimizer.zero_grad()
            loss.backward()        
            optimizer.step()
            
        with torch.no_grad():
            aux_p.append(aux_out.detach().cpu())
            aux_t.append(aux_obs.detach().cpu())    
            
        epoch_loss += loss.item()
        epoch_instances += aux_obs.shape[0]

    if tracker is not None:
        tracker['epoch_loss'].append(epoch_loss/len(dl))
        tracker['epoch_aux_p'].append(torch.cat(aux_p))
        tracker['epoch_aux_t'].append(torch.cat(aux_t))


class TDABaselineBackbone(nn.Module):
    def __init__(self, args):
        super(TDABaselineBackbone, self).__init__()
        self.num_timepts = args.num_timepts    
        self.recog_net = VecRecogNetBaseline(
            mtan_input_dim=args.vec_inp_dim, 
            mtan_hidden_dim=args.mtan_h_dim, 
            mtan_embed_time=args.mtan_embed_time,
            mtan_num_queries=args.mtan_num_queries,
            use_atanh=False)
    
    def forward(self, batch, device):
        parts = {key: val.to(device) for key, val in batch.items()}
        parts_inp_obs = parts['inp_obs']
        parts_inp_msk = parts['inp_msk']
        parts_inp_tps = parts['inp_tps']  
        inp = (parts_inp_obs, parts_inp_msk, parts_inp_tps)
        return self.recog_net(inp), parts['evd_obs'], parts['evd_msk'], parts['aux_obs']


class PointNetBaselineBackbone(nn.Module):
    def __init__(self, args):
        super(PointNetBaselineBackbone, self).__init__()
        self.num_timepts = args.num_timepts
        self.point_net = PointNet(h_dim=args.pointnet_dim)
        self.recog_net = VecRecogNetBaseline(
            mtan_input_dim=args.pointnet_dim, 
            mtan_hidden_dim=args.mtan_h_dim, 
            use_atanh=False)

    def forward(self, batch, device):
        pts_msk_batch = batch['pts_msk_batch'].to(device)
        pts_tid_batch = batch['pts_tid_batch'].to(device)  
        pts_aux_batch = batch['pts_aux_batch'].to(device)
        pts_cut_batch = batch['pts_cut_batch']
        
        pts_obs_batch = batch['pts_obs_batch']     
        pts_obs_batch = Batch.from_data_list(pts_obs_batch)
        pts_obs_batch = pts_obs_batch.to(device)
    
        enc = self.point_net(pts_obs_batch.pos, pts_obs_batch.edge_index, pts_obs_batch.batch)
        enc = enc.tensor_split(pts_cut_batch, dim=0)
        enc = torch.stack(enc) 
        
        N,T,D = enc.shape
        parts_inp_obs = torch.zeros(N,self.num_timepts,D,device=device)
        parts_inp_msk = torch.zeros(N,self.num_timepts,D,device=device)
        parts_inp_tps = torch.zeros(N,self.num_timepts,device=device)
        parts_inp_obs[:,:T] = enc
        parts_inp_tps[:,:T] = pts_tid_batch/self.num_timepts
        parts_inp_msk[:,:T] = 1
        inp = (parts_inp_obs, parts_inp_msk, parts_inp_tps)
        
        pts_tid_batch = pts_tid_batch.view(pts_tid_batch.shape + torch.Size([1])).expand_as(enc)
        evd_obs = torch.zeros(N,self.num_timepts,D,device=device)
        evd_obs.scatter_(1,pts_tid_batch,enc)
        evd_msk = pts_msk_batch.expand(N,self.num_timepts,D)
    
        return self.recog_net(inp), evd_obs, evd_msk, pts_aux_batch


class JointBaselineBackbone(nn.Module):
    def __init__(self, args):
        super(JointBaselineBackbone, self).__init__()
        self.num_timepts = args.num_timepts    
        self.point_net = PointNet(h_dim=args.pointnet_dim)
        self.recog_net = VecRecogNetBaseline(
            mtan_input_dim=args.vec_inp_dim + args.pointnet_dim, 
            mtan_hidden_dim=args.mtan_h_dim, 
            use_atanh=False)
    
    def forward(self, batch, device):
        batch_tda = batch['tda_obs_batch']        
        parts = {key: val.to(device) for key, val in batch_tda.items()}
        parts_inp_obs = parts['inp_obs']
        parts_inp_msk = parts['inp_msk']
        parts_inp_tps = parts['inp_tps']    
        
        pts_aux_batch = batch['pts_aux_batch'].to(device)  
        pts_tid_batch = batch['pts_tid_batch'].to(device)  
        pts_cut_batch = batch['pts_cut_batch']          
        pts_obs_batch = batch['pts_obs_batch']     
        pts_obs_batch = Batch.from_data_list(pts_obs_batch)
        pts_obs_batch = pts_obs_batch.to(device)
    
        enc = self.point_net(pts_obs_batch.pos, pts_obs_batch.edge_index, pts_obs_batch.batch)
        enc = enc.tensor_split(pts_cut_batch, dim=0)
        enc = torch.stack(enc) 
        
        N,T,D = enc.shape
        enc_ext = torch.zeros(N,self.num_timepts,D,device=device)
        enc_ext[:,:T] = enc
        
        parts_inp_obs = torch.cat((parts_inp_obs, enc_ext), dim=2)  # N,T,D
        parts_inp_msk = parts_inp_msk[:,:,0].view(
            parts_inp_obs.shape[0],
            parts_inp_obs.shape[1],1).expand(
                parts_inp_obs.shape[0],
                parts_inp_obs.shape[1],
                parts_inp_obs.shape[2])
                
        pts_tid_batch = pts_tid_batch.view(pts_tid_batch.shape + torch.Size([1])).expand_as(enc)
        evd_obs = torch.zeros(N,self.num_timepts,D,device=device)
        evd_obs.scatter_(1,pts_tid_batch,enc)
        evd_obs = torch.cat((evd_obs, parts['evd_obs']),dim=2)
        evd_msk = parts['evd_msk'][:,:,0].view(evd_obs.shape[0],evd_obs.shape[1],1).expand(
            evd_obs.shape[0],
            evd_obs.shape[1],
            evd_obs.shape[2])
        
        inp = (parts_inp_obs, parts_inp_msk, parts_inp_tps)
        return self.recog_net(inp), evd_obs, evd_msk, pts_aux_batch


def create_recog_backbone(args):
    if args.backbone == 'topdyn_only': return TDABaselineBackbone(args)
    elif args.backbone == 'ptsdyn_only':return PointNetBaselineBackbone(args)
    elif args.backbone == 'joint': return JointBaselineBackbone(args)
    else: raise NotImplementedError   


def load_data(args):
    if args.backbone == 'topdyn_only':
        return PHDataset(
            args.vec_inp_file, 
            args.aux_inp_file)
    elif args.backbone == 'ptsdyn_only':
        return PointCloudDataset(
            args.pts_inp_file, 
            args.aux_inp_file)
    elif args.backbone == 'joint':
        ds_topdyn = PHDataset(
            args.vec_inp_file, 
            args.aux_inp_file)    
        ds_ptsdyn = PointCloudDataset(
            args.pts_inp_file,
            args.aux_inp_file)
        return JointDataset([ds_topdyn, ds_ptsdyn])
    else:
        raise NotImplementedError()



def main():
    trn_tracker = defaultdict(list) # track training stats
    tst_tracker = defaultdict(list) # track testing stats
    
    spinner = Halo(spinner='dots')
    
    parser = setup_cmdline_parsing()
    args = parser.parse_args()
    print(args)
    
    writer = SummaryWriter(os.path.join(args.run_dir, args.experiment_id))
    
    spinner.start('Loading data')
    ds = load_data(args)
    spinner.succeed('Loaded data!')
    
    spinner.start('Patching command line args')
    args_dict = vars(args)
    args_dict['vec_inp_dim'] = ds.num_vec_dim
    args_dict['num_aux_dim'] = ds.num_aux_dim
    args_dict['num_timepts'] = ds.num_timepts
    spinner.succeed('Patched command line args!')
    
    spinner.start('Subsampling')
    if args.tps_frac > 0:
        assert args.tps_frac < 1, 'Timepoint subsampling not in range (0,1)'
        indices = create_sampling_indices(len(ds), args.num_timepts, int(args.tps_frac*args.num_timepts)) 
        ds.subsample(indices) 
    spinner.succeed('Subsampled!')
    
    generator = torch.Generator()
    trn_set, tst_set = torch.utils.data.random_split(ds, [0.8, 0.2], generator=generator)
    dl_trn = DataLoader(trn_set, batch_size=args.batch_size, shuffle=True, collate_fn=ds.get_collate())
    dl_tst = DataLoader(tst_set, batch_size=args.batch_size, shuffle=False, collate_fn=ds.get_collate())

    recog_backbone = create_recog_backbone(args)
    
    modules = nn.ModuleDict(
    {
        "recog_net":recog_backbone
    })
    modules.add_module("regressor", nn.Sequential(
        nn.Linear(args.mtan_h_dim, args.num_aux_dim), nn.Tanh()))
    modules = modules.to(args.device)
    
    num_params = 0
    for p in modules.parameters():
        num_params += p.numel()
    print(f'Number of parameters is {num_params}')

    assert args.num_timepts is not None, "Nr. of timepoints not set!"
    t = torch.linspace(0, 1.0, args.num_timepts).to(args.device)
    
    optimizer = Adam(modules.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    scheduler = CosineAnnealingLR(optimizer, args.restart, eta_min=0, last_epoch=-1)
    aux_loss_fn = nn.MSELoss()

    for epoch_cnt in range(args.n_epochs):
        run_epoch(args, 
            dl_trn, 
            modules, 
            optimizer, 
            aux_loss_fn, 
            trn_tracker, 
            mode='train')
        with torch.no_grad():
            run_epoch(args, 
                dl_tst, 
                modules, 
                optimizer,
                aux_loss_fn, 
                tst_tracker, 
                mode='test')
        scheduler.step()
        
        scorefns = {'r2s': r2_score, 
                    'mse': mean_squared_error}
        trackers = {'trn': trn_tracker, 
                    'tst': tst_tracker}
        
        scores = defaultdict(list)
        for scorefn_key, trackers_key in product(scorefns, trackers):  
            key_str = scorefn_key + "_" + trackers_key   
            for aux_d in range(args.num_aux_dim):
                tmp = scorefns[scorefn_key](
                    trackers[trackers_key]['epoch_aux_t'][-1][:,aux_d],
                    trackers[trackers_key]['epoch_aux_p'][-1][:,aux_d])
                scores[key_str].append(tmp)
                writer.add_scalar("{}_{}/{}".format(scorefn_key, aux_d, trackers_key), tmp, epoch_cnt)
        
        
        print('{:04d} | trn_loss={:.4f} | avg_trn_mse={:0.4f} | avg_tst_mse={:0.4f} | avg_tst_r2s={:0.4f} | lr={:0.6f}'.format(
            epoch_cnt,
            trn_tracker['epoch_loss'][-1],
            np.mean(scores['mse_trn']),
            np.mean(scores['mse_tst']),
            np.mean(scores['r2s_tst']),
            scheduler.get_last_lr()[-1]))
        
        writer.add_scalar("r2s_avg/trn", np.mean(scores['r2s_trn']), epoch_cnt)
        writer.add_scalar("r2s_avg/tst", np.mean(scores['r2s_tst']), epoch_cnt)
        
        for aux_d in range(args.num_aux_dim):
            plt.plot(
                tst_tracker['epoch_aux_t'][-1][:,aux_d], 
                tst_tracker['epoch_aux_p'][-1][:,aux_d], '.')
            writer.add_figure('r2s/tst_scatter_{}'.format(aux_d), plt.gcf(), epoch_cnt)
            plt.close()

    writer.close()
    if args.log_out_file:
        torch.save((trn_tracker, tst_tracker, args), args.log_out_file)
    
    
if __name__ == "__main__":
    main()