import numpy as np
from tf_utils import *

def eqcls_indexes_to_delete(eqcls):
    result = []
    for eqc in eqcls:
        eqcl = list(sorted(eqc))
        result += eqcl[1:]
    return result

def num_filters_in_eqcls(eqcls):
    num = 0
    max_idx = 0
    for eqc in eqcls:
        num += len(eqc)
        max_idx = max(max_idx, max(eqc))
    assert max_idx == num - 1
    return num

def shift_eqcls(eqcls, offset):
    result = []
    for eqc in eqcls:
        new_eqc = [offset + e for e in eqc]
        result.append(new_eqc)
    return result

def calculate_bn_eqcls_dc40(conv_layer_to_eqcls):
    bn_layer_idx_to_eqcls = {0 : list(conv_layer_to_eqcls[0])}
    def calc_bn_eqcls(layer_range):
        for i in layer_range:
            last_layer_eqcls = bn_layer_idx_to_eqcls[i - 1]
            num_filters_in_last_layer_eqcls = num_filters_in_eqcls(last_layer_eqcls)
            cur_layer_eqcls = last_layer_eqcls + shift_eqcls(list(conv_layer_to_eqcls[i]),
                offset=num_filters_in_last_layer_eqcls)
            bn_layer_idx_to_eqcls[i] = cur_layer_eqcls
    calc_bn_eqcls(range(1, 13))
    bn_layer_idx_to_eqcls[13] = list(conv_layer_to_eqcls[13])
    calc_bn_eqcls(range(14, 26))
    bn_layer_idx_to_eqcls[26] = list(conv_layer_to_eqcls[26])
    calc_bn_eqcls(range(27, 39))
    return bn_layer_idx_to_eqcls

def calculate_bn_eqcls_dense121(conv_layer_to_eqcls):
    bn_layer_idx_to_eqcls = {0 : list(conv_layer_to_eqcls[0])}
    # bn_layer_idx_to_eqcls[1] = list(conv_layer_to_eqcls[0])
    def calc_bn_eqcls_stage(start_bn_layer_idx, num_blocks_in_stage):
        for i in range(num_blocks_in_stage):
            first_bn_layer_idx = start_bn_layer_idx + 2 * i
            second_bn_layer_idx = first_bn_layer_idx + 1
            if i == 0:
                bn_layer_idx_to_eqcls[first_bn_layer_idx] = list(conv_layer_to_eqcls[first_bn_layer_idx - 1])
            else:
                last_first_bn_layer_eqcls = bn_layer_idx_to_eqcls[first_bn_layer_idx - 2]
                num_filters_in_last_bn = num_filters_in_eqcls(last_first_bn_layer_eqcls)
                bn_layer_idx_to_eqcls[first_bn_layer_idx] = last_first_bn_layer_eqcls + shift_eqcls(list(conv_layer_to_eqcls[first_bn_layer_idx - 1]),
                    offset=num_filters_in_last_bn)
            bn_layer_idx_to_eqcls[second_bn_layer_idx] = list(conv_layer_to_eqcls[second_bn_layer_idx - 1])
        #   the last 'blk' bn
        blk_bn_layer_idx = second_bn_layer_idx + 1
        last_first_bn_layer_eqcls = bn_layer_idx_to_eqcls[blk_bn_layer_idx - 2]
        num_filters_in_last_bn = num_filters_in_eqcls(last_first_bn_layer_eqcls)
        bn_layer_idx_to_eqcls[blk_bn_layer_idx] = last_first_bn_layer_eqcls + shift_eqcls(
            list(conv_layer_to_eqcls[blk_bn_layer_idx - 1]),
            offset=num_filters_in_last_bn)

    calc_bn_eqcls_stage(1, 6)   #1,2; 3,4; 5,6; 7,8; 9,10; 11,12;   13;
    calc_bn_eqcls_stage(14, 12)    # 14 = 1 + 6*2 + 1
    calc_bn_eqcls_stage(39, 24)           # 39 = 14 + 12*2 + 1
    calc_bn_eqcls_stage(88, 16)           # = 39 + 24 * 2 + 1
    return bn_layer_idx_to_eqcls


def calculate_bn_eqcls_wrnc16(conv_layer_to_eqcls):
    bn_follows = [0,0,2,1,4,1,1,7,6,9,6,6,12,11,14,11]
    result = {}
    for i, f in enumerate(bn_follows):
        if f in conv_layer_to_eqcls:
            result[i] = conv_layer_to_eqcls[f]
    return result





def tfm_prune_merged_filters_and_save_dense121(model, conv_layer_to_eqcls, bn_layer_to_eqcls, save_file, new_deps=None,
                                               st_decay_on_vecs=None):

    kernel_tensors = model.get_kernel_variables()
    mu_tensors = model.get_moving_mean_variables()
    var_tensors = model.get_moving_variance_variables()
    beta_tensors = model.get_beta_variables()
    gamma_tensors = model.get_gamma_variables()

    result = {}

    #   step 1, prune all the conv layers, NO NEED to adjust following layers
    for layer_idx, eqcls in conv_layer_to_eqcls.items():
        kernel = kernel_tensors[layer_idx]
        kv = model.get_value(kernel)
        conv_idxes_to_delete = eqcls_indexes_to_delete(eqcls)
        pruned_kv = delete_or_keep(kv, idxes=conv_idxes_to_delete, axis=3)
        result[kernel.name] = pruned_kv

    #   step 3, prune bn layers and re-construct the following conv layers (if exists, i.e., except the first bn layer)
    for i in range(0, 121):
        bn_eqcls = bn_layer_to_eqcls[i]
        mu = mu_tensors[i]
        var = var_tensors[i]
        beta = beta_tensors[i]
        gamma = gamma_tensors[i]
        bn_eqcls_to_delete = eqcls_indexes_to_delete(bn_eqcls)

        result[mu.name] = delete_or_keep(model.get_value(mu), idxes=bn_eqcls_to_delete)
        result[var.name] = delete_or_keep(model.get_value(var), idxes=bn_eqcls_to_delete)
        result[beta.name] = delete_or_keep(model.get_value(beta), idxes=bn_eqcls_to_delete)
        gamma_value = model.get_value(gamma)
        result[gamma.name] = delete_or_keep(gamma_value, idxes=bn_eqcls_to_delete)

        if i > 0 and i < 120:
            follow_kernel = kernel_tensors[i]
            follow_kernel_value = result[follow_kernel.name]
            for eqcl in bn_eqcls:
                if len(eqcl) == 1:
                    continue
                eqc = np.array(sorted(eqcl))

                assert st_decay_on_vecs is not None
                if st_decay_on_vecs == 'cancelbeta':
                    k_bar = 0.0
                    for fi in eqcl:
                        k_bar += gamma_value[fi] * follow_kernel_value[:, :, fi, :]
                    follow_kernel_value[:, :, eqcl[0], :] = k_bar / gamma_value[eqcl[0]]
                else:
                    selected_k_follow = follow_kernel_value[:, :, eqc, :]
                    aggregated_k_follow = np.sum(selected_k_follow, axis=2)
                    follow_kernel_value[:, :, eqc[0], :] = aggregated_k_follow

            result[follow_kernel.name] = delete_or_keep(follow_kernel_value, idxes=bn_eqcls_to_delete, axis=2)

    #   step 4, deal with the fc layer
    fc_kernel = kernel_tensors[-1]
    fc_value = model.get_value(fc_kernel)
    fc_indexes_to_delete = []
    origin_last_bn_width = num_filters_in_eqcls(bn_layer_to_eqcls[120])
    corresponding_neurons_per_kernel = fc_value.shape[0] // origin_last_bn_width
    base = np.arange(0, corresponding_neurons_per_kernel * origin_last_bn_width, origin_last_bn_width)
    for eqcl in bn_layer_to_eqcls[120]:
        if len(eqcl) == 1:
            continue
        se = sorted(eqcl)
        for i in se[1:]:
            fc_indexes_to_delete.append(base + i)
        to_concat = []
        for i in se:
            corresponding_neurons_idxes = base + i
            to_concat.append(np.expand_dims(fc_value[corresponding_neurons_idxes, :], axis=0))
        merged = np.sum(np.concatenate(to_concat, axis=0), axis=0)
        reserved_idxes = base + se[0]
        fc_value[reserved_idxes, :] = merged
    if len(fc_indexes_to_delete) > 0:
        fc_value = delete_or_keep(fc_value, np.concatenate(fc_indexes_to_delete, axis=0), axis=0)
    result[fc_kernel.name] = fc_value
    key_variables = model.get_key_variables()
    for var in key_variables:
        if var.name not in result:
            result[var.name] = model.get_value(var)
    if new_deps is not None:
        result['deps'] = new_deps
    print('save {} varialbes to {} after pruning merged filters'.format(len(result), save_file))
    if save_file.endswith('npy'):
        np.save(save_file, result)
    else:
        save_hdf5(result, save_file)


