import os
import time
import argparse
import math
import json
from collections import OrderedDict
from functools import partial

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader, random_split

from ..constant import DATA_DIR
from .utils import compute_trajectory_time
from .static import EvaluationMetrics
from ..utils import choose_device, EarlyStopper

# Special token for padding [PAD] and sentence [CLS].
# As I use pretrained embedding, I get the summary representation via pooling output of Transformer encoder.
# So, I don't need the special token for sentence [CLS].
# For [PAD], I will create a all zero vector for padding token.


EPS=1e-7
NUM_WORKERS = 2
START_EPOCHS = 30
Special_tokens = OrderedDict({'PAD': 0})


# TODO: test the dataset and dataloader.
class TrajectoryDataset(Dataset):
    """
    Dataset for trajectory data.
    Can load the whole data, and then used `random_split` to generate train and test set (Subset).

    Add padding in the `transform` function.
    """
    def __init__(self, dataname, data_dir=DATA_DIR, transform=None, target_transform=None) -> None:
        super().__init__()
        self.transform = transform
        self.target_transform = target_transform
        self.dataname = dataname
        self.data_dir = os.path.join(data_dir, dataname)
        # task_dir = os.path.join(data_dir, 'task')
        # filename = 'trajectory.csv'
        data = self._load_data()
        self.trj_id = torch.from_numpy(data['id'].values).to(torch.long)
        self.targets = torch.from_numpy(data['pingtimestamp'].values).to(torch.float32)
        self.data, index = self._parse_trj(data)  # List of tensors
        self.targets = self.targets[index]
        self._raw_data = data
        self.max_seq_len = max(map(len, self.data))

    def _load_data(self, filename = 'trajectory.csv'):
        dataname = self.dataname
        data_dir = self.data_dir
        task_dir = os.path.join(data_dir, 'task')
        if not os.path.exists(os.path.join(task_dir, filename)):
            # Preprocessing from raw data.
            trjs = pd.read_parquet(os.path.join(data_dir, f'{dataname}_trjs.parquet'))
            trj_time = compute_trajectory_time(trjs, sort=False)  # Alreadly sorted.
            matched_trjs = pd.read_csv(os.path.join(data_dir, 'trajectory/stmr.txt'), sep=';')
            trjs = matched_trjs[['id', 'cpath', 'length']]
            trjs = trjs.sort_values(by='id')
            data = pd.merge(trjs, trj_time, left_on='id', right_index=True)
            data.to_csv(os.path.join(task_dir, filename), index=False)
        else:
            data = pd.read_csv(os.path.join(task_dir, filename))
        return data

    @staticmethod
    def _parse_trj(data):
        paths = data['cpath'].values
        parsed = []
        index = []
        for i, path in enumerate(paths):
            if isinstance(path, float) and np.isnan(path):
                continue
            elif isinstance(path, str) and ',' in path:
                parsed.append(torch.LongTensor(list(map(int, path.split(',')))))
            else:
                parsed.append(torch.LongTensor([int(path)]))
            index.append(i)
        return parsed, index

    def __len__(self):
        return self.targets.size(0)

    def __getitem__(self, idx):
        trj, target = self.data[idx], int(self.targets[idx])
        if self.transform:
            trj = self.transform(trj)
        if self.target_transform:
            target = self.target_transform(target)
        return trj, target


class TargetScaler():
    def __init__(self) -> None:
        self.mean = 0
        self.std = 0

    def fit(self, target: torch.Tensor):
        self.mean = target.mean().item()
        self.std = target.std().item()
        return self
    
    def transform(self, target: torch.Tensor):
        return (target - self.mean) / self.std
    
    def fit_transform(self, target: torch.Tensor):
        self.fit(target)
        return self.transform(target)


def collate_fn_trj(batch, padding_value=0, scaler=None):
    data_batch, target_batch = zip(*batch)
    seq_len = torch.LongTensor([data.size(0) for data in data_batch])
    data_batch = pad_sequence(data_batch, padding_value=padding_value)
    data_pack = dict(data=data_batch, seq_len=seq_len)
    if scaler is not None:
        target_batch = scaler.transform(torch.tensor(target_batch, dtype=torch.float32))
    return data_pack, target_batch


def create_mask(seq, pad_index=0):
    mask = (seq == pad_index).transpose(0, 1)
    return mask


class PositionalEncoding(nn.Module):
    def __init__(self,
                 emb_size: int,
                 dropout: float,
                 maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: torch.Tensor):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])


