# coding=utf-8
# Copyright 2020 The Attribution Gnn Benchmarking Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Templates for models, tasks and attribution techniques."""
import abc
from typing import Any, Callable, List, MutableMapping, Text, Tuple

import numpy as np
import tensorflow as tf

import graph_nets

NodeEdgeTensors = Tuple[tf.Tensor, tf.Tensor]
OrderedDict = MutableMapping
GraphsTuple = graph_nets.graphs.GraphsTuple
LossFunction = Callable[[tf.Tensor, tf.Tensor], tf.Tensor]
Activation = Callable[[tf.Tensor], tf.Tensor]


class TransparentModel(abc.ABC):
  """Abstract class for a Model that can be probed with AttributionTechnique."""

  @property
  @abc.abstractmethod
  def name(self):
    pass

  @abc.abstractmethod
  def __call__(self, inputs):
    """Typical forward pass for the model."""
    pass

  @abc.abstractmethod
  def predict(self, inputs):
    """Forward pass with output set on the task of interest, returns 1D tensor.

    Often models will have a multi-dimensional output, or several outputs.
    Output of predict is the subset of the model output
    that is relevant for attribution. This is used in many TransparentModel
    methods.

    Args:
      inputs: input for model.
    """
    pass

  @abc.abstractmethod
  def get_gradient(self, inputs):
    """Gets gradient of target w.r.t. to the input."""
    pass

  @abc.abstractmethod
  def get_gap_activations(self, inputs):
    """Gets node-wise and edge-wise contributions to graph embedding.

    Asummes there is a global average pooling (GAP) layer to produce the
    graph embedding (u). This returns the pre-pooled activations.
    With this layer the graph embedding is of the form
    u = sum_i nodes_i + sum_j edges_j , nodes_i and edges_j have been
    transformed to the same dim as u (i.g. via a MLP). Useful for CAM.

    Args:
      inputs: Model inputs.
    """
    pass

  @abc.abstractmethod
  def get_prediction_weights(self):
    """Gets last layer prediction weights.

    Assumes layer is of type Linear with_bias=False, useful for CAM.
    """
    pass

  @abc.abstractmethod
  def get_intermediate_activations_gradients(
      self, x
  ):
    """Gets last layer prediction weights, useful for CAM."""

    pass


class AttributionTechnique(abc.ABC):
  """Abstract class for an attribution technique."""

  name: Text
  sample_size: int  # Number of graphs to hold in memory per input.

  @abc.abstractmethod
  def attribute(self, x,
                model):
    """Compute GraphTuple with node and edges importances.

    Assumes that x (GraphTuple) has node and edge information as 2D arrays
    and the returned attribution will be a list of GraphsTuple, for each
    graph inside of x, with the same shape but with 1D node and edge arrays.

    Args:
      x: Input to get attributions for.
      model: model that gives gradients, predictions, activations, etc.
    """
    pass


class AttributionTask(abc.ABC):
  """Abstract class for an attribution task.

  Can be thought as setting a problem specification. Assumes
  there is a predictive and scorable task for which there is an underlying
  attribution that 'explains' the preditive task. It also has functions to
  aid neural network model building(get_nn_activation_fn) and optimizing
  get_nn_loss_fn) based on the predictive task.
  """

  @property
  @abc.abstractmethod
  def name(self):
    pass

  @abc.abstractmethod
  def get_true_attributions(self, x):
    """Computes ground truth attribution for some list of inputs x.

    If there are k datapoints, the GraphsTuple will have k graphs.

    Args:
      x: List of datapoints.
    """
    pass

  @abc.abstractmethod
  def get_true_predictions(self, x):
    """Get true prediction values, useful for training a model."""
    pass

  @abc.abstractmethod
  def evaluate_predictions(self, y_true,
                           y_pred):
    """Evaluate metrics on predictions, return results as (metric, value)."""
    pass

  @abc.abstractmethod
  def evaluate_attributions(
      self, y_true, y_pred,
      reducer_fn):
    """Evaluate attribution metrics on predicted attributions.

    Assumes attributions are stored as many graphs with 1D node/edge
    information. reducer_fn will take the metrics on each graph and apply
    a transformation on it (i.g. np.mean, lambda x: x). Results are a dict
    of the form (metric, reducer_fn(values)).

    Args:
      y_true: True attributions.
      y_pred: Predicted attributions.
      reducer_fn: Function that takes numpy arrays.
    """

    pass

  @abc.abstractmethod
  def get_nn_activation_fn(self):
    """Get activation function for building a NN a predictive task."""
    pass

  @abc.abstractmethod
  def get_nn_loss_fn(self):
    """Get a loss function for training a NN in a predictive task."""
    pass