def tfm_prune_merged_filters_and_save_dc40(model, conv_layer_to_eqcls, bn_layer_to_eqcls, save_file, new_deps=None):

    kernel_tensors = model.get_kernel_variables()
    mu_tensors = model.get_moving_mean_variables()
    var_tensors = model.get_moving_variance_variables()
    beta_tensors = model.get_beta_tensors()
    gamma_tensors = model.get_gamma_tensors()
    assert len(gamma_tensors) == 39
    assert len(conv_layer_to_eqcls) == 39

    result = {}

    #   step 1, prune all the conv layers, NO NEED to adjust following layers
    for layer_idx, eqcls in conv_layer_to_eqcls.items():
        kernel = kernel_tensors[layer_idx]
        kv = model.get_value(kernel)
        conv_idxes_to_delete = eqcls_indexes_to_delete(eqcls)
        pruned_kv = delete_or_keep(kv, idxes=conv_idxes_to_delete, axis=3)
        result[kernel.name] = pruned_kv

        #   step 3, prune bn layers and re-construct the following conv layers (if exists)
    for i in range(0, 39):
        bn_eqcls = bn_layer_to_eqcls[i]
        mu = mu_tensors[i]
        var = var_tensors[i]
        beta = beta_tensors[i]
        gamma = gamma_tensors[i]
        bn_eqcls_to_delete = eqcls_indexes_to_delete(bn_eqcls)

        result[mu.name] = delete_or_keep(model.get_value(mu), idxes=bn_eqcls_to_delete)
        result[var.name] = delete_or_keep(model.get_value(var), idxes=bn_eqcls_to_delete)
        result[beta.name] = delete_or_keep(model.get_value(beta), idxes=bn_eqcls_to_delete)
        result[gamma.name] = delete_or_keep(model.get_value(gamma), idxes=bn_eqcls_to_delete)

        if i < 38:
            follow_kernel = kernel_tensors[i + 1]
            follow_kernel_value = result[follow_kernel.name]
            for eqcl in bn_eqcls:
                if len(eqcl) == 1:
                    continue
                eqc = np.array(sorted(eqcl))
                selected_k_follow = follow_kernel_value[:, :, eqc, :]
                aggregated_k_follow = np.sum(selected_k_follow, axis=2)
                follow_kernel_value[:, :, eqc[0], :] = aggregated_k_follow
            result[follow_kernel.name] = delete_or_keep(follow_kernel_value, idxes=bn_eqcls_to_delete, axis=2)


    #   step 4, deal with the fc layer
    fc_kernel = kernel_tensors[-1]
    fc_value = model.get_value(fc_kernel)
    fc_indexes_to_delete = []
    # assert kvf.shape[0] % kernel_value.shape[3] == 0
    origin_last_bn_width = num_filters_in_eqcls(bn_layer_to_eqcls[38])
    corresponding_neurons_per_kernel = fc_value.shape[0] // origin_last_bn_width
    base = np.arange(0, corresponding_neurons_per_kernel * origin_last_bn_width, origin_last_bn_width)
    for eqcl in bn_layer_to_eqcls[38]:
        if len(eqcl) == 1:
            continue
        se = sorted(eqcl)
        for i in se[1:]:
            fc_indexes_to_delete.append(base + i)
        to_concat = []
        for i in se:
            corresponding_neurons_idxes = base + i
            to_concat.append(np.expand_dims(fc_value[corresponding_neurons_idxes, :], axis=0))
        merged = np.sum(np.concatenate(to_concat, axis=0), axis=0)
        reserved_idxes = base + se[0]
        fc_value[reserved_idxes, :] = merged
    if len(fc_indexes_to_delete) > 0:
        fc_value = delete_or_keep(fc_value, np.concatenate(fc_indexes_to_delete, axis=0), axis=0)
    result[fc_kernel.name] = fc_value
    key_variables = model.get_key_variables()
    for var in key_variables:
        if var.name not in result:
            result[var.name] = model.get_value(var)
    if new_deps is not None:
        result['deps'] = new_deps
    print('save {} varialbes to {} after pruning merged filters'.format(len(result), save_file))
    if save_file.endswith('npy'):
        np.save(save_file, result)
    else:
        save_hdf5(result, save_file)


def tfm_prune_merged_filters_and_save_wrn16(model, conv_layer_to_eqcls, bn_layer_to_eqcls, save_file, new_deps=None):

    kernel_tensors = model.get_kernel_variables()
    mu_tensors = model.get_moving_mean_variables()
    var_tensors = model.get_moving_variance_variables()
    beta_tensors = model.get_beta_variables()
    gamma_tensors = model.get_gamma_variables()
    assert len(gamma_tensors) == 16
    # assert len(conv_layer_to_eqcls) == 16

    result = {}

    #   step 1, prune all the conv layers, NO NEED to adjust following layers
    for layer_idx, eqcls in conv_layer_to_eqcls.items():
        kernel = kernel_tensors[layer_idx]
        kv = model.get_value(kernel)
        conv_idxes_to_delete = eqcls_indexes_to_delete(eqcls)
        pruned_kv = delete_or_keep(kv, idxes=conv_idxes_to_delete, axis=3)
        result[kernel.name] = pruned_kv

        #   step 3, prune bn layers and re-construct the following conv layers (if exists)
    for i in range(0, 16):
        if i not in bn_layer_to_eqcls:
            continue
        bn_eqcls = bn_layer_to_eqcls[i]
        mu = mu_tensors[i]
        var = var_tensors[i]
        beta = beta_tensors[i]
        gamma = gamma_tensors[i]
        bn_eqcls_to_delete = eqcls_indexes_to_delete(bn_eqcls)

        result[mu.name] = delete_or_keep(model.get_value(mu), idxes=bn_eqcls_to_delete)
        result[var.name] = delete_or_keep(model.get_value(var), idxes=bn_eqcls_to_delete)
        result[beta.name] = delete_or_keep(model.get_value(beta), idxes=bn_eqcls_to_delete)
        result[gamma.name] = delete_or_keep(model.get_value(gamma), idxes=bn_eqcls_to_delete)

        if i <= 14:
            follow_kernel = kernel_tensors[i + 1]
            follow_kernel_value = result[follow_kernel.name]
            for eqcl in bn_eqcls:
                if len(eqcl) == 1:
                    continue
                eqc = np.array(sorted(eqcl))
                selected_k_follow = follow_kernel_value[:, :, eqc, :]
                aggregated_k_follow = np.sum(selected_k_follow, axis=2)
                follow_kernel_value[:, :, eqc[0], :] = aggregated_k_follow
            result[follow_kernel.name] = delete_or_keep(follow_kernel_value, idxes=bn_eqcls_to_delete, axis=2)


    #   step 4, deal with the fc layer
    fc_kernel = kernel_tensors[-1]
    fc_value = model.get_value(fc_kernel)
    fc_indexes_to_delete = []
    # assert kvf.shape[0] % kernel_value.shape[3] == 0
    origin_last_bn_width = num_filters_in_eqcls(bn_layer_to_eqcls[15])
    corresponding_neurons_per_kernel = fc_value.shape[0] // origin_last_bn_width
    base = np.arange(0, corresponding_neurons_per_kernel * origin_last_bn_width, origin_last_bn_width)
    for eqcl in bn_layer_to_eqcls[15]:
        if len(eqcl) == 1:
            continue
        se = sorted(eqcl)
        for i in se[1:]:
            fc_indexes_to_delete.append(base + i)
        to_concat = []
        for i in se:
            corresponding_neurons_idxes = base + i
            to_concat.append(np.expand_dims(fc_value[corresponding_neurons_idxes, :], axis=0))
        merged = np.sum(np.concatenate(to_concat, axis=0), axis=0)
        reserved_idxes = base + se[0]
        fc_value[reserved_idxes, :] = merged
    if len(fc_indexes_to_delete) > 0:
        fc_value = delete_or_keep(fc_value, np.concatenate(fc_indexes_to_delete, axis=0), axis=0)
    result[fc_kernel.name] = fc_value
    key_variables = model.get_key_variables()
    for var in key_variables:
        if var.name not in result:
            result[var.name] = model.get_value(var)
    if new_deps is not None:
        result['deps'] = new_deps
    print('save {} varialbes to {} after pruning merged filters'.format(len(result), save_file))
    if save_file.endswith('npy'):
        np.save(save_file, result)
    else:
        save_hdf5(result, save_file)


