# coding=utf-8
# Copyright 2018 The Condconv NeurIPS19 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""CondConv-MobileNetV1 Implementation.

Implements the CondConv block, SepCondConv block, routing function
and CondConv-MobileNetV1 model for CondConv: Conditionally Parameterized
Convolutions for Efficient Inference.

The cc_mobilenet_v1 function builds a CondConv-MobileNetV1 model for classification,
to be used with the Tensorflow Estimators API. The cc_mobilenet_v1_base
function builds a feature extractor to be used with the Tensorflow
Object Detection API.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections

import tensorflow as tf

arg_scope = tf.contrib.framework.arg_scope
avg_pool2d = tf.contrib.layers.avg_pool2d
batch_norm = tf.contrib.layers.batch_norm
conv2d = tf.contrib.layers.conv2d
dropout = tf.contrib.layers.dropout
separable_conv2d = tf.contrib.layers.separable_conv2d
fully_connected = tf.contrib.layers.fully_connected

# Specifies the convolution operations in the base feature extractor.
# 'Conv' means regular 3x3 convolution.
# 'SplitConv' means separable 3x3 convolution
#     (3x3 depth-wise convolution followed by 1x1 point-wise convolution).
# After cond_index_begin:
#     'Conv' layers are replaced by CondConv blocks.
#     'SplitConv' layers are replaced by 'SepCondConv' blocks.
ConvDef = collections.namedtuple('ConvDef',
                                 ['stride', 'depth', 'layer_type'])
_CONV_DEFS = [
    ConvDef(stride=2, depth=32, layer_type='Conv'),
    ConvDef(
        stride=1, depth=64, layer_type='SplitConv'),
    ConvDef(
        stride=2, depth=128, layer_type='SplitConv'),
    ConvDef(
        stride=1, depth=128, layer_type='SplitConv'),
    ConvDef(
        stride=2, depth=256, layer_type='SplitConv'),
    ConvDef(
        stride=1, depth=256, layer_type='SplitConv'),
    ConvDef(
        stride=2, depth=512, layer_type='SplitConv'),
    ConvDef(
        stride=1, depth=512, layer_type='SplitConv'),
    ConvDef(
        stride=1, depth=512, layer_type='SplitConv'),
    ConvDef(
        stride=1, depth=512, layer_type='SplitConv'),
    ConvDef(
        stride=1, depth=512, layer_type='SplitConv'),
    ConvDef(
        stride=1, depth=512, layer_type='SplitConv'),
    ConvDef(
        stride=2, depth=1024, layer_type='SplitConv'),
    ConvDef(
        stride=1, depth=1024, layer_type='SplitConv'),
]


def compute_routing_weights(inputs, num_branches):
  """Implements the routing function to compute routing weights.

  Args:
    inputs: a tensor of [batch_size, height, width, channels]
    num_branches: the number of branches to compute routing weights for
    num_output_channels: the number of output channels for the operation

  Returns:
    output_weights: the routing weights of dimension [batch_size, num_branches]
  """
  num_input_channels = inputs.shape[3].value
  pool_size = (inputs.shape[1], inputs.shape[2])
  inputs = tf.layers.average_pooling2d(
      inputs=inputs, pool_size=pool_size, strides=1, padding='VALID')
  inputs = tf.reshape(inputs, [-1, num_input_channels])
  output_dim = num_branches
  output_logits = fully_connected(
      inputs=inputs,
      num_outputs=output_dim,
      activation_fn=None,
      weights_initializer=tf.zeros_initializer())
  output_weights = tf.nn.sigmoid(output_logits)
  return output_weights


