# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Densenet model configuration.

References:
  "Densely Connected Convolutional Networks": https://arxiv.org/pdf/1608.06993
"""
import numpy as np
from six.moves import xrange  # pylint: disable=redefined-builtin
import tensorflow as tf
from models import model as model_lib
from tfm_constants import convert_flattened_densenet121_deps

class Densenet121Model(model_lib.CNNModel):
    """Densenet cnn network configuration."""

    def __init__(self, model_name, params=None):
        super(Densenet121Model, self).__init__(
            model_name, 224, 256, 0.1, params=params)
        self.deps = convert_flattened_densenet121_deps(params.deps)

    def add_inference(self, cnn):
        cnn.set_whether_use_batch_norm_by_default(False)
        cnn.set_default_batch_norm_config(decay=0.9, epsilon=1.1e-5, scale=True)
        nb_blocks_in_stage = [6, 12, 24, 16]  # For DenseNet-121

        def conv_block(x, stage_idx, branch, filters):
            conv_name_base = 'conv' + str(stage_idx) + '_' + str(branch)
            inter_channel, out_channel = filters
            # 1x1 Convolution (Bottleneck layer)
            x = cnn.batch_norm(input_layer=x, name=conv_name_base + '_x1_bn')
            x = cnn.relu(x)
            x = cnn.conv(num_out_channels=inter_channel, k_height=1, k_width=1, d_height=1, d_width=1, mode='SAME',
                input_layer=x, activation=None, bias=None, name=conv_name_base + '_x1')

            x = cnn.batch_norm(input_layer=x, name=conv_name_base + '_x2_bn')
            x = cnn.relu(x)
            x = cnn.conv(num_out_channels=out_channel, k_height=3, k_width=3, d_height=1, d_width=1, mode='SAME',
                input_layer=x, activation=None, bias=None, name=conv_name_base + '_x2')
            return x

        def dense_block(x, stage_idx, num_blocks, num_filters_list):
            concat_feat = x
            for i in range(num_blocks):
                branch = i + 1
                x = conv_block(concat_feat, stage_idx=stage_idx, branch=branch, filters=num_filters_list[i])
                concat_feat = cnn.channel_concat([concat_feat, x])
            return concat_feat

        def transition_block(x, stage_idx, num_filters):
            conv_name_base = 'conv' + str(stage_idx) + '_blk'
            x = cnn.batch_norm(input_layer=x, name=conv_name_base + '_bn')
            x = cnn.relu(x)
            x = cnn.conv(num_out_channels=num_filters, k_height=1, k_width=1, d_height=1, d_width=1, mode='SAME',
                input_layer=x, activation=None, bias=None, name=conv_name_base)
            x = cnn.apool(2, 2, 2, 2, input_layer=x)
            return x

        cnn.conv(num_out_channels=self.deps[0], k_height=7, k_width=7, d_height=2, d_width=2,
            mode='VALID', activation='relu', use_batch_norm=True, specify_padding=3, name='conv1', bias=None)
        cnn.pad2d(1)
        cnn.mpool(3, 3, 2, 2)

        output = cnn.top_layer
        for stage in range(2, 5):
            output = dense_block(output, stage, nb_blocks_in_stage[stage - 2], self.deps[stage * 2 - 3])
            output = transition_block(output, stage, self.deps[stage * 2 - 2])
        output = dense_block(output, 5, nb_blocks_in_stage[3], self.deps[7])

        cnn.batch_norm(input_layer=output, name='stage5_blk_bn')
        cnn.relu()
        cnn.spatial_mean()

    #
    #
    #
    # def add_inference(self, cnn):
    #     modelbuilder = DenseNet121Builder(training=cnn.phase_train)
    #     cnn.top_layer = modelbuilder.build(cnn.top_layer)
    #     cnn.top_size = 1024
    #     assert cnn.top_size == 1024

    def get_learning_rate(self, global_step, batch_size):
        num_batches_per_epoch = int(1281167 / batch_size)
        boundaries = num_batches_per_epoch * np.array([30, 60, 90, 120],
            dtype=np.int64)
        boundaries = [x for x in boundaries]
        values = [0.1, 0.01, 0.001, 0.0001, 0.00001]
        return tf.train.piecewise_constant(global_step, boundaries, values)


class LRUDensenet121Model(Densenet121Model):

    def add_inference(self, cnn):
        print('LRU Densenet121 add inference!')
        modelbuilder = LRUDenseNet121Builder(training=cnn.phase_train)
        cnn.top_layer = modelbuilder.build(cnn.top_layer)
        cnn.top_size = 1024
        assert cnn.top_size == 1024
        assert False


class DensenetCifar10Model(model_lib.CNNModel):
    """Densenet cnn network configuration."""

    def __init__(self, model, layer_counts, params):
        super(DensenetCifar10Model, self).__init__(
            model, 32, 64, 0.1, layer_counts=layer_counts, params=params)
        self.batch_norm_config = {'decay': 0.9, 'epsilon': 1e-5, 'scale': True}
        self.deps = params.deps

    def dense_block(self, cnn, num_filters):
        input_layer = cnn.top_layer
        c = cnn.batch_norm(input_layer, **self.batch_norm_config)
        c = tf.nn.relu(c)
        c = cnn.conv(num_filters, 3, 3, 1, 1, activation=None, input_layer=c, bias=None)
        channel_index = 3 if cnn.channel_pos == 'channels_last' else 1
        cnn.top_layer = tf.concat([input_layer, c], channel_index)

    def transition_layer(self, cnn, num_filters):
        cnn.batch_norm(**self.batch_norm_config)
        cnn.top_layer = tf.nn.relu(cnn.top_layer)
        cnn.conv(num_filters, 1, 1, 1, 1, activation='relu', bias=None)
        cnn.apool(2, 2, 2, 2)

    def add_inference(self, cnn):
        cnn.set_whether_use_batch_norm_by_default(False)
        cnn.conv(self.deps[0], 3, 3, 1, 1, activation=None, name='conv0', bias=None)
        # Block 1
        with tf.variable_scope('stage1'):
            for i in range(self.layer_counts[0]):
                self.dense_block(cnn, self.deps[1 + i])
        with tf.variable_scope('transition1'):
            self.transition_layer(cnn, self.deps[self.layer_counts[0] + 1])
        # Block 2
        with tf.variable_scope('stage2'):
            for i in range(self.layer_counts[1]):
                self.dense_block(cnn, self.deps[self.layer_counts[0] + 2 + i])
        with tf.variable_scope('transition2'):
            self.transition_layer(cnn, self.deps[self.layer_counts[0] + self.layer_counts[1] + 2])
        # Block 3
        with tf.variable_scope('stage3'):
            for i in range(self.layer_counts[2]):
                self.dense_block(cnn, self.deps[self.layer_counts[0] + self.layer_counts[1] + 3 + i])

        cnn.batch_norm(**self.batch_norm_config)
        cnn.top_layer = tf.nn.relu(cnn.top_layer)
        channel_index = 3 if cnn.channel_pos == 'channels_last' else 1
        cnn.top_size = cnn.top_layer.get_shape().as_list()[channel_index]
        cnn.spatial_mean()

    def get_learning_rate(self, global_step, batch_size):
        num_batches_per_epoch = int(50000 / batch_size)
        boundaries = num_batches_per_epoch * np.array([150, 225, 300],
            dtype=np.int64)
        boundaries = [x for x in boundaries]
        values = [0.1, 0.01, 0.001, 0.0001]
        return tf.train.piecewise_constant(global_step, boundaries, values)


def create_densenet40_k12_model(params):
    return DensenetCifar10Model('densenet40_k12', (12, 12, 12), params=params)


# def create_densenet100_k12_model():
#     return DensenetCifar10Model('densenet100_k12', (32, 32, 32), 12)
#
#
# def create_densenet100_k24_model():
#     return DensenetCifar10Model('densenet100_k24', (32, 32, 32), 24)

def create_densenet121_model(params):
    return Densenet121Model('densenet121', params=params)

def create_lrudensenet121_model(params):
    return LRUDensenet121Model('lru-densenet121', params=params)