#   assume that the filters have been merged and the following variables have been adjusted
def tfm_prune_merged_filters_and_save(model, layer_to_eqcls, save_file, fc_layer_idxes,
                                      subsequent_strategy, layer_idx_to_follow_offset={},
                                      fc_neurons_per_kernel=None, new_deps=None,
                                      st_decay_on_vecs=None):
    result = dict()
    number_filters_seen = 0
    num_filters_alike = 0

    if subsequent_strategy is None:
        subsequent_map = None
    elif subsequent_strategy == 'simple':
        subsequent_map = {idx : (idx+1) for idx in layer_to_eqcls.keys()}
    else:
        subsequent_map = subsequent_strategy
    if type(fc_layer_idxes) is not list:
        fc_layer_idxes = [fc_layer_idxes]

    kernels = [v for v in model.get_kernel_variables() if v.name.startswith('v0')]

    for layer_idx, eqcls in layer_to_eqcls.items():

        kernel_tensor = kernels[layer_idx]
        print('cur kernel name:', kernel_tensor.name)
        num_filters = kernel_tensor.get_shape().as_list()[3]
        bias_tensor = model.get_bias_variable_for_kernel(kernel_tensor)
        beta_tensor = model.get_beta_variable_for_kernel(kernel_tensor)
        gamma_tensor = model.get_gamma_variable_for_kernel(kernel_tensor)
        moving_mean_tensor = model.get_moving_mean_variable_for_kernel(kernel_tensor)
        moving_variance_tensor = model.get_moving_variance_variable_for_kernel(kernel_tensor)

        beta_value = model.get_value(beta_tensor)
        if beta_value is None:
            beta_value = np.zeros(num_filters, dtype=np.float32)
        gamma_value = model.get_value(gamma_tensor)
        if gamma_value is None:
            gamma_value = np.ones(num_filters, dtype=np.float32)

        if kernel_tensor.name in result:
            kernel_value = result[kernel_tensor.name]
        else:
            kernel_value = model.get_value(kernel_tensor)

        if subsequent_map is None or layer_idx not in subsequent_map:
            indexes_to_delete = []
            for eqcl in eqcls:
                number_filters_seen += len(eqcl)
                if len(eqcl) == 1:
                    continue
                num_filters_alike += len(eqcl)
                indexes_to_delete += eqcl[1:]
        else:
            follows = subsequent_map[layer_idx]
            print('{} follows {}'.format(follows, layer_idx))
            if type(follows) is not list:
                follows = [follows]

            for follow_idx in follows:
                follow_kernel_tensor = kernels[follow_idx]
                if follow_kernel_tensor.name in result:
                    kvf = result[follow_kernel_tensor.name]
                else:
                    kvf = model.get_value(follow_kernel_tensor)
                print('following kernel name: ', follow_kernel_tensor.name, 'origin shape: ', kvf.shape)

                if follow_idx in fc_layer_idxes:

                    offset = layer_idx_to_follow_offset.get(layer_idx, 0)
                    if offset > 0:
                        print('offset,',offset)
                    conv_indexes_to_delete = []
                    fc_indexes_to_delete = []
                    # assert kvf.shape[0] % kernel_value.shape[3] == 0

                    if fc_neurons_per_kernel is None:
                        conv_deps = kernel_value.shape[3] + offset
                        corresponding_neurons_per_kernel = kvf.shape[0] // conv_deps
                    else:
                        corresponding_neurons_per_kernel=fc_neurons_per_kernel
                        conv_deps = kvf.shape[0] // corresponding_neurons_per_kernel
                    print('total conv deps:', conv_deps, corresponding_neurons_per_kernel, 'neurons per kernel')

                    base = np.arange(offset, corresponding_neurons_per_kernel*conv_deps+offset, conv_deps)
                    for eqcl in eqcls:
                        number_filters_seen += len(eqcl)
                        if len(eqcl) == 1:
                            continue
                        num_filters_alike += len(eqcl)
                        conv_indexes_to_delete += eqcl[1:]
                        for i in eqcl[1:]:
                            fc_indexes_to_delete.append(base + i)
                        to_concat = []
                        for i in eqcl:
                            corresponding_neurons_idxes = base + i
                            to_concat.append(np.expand_dims(kvf[corresponding_neurons_idxes, :], axis=0))
                        merged = np.sum(np.concatenate(to_concat, axis=0), axis=0)
                        reserved_idxes = base + eqcl[0]
                        kvf[reserved_idxes, :] = merged
                    if len(fc_indexes_to_delete) > 0:
                        kvf = delete_or_keep(kvf, np.concatenate(fc_indexes_to_delete, axis=0), axis=0)
                        result[follow_kernel_tensor.name] = kvf
                        print('shape of pruned following kernel: ', kvf.shape)
                    indexes_to_delete = conv_indexes_to_delete
                else:
                    offset = layer_idx_to_follow_offset.get(layer_idx, 0)
                    indexes_to_delete = []
                    for eqcl in eqcls:
                        number_filters_seen += len(eqcl)
                        if len(eqcl) == 1:
                            continue
                        num_filters_alike += len(eqcl)
                        indexes_to_delete += eqcl[1:]

                        #
                        # selected_my_gamma = gamma_value[eqc+offset]
                        # selected_my_beta = beta_value[eqc+offset]

                        assert st_decay_on_vecs is not None
                        if st_decay_on_vecs == 'cancelbeta':
                            k_bar = 0.0
                            for fi in eqcl:
                                k_bar += gamma_value[fi + offset] * kvf[:, :, fi + offset, :]
                            kvf[:, :, eqcl[0] + offset, :] = k_bar / gamma_value[eqcl[0] + offset]
                        else:
                            eqc = np.array(eqcl)
                            selected_k_follow = kvf[:, :, eqc + offset, :]
                            aggregated_k_follow = np.sum(selected_k_follow, axis=2)
                            kvf[:, :, eqcl[0] + offset, :] = aggregated_k_follow

                    if 'depth' in follow_kernel_tensor.name:
                        print('skip adding up and pruning the following layer, because it is a depthwise layer')
                    else:
                        follow_indexes_to_delete = [offset + p for p in indexes_to_delete]
                        kvf = delete_or_keep(kvf, follow_indexes_to_delete, axis=2)
                        result[follow_kernel_tensor.name] = kvf
                        print('shape of pruned following kernel: ', kvf.shape)
        if 'depth' in kernel_tensor.name:
            kernel_value_after_pruned = delete_or_keep(kernel_value, indexes_to_delete, axis=2)
        else:
            kernel_value_after_pruned = delete_or_keep(kernel_value, indexes_to_delete, axis=3)
        result[kernel_tensor.name] = kernel_value_after_pruned

        # TODO compatible with old rc modls, do not know why
        # if bias_tensor is not None:
        #     bias_value = delete_or_keep(model.get_value(bias_tensor), indexes_to_delete)
        #     result[bias_tensor.name] = bias_value

        if moving_mean_tensor is not None:
            moving_mean_value = delete_or_keep(model.get_value(moving_mean_tensor), indexes_to_delete)
            result[moving_mean_tensor.name] = moving_mean_value
        if moving_variance_tensor is not None:
            moving_variance_value = delete_or_keep(model.get_value(moving_variance_tensor), indexes_to_delete)
            result[moving_variance_tensor.name] = moving_variance_value
        if beta_tensor is not None:
            beta_value = delete_or_keep(beta_value, indexes_to_delete)
            result[beta_tensor.name] = beta_value
        if gamma_tensor is not None:
            gamma_value = delete_or_keep(gamma_value, indexes_to_delete)
            result[gamma_tensor.name] = gamma_value
        print('kernel name: ', kernel_tensor.name)
        print(
            'removed merged filters. {} filters seen. {} filters alike. shape of origin kernel {}, shape of pruned kernel {}'
            .format(number_filters_seen, num_filters_alike, kernel_value.shape, kernel_value_after_pruned.shape))

    key_variables = [v for v in model.get_key_variables() if v.name.startswith('v0')]
    for var in key_variables:
        if var.name not in result:
            result[var.name] = model.get_value(var)
    if new_deps is not None:
        result['deps'] = new_deps
    print('save {} varialbes to {} after pruning merged filters'.format(len(result), save_file))
    if save_file.endswith('npy'):
        np.save(save_file, result)
    else:
        save_hdf5(result, save_file)