def cond_conv2d(inputs, num_outputs, kernel_size, stride, normalizer_fn,
                activation_fn, scope_base, num_branches, routing_weights,
                all_routing_weights):
  """Implements the CondConv block.

  Args:
    inputs: a tensor of [batch_size, height, width, channels]
    num_outputs: the number of output channels
    kernel_size: a [height, width] list of the kernel dimensions
    stride: the stride for the convolution
    normalizer_fn: the normalizer function to be used after the convolution
    activation_fn: the activation function to be used after the normalizer_fn
    scope_base: the base scope name for the CondConv block
    num_branches: the number of branches used in the CondConv block
    routing_weights: routing weights to be used in the CondConv block, or None
      to generate new routing weights in the block
    all_routing_weights: a list to append routing weights to for logging

  Returns:
    routing_weights: the routing weights used in the CondConv block
    output: the output tensor of the CondConv block
  """
  if not routing_weights:
    routing_weights = compute_routing_weights(inputs, num_branches)
    all_routing_weights.append(routing_weights)

  conv_arg_scope = get_arg_scope_for_op(conv2d)
  num_input_channels = inputs.shape[3].value
  kernel_shape = kernel_size + [num_input_channels] + [num_outputs]

  all_kernels_shape = [
      num_branches,
      kernel_size[0] * kernel_size[1] * num_input_channels * num_outputs
  ]
  all_kernels_var_name = scope_base+'_mix/weights'
  all_kernels = tf.get_variable(
      all_kernels_var_name,
      initializer=conv_arg_scope['weights_initializer'],
      regularizer=conv_arg_scope['weights_regularizer'],
      shape=all_kernels_shape,
      trainable=True)

  per_example_kernels = tf.matmul(routing_weights, all_kernels)

  bs = inputs.shape[0].value
  if bs == 1:
    examples_list = [inputs]
    per_example_kernels_list = [per_example_kernels]
  else:
    examples_list = tf.split(inputs, bs, axis=0)
    per_example_kernels_list = tf.split(per_example_kernels, bs, axis=0)
  outputs = []
  for i in range(bs):
    example_kernel = tf.reshape(per_example_kernels_list[i], kernel_shape)
    stride_arr = [1, stride, stride, 1]
    outputs.append(tf.nn.conv2d(
        examples_list[i], example_kernel, strides=stride_arr, padding='SAME'))
  if bs == 1:
    output = outputs[0]
  else:
    output = tf.concat(outputs, axis=0)

  if normalizer_fn is not None:
    output = normalizer_fn(output)
  if activation_fn is not None:
    output = activation_fn(output)
  return routing_weights, output


def depthwise_cond_conv2d(inputs, kernel_size, stride, rate, normalizer_fn,
                          activation_fn, scope_base, num_branches,
                          routing_weights, all_routing_weights):
  """Implements the depth-wise CondConv block.

  Args:
    inputs: a tensor of [batch_size, height, width, channels]
    kernel_size: a [height, width] list of the kernel dimensions
    stride: the stride for the convolution
    rate: the dilation rate
    normalizer_fn: the normalizer function to be used after the depthwise
      convolution
    activation_fn: the activation function to be used after the normalizer_fn
    scope_base: the base scope name for the depthwise CondConv block
    num_branches: the number of branches used in the CondConv block
    routing_weights: routing weights to be used in the CondConv block, or None
      to generate new routing weights in the block
    all_routing_weights: a list to append routing weights to for logging

  Returns:
    routing_weights: the routing weights used in the separable CondConv block
    output: the output tensor of the separable CondConv block
  """
  num_input_channels = inputs.shape[3].value
  num_output_channels = num_input_channels

  if not routing_weights:
    routing_weights = compute_routing_weights(inputs, num_branches)
    all_routing_weights.append(routing_weights)

  sep_conv_arg_scope = get_arg_scope_for_op(separable_conv2d)
  kernel_shape = kernel_size + [num_input_channels] + [1]

  all_kernels_shape = [
      num_branches, kernel_size[0] * kernel_size[1] * num_output_channels
  ]
  all_kernels_var_name = scope_base+'_mix/depthwise_weights'
  all_kernels = tf.get_variable(
      all_kernels_var_name,
      initializer=sep_conv_arg_scope['weights_initializer'],
      regularizer=sep_conv_arg_scope['weights_regularizer'],
      shape=all_kernels_shape,
      trainable=True)
    
  per_example_kernels = tf.matmul(routing_weights, all_kernels)
  bs = inputs.shape[0].value
  if bs == 1:
    examples_list = [inputs]
    per_example_kernels_list = [per_example_kernels]
  else:
    examples_list = tf.split(inputs, bs, axis=0)
    per_example_kernels_list = tf.split(per_example_kernels, bs, axis=0)
  outputs = []
  for i in range(bs):
    example_kernel = tf.reshape(per_example_kernels_list[i], kernel_shape)
    stride_arr = [1, stride, stride, 1]
    outputs.append(tf.nn.depthwise_conv2d(
        examples_list[i], example_kernel, strides=stride_arr, padding='SAME',
        rate=[rate, rate]))
  if bs == 1:
    output = outputs[0]
  else:
    output = tf.concat(outputs, axis=0)

  if normalizer_fn is not None:
    output = normalizer_fn(output)
  if activation_fn is not None:
    output = activation_fn(output)
  return routing_weights, output