class RoadEmbedding(nn.Module):
    def __init__(self, vocab_size=10000, dim=128, emb: torch.Tensor = None):
        """
        If load embedding from file, emb should be a torch.Tensor, otherwise, emb should be None.
        """
        super(RoadEmbedding, self).__init__()
        n_special_tokens = len(Special_tokens)
        if isinstance(emb, torch.Tensor):
            emb_dim = emb.shape[1]
            pad_emb = torch.zeros(1, emb_dim)
            emb = torch.cat([emb, pad_emb], dim=0)
            self.tok_embed = nn.Embedding.from_pretrained(emb)
            self.size = emb.shape[0]
            self.emb_dim = emb.size(1)
        else:
            self.tok_embed = nn.Embedding(vocab_size + n_special_tokens, dim) # token embeddind
            self.size = vocab_size + n_special_tokens
            self.emb_dim = dim
        self.vocab_size = self.size - n_special_tokens
        for i, k in enumerate(Special_tokens.keys()):
            Special_tokens[k] = self.size - n_special_tokens + i

    def forward(self, x):
        """
        x: (seq_len, batch_size)

        return: (seq_len, batch_size, dim)
        """
        return self.tok_embed(x) * math.sqrt(self.emb_dim)


class TrajectoryTransformer(nn.Module):
    def __init__(self, vocab_size=10000, dim=128, emb=None, max_seq_len=10000, num_heads=8, dim_ff=128, num_layers=6, dropout=0.1):
        super(TrajectoryTransformer, self).__init__()
        self.model_type = 'TrajectoryTransformer'
        if emb is not None:
            vocab_size, dim = emb.shape
        self.embedding = RoadEmbedding(vocab_size, dim, emb)
        self.positional_encoding = PositionalEncoding(dim, dropout)
        encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=num_heads, dim_feedforward=dim_ff, 
                                                   dropout=dropout, layer_norm_eps=EPS)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.n_dim = dim
        self.linear = nn.Linear(dim, 1)

    def forward(self, x: torch.Tensor, seq_len: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        """
        Arguments:
            x: (seq_len, batch_size)
            mask: (seq_len, seq_len)

        Return: 
            (seq_len, batch_size, dim)
        """
        x = self.positional_encoding(self.embedding(x))
        # print(x.shape, seq_len.shape)
        x = self.encoder(x, src_key_padding_mask=mask)  # Bugs on `x`, not on `mask`.
        out = torch.sum(x, dim=0) / seq_len.unsqueeze(1)
        pred = self.linear(out)
        return pred


class TravelTimeTrainer():
    def __init__(self):
        self.args: argparse.Namespace = None
        self.device: torch.device = None
        self.road_embedding: torch.Tensor = None
        self.max_seq_len = -1
    
    @staticmethod
    def get_args(argv_list=None):
        #TODO: reset the default parameters.
        parser = argparse.ArgumentParser()
        parser.add_argument('--dataset', type=str, default='singapore')
        parser.add_argument('--gpu', type=int, default=-1, help='-1 for cpu')
        parser.add_argument('--seed', type=int, default=-1, 
                            help='Random seed. Negative for random seed.')
        parser.add_argument('--epochs', type=int, default=64,
                            help='Number of epochs to train.')
        parser.add_argument('--emb-filename', type=str, default=None,
                            help='Embedding file.')
        parser.add_argument('--num-layers', type=int, default=6, help='Number of layers.')
        parser.add_argument('--num-heads', type=int, default=8,)
        parser.add_argument('--vocab-size', type=int, default=10000, 
                            help='Vocabulary size. Usable when emb-filename is None.')
        parser.add_argument('--dim', type=int, default=512, help='Embedding and hidden dimension.')
        parser.add_argument('--dim-ff', type=int, default=512, help='Feedforward dimension.')
        parser.add_argument('--patience', type=int, default=50, 
                            help='Patient epochs to wait before early stopping. 0 for no early stopping.')
        parser.add_argument('--lr', type=float, default=5e-4, help='Learning rate.')
        parser.add_argument('--wd', type=float, default=5e-4, help='Weight decay.')
        parser.add_argument('--batch-size', type=int, default=64, help='Batch size.')
        parser.add_argument('--runs', type=int, default=10, help='Number of runs.')
        parser.add_argument('--test-ratio', type=float, default=0.2,
                            help='Ratio of test set.')
        parser.add_argument('--print-steps', type=int, default=100,
                            help='Number of steps to print the loss.')
        parser.add_argument('--emb-path', type=str, default=None)
        if argv_list is None:
            args = parser.parse_args()
        else:
            args = parser.parse_args(argv_list)
        return args
    
    def set_args(self, args):
        self.args = args
        print(args)

    def load_road_embedding(self, emb_path = None):
        if emb_path is not None:
            self.road_embedding = torch.load(emb_path)
        else:
            self.road_embedding = None

    def set_env(self, args):
        """Set the environment."""
        if args.seed >= 0:
            np.random.seed(args.seed)
            torch.manual_seed(args.seed)
            torch.cuda.manual_seed(args.seed)
        self.device = choose_device(args.gpu)

    def data_process(self, road_dataset):
        raise NotImplementedError

    def train(self, dataset, scaler=None):
        args = self.args
        device = self.device
        dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=NUM_WORKERS, 
                                collate_fn=partial(collate_fn_trj, padding_value=Special_tokens['PAD'], scaler=scaler), 
                                pin_memory=True)
        #TODO: add change the `vocab_size` and `dim` to the parameters.
        model = TrajectoryTransformer(vocab_size=args.vocab_size, dim=args.dim, dim_ff=args.dim_ff, 
                                      emb=self.road_embedding, max_seq_len=self.max_seq_len, 
                                      num_heads=args.num_heads,
                                      num_layers=args.num_layers, dropout=0.1).to(device)
        print(model)
        criterion = nn.MSELoss()
        opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd)
        stopper = EarlyStopper(patience=args.patience)

        model.train()
        global_step = 0
        for epoch in range(args.epochs):
            loss_values = []
            iter_step = 0
            epoch_time_start = time.time()
            for data_pack, target in dataloader:
                iter_time_start = time.time()
                data = data_pack['data']
                seq_len = data_pack['seq_len']
                data, target = data.to(device), target.to(device)
                seq_len = seq_len.to(self.device)
                padding_mask = create_mask(data, pad_index=Special_tokens['PAD'])
                # pred = model(data, padding_mask)
                pred = model(data, seq_len, padding_mask)
                loss = criterion(pred.squeeze(), target)
                iter_loss = loss.item()
                loss_values.append(iter_loss)
                opt.zero_grad()
                loss.backward()
                opt.step()

                global_step += 1
                iter_step += 1
                iter_time_end = time.time()
                if global_step % args.print_steps == 0:
                    print(f'Epoch {epoch} | Step {iter_step} | Mini-batch Loss {iter_loss:.4f} | Iter Time {iter_time_end - iter_time_start:.2f}s')

            epoch_loss = sum(loss_values) / len(loss_values)
            if args.patience > 0 and epoch >= START_EPOCHS and stopper.step(epoch_loss, model):
                break
            epoch_time_end = time.time()
            print(f'Epoch {epoch} | Epoch Loss {epoch_loss:.4f} | Epoch Time {epoch_time_end - epoch_time_start:.2f}s')
        
        state_dict = stopper.load_checkpoint()
        if state_dict is not None:
            model.load_state_dict(state_dict)
        return model
    
    @torch.no_grad()
    def evaluate(self, dataset, model, metrics=None, scaler=None):
        args = self.args
        model.eval()
        dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=NUM_WORKERS, 
                                collate_fn=partial(collate_fn_trj, padding_value=Special_tokens['PAD'], scaler=scaler), 
                                pin_memory=True)
        if metrics is None:
            metrics = EvaluationMetrics(mode='regression')
        preds = []
        targets = []
        for data_pack, target in dataloader:
            data = data_pack['data']
            seq_len = data_pack['seq_len']
            data, target = data.to(self.device), target.to(self.device)
            seq_len = seq_len.to(self.device)
            padding_mask = create_mask(data, pad_index=Special_tokens['PAD'])
            pred = model(data, seq_len, padding_mask)
            preds.append(pred.squeeze())
            targets.append(target.squeeze())
        preds = torch.cat(preds, dim=0)
        targets = torch.cat(targets, dim=0)
        results = metrics.compute(preds, targets, update=True)
        print(json.dumps(results))
        # print(f'{results}')
        return results

    def train_and_evaluate(self, dataset):
        args = self.args
        runs = args.runs
        data_size = len(dataset)
        metrics = EvaluationMetrics(mode='regression')
        for _ in range(runs):
            train_size = int(data_size * (1 - args.test_ratio))
            test_size = data_size - train_size
            train_set, test_set = random_split(dataset, [train_size, test_size])
            scaler = TargetScaler().fit(train_set.dataset.targets[train_set.indices])
            model = self.train(train_set, scaler)
            results = self.evaluate(test_set, model, metrics, scaler)
        print(f'Total results: {metrics.dump_results_json()}')


def travel_time_estimation(argv: list = None):
    trainer = TravelTimeTrainer()
    args = trainer.get_args(argv)
    trainer.set_args(args)
    trainer.load_road_embedding(args.emb_filename)
    trainer.set_env(args)
    data = TrajectoryDataset(args.dataset)
    trainer.train_and_evaluate(data)