def tfm_prune_merged_filters_and_save_for_alexnet_conv5(model, layer_to_eqcls, save_file, fc_layer_idxes,
                                                        subsequent_strategy, layer_idx_to_follow_offset={},
                                                        fc_neurons_per_kernel=None):
    result = dict()
    number_filters_seen = 0
    num_filters_alike = 0

    if subsequent_strategy is None:
        subsequent_map = None
    elif subsequent_strategy == 'simple':
        subsequent_map = {idx : (idx+1) for idx in layer_to_eqcls.keys()}
    else:
        subsequent_map = subsequent_strategy
    if type(fc_layer_idxes) is not list:
        fc_layer_idxes = [fc_layer_idxes]

    kernels = model.get_kernel_variables()

    for layer_idx, eqcls in layer_to_eqcls.items():
        kernel_tensor = kernels[layer_idx]
        print('cur kernel name:', kernel_tensor.name)
        bias_tensor = model.get_bias_variable_for_kernel(layer_idx)
        beta_tensor = model.get_beta_variable_for_kernel(layer_idx)
        gamma_tensor = model.get_gamma_variable_for_kernel(layer_idx)
        moving_mean_tensor = model.get_moving_mean_variable_for_kernel(layer_idx)
        moving_variance_tensor = model.get_moving_variance_variable_for_kernel(layer_idx)


        if kernel_tensor.name in result:
            kernel_value = result[kernel_tensor.name]
        else:
            kernel_value = model.get_value(kernel_tensor)

        if subsequent_map is None or layer_idx not in subsequent_map:
            indexes_to_delete = []
            for eqcl in eqcls:
                number_filters_seen += len(eqcl)
                if len(eqcl) == 1:
                    continue
                num_filters_alike += len(eqcl)
                indexes_to_delete += eqcl[1:]
        else:
            follows = subsequent_map[layer_idx]
            if type(follows) is not list:
                follows = [follows]
            for follow_idx in follows:
                follow_kernel_tensor = kernels[follow_idx]
                if follow_kernel_tensor.name in result:
                    kvf = result[follow_kernel_tensor.name]
                else:
                    kvf = model.get_value(follow_kernel_tensor)
                print('following kernel name: ', follow_kernel_tensor.name, 'origin shape: ', kvf.shape)
                if follow_idx in fc_layer_idxes:
                    offset = layer_idx_to_follow_offset.get(layer_idx, 0)
                    if offset > 0:
                        print('offset,',offset)
                    conv_indexes_to_delete = []
                    fc_indexes_to_delete = []
                    # assert kvf.shape[0] % kernel_value.shape[3] == 0

                    if fc_neurons_per_kernel is None:
                        conv_deps = kernel_value.shape[3] + offset
                        corresponding_neurons_per_kernel = kvf.shape[0] // conv_deps
                    else:
                        corresponding_neurons_per_kernel=fc_neurons_per_kernel
                        conv_deps = kvf.shape[0] // corresponding_neurons_per_kernel
                    print('total conv deps:', conv_deps, corresponding_neurons_per_kernel, 'neurons per kernel')

                    base = np.arange(offset, corresponding_neurons_per_kernel*conv_deps+offset, conv_deps)
                    for eqcl in eqcls:
                        number_filters_seen += len(eqcl)
                        if len(eqcl) == 1:
                            continue
                        num_filters_alike += len(eqcl)
                        conv_indexes_to_delete += eqcl[1:]
                        for i in eqcl[1:]:
                            fc_indexes_to_delete.append(base + i)
                        to_concat = []
                        for i in eqcl:
                            corresponding_neurons_idxes = base + i
                            to_concat.append(np.expand_dims(kvf[corresponding_neurons_idxes, :], axis=0))
                        merged = np.sum(np.concatenate(to_concat, axis=0), axis=0)
                        reserved_idxes = base + eqcl[0]
                        kvf[reserved_idxes, :] = merged
                    if len(fc_indexes_to_delete) > 0:
                        kvf = delete_or_keep(kvf, np.concatenate(fc_indexes_to_delete, axis=0), axis=0)
                    indexes_to_delete = conv_indexes_to_delete
                else:
                    offset = layer_idx_to_follow_offset.get(layer_idx, 0)
                    indexes_to_delete = []
                    for eqcl in eqcls:
                        number_filters_seen += len(eqcl)
                        if len(eqcl) == 1:
                            continue
                        num_filters_alike += len(eqcl)
                        indexes_to_delete += eqcl[1:]
                        eqc = np.array(eqcl)
                        selected_k_follow = kvf[:, :, eqc+offset, :]
                        aggregated_k_follow = np.sum(selected_k_follow, axis=2)
                        kvf[:, :, eqcl[0]+offset, :] = aggregated_k_follow
                    follow_indexes_to_delete = [offset + p for p in indexes_to_delete]
                    kvf = delete_or_keep(kvf, follow_indexes_to_delete, axis=2)
                if layer_idx != 8 and layer_idx !=9:
                    result[follow_kernel_tensor.name] = kvf
                    print('shape of pruned following kernel: ', kvf.shape)

        kernel_value_after_pruned = delete_or_keep(kernel_value, indexes_to_delete, axis=3)
        result[kernel_tensor.name] = kernel_value_after_pruned
        if bias_tensor is not None:
            bias_value = delete_or_keep(model.get_value(bias_tensor), indexes_to_delete)
            result[bias_tensor.name] = bias_value
        if moving_mean_tensor is not None:
            moving_mean_value = delete_or_keep(model.get_value(moving_mean_tensor), indexes_to_delete)
            result[moving_mean_tensor.name] = moving_mean_value
        if moving_variance_tensor is not None:
            moving_variance_value = delete_or_keep(model.get_value(moving_variance_tensor), indexes_to_delete)
            result[moving_variance_tensor.name] = moving_variance_value
        if beta_tensor is not None:
            beta_value = delete_or_keep(model.get_value(beta_tensor), indexes_to_delete)
            result[beta_tensor.name] = beta_value
        if gamma_tensor is not None:
            gamma_value = delete_or_keep(model.get_value(gamma_tensor), indexes_to_delete)
            result[gamma_tensor.name] = gamma_value
        print('kernel name: ', kernel_tensor.name)
        print(
            'removed merged filters. {} filters seen. {} filters alike. shape of origin kernel {}, shape of pruned kernel {}'
            .format(number_filters_seen, num_filters_alike, kernel_value.shape, kernel_value_after_pruned.shape))

    #   adjust fc
    conv5_joint_eqcls = []
    for eqcl in layer_to_eqcls[8]:
        conv5_joint_eqcls.append(np.array(eqcl))
    for eqcl in layer_to_eqcls[9]:
        conv5_joint_eqcls.append(np.array(eqcl) + layer_idx_to_follow_offset[9])
    fc_kernel = kernels[10]
    fc_value = model.get_value(fc_kernel)
    #   sum and delete
    fc_joint_indexes_to_delete = []
    corresponding_neurons_per_kernel = fc_neurons_per_kernel
    conv_deps = fc_value.shape[0] // corresponding_neurons_per_kernel
    base = np.arange(0, corresponding_neurons_per_kernel * conv_deps, conv_deps)
    for eqcl in conv5_joint_eqcls:
        if len(eqcl) == 1:
            continue
        for i in eqcl[1:]:
            fc_joint_indexes_to_delete.append(base + i)
        to_concat = []
        for i in eqcl:
            corresponding_neurons_idxes = base + i
            to_concat.append(np.expand_dims(fc_value[corresponding_neurons_idxes, :], axis=0))
        merged = np.sum(np.concatenate(to_concat, axis=0), axis=0)
        reserved_idxes = base + eqcl[0]
        fc_value[reserved_idxes, :] = merged
    if len(fc_indexes_to_delete) > 0:
        fc_value = delete_or_keep(fc_value, np.concatenate(fc_joint_indexes_to_delete, axis=0), axis=0)
    result[fc_kernel.name] = fc_value

    key_variables = model.get_key_variables()
    for var in key_variables:
        if var.name not in result:
            result[var.name] = model.get_value(var)

    print('save {} varialbes to {} after pruning merged filters'.format(len(result), save_file))
    if save_file.endswith('npy'):
        np.save(save_file, result)
    else:
        save_hdf5(result, save_file)