def cc_mobilenet_v1_base(inputs,
                         num_branches=1,
                         cond_index_begin=13,
                         final_endpoint='Conv2d_13_pointwise',
                         min_depth=8,
                         depth_multiplier=1.0,
                         conv_defs=None,
                         output_stride=None,
                         scope=None):
  """The CC-MobileNetV1 base model.

  Constructs a CC-MobileNetV1 network from inputs to the given final endpoint.

  Args:
    inputs: a tensor of shape [batch_size, height, width, channels].
    num_branches: the number of branches for the CondConv blocks
    cond_index_begin: the layer index to begin using CondConv blocks
    final_endpoint: specifies the endpoint to construct the network up to.
    min_depth: Minimum depth value (number of channels) for all convolution ops,
      enforced when depth_multiplier < 1.
    depth_multiplier: float multiplier for the depth (number of channels)
      for all convolution ops; must be greater than zero.
    conv_defs: a list of ConvDef namedtuples specifying the net architecture,
      None for the default MobileNetV1 architecture.
    output_stride: an integer that specifies the requested ratio of input to
      output spatial resolution.
    scope: optional variable_scope.

  Returns:
    net: an output tensor corresponding to the final endpoint
    end_points: a set of activations for external use, in particular includes
      'all_routing_weights' which contains a list of routing weights used
      for the batch
    all_routing_weights: the list of routing weights used in the batch for
      logging purposes

  Raises:
    ValueError: if final_endpoint is not set to one of the predefined values,
                or depth_multiplier <= 0, or the target output_stride is not
                allowed, or conv_defs is improperly specified.
  """
  depth = lambda d: max(int(d * depth_multiplier), min_depth)
  end_points = {}

  # Used to find thinned depths for each layer.
  if depth_multiplier <= 0:
    raise ValueError('depth_multiplier is not greater than zero.')

  if conv_defs is None:
    conv_defs = _CONV_DEFS

  if output_stride is not None and output_stride not in [8, 16, 32]:
    raise ValueError('Only allowed output_stride values are 8, 16, 32.')

  with tf.variable_scope(scope, 'MobilenetV1', [inputs]):
    with arg_scope([conv2d, separable_conv2d], padding='SAME'):
      # Keeps track of the output stride of the activations.
      current_stride = 1

      # The atrous convolution rate parameter.
      rate = 1

      net = inputs

      # List of routing weights used in CondConv blocks in the network
      all_routing_weights = []

      for i, conv_def in enumerate(conv_defs):
        end_point_base = 'Conv2d_%d' % i

        if output_stride is not None and current_stride == output_stride:
          # If we have reached the target output_stride, then we need to employ
          # atrous convolution with stride=1 and multiply the atrous rate by the
          # current unit's stride for use in subsequent layers.
          layer_stride = 1
          layer_rate = rate
          rate *= conv_def.stride
        else:
          layer_stride = conv_def.stride
          layer_rate = 1
          current_stride *= conv_def.stride

        num_input_channels = net.shape[3].value
        num_output_channels = depth(conv_def.depth)

        if conv_def.layer_type == 'Conv':
          if i < cond_index_begin:
            # Conv
            end_point = end_point_base
            net = conv2d(
                net,
                num_output_channels,
                [3, 3],
                stride=conv_def.stride,
                normalizer_fn=batch_norm,
                activation_fn=tf.nn.relu6,
                scope=end_point)
          else:
            # CondConv
            end_point = end_point_base
            _, net = cond_conv2d(
                net,
                num_output_channels,
                [3, 3],
                stride=conv_def.stride,
                normalizer_fn=batch_norm,
                activation_fn=tf.nn.relu6,
                scope_base=end_point,
                num_branches=num_branches,
                routing_weights=None,
                all_routing_weights=all_routing_weights)
        elif conv_def.layer_type == 'SplitConv':
          if i < cond_index_begin:
            # Separable Conv
            end_point = end_point_base + '_depthwise'

            # Depth-wise convolution only since num_outputs is None
            net = separable_conv2d(
                net,
                None,
                [3, 3],
                depth_multiplier=1,
                stride=layer_stride,
                rate=layer_rate,
                normalizer_fn=batch_norm,
                activation_fn=tf.nn.relu6,
                scope=end_point)

            end_points[end_point] = net

            end_point = end_point_base + '_pointwise'

            net = conv2d(
                net,
                num_output_channels, [1, 1],
                stride=1,
                normalizer_fn=batch_norm,
                activation_fn=tf.nn.relu6,
                scope=end_point)
          else:
            # SepCondConv
            end_point = end_point_base + '_depthwise'

            routing_weights, net = depthwise_cond_conv2d(
                net,
                [3, 3],
                stride=layer_stride,
                rate=layer_rate,
                normalizer_fn=batch_norm,
                activation_fn=tf.nn.relu6,
                scope_base=end_point,
                num_branches=num_branches,
                routing_weights=None,
                all_routing_weights=all_routing_weights)
            end_points[end_point] = net
            end_point = end_point_base + '_pointwise'

            net = cond_conv2d(
                net,
                num_output_channels, [1, 1],
                stride=1,
                normalizer_fn=batch_norm,
                activation_fn=tf.nn.relu6,
                scope_base=end_point,
                num_branches=num_branches,
                routing_weights=routing_weights,
                all_routing_weights=all_routing_weights)
        else:
          raise ValueError('Unknown convolution type %s for layer %d'
                           % (conv_def.layer_type, i))

        end_points[end_point] = net
        if end_point == final_endpoint:
          end_points['routing_weights'] = all_routing_weights
          return net, end_points, all_routing_weights

  raise ValueError('Unknown final endpoint %s' % final_endpoint)


