# Copyright 2020 The "Attribution for Graph Neural Networks" Authors.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Visualize attributions."""
from typing import List, Text

import graph_nets
import IPython.display
import matplotlib
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import rdkit
import rdkit.Chem.Draw
import seaborn as sns
from rdkit.Chem import AllChem

# Typing alias
GraphsTuple = graph_nets.graphs.GraphsTuple
Mol = rdkit.Chem.Mol
ColorMap = matplotlib.cm.ScalarMappable


def display_html_header(text: Text, level: int = 1):
  """Display a html header."""
  tag = {1: 'h1', 2: 'h2', 3: 'h3', 4: 'h4'}[level]
  return IPython.display.display(IPython.display.HTML(f'<{tag}>{text}</{tag}>'))


def mol_to_nxgraph(mol):
  """Convert an rdkit molecule into a networkx graph."""

  graph = nx.Graph()
  if not mol.GetNumConformers():
    AllChem.Compute2DCoords(mol)
  pos = {
      index: xyz[:2]
      for index, xyz in enumerate(mol.GetConformer(0).GetPositions())
  }

  for atom in mol.GetAtoms():
    a_id = atom.GetIdx()
    graph.add_node(
        a_id, atom_num=atom.GetAtomicNum(), atom=atom.GetSymbol(), xy=pos[a_id])
  for bond in mol.GetBonds():
    start_index, end_index = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
    graph.add_edge(start_index, end_index, bond_type=bond.GetBondType())

  return graph


def draw_mlp_nxmol(mol, att=None, cmap=None):
  """Draw a molecule and (attributions, colormap) using matplotlib."""
  plt.figure(figsize=(12, 12))
  graph = mol_to_nxgraph(mol)
  # Get node positions.
  pos = {index: node['xy'] for index, node in graph.nodes.items()}
  if att is None and cmap is None:
    node_colors = 'w'
  else:
    node_colors = [cmap.to_rgba(v) for v in att.nodes]

  # Get colors for node edges.
  color_map = rdkit.Chem.Draw.DrawingOptions.elemDict
  color_map[6] = [0.0, 0.0, 0.0]
  node_labels = {
      index: str(node['atom']) for index, node in graph.nodes.items()
  }

  # Get edge labels.
  edge_labels = {
      index: str(edge['bond_type'])[:1].lower()
      for index, edge in graph.edges.items()
  }

  # Matplotlib drawing per graph elment for more control.
  nx.draw_networkx_nodes(
      graph,
      pos,
      node_color=node_colors,
      node_size=25**2,
      linewidths=1.0,
      edgecolors='k')
  nx.draw_networkx_labels(graph, pos, node_labels)
  nx.draw_networkx_edges(graph, pos, width=3.0)
  nx.draw_networkx_edge_labels(
      graph, pos, edge_labels, font_color='0.25', font_size=14)
  plt.axis('off')
  plt.axis('equal')


def get_regression_colormaps(atts: List[GraphsTuple]) -> List[ColorMap]:
  """Gets colormaps based on a list of attributions."""
  cmap_list = []
  for att in atts:
    vmin, vmax = np.min(att.nodes), np.max(att.nodes)
    norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax)
    pal = matplotlib.cm.ScalarMappable(norm=norm, cmap='viridis')
    cmap_list.append(pal)
  return cmap_list


def get_binaryclass_colormaps(atts: List[GraphsTuple]) -> List[ColorMap]:
  """Gets colormaps based on a list of attributions."""
  cmap_list = []
  for att in atts:
    vmax = np.max(att.nodes)
    colors = sns.cubehelix_palette(
        start=2.0, rot=0, light=1.0, dark=0.5, as_cmap=True)
    norm = matplotlib.colors.Normalize(vmin=0.0, vmax=vmax, clip=True)
    pal = matplotlib.cm.ScalarMappable(norm=norm, cmap=colors)
    cmap_list.append(pal)
  return cmap_list