def delete_or_keep(array, idxes, axis=None):
    if len(idxes) > 0:
        return np.delete(array, idxes, axis=axis)
    else:
        return array

#
# #   assume that the filters have been merged and the following variables have been adjusted
# def tfm_prune_merged_filters_and_save(model, layer_to_eqcls, save_np_file, layer_idx_followed_by_fc, layer_idxes_not_adjust_follow):
#     result = dict()
#     number_filters_seen = 0
#     num_filters_alike = 0
#
#     kernels = model.get_kernel_variables()
#
#     for layer_idx, eqcls in layer_to_eqcls.items():
#         kernel_tensor = kernels[layer_idx]
#         bias_tensor = model.get_bias_variable_for_kernel(layer_idx)
#         beta_tensor = model.get_beta_variable_for_kernel(layer_idx)
#         gamma_tensor = model.get_gamma_variable_for_kernel(layer_idx)
#         moving_mean_tensor = model.get_moving_mean_variable_for_kernel(layer_idx)
#         moving_variance_tensor = model.get_moving_variance_variable_for_kernel(layer_idx)
#
#         if kernel_tensor.name in result:
#             kernel_value = result[kernel_tensor.name]
#         else:
#             kernel_value = model.get_value(kernel_tensor)
#
#         if layer_idx in layer_idxes_not_adjust_follow:
#             indexes_to_delete = []
#             for eqcl in eqcls:
#                 number_filters_seen += len(eqcl)
#                 if len(eqcl) == 1:
#                     continue
#                 num_filters_alike += len(eqcl)
#                 indexes_to_delete += eqcl[1:]
#         else:
#             follow_kernel_tensor = kernels[layer_idx + 1]
#             if follow_kernel_tensor.name in result:
#                 kvf = result[follow_kernel_tensor.name]
#             else:
#                 kvf = model.get_value(follow_kernel_tensor)
#
#             if layer_idx == layer_idx_followed_by_fc:  # last conv layer
#                 conv_indexes_to_delete = []
#                 fc_indexes_to_delete = []
#                 assert kvf.shape[0] % kernel_value.shape[3] == 0
#                 last_conv_origin_deps = kernel_value.shape[3]
#                 corresponding_neurons_per_kernel = kvf.shape[0] // kernel_value.shape[3]
#                 base = np.arange(0, corresponding_neurons_per_kernel) * last_conv_origin_deps
#                 for eqcl in eqcls:
#                     number_filters_seen += len(eqcl)
#                     if len(eqcl) == 1:
#                         continue
#                     num_filters_alike += len(eqcl)
#                     conv_indexes_to_delete += eqcl[1:]
#                     for i in eqcl[1:]:
#                         fc_indexes_to_delete.append(base + i)
#                     to_concat = []
#                     for i in eqcl:
#                         corresponding_neurons_idxes = base + i
#                         to_concat.append(np.expand_dims(kvf[corresponding_neurons_idxes, :], axis=0))
#                     merged = np.sum(np.concatenate(to_concat, axis=0), axis=0)
#                     reserved_idxes = base + eqcl[0]
#                     kvf[reserved_idxes, :] = merged
#                 if len(fc_indexes_to_delete) > 0:
#                     kvf = delete_or_keep(kvf, np.concatenate(fc_indexes_to_delete, axis=0), axis=0)
#                 indexes_to_delete = conv_indexes_to_delete
#             else:
#                 indexes_to_delete = []
#                 for eqcl in eqcls:
#                     number_filters_seen += len(eqcl)
#                     if len(eqcl) == 1:
#                         continue
#                     num_filters_alike += len(eqcl)
#                     indexes_to_delete += eqcl[1:]
#                     eqc = np.array(eqcl)
#                     selected_k_follow = kvf[:, :, eqc, :]
#                     aggregated_k_follow = np.sum(selected_k_follow, axis=2)
#                     kvf[:, :, eqcl[0], :] = aggregated_k_follow
#                 kvf = delete_or_keep(kvf, indexes_to_delete, axis=2)
#             result[follow_kernel_tensor.name] = kvf
#             print('following kernel name: ', follow_kernel_tensor.name)
#             print('shape of following kernel: ', kvf.shape)
#
#         kernel_value_after_pruned = delete_or_keep(kernel_value, indexes_to_delete, axis=3)
#         result[kernel_tensor.name] = kernel_value_after_pruned
#         if bias_tensor is not None:
#             bias_value = delete_or_keep(model.get_value(bias_tensor), indexes_to_delete)
#             result[bias_tensor.name] = bias_value
#         if moving_mean_tensor is not None:
#             moving_mean_value = delete_or_keep(model.get_value(moving_mean_tensor), indexes_to_delete)
#             result[moving_mean_tensor.name] = moving_mean_value
#         if moving_variance_tensor is not None:
#             moving_variance_value = delete_or_keep(model.get_value(moving_variance_tensor), indexes_to_delete)
#             result[moving_variance_tensor.name] = moving_variance_value
#         if beta_tensor is not None:
#             beta_value = delete_or_keep(model.get_value(beta_tensor), indexes_to_delete)
#             result[beta_tensor.name] = beta_value
#         if gamma_tensor is not None:
#             gamma_value = delete_or_keep(model.get_value(gamma_tensor), indexes_to_delete)
#             result[gamma_tensor.name] = gamma_value
#         print('kernel name: ', kernel_tensor.name)
#         print(
#             'removed merged filters. {} filters seen. {} filters alike. shape of origin kernel {}, shape of pruned kernel {}'
#             .format(number_filters_seen, num_filters_alike, kernel_value.shape, kernel_value_after_pruned.shape))
#
#     key_variables = model.get_key_variables()
#     for var in key_variables:
#         if var.name not in result:
#             result[var.name] = model.get_value(var)
#     print('save {} varialbes to {} after pruning merged filters'.format(len(result), save_np_file))
#     np.save(save_np_file, result)


#   items can be a list of tensors or a numpy array
def _weighted_mean(items, weights):
    assert len(items) == len(weights)
    a_weights = np.array(weights)
    weights_sum = np.sum(a_weights)
    if weights_sum == 0:
        normalized = np.zeros_like(a_weights)
    else:
        normalized = a_weights / weights_sum
    sum = 0
    for item, weight in zip(items, normalized):
        sum += item * weight
    return sum


def _get_np_array_from_dict(dict, keys):
    result = np.zeros((len(keys),))
    for i,k in enumerate(keys):
        result[i] = dict[k]
    return result