def cc_mobilenet_v1(inputs,
                    num_classes=1000,
                    dropout_keep_prob=0.999,
                    cond_final_layer=False,
                    is_training=True,
                    num_branches=1,
                    cond_index_begin=13,
                    min_depth=8,
                    depth_multiplier=1.0,
                    conv_defs=None,
                    output_stride=None,
                    final_endpoint=None,
                    prediction_fn=tf.contrib.layers.softmax,
                    spatial_squeeze=True,
                    reuse=None,
                    scope='CCMobilenetV1'):
  """CC-MobileNet-V1 model for classification.

  Args:
    inputs: a tensor of shape [batch_size, height, width, channels].
    num_classes: number of predicted classes. if 0 or None, the logits layer
      is omitted and the input features to the logits layer (before dropout)
      are returned instead.
    dropout_keep_prob: the fraction of activation values that are retained,
      only used if `num_classes` is not 0 or `None`. The value is expected to
      between 0 and 1.
    cond_final_layer: True to use CondConv for the final classification layer,
      False otherwise
    is_training: True if model is used for training, False otherwise
    num_branches: number of branches to use for CondConv blocks
    cond_index_begin: the index of the layer in conv_defs to begin using
      CondConv and SepCondConv blocks
    min_depth: minimum depth value (number of channels) for all convolution ops,
      enforced when depth_multiplier < 1.
    depth_multiplier: float multiplier for the depth (number of channels)
      for all convolution ops; the value must be greater than zero.
    conv_defs: a list of ConvDef namedtuples specifying the net architecture,
      None for the default MobileNetV1 architecture.
    output_stride: an integer that specifies the requested ratio of input to
      output spatial resolution.
    final_endpoint: A string that specifies the endpoint to construct the
      network up to. Use `None` for the default value 'Conv2d_13_pointwise'.
    prediction_fn: a function to get predictions out of logits.
    spatial_squeeze: if True, logits is of shape is [B, C], if false logits is
        of shape [B, 1, 1, C], where B is batch_size and C is number of classes.
    reuse: whether or not the network and its variables should be reused. To be
      able to reuse 'scope' must be given.
    scope: optional variable_scope.

  Returns:
    logits: the pre-softmax activations, a tensor of size
      [batch_size, num_classes], if `num_classes` is not 0 or `None`, or the
      average pooling activations otherwise.
    end_points: a dictionary from components of the network to the corresponding
      activation. in particular, 'all_routing_weights' specifies the list of
      routing weights used in the network.

  Raises:
    ValueError: input rank is invalid.
  """
  input_shape = inputs.get_shape().as_list()
  if len(input_shape) != 4:
    raise ValueError('Invalid input tensor rank, expected 4, was: %d' %
                     len(input_shape))

  with tf.variable_scope(scope, 'MobilenetV1', [inputs, num_classes],
                         reuse=reuse) as scope:
    if final_endpoint is None:
      final_endpoint = 'Conv2d_13_pointwise'
    with arg_scope([batch_norm, dropout], is_training=is_training):
      net, end_points, all_routing_weights = cc_mobilenet_v1_base(
          inputs,
          num_branches=num_branches,
          cond_index_begin=cond_index_begin,
          scope=scope,
          min_depth=min_depth,
          depth_multiplier=depth_multiplier,
          conv_defs=conv_defs,
          output_stride=output_stride,
          final_endpoint=final_endpoint)
      with tf.variable_scope('Logits'):
        kernel_size = _reduced_kernel_size_for_small_input(net, [7, 7])
        net = avg_pool2d(net, kernel_size, padding='VALID', scope='AvgPool_1a')
        end_points['AvgPool_1a'] = net
        if not num_classes:
          # Skip the dropout and the final logit layer.
          return net, end_points

        # 1 x 1 x 1024
        net = dropout(net, keep_prob=dropout_keep_prob, scope='Dropout_1b')
        if cond_final_layer:
          logits = cond_conv2d(
              net,
              num_classes, [1, 1],
              stride=1,
              normalizer_fn=None,
              activation_fn=None,
              scope_base='Conv2d_1c_1x1',
              num_branches=num_branches,
              routing_weights=None,
              all_routing_weights=all_routing_weights)
        else:
          logits = conv2d(
              net,
              num_classes, [1, 1],
              activation_fn=None,
              normalizer_fn=None,
              scope='Conv2d_1c_1x1')
        if spatial_squeeze:
          logits = tf.squeeze(logits, [1, 2], name='SpatialSqueeze')
      end_points['Logits'] = logits
      if prediction_fn:
        end_points['Predictions'] = prediction_fn(logits, scope='Predictions')
  return logits, end_points

