from abc import abstractmethod
from tfm_aggregation import tfm_merge_gradients
import tensorflow as tf
import numpy as np
from tfm_aggregation import tfm_merge_filters_with_bn, tfm_merge_filters_no_bn
from tfm_aggregation import tf_get_gradient_by_var, _weighted_mean

class GradientHandler(object):

    @abstractmethod
    def handle_gradient(self, origin_grads_and_vars, device_idx, device):
        pass

class PreserveGradientHandler(object):

    def __init__(self, keyword):
        self.keyword = keyword

    @abstractmethod
    def handle_gradient(self, origin_grads_and_vars):
        result = []
        for (g,v) in origin_grads_and_vars:
            if self.keyword in v.name:
                result.append((g,v))
        print('only {} gradients to update'.format(len(result)))
        return result

class MergeGradientHandler(GradientHandler):

    def __init__(self, model, layer_to_eqcls, weight_decay, diff_factor,
                 exclude_l2_decay_keywords=None, bn_layer_to_eqcls=None, l2_decay_on_vecs=False, follow_dict=None, st_decay_on_vecs=True):
        self.layer_to_eqcls = layer_to_eqcls
        self.model = model
        self.weight_decay = weight_decay
        self.diff_factor = diff_factor
        self.exclude_l2_decay_keywords=exclude_l2_decay_keywords
        self.bn_layer_to_eqcls = bn_layer_to_eqcls
        self.l2_decay_on_vecs = l2_decay_on_vecs
        self.st_decay_on_vecs = st_decay_on_vecs
        assert st_decay_on_vecs in ['same', 'propor', 'none', 'cancelbeta', 'slowgamma', 'slowgammabeta']
        self.follow_dict = follow_dict
        self.device_to_layer_to_mergemat = {}
        self.device_to_layer_to_decaymat = {}
        self.device_to_layer_to_propomat = {}

    def handle_gradient(self, origin_grads_and_vars, device_idx, device):
        return self.merge_and_decay_grads_v2(self.model, origin_grads_and_vars, self.layer_to_eqcls, device_idx, device)

    def make_betas_proportional(self, model, eqcls, diff_factor, origin_kv, origin_grads_and_vars, origin_g_to_decayed_g):
        gamma_variable = model.get_gamma_variable_for_kernel(origin_kv)
        beta_variable = model.get_beta_variable_for_kernel(origin_kv)
        # gamma_g =tf_get_gradient_by_var(origin_grads_and_vars, gamma_variable)
        beta_g = tf_get_gradient_by_var(origin_grads_and_vars, beta_variable)
        num_filters = gamma_variable.get_shape().as_list()[0]
        slice_list = [0] * num_filters
        for eqcl in eqcls:
            eqc = sorted(eqcl)
            slice_list[eqc[0]] = tf.expand_dims(beta_g[eqc[0]], 0)
            if len(eqc) == 1:
                continue
            for ee in eqc[1:]:
                slice_list[ee] = tf.expand_dims(0.5 * (beta_g[ee] + gamma_variable[ee] / gamma_variable[eqc[0]] * beta_g[eqc[0]]) \
                                 - diff_factor * (gamma_variable[ee] / gamma_variable[eqc[0]] * beta_variable[eqc[0]] - beta_variable[ee]), 0)
        for s in slice_list:
            assert s != 0
        origin_g_to_decayed_g[beta_g] = tf.concat(slice_list, axis=0)

    def merge_and_decay_grads_v2(self, model, origin_grads_and_vars, layer_to_eqcls, device_idx, device):
        if device_idx > 0:
            kernels = [k for k in model.get_kernel_variables() if k.name.startswith('v{}'.format(device_idx))]
        else:   # compatible with old tfm_builder models
            kernels = [k for k in model.get_kernel_variables() if k.name.startswith('v0') or not k.name.startswith('v')]
        print('the kernels to merge and decay are: ', [k.name for k in kernels])
        origin_g_to_decayed_g = {}
        # print('grads and vars: ', origin_grads_and_vars)

        def md_1d(layer_idx, var, l2):
            if var is not None:
                if l2 > 0:
                    print('--------- l2 decay {} on vector param: '.format(l2), var.name)
                origin_g = tf_get_gradient_by_var(origin_grads_and_vars, var)
                # print('1d var is {}, the origin gradient is {}'.format(var, origin_g))
                reshaped_g = tf.expand_dims(origin_g, 1)
                merged_g = self.merge_gradient_v2(layer_idx, reshaped_g, layer_to_eqcls[layer_idx], device=device)
                if self.st_decay_on_vecs == 'slowgamma' and 'gamma' in var.name:
                    diff = self.diff_factor * 0.1
                    print('slow decay on gamma {}, the factor is {}'.format(var.name, diff))
                elif self.st_decay_on_vecs == 'slowgammabeta' and ('gamma' in var.name or 'beta' in var.name):
                    diff = self.diff_factor * 0.1
                    print('slow decay on gamma/beta {}, the factor is {}'.format(var.name, diff))
                else:
                    diff = self.diff_factor
                decayed_g = self.add_decay_to_merged_gradient_v2(layer_idx, merged_g, tf.expand_dims(var, 1),
                                                                    layer_to_eqcls[layer_idx], l2, diff, device=device)
                origin_g_to_decayed_g[origin_g] = tf.squeeze(decayed_g)

        def cancel_beta(eqcls, beta_order, beta_variable):
            if beta_variable is not None:
                beta_len = beta_variable.get_shape().as_list()[-1]
                print('prepare the cancel vector for beta variable {} which length is {}'.format(beta_variable.name, beta_len))
                zero_grad_vector = np.ones(beta_len, dtype=np.float32)
                decay_vector = np.zeros(beta_len, dtype=np.float32)
                beta_g = tf_get_gradient_by_var(origin_grads_and_vars, beta_variable)
                print(beta_g, beta_variable)
                for eqc in eqcls:
                    if len(eqc) > 1:
                        zero_grad_vector[np.array(eqc)] = 0.0
                        decay_vector[np.array(eqc)] = self.diff_factor

                with tf.device(device):
                    gpu_zero_grad_vec = tf.constant(zero_grad_vector, dtype=tf.float32, name='beta_zerograd_vec_{}'.format(beta_order))
                    gpu_decay_vec = tf.constant(decay_vector, dtype=tf.float32, name='beta_decay_vec_{}'.format(beta_order))
                    origin_g_to_decayed_g[beta_g] = beta_g * gpu_zero_grad_vec + gpu_decay_vec * beta_variable



        for layer_idx in layer_to_eqcls.keys():
            #   handle kernel
            origin_kv = kernels[layer_idx]
            origin_kg = tf_get_gradient_by_var(origin_grads_and_vars, origin_kv)
            # TODO ugly hack for mobilenet depthwise kernel only
            if 'depthwise' in origin_kv.name:
                kv = tf.transpose(origin_kv, [0,1,3,2])
                kg = tf.transpose(origin_kg, [0,1,3,2])
            else:
                kv = origin_kv
                kg = origin_kg

            kv_shape = kv.get_shape()

            if len(kv_shape) == 4:
                reshaped_kg = tf.transpose(tf.reshape(kg, (-1, kv_shape[3])), [1,0])    # filters
                reshaped_k = tf.transpose(tf.reshape(kv, (-1, kv_shape[3])), [1,0])
            else:
                assert len(kv_shape) == 2
                print('handling gradients for fc kernel: ', kv.name)
                reshaped_kg = tf.transpose(kg, [1,0])
                reshaped_k = tf.transpose(kv, [1,0])

            merged_kg = self.merge_gradient_v2(layer_idx, reshaped_kg, layer_to_eqcls[layer_idx], device=device)
            if self.exclude_l2_decay_keywords is not None and self.exclude_l2_decay_keywords in kv.name:
                decayed_kg = self.add_decay_to_merged_gradient_v2(layer_idx, merged_kg, reshaped_k,
                                                             layer_to_eqcls[layer_idx], 0, self.diff_factor, device=device)
            else:
                decayed_kg = self.add_decay_to_merged_gradient_v2(layer_idx, merged_kg, reshaped_k,
                                                             layer_to_eqcls[layer_idx], self.weight_decay, self.diff_factor, device=device)

            restored_kg = tf.transpose(decayed_kg, [1,0])
            restored_kg = tf.reshape(restored_kg, kv_shape)
            # TODO ugly hack for mobilenet depthwise kernel only
            if 'depthwise' in origin_kv.name:
                restored_kg = tf.transpose(restored_kg, [0,1,3,2])
            origin_g_to_decayed_g[origin_kg] = restored_kg

            assert not self.l2_decay_on_vecs    #TODO 20190201, because cancel_beta cannot handle l2 on vecs

            #   handlle other vectors
            if self.l2_decay_on_vecs:
                vec_weight_decay = self.weight_decay
            else:
                vec_weight_decay = 0

            if self.bn_layer_to_eqcls is not None:
                continue

            if self.st_decay_on_vecs in ['same', 'slowgamma', 'slowgammabeta']:
                print('---------- old version: we merge and decay bias, gamma, and betas just as kernels ---------')
                # bias_variable = model.get_bias_variable_for_kernel(origin_kv)
                # if bias_variable is not None:
                #     print('got bias variable: ', bias_variable.name)
                # md_1d(layer_idx, bias_variable, vec_weight_decay)
                gamma_variable = model.get_gamma_variable_for_kernel(origin_kv)
                if gamma_variable is not None:
                    print('got gamma variable: ', gamma_variable.name)
                md_1d(layer_idx, gamma_variable, vec_weight_decay)
                beta_variable = model.get_beta_variable_for_kernel(origin_kv)
                if beta_variable is not None:
                    print('got beta variable: ', beta_variable.name)
                md_1d(layer_idx, beta_variable, vec_weight_decay)
            elif self.st_decay_on_vecs == 'none':
                print('------------ we do not decay vecs -------------')
            elif self.st_decay_on_vecs == 'prop':
                print('---------- deprecated version: we seek to make betas proportional ---------')
                assert False
            elif self.st_decay_on_vecs == 'cancelbeta':
                print('---------- new version: we seek to ignore gammas and cancel betas ---------')
                bias_variable = model.get_bias_variable_for_kernel(origin_kv)
                assert bias_variable is None
                beta_variable = model.get_beta_variable_for_kernel(origin_kv)
                cancel_beta(layer_to_eqcls[layer_idx], layer_idx, beta_variable)
            else:
                assert False

                # self.make_betas_proportional(model, eqcls=layer_to_eqcls[layer_idx],
                #     diff_factor=self.diff_factor,
                #     origin_kv=origin_kv, origin_grads_and_vars=origin_grads_and_vars, origin_g_to_decayed_g=origin_g_to_decayed_g)

        #   handle seperate bn parameters
        if self.bn_layer_to_eqcls is not None:

            beta_vars = [v for v in model.get_beta_variables() if v.name.startswith('v{}'.format(device_idx))]
            gamma_vars = [v for v in model.get_gamma_variables() if v.name.startswith('v{}'.format(device_idx))]

            if self.st_decay_on_vecs == 'cancelbeta':
                for bn_layer_idx, bn_eqcls in self.bn_layer_to_eqcls.items():
                    print('cancel the separate beta {}, the name is {}'.format(bn_layer_idx, beta_vars[bn_layer_idx].name))
                    cancel_beta(bn_eqcls, bn_layer_idx, beta_vars[bn_layer_idx])

            else:

                def md_1d_bn(layer_idx, eqcls, var, l2):
                    if var is not None:
                        if l2 > 0:
                            print('--------- l2 decay {} on vector param: '.format(l2), var.name)
                        origin_g = tf_get_gradient_by_var(origin_grads_and_vars, var)
                        reshaped_g = tf.expand_dims(origin_g, 1)
                        merged_g = self.merge_gradient_v2(layer_idx + 10000, reshaped_g, eqcls,
                                                          device=device)        # TODO note !
                        if self.st_decay_on_vecs == 'slowgamma' and 'gamma' in var.name:
                            diff = self.diff_factor * 0.1
                            print('slow decay on gamma {}, the factor is {}'.format(var.name, diff))
                        else:
                            diff = self.diff_factor
                        decayed_g = self.add_decay_to_merged_gradient_v2(layer_idx + 10000, merged_g, tf.expand_dims(var, 1),
                                                                         eqcls, l2, diff,
                                                                         device=device)
                        origin_g_to_decayed_g[origin_g] = tf.squeeze(decayed_g)

                if self.l2_decay_on_vecs:
                    vec_weight_decay = self.weight_decay
                else:
                    vec_weight_decay = 0

                for bn_layer_idx, bn_eqcls in self.bn_layer_to_eqcls.items():
                    md_1d_bn(bn_layer_idx, bn_eqcls, beta_vars[bn_layer_idx], l2=vec_weight_decay)
                    md_1d_bn(bn_layer_idx, bn_eqcls, gamma_vars[bn_layer_idx], l2=vec_weight_decay)

        result = []
        merged_cnt = 0
        for (g, v) in origin_grads_and_vars:
            if g in origin_g_to_decayed_g:
                result.append((origin_g_to_decayed_g[g], v))
                merged_cnt += 1
            else:
                result.append((g, v))
        print(merged_cnt, 'gradients merged')
        return result


    def merge_and_add_l2_or_diff_gradient_kernel(self, target_k_grad, target_k_var, eqcls, weights, l2_factor, diff_factor):
        num_filters = target_k_grad.get_shape()[3]
        result_list = [0] * num_filters
        kernels_seen = set()
        for eqcl in eqcls:
            if len(eqcl) == 1:
                kernels_seen.add(eqcl[0])
                result_list[eqcl[0]] = tf.expand_dims(target_k_grad[:, :, :, eqcl[0]] + l2_factor * target_k_var[:, :, :, eqcl[0]], axis=3)
            else:
                eqcl_tensors = []
                eqcl_weights = []
                for e in eqcl:
                    eqcl_tensors.append(tf.expand_dims(target_k_grad[:, :, :, e], axis=3))
                    eqcl_weights.append(weights[e])
                    kernels_seen.add(e)
                mean = _weighted_mean(eqcl_tensors, eqcl_weights)
                sorted_eqcl = sorted(eqcl)
                result_list[sorted_eqcl[0]] = mean + tf.cast(l2_factor, tf.float32) * tf.expand_dims(target_k_var[:, :, :, sorted_eqcl[0]], axis=3)
                for e in sorted_eqcl[1:]:
                    diff_term = tf.expand_dims(target_k_var[:, :, :, e] - target_k_var[:, :, :, sorted_eqcl[0]], axis=3)
                    result_list[e] = mean + tf.cast(diff_factor, tf.float32) * diff_term \
                                     + tf.cast(l2_factor, tf.float32) * tf.expand_dims(target_k_var[:, :, :, sorted_eqcl[0]], axis=3)
        assert len(kernels_seen) == len(weights) and num_filters == len(kernels_seen)
        print('{} kernels merged, {} eqcls'.format(num_filters, len(eqcls)))
        return tf.concat(result_list, axis=3)

    def merge_and_add_l2_or_diff_gradient_1d_tensor(self, target_t_grad, target_t_var, eqcls, weights,
                                                 diff_factor):
        num_filters = target_t_grad.get_shape()[0]
        result_list = [0] * num_filters
        kernels_seen = set()
        for eqcl in eqcls:
            if len(eqcl) == 1:
                kernels_seen.add(eqcl[0])
                result_list[eqcl[0]] = tf.expand_dims(target_t_grad[eqcl[0]], axis=0)
            else:
                eqcl_tensors = []
                eqcl_weights = []
                for e in eqcl:
                    eqcl_tensors.append(target_t_grad[e])
                    eqcl_weights.append(weights[e])
                    kernels_seen.add(e)
                mean = _weighted_mean(eqcl_tensors, eqcl_weights)
                sorted_eqcl = sorted(eqcl)
                result_list[sorted_eqcl[0]] = tf.expand_dims(mean, axis=0)
                for e in sorted_eqcl[1:]:
                    result_list[e] = tf.expand_dims(mean + tf.cast(diff_factor, tf.float32) * (target_t_var[e] - target_t_var[sorted_eqcl[0]]), axis=0)
        assert len(kernels_seen) == len(weights) and num_filters == len(kernels_seen)
        print('{} kernels merged, {} eqcls'.format(num_filters, len(eqcls)))
        return tf.concat(result_list, axis=0)


    #   target_t_grad: 2-d
    def merge_gradient_v2(self, layer_idx, target_t_grad, eqcls, device):
        num_filters = target_t_grad.get_shape()[0]
        #   TODO weights is now deprecated
        weights = np.ones(num_filters)
        if device not in self.device_to_layer_to_mergemat:
            self.device_to_layer_to_mergemat[device] = dict()
        if layer_idx in self.device_to_layer_to_mergemat[device]:
            gpu_trans_mat = self.device_to_layer_to_mergemat[device][layer_idx]
        elif self.follow_dict is not None and layer_idx in self.follow_dict and self.follow_dict[layer_idx] in self.device_to_layer_to_mergemat[device]:
            gpu_trans_mat = self.device_to_layer_to_mergemat[device][self.follow_dict[layer_idx]]
        else:
            merge_trans_mat = np.zeros((num_filters, num_filters), dtype=np.float32)
            for eqc in eqcls:
                if len(eqc) == 1:
                    merge_trans_mat[eqc[0], eqc[0]] = 1
                    continue
                weights_sum = np.sum(np.array([weights[ee] for ee in eqc]))
                se = sorted(eqc)
                for ei in se:
                    for ej in se:
                        merge_trans_mat[ei, ej] = weights[ej] / weights_sum
            with tf.device(device):
                gpu_trans_mat = tf.constant(merge_trans_mat, dtype=tf.float32, name='merge_trans_var_layer{}'.format(layer_idx))
            self.device_to_layer_to_mergemat[device][layer_idx] = gpu_trans_mat
            if self.follow_dict is not None and layer_idx in self.follow_dict:
                self.device_to_layer_to_mergemat[device][self.follow_dict[layer_idx]] = gpu_trans_mat
        with tf.device(device):
            merged_gradient = tf.matmul(gpu_trans_mat, target_t_grad)
        return merged_gradient


    def add_decay_to_merged_gradient_v2(self, layer_idx, merged_gradient, target_t_var, eqcls, l2_factor, diff_factor, device):
        num_filters = target_t_var.get_shape()[0]
        ############
        if device not in self.device_to_layer_to_decaymat:
            self.device_to_layer_to_decaymat[device] = dict()
        if layer_idx in self.device_to_layer_to_decaymat[device]:
            gpu_trans_mat = self.device_to_layer_to_decaymat[device][layer_idx]
        elif self.follow_dict is not None and layer_idx in self.follow_dict and self.follow_dict[layer_idx] in self.device_to_layer_to_decaymat[device]:
            gpu_trans_mat = self.device_to_layer_to_decaymat[device][self.follow_dict[layer_idx]]
        else:
            decay_trans_mat = np.zeros((num_filters, num_filters), dtype=np.float32)
            for eqc in eqcls:
                for ee in eqc:
                    decay_trans_mat[ee, ee] = l2_factor + diff_factor
                    for p in eqc:
                        decay_trans_mat[ee, p] += -diff_factor / len(eqc)

            with tf.device(device):
                gpu_trans_mat = tf.constant(decay_trans_mat, dtype=tf.float32, name='decaytrans_var_layer{}'.format(layer_idx))
            self.device_to_layer_to_decaymat[device][layer_idx] = gpu_trans_mat
            if self.follow_dict is not None and layer_idx in self.follow_dict:
                self.device_to_layer_to_decaymat[device][self.follow_dict[layer_idx]] = gpu_trans_mat

        with tf.device(device):
            # gpu_trans_mat = tf.Variable(decay_trans_mat, dtype=tf.float32, trainable=False, name='decay_trans_var')
            decayed_gradient = merged_gradient + tf.matmul(gpu_trans_mat, target_t_var)
        return decayed_gradient





class Callback(object):

    def before_train(self):
        pass

    def after_train(self):
        pass

    def before_step(self, step):
        pass

    def after_step(self, step):
        pass

class CallbackList(object):

    def __init__(self, callback_list):
        if type(callback_list) is list:
            self.callback_list = callback_list
        else:
            self.callback_list = [callback_list]

    def before_train(self):
        for c in self.callback_list:
            c.before_train()

    def after_train(self):
        for c in self.callback_list:
            c.after_train()

    def before_step(self, step):
        for c in self.callback_list:
            c.before_step(step)

    def after_step(self, step):
        for c in self.callback_list:
            c.after_step(step)