#   for 4-D kernels
def merge_kernel(target_k, eqcls, weights):
    num_filters = target_k.get_shape()[3]
    result_list = [0] * num_filters
    kernels_seen = set()
    for eqcl in eqcls:
        if len(eqcl) == 1:
            kernels_seen.add(eqcl[0])
            result_list[eqcl[0]] = tf.expand_dims(target_k[:,:,:,eqcl[0]], axis=3)
        else:
            eqcl_tensors = []
            eqcl_weights = []
            for e in eqcl:
                eqcl_tensors.append(tf.expand_dims(target_k[:,:,:,e], axis=3))
                eqcl_weights.append(weights[e])
                kernels_seen.add(e)
            mean = _weighted_mean(eqcl_tensors, eqcl_weights)
            for e in eqcl:
                result_list[e] = mean
    assert len(kernels_seen) == len(weights) and num_filters == len(kernels_seen)
    print('{} kernels merged, {} eqcls'.format(num_filters, len(eqcls)))
    return tf.concat(result_list, axis=3)

def merge_1d_tensor(target_t, eqcls, weights):
    num_tensors = target_t.get_shape()[0]
    result_list = [0] * num_tensors
    tensors_seen = set()
    for eqcl in eqcls:
        if len(eqcl) == 1:
            tensors_seen.add(eqcl[0])
            result_list[eqcl[0]] = tf.expand_dims(target_t[eqcl[0]], axis=0)
        else:
            eqcl_tensors = []
            eqcl_weights = []
            for e in eqcl:
                eqcl_tensors.append(target_t[e])
                eqcl_weights.append(weights[e])
                tensors_seen.add(e)
            mean = _weighted_mean(eqcl_tensors, eqcl_weights)
            for e in eqcl:
                result_list[e] = tf.expand_dims(mean, axis=0)
    assert len(tensors_seen) == len(weights) and num_tensors == len(tensors_seen)
    print('{} tensors merged, {} eqcls'.format(num_tensors, len(eqcls)))
    return tf.concat(result_list, axis=0)

def tfm_merge_gradients(model, origin_grads_and_vars, layer_to_eqcls, layer_to_weights):
    assert len(layer_to_eqcls) == len(layer_to_weights)
    kernels = model.get_kernel_variables()
    origin_g_to_merged_g = {}
    for layer_idx in layer_to_eqcls.keys():
        origin_kernel = tf_get_gradient_by_var(origin_grads_and_vars, kernels[layer_idx])
        origin_g_to_merged_g[origin_kernel] = merge_kernel(origin_kernel, layer_to_eqcls[layer_idx], layer_to_weights[layer_idx])

        bias_variable = model.get_bias_variable_for_kernel(layer_idx)
        if bias_variable is not None:
            origin_bias = tf_get_gradient_by_var(origin_grads_and_vars, bias_variable)
            origin_g_to_merged_g[origin_bias] = merge_1d_tensor(origin_bias, layer_to_eqcls[layer_idx],
                layer_to_weights[layer_idx])

        gamma_variable = model.get_gamma_variable_for_kernel(layer_idx)
        if gamma_variable is not None:
            origin_gamma = tf_get_gradient_by_var(origin_grads_and_vars, gamma_variable)
            origin_g_to_merged_g[origin_gamma] = merge_1d_tensor(origin_gamma, layer_to_eqcls[layer_idx],
                layer_to_weights[layer_idx])

        beta_variable = model.get_beta_variable_for_kernel(layer_idx)
        if beta_variable is not None:
            origin_beta = tf_get_gradient_by_var(origin_grads_and_vars, beta_variable)
            origin_g_to_merged_g[origin_beta] = merge_1d_tensor(origin_beta, layer_to_eqcls[layer_idx],
                layer_to_weights[layer_idx])

    result = []
    merged_cnt = 0
    for (g, v) in origin_grads_and_vars:
        if g in origin_g_to_merged_g:
            result.append((origin_g_to_merged_g[g], v))
            merged_cnt += 1
        else:
            result.append((g, v))
    print(merged_cnt, 'gradients merged')
    return result






def calculate_featuremap_weights_by_kernel_norm(kernel_value, exp_order):
    print('feature map weights by kernel norm, exp order ', exp_order)
    sum_array = np.sum(np.abs(kernel_value)**exp_order, axis=(0, 1, 2))
    print('kernel value shape ', kernel_value.shape, 'weights shape ', sum_array.shape)
    return sum_array

def calculate_featuremap_weights_by_mean_response(mean_response, exp_order):
    print('feature map weights by mean response, exp order ', exp_order)
    sum_array = np.sum(np.abs(mean_response)**exp_order, axis=0)
    print('mean response shape ', mean_response.shape, 'weights shape', sum_array.shape)
    return sum_array

def calculate_featuremap_weights_by_apop(apop):
    result = []
    nb = apop.shape[1]
    print('calculating APoP weights for {} featuremaps'.format(nb))
    for i in range(nb):
        weight = np.mean(apop[:,i])
        result.append(weight)
    # exp_sum = np.sum(np.exp(np.array(result)))
    # for i in range(nb):
    #     result[i] = math.exp(result[i]) / exp_sum
    return np.array(result)

def calculate_featuremap_weights(survey_outs):
    nb = survey_outs.shape[1]
    examples = survey_outs.shape[0]
    print('calculating weights for {} featuremaps'.format(nb))
    result = []
    for i in range(nb):
        # weight = (np.sum(survey_outs[:,i] > 0) + 0.0) / examples
        # weight = math.exp((np.sum(survey_outs[:,i] > 0) + 0.0) / examples)
        survey_outs = np.maximum(survey_outs, 0)
        weight = np.mean(survey_outs[:,i])
        result.append(weight)
    # exp_sum = np.sum(np.exp(np.array(result)))
    # for i in range(nb):
    #     result[i] = math.exp(result[i]) / exp_sum
    return np.array(result)


def _weighted_sum(array, a_weight):
    if array.ndim == 4:
        # print(np.sum(np.abs(array[:,:,:,0] - array[:,:,:,1])))
        tiled_weight = np.tile(a_weight, (array.shape[0], array.shape[1], array.shape[2], 1))
        assert array.shape == tiled_weight.shape
        return np.sum(array * tiled_weight, axis=3)
    elif array.ndim == 1:
        # print(np.sum(np.abs(array[0] - array[1])))
        return np.sum(array * a_weight)
    else:
        assert False

def _calculate_eqcl_normalized_weight(eqcl, filter_to_weight):
    filter_to_weight = np.array(filter_to_weight)
    result = np.zeros((len(eqcl),))
    for i, e in enumerate(eqcl):
        result[i] = filter_to_weight[e]
    w_sum = np.sum(result)
    # if w_sum != 0:
    #     result /= w_sum
    base = np.sum(np.square(result))
    if base != 0:
        result = result * w_sum / base
    return np.array(result)


def _adjust_following_conv_kernel(eqcls, kvf, filter_to_weight, offset=0):
    for eqcl in eqcls:
        if len(eqcl) == 1:
            continue
        normalized_weight = _calculate_eqcl_normalized_weight(eqcl, filter_to_weight)
        print(normalized_weight)
        off_eqcl = [offset + e for e in eqcl]
        for idx, e in enumerate(off_eqcl):
            kvf[:,:,e,:] *= normalized_weight[idx]
    print('adjusted following conv kernel, {} eqcls'.format(len(eqcls)))

def _adjust_following_fc_kernel(eqcls, kv, kvf, filter_to_weight, offset=0):
    assert kvf.shape[0] % kv.shape[3] == 0
    last_conv_origin_deps = kv.shape[3]
    corresponding_neurons_per_kernel = kvf.shape[0] // kv.shape[3]
    base = np.arange(0, corresponding_neurons_per_kernel) * last_conv_origin_deps

    for eqcl in eqcls:
        if len(eqcl) == 1:
            continue
        normalized_weight = _calculate_eqcl_normalized_weight(eqcl, filter_to_weight)
        off_eqcl = [offset + e for e in eqcl]
        for idx, e in enumerate(off_eqcl):
            corresponding_neurons_idxes = base + e
            kvf[corresponding_neurons_idxes, :] *= normalized_weight[idx]

    print('adjusted following fc kernel, {} eqcls'.format(len(eqcls)))