cc_mobilenet_v1.default_image_size = 224


def _reduced_kernel_size_for_small_input(input_tensor, kernel_size):
  """Define kernel size which is automatically reduced for small input.

  If the shape of the input images is unknown at graph construction time this
  function assumes that the input images are large enough.

  Args:
    input_tensor: input tensor of size [batch_size, height, width, channels].
    kernel_size: desired kernel size of length 2: [kernel_height, kernel_width]

  Returns:
    kernel_size_out: a tensor with the kernel size.
  """
  shape = input_tensor.get_shape().as_list()
  if shape[1] is None or shape[2] is None:
    kernel_size_out = kernel_size
  else:
    kernel_size_out = [min(shape[1], kernel_size[0]),
                       min(shape[2], kernel_size[1])]
  return kernel_size_out


def cc_mobilenet_v1_arg_scope(is_training=True,
                              weight_decay=0.00004,
                              routefn_weight_decay=0.00004,
                              batch_norm_decay=0.9997,
                              stddev=0.09,
                              regularize_depthwise=False,
                              fuse_batch_norm=True):
  """Defines the default MobilenetV1 arg scope.

  Args:
    is_training: whether or not we're training the model.
    weight_decay: the weight decay to use for regularizing the model.
    routefn_weight_decay: the weight decay to use for fully connected layers
      in the routing function.
    batch_norm_decay: the batch norm decay to use for moving averages.
    stddev: the standard deviation of the truncated normal weight initializer.
    regularize_depthwise: whether or not apply regularization on depthwise.
    fuse_batch_norm: whether to use fused or the regular batch norm.

  Returns:
    an `arg_scope` to use for the cc_mobilenet_v1 model.
  """
  batch_norm_params = {
      'center': True,
      'scale': True,
      'decay': batch_norm_decay,
      'epsilon': 0.001,
      'fused': fuse_batch_norm,
  }
  if is_training is not None:
    batch_norm_params['is_training'] = is_training

  weights_init = tf.truncated_normal_initializer(stddev=stddev)
  regularizer = tf.contrib.layers.l2_regularizer(weight_decay)
  routefn_regularizer = tf.contrib.layers.l2_regularizer(routefn_weight_decay)
  if regularize_depthwise:
    depthwise_regularizer = regularizer
  else:
    depthwise_regularizer = None
  with arg_scope([conv2d, separable_conv2d],
                 weights_initializer=weights_init,
                 normalizer_fn=batch_norm):
    with arg_scope([batch_norm], **batch_norm_params):
      with arg_scope([conv2d], weights_regularizer=regularizer):
        with arg_scope([separable_conv2d],
                       weights_regularizer=depthwise_regularizer):
          with arg_scope([fully_connected],
                         weights_regularizer=routefn_regularizer) as sc:
            return sc
