import torch
from torch_geometric.transforms import BaseTransform

from temporal_graph.data import TemporalData


class ToTemporalUndirected(BaseTransform):
    def forward(self, data: TemporalData) -> TemporalData:
        src, dst = data.src, data.dst
        num_events = src.size(0)
        edge_dir = torch.cat([torch.zeros_like(src), torch.ones_like(dst)])

        # TODO: considering self-loops
        src, dst = torch.cat([src, dst]), torch.cat([dst, src])
        for key, value in data._store.items():
            if key in ['src', 'dst']:
                continue
            if not isinstance(value, torch.Tensor) or value.dim() == 0:
                continue
            if value.size(0) == num_events:
                if value.dtype == torch.bool:
                    # we don't make bool masks undirected here
                    data[key] = torch.cat([value, torch.zeros_like(value)])
                else:
                    data[key] = torch.cat([value, value])
        data.src, data.dst = src, dst
        data.edge_dir = edge_dir
        return data