# merge kernels, moving_means, moving_variances, gammas, betas
# no biases are allowed!
def tfm_merge_filters_with_bn(model, eqcls, layer_idx, filter_to_weight=None, adjust_following_kernel=True, followed_by_fc=False):

    num_filters_seen = 0
    num_filters_alike = 0
    epsilon = 1e-12

    kernels = model.get_kernel_variables()

    origin_kernel = kernels[layer_idx]
    # TODO
    if 'depth' in origin_kernel.name:
        kernel = tf.transpose(origin_kernel, [0,1,3,2])
    else:
        kernel = origin_kernel
    bias = model.get_bias_variable_for_kernel(layer_idx)
    moving_mean = model.get_moving_mean_variable_for_kernel(layer_idx)
    moving_variance = model.get_moving_variance_variable_for_kernel(layer_idx)
    beta = model.get_beta_variable_for_kernel(layer_idx)
    gamma = model.get_gamma_variable_for_kernel(layer_idx)

    kv, biasv, muv, varv, gammav, betav = model.get_value([kernel, bias, moving_mean, moving_variance, gamma, beta])
    ndim = kv.shape[3]
    if biasv is None:
        biasv = np.zeros(ndim)
    if muv is None:
        muv = np.zeros(ndim)
    if varv is None:
        varv = np.ones(ndim)
    if gammav is None:
        gammav = np.ones(ndim)
    if betav is None:
        betav = np.zeros(ndim)

    t = gammav / (np.sqrt(varv) + epsilon)

    if filter_to_weight is None:
        filter_to_weight = [1.] * len(muv)
    for w in filter_to_weight:
        assert w >= 0

    for eqcl in eqcls:
        num_filters_seen += len(eqcl)
        if len(eqcl) == 1:
            continue
        num_filters_alike += len(eqcl)
        eqc = np.array(eqcl)

        a_weight = np.zeros((len(eqcl),), dtype=np.float32)
        for i, idx in enumerate(eqcl):
            a_weight[i] = filter_to_weight[idx]
        weight_sum = np.sum(a_weight)
        if weight_sum != 0:
            a_weight /= weight_sum
        # print('merging weights:', a_weight)

        beta_bar = _weighted_sum(betav[eqc], a_weight)
        betav[eqc] = beta_bar

        mu_bar = _weighted_sum(t[eqc] * muv[eqc], a_weight)
        muv[eqc] = mu_bar

        bias_bar = _weighted_sum(t[eqc] * biasv[eqc], a_weight)
        biasv[eqc] = bias_bar

        tiled_t = np.tile(t[eqc], (kv.shape[0], kv.shape[1], kv.shape[2], 1))
        k_bar = _weighted_sum(kv[:, :, :, eqc] * tiled_t, a_weight)
        k_bar = np.expand_dims(k_bar, axis=3)
        kv[:,:,:,eqc] = k_bar

        square_weight = np.square(a_weight) * np.square(t[eqc])
        var_bar = _weighted_sum(varv[eqc], square_weight)
        varv[eqc] = var_bar

        gamma_bar = np.sqrt(var_bar)
        gammav[eqc] = gamma_bar


    if 'depth' in origin_kernel.name:
        kv_to_assign = np.transpose(kv, [0,1,3,2])
    else:
        kv_to_assign = kv
    model.batch_set_value([bias, beta,moving_mean,origin_kernel,moving_variance,gamma],
                          [biasv, betav,muv,kv_to_assign,varv,gammav])
    print('merging completed! {} eqcls. {} filters seen. {} filters alike. {} filters eliminated'
          .format(len(eqcls),num_filters_seen,num_filters_alike,num_filters_seen - len(eqcls)))
    if adjust_following_kernel:
        following_kernel_variable = kernels[layer_idx+1]
        kvf = model.get_value(following_kernel_variable)
        if followed_by_fc:    #last conv layer
            _adjust_following_fc_kernel(eqcls, kv, kvf, filter_to_weight)
        else:
            _adjust_following_conv_kernel(eqcls, kvf, filter_to_weight)
        model.set_value(following_kernel_variable, kvf)

# merge kernels, and biases
# no biases are allowed!
def tfm_merge_filters_no_bn(model, eqcls, layer_idx, filter_to_weight=None, adjust_following_kernel=True, followed_by_fc=False, subsequent_strategy=None,
                            layer_idx_to_follow_offset=None):

    num_filters_seen = 0
    num_filters_alike = 0

    kernels = model.get_kernel_variables()
    biases = model.get_bias_tensors()

    assert len(kernels) == len(biases)

    kv = model.get_value(kernels[layer_idx])
    bv = model.get_value(biases[layer_idx])

    if filter_to_weight is None:
        filter_to_weight = [1.] * len(bv)
    for w in filter_to_weight:
        assert w >= 0

    for eqcl in eqcls:
        num_filters_seen += len(eqcl)
        if len(eqcl) == 1:
            continue
        num_filters_alike += len(eqcl)
        eqc = np.array(eqcl)

        a_weight = np.zeros((len(eqcl),), dtype=np.float32)
        for i, idx in enumerate(eqcl):
            a_weight[i] = filter_to_weight[idx]
        weight_sum = np.sum(a_weight)
        if weight_sum != 0:
            a_weight /= weight_sum
        # print('merging weights:', a_weight)

        bias_bar = _weighted_sum(bv[eqc], a_weight)
        bv[eqc] = bias_bar

        k_bar = _weighted_sum(kv[:, :, :, eqc], a_weight)
        k_bar = np.expand_dims(k_bar, axis=3)
        kv[:,:,:,eqc] = k_bar

    model.set_value(biases[layer_idx], bv)
    model.set_value(kernels[layer_idx], kv)
    print('merging completed! {} eqcls. {} filters seen. {} filters alike. {} filters eliminated'
          .format(len(eqcls),num_filters_seen,num_filters_alike,num_filters_seen - len(eqcls)))
    if adjust_following_kernel:
        offset = layer_idx_to_follow_offset.get(layer_idx, 0)
        if subsequent_strategy is None:
            following_kernel_variables = [kernels[layer_idx+1]]
        else:
            if type(subsequent_strategy[layer_idx]) is list:
                following_kernel_variables = [kernels[i] for i in subsequent_strategy[layer_idx]]
            else:
                following_kernel_variables = [kernels[subsequent_strategy[layer_idx]]]

        for f in following_kernel_variables:
            kvf = model.get_value(f)
            if followed_by_fc:  # last conv layer
                _adjust_following_fc_kernel(eqcls, kv, kvf, filter_to_weight, offset)
            else:
                _adjust_following_conv_kernel(eqcls, kvf, filter_to_weight, offset)
            model.set_value(f, kvf)









def tfm_aggregate_filters(model, eqcls, layer_idx, method='mean', weights=None):
    number_filters_seen = 0
    num_filters_alike = 0

    cnt = 0

    kernels = model.get_kernel_variables()
    biases = model.get_bias_tensors()
    moving_means = model.get_moving_mean_variables()
    moving_variances = model.get_moving_variance_variables()
    gammas = model.get_gamma_tensors()
    betas = model.get_beta_tensors()

    kv = model.get_value(kernels[layer_idx])

    if len(biases) >= len(kernels):
        bv = model.get_value(biases[layer_idx])
    else:
        bv = None
    if len(moving_means) > 0:
        muv = model.get_value(moving_means[layer_idx])
    else:
        muv = None
    if len(moving_variances) > 0:
        varv = model.get_value(moving_variances[layer_idx])
    else:
        varv = None
    if len(gammas) > 0:
        gammav = model.get_value(gammas[layer_idx])
    else:
        gammav = None
    if len(betas) > 0:
        betav = model.get_value(betas[layer_idx])
    else:
        betav = None

    for eqcl in eqcls:
        number_filters_seen += len(eqcl)
        if len(eqcl) == 1:
            continue
        num_filters_alike += len(eqcl)
        eqc = np.array(eqcl)

        # elif method == 'first':
        #     aggregated_k = selected_k[:,:,:,0]
        #     aggregated_b = selected_b[0]
        # else:
        #     l2_norms = np.sum(np.square(selected_k), axis=(0,1,2))
        #     idx = np.argmin(l2_norms)
        #     aggregated_k = selected_k[:, :, :, idx]
        #     aggregated_b = selected_b[idx]

        aggregated_k = np.mean(kv[:, :, :, eqc], axis=3)
        aggregated_k = np.expand_dims(aggregated_k, axis=3)
        kv[:,:,:,eqc] = aggregated_k
        cnt += 1
        model.set_value(kernels[layer_idx], kv)
        if bv is not None:
            aggregated_b = np.mean(bv[eqc])
            bv[eqc] = aggregated_b
            cnt += 1
            model.set_value(biases[layer_idx], bv)
        if muv is not None:
            aggregated_mu = np.mean(muv[eqc])
            muv[eqc] = aggregated_mu
            cnt += 1
            model.set_value(moving_means[layer_idx], muv)
        if varv is not None:
            aggregated_var = (np.sum(varv[eqc])) / ((len(eqcl)) ** 2)
            # aggregated_var = np.mean(varv[eqc])
            varv[eqc] = aggregated_var
            cnt += 1
            model.set_value(moving_variances[layer_idx], varv)
        if gammav is not None:
            aggregated_gamma = np.mean(gammav[eqc])
            gammav[eqc] = aggregated_gamma
            cnt += 1
            model.set_value(gammas[layer_idx], gammav)
        if betav is not None:
            aggregated_beta = np.mean(betav[eqc])
            betav[eqc] = aggregated_beta
            cnt += 1
            model.set_value(betas[layer_idx], betav)
    print('aggregation completed! {} eqcls. {} filters seen. {} filters alike. {} filters eliminated'.format(len(eqcls), number_filters_seen, num_filters_alike,  number_filters_seen - len(eqcls)))
    print(cnt, 'variables modified')

def tfm_aggregate_filters_with_weights(model, eqcls, filter_to_weight, layer_idx, method='mean'):
    number_filters_seen = 0
    num_filters_alike = 0
    kernels = model.get_kernel_variables()
    biases = model.get_bias_tensors()
    assert len(kernels) == len(biases)
    kv = model.get_value(kernels[layer_idx])
    bv = model.get_value(biases[layer_idx])
    for eqcl in eqcls:
        number_filters_seen += len(eqcl)
        if len(eqcl) == 1:
            continue
        num_filters_alike += len(eqcl)
        # eqc = np.array(eqcl)
        # selected_k = kv[:,:,:,eqc]
        # selected_b = bv[eqc]
        if method == 'mean':
            sum_k = np.zeros_like(kv[:,:,:,0])
            sum_b = 0.
            sum_weights = 0.
            for e in eqcl:
                sum_k += kv[:,:,:,e] * filter_to_weight[e]
                sum_weights += filter_to_weight[e]
                sum_b += bv[e] * filter_to_weight[e]
            if sum_weights != 0:
                aggregated_k = sum_k / sum_weights
                aggregated_b = sum_b / sum_weights
                # aggregated_k = sum_k / len(eqcl)
                # aggregated_b = sum_b / len(eqcl)
            else:
                aggregated_k = np.zeros_like(sum_k)
                aggregated_b = 0
        else:
            assert False

        aggregated_k = np.expand_dims(aggregated_k, axis=3)
        kv[:,:,:,np.array(eqcl)] = aggregated_k
        bv[np.array(eqcl)] = aggregated_b
    model.set_value(kernels[layer_idx], kv)
    model.set_value(biases[layer_idx], bv)
    print('aggregation with weights completed! {} eqcls. {} filters seen. {} filters alike. {} filters eliminated'.format(len(eqcls), number_filters_seen, num_filters_alike,  number_filters_seen - len(eqcls)))


# multi_survey_file: a map, {layer idx 1: outs 1, layer idx 2: outs 2, ..., 'labels' : labels}
# def direct_agg_prune_on_np(from_np, to_np, multi_survey_file, remains):
#
#     survey_dict = np.load(multi_survey_file).item()
#     source = np.load(from_np).item()
#
#     kernel_weights_dict = extract_kernel_weights_from_np_dict(source)
#     bias_weights_dict = extract_bias_weights_from_np_dict(source)
#     assert len(kernel_weights_dict) == len(bias_weights_dict)
#     nb_conv_layer = len(remains)
#     assert len(survey_dict) - 1 <= nb_conv_layer
#
#     kernel_name_list = list(kernel_weights_dict.keys())
#     kernel_value_list = list(kernel_weights_dict.values())
#     bias_name_list = list(bias_weights_dict.keys())
#     bias_value_list = list(bias_weights_dict.values())
#
#     for layer_idx in range(nb_conv_layer):
#         raw_outs = survey_dict[layer_idx]
#         nb_origin_filters = raw_outs.shape[1]
#         if nb_origin_filters == remains[layer_idx]:
#             continue
#         eqcls = calculate_eqcls_from_raw(raw_outs=raw_outs, num_eqcls=remains[layer_idx])
#
#         # merge
#
#         kv = kernel_value_list[layer_idx]
#         bv = bias_value_list[layer_idx]
#         for eqcl in eqcls:
#             if len(eqcl) == 1:
#                 continue
#             eqc = np.array(eqcl)
#             selected_k = kv[:, :, :, eqc]
#             selected_b = bv[eqc]
#             aggregated_k = np.mean(selected_k, axis=3)
#             aggregated_b = np.mean(selected_b)
#             aggregated_k = np.expand_dims(aggregated_k, axis=3)
#             kv[:, :, :, eqc] = aggregated_k
#             bv[eqc] = aggregated_b
#
#         # cut this layer and next layer
#         origin_kernel_shape = kv.shape
#         kvf = kernel_value_list[layer_idx + 1]
#         if layer_idx == nb_conv_layer - 1:
#             conv_indexes_to_delete = []
#             fc_indexes_to_delete = []
#             assert kvf.shape[0] % kv.shape[3] == 0
#             last_conv_origin_deps = kv.shape[3]
#             corresponding_neurons_per_kernel = kvf.shape[0] // kv.shape[3]
#             base = np.arange(0, corresponding_neurons_per_kernel) * last_conv_origin_deps
#             for eqcl in eqcls:
#                 if len(eqcl) == 1:
#                     continue
#                 conv_indexes_to_delete += eqcl[1:]
#                 for i in eqcl[1:]:
#                     fc_indexes_to_delete.append(base + i)
#                 to_concat = []
#                 for i in eqcl:
#                     corresponding_neurons_idxes = base + i
#                     to_concat.append(np.expand_dims(kvf[corresponding_neurons_idxes, :], axis=0))
#                 merged = np.sum(np.concatenate(to_concat, axis=0), axis=0)
#                 reserved_idxes = base + eqcl[0]
#                 kvf[reserved_idxes, :] = merged
#             kvf = np.delete(kvf, np.concatenate(fc_indexes_to_delete, axis=0), axis=0)
#             kv = np.delete(kv, conv_indexes_to_delete, axis=3)
#             bv = np.delete(bv, conv_indexes_to_delete)
#         else:
#             indexes_to_delete = []
#             for eqcl in eqcls:
#                 if len(eqcl) == 1:
#                     continue
#                 indexes_to_delete += eqcl[1:]
#                 eqc = np.array(eqcl)
#                 selected_k_follow = kvf[:, :, eqc, :]
#                 aggregated_k_follow = np.sum(selected_k_follow, axis=2)
#                 kvf[:, :, eqcl[0], :] = aggregated_k_follow
#             kvf = np.delete(kvf, indexes_to_delete, axis=2)
#             kv = np.delete(kv, indexes_to_delete, axis=3)
#             bv = np.delete(bv, indexes_to_delete)
#
#         kernel_value_list[layer_idx] = kv
#         bias_value_list[layer_idx] = bv
#         kernel_value_list[layer_idx + 1] = kvf
#         print('aggregate and prune filters for layer {}, origin kernel shape {}, now kernel shape{}, next layer kernel shape {}'
#             .format(layer_idx, origin_kernel_shape, kv.shape, kvf.shape))
#
#     result = {}
#     for n, v in zip(kernel_name_list, kernel_value_list):
#         result[n] = v
#     for n, v in zip(bias_name_list, bias_value_list):
#         result[n] = v
#     for k, v in source.items():
#         if k not in result:
#             result[k] = v
#     np.save(to_np, result)

