import os
import pdb

# MXNET_CPU_WORKER_NTHREADS must be greater than 1 for custom op to work on CPU
os.environ["MXNET_CPU_WORKER_NTHREADS"] = "4"
import mxnet as mx
import numpy as np
from mxnet.gluon import nn

#============== DIV monitor =============#
class DMon(mx.operator.CustomOp):
    def forward(self, is_train, req, in_data, out_data, aux):
        x0 = in_data[0]
        x1 = in_data[1]
        x2 = in_data[2]
        flops = x1/x2
        self.assign(out_data[0], req[0], x0)
        self.assign(out_data[1], req[0], flops)
    def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
        self.assign(in_grad[0], req[0], out_grad[0])

@mx.operator.register("DMon")
class DMonProp(mx.operator.CustomOpProp):
    def list_arguments(self):
        
        return ['data0', 'data1', 'data2']
    def list_outputs(self):
        return ['output', 'flops']
        cs_shape = (in_shapes[0][1],)
    def infer_shape(self, in_shapes):
        data_shape = [1]
        return [in_shapes[0], data_shape, data_shape], [in_shapes[0], data_shape]
    def create_operator(self, ctx, in_shapes, in_dtypes):
        return DMon()
#============== FLOPs reduction monitor =============#

#============= activation quantization ===============#
class quantizeA(mx.operator.CustomOp):
    def __init__(self, num_bits):
        self._bits = float(num_bits)

    def forward(self, is_train, req, in_data, out_data, aux):
        x = in_data[0]
        upper_bound = in_data[1].asscalar()
        clipped_x = mx.nd.clip(data=x, a_min=0, a_max=upper_bound)
        x_range = 2 ** self._bits - 1
        scale = x_range/upper_bound
        quantized_x = mx.nd.round(clipped_x * scale) / scale 
        self.assign(out_data[0], req[0], quantized_x)

    def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
        dy = out_grad[0]
        x = in_data[0]
        upper_bound = in_data[1].asscalar() * mx.nd.ones_like(x)
        dx = mx.nd.greater(x, mx.nd.zeros_like(x)) * mx.nd.lesser(x, upper_bound)
        dx = dx * dy
        d_ub = mx.nd.greater_equal(x, upper_bound)
        d_ub = mx.nd.sum(d_ub * dy)

        self.assign(in_grad[0], req[0], dx)
        self.assign(in_grad[1], req[0], d_ub)

@mx.operator.register("quantizeA")
class quantizeAProp(mx.operator.CustomOpProp):
    def __init__(self, num_bits):
        super(quantizeAProp, self).__init__(True)
        self._bits  = float(num_bits)

    def list_arguments(self):
        return ['data0', 'data1']

    def list_outputs(self):
        return ['output']

    def infer_shape(self, in_shapes):
        data_shape = in_shapes[0]
        u_shape = [1]
        output_shape = data_shape
        return [data_shape, u_shape], [output_shape]

    def create_operator(self, ctx, in_shapes, in_dtypes):
        return quantizeA(self._bits)
    
class quantizeABlock(mx.gluon.HybridBlock):
    def __init__(self, num_bits, uInit, **kwargs):
        super(quantizeABlock, self).__init__(**kwargs)
        with self.name_scope():
            self._bits = float(num_bits)
            self._uInit = float(uInit)
            self.u = self.params.get('u', shape=(1,), 
                    init=mx.init.Constant(self._uInit), wd_mult=1.0, lr_mult=1.0)

    def hybrid_forward(self, F, x0, u):
        return F.Custom(x0, u, num_bits = self._bits, op_type='quantizeA') 
#============= activation quantization ===============#

#=============== Act binarization =================#
class binarizeAct(mx.operator.CustomOp):

    def forward(self, is_train, req, in_data, out_data, aux):
        x = in_data[0]
        qx = mx.nd.sign(x)
        scale = mx.nd.mean(mx.nd.abs(qx), axis=(1,2,3), keepdims=True)
        scale_x = qx*scale
        self.assign(out_data[0], req[0], scale_x)

    def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
        self.assign(in_grad[0], req[0], out_grad[0])

@mx.operator.register("binarizeAct")
class binarizeActProp(mx.operator.CustomOpProp):

    def list_arguments(self):
        return ['data0']

    def list_outputs(self):
        return ['output']
    
    def infer_shape(self, in_shapes):
        data_shape = in_shapes[0]
        output_shape = data_shape
        return [data_shape], [output_shape]

    def create_operator(self, ctx, in_shapes, in_dtypes):
        return binarizeAct()
#=============== Act binarization =================#

#=============== weight binarization =================#
class binarizeW(mx.operator.CustomOp):

    def forward(self, is_train, req, in_data, out_data, aux):
        w = in_data[0]
        qw = mx.nd.sign(w)
        scale = mx.nd.mean(mx.nd.abs(qw), axis=(1,2,3), keepdims=True)
        scale_w = qw*scale
        self.assign(out_data[0], req[0], scale_w)

    def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
        self.assign(in_grad[0], req[0], out_grad[0])

@mx.operator.register("binarizeW")
class binarizeWProp(mx.operator.CustomOpProp):

    def list_arguments(self):
        return ['data0']

    def list_outputs(self):
        return ['output']
    
    def infer_shape(self, in_shapes):
        data_shape = in_shapes[0]
        output_shape = data_shape
        return [data_shape], [output_shape]

    def create_operator(self, ctx, in_shapes, in_dtypes):
        return binarizeW()
#=============== weight binarization =================#

#============== channel shuffle function =============#
def channel_shuffle(data, groups):
    data = mx.sym.reshape(data, shape=(0,-4,groups,-1,-2))
    data = mx.sym.swapaxes(data, 1, 2)
    data = mx.sym.reshape(data, shape=(0,-3,-2))
    return data
#============== channel shuffle function =============#

#=============== group channel gating ================#
def group_cg(data, num_filter, kernel, stride, pad, depth, groups, binary_w, 
        quantize_a, tag):
    
    wt = mx.sym.Variable('wt_'+tag, shape=(num_filter,depth,kernel,kernel), 
            dtype=np.float32, init=mx.init.Xavier(rnd_type='gaussian', 
                factor_type="in", magnitude=2), lr_mult=1.0, wd_mult=4.0)

    if binary_w:
        wt = mx.sym.Custom(wt, op_type='binarizeW')
    # W*X
    cnv = mx.sym.Convolution(data=data, weight=wt, num_filter=num_filter, 
            kernel=(kernel,kernel), stride=stride, pad=(pad,pad), no_bias=True)
    w = mx.sym.split(data=wt, num_outputs=groups, axis=0)
    gw = []
    for idx in range(0, groups):
        gw.append(mx.sym.slice_axis(w[idx], axis=1, begin=int(idx*depth/groups), end=int((idx+1)*depth/groups)))
    # mxnet concat does not work for list input
    if groups == 16:
        w_group = mx.sym.concat(gw[0], gw[1], gw[2], gw[3], gw[4], gw[5], gw[6], gw[7], gw[8], gw[9], gw[10], gw[11], gw[12], gw[13], gw[14], gw[15], dim=0)
    elif groups == 8:
        w_group = mx.sym.concat(gw[0], gw[1], gw[2], gw[3], gw[4], gw[5], gw[6], gw[7], dim=0)
    elif groups == 4:
        w_group = mx.sym.concat(gw[0], gw[1], gw[2], gw[3], dim=0)
    elif groups == 2:
        w_group = mx.sym.concat(gw[0], gw[1], dim=0)
    # Wp*Xp
    gcnv = mx.sym.Convolution(data=data, weight=w_group, num_filter=num_filter, 
            kernel=(kernel,kernel), stride=stride, pad=(pad,pad), 
            num_group=groups, no_bias=True)
    return gcnv, cnv
#=============== group channel gating ================#

#======= group channel gating building block =========#
def gSGBCNV(data, nsigma, gamma, kernel, depth, nfltr, stride, pad, no_bias, 
        stop, target, level, tag, groups, binary_w, quantize_act, act_bits, workspace):
    x1 = mx.sym.slice_axis(data=data, axis=1, begin=0, end=int(depth*stop)) #(batch_size, channel, height, width)
    #lvl: specify the dynamic skip level
    #level 1: skip at pixel level
    #level 2: skip at channel level (parameter and parameterless)
    #level 3: skip at layer level (parameter and parameterless)
    
    sgb = lvl(level, nsigma, gamma, nfltr, int(kernel), int(depth), float(target))
    sgb.initialize()
    sgb.hybridize()
    
    x1_c, x2_s = group_cg(data, nfltr, kernel, stride, pad, depth, groups, binary_w, quantize_act, tag)
    
    gamma = mx.sym.Variable('gamma_'+tag, shape=(nfltr,), dtype=np.float32, init=mx.init.One(), lr_mult=1.0, wd_mult=1.0)
    beta = mx.sym.Variable('beta_'+tag, shape=(nfltr,), dtype=np.float32, init=mx.init.Zero(), lr_mult=1.0, wd_mult=1.0)

    x1_bn = mx.sym.BatchNorm(data=x1_c, fix_gamma=False, gamma=gamma, beta=beta, momentum=0.9, eps=2e-5)
    x1_act = mx.sym.relu(x1_bn, name=tag+'_act1')
    
    #beta = mx.sym.Variable('beta_'+tag, shape=(nfltr,), dtype=np.float32, init=mx.init.Zero(), lr_mult=0.0, wd_mult=0.0)

    _x1_bn = mx.sym.BatchNorm(data=x1_c, fix_gamma=True, momentum=0.9, eps=2e-5)
    #_x1_bn = mx.sym.Dropout(data=_x1_bn, p=0.3)
    # set the scale parameter nvar for each output channel
    
    if level == "channel":
        _x1_bn = mx.sym.mean(data=_x1_bn, axis=(2,3))
    elif level == "layer":
        _x1_bn = mx.sym.mean(data=_x1_bn, axis=(1,2,3))
    
    x2_bn = mx.sym.BatchNorm(data=x2_s, fix_gamma=False, gamma=gamma, beta=beta, momentum=0.9, eps=2e-5)
    x2_act = mx.sym.relu(x2_bn, name=tag+'_act2')
    stp = mx.sym.ones([1])*float(stop)
    conv, rmcc, r, mac = sgb(x1_act, _x1_bn, x2_act, stp)
    if quantize_act:
        return quantA(conv), rmcc, r, mac
    else:
        return conv, rmcc, r, mac
#======= group channel gating building block =========#

#============== FLOPs reduction monitor =============#
class DivMon(mx.operator.CustomOp):
    def __init__(self, stop):
        self._stop = float(stop)
    
    def forward(self, is_train, req, in_data, out_data, aux):
        x0 = in_data[0]
        x1 = in_data[1]
        x2 = in_data[2]
        flops = (mx.nd.ones(x1.shape)-x1/x2)*(1-self._stop)
        self.assign(out_data[0], req[0], x0)
        self.assign(out_data[1], req[0], flops)
    def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
        self.assign(in_grad[0], req[0], out_grad[0])

@mx.operator.register("DivMon")
class DivMonProp(mx.operator.CustomOpProp):
    def __init__(self, stop):
        super(DivMonProp, self).__init__(True)
        self._stop = float(stop)
    def list_arguments(self):
        return ['data0', 'data1', 'data2']
    def list_outputs(self):
        return ['psum', 'flops']
    def infer_shape(self, in_shapes):
        data_shape = [1]
        return [in_shapes[0], data_shape, data_shape], [in_shapes[0], data_shape]
    def create_operator(self, ctx, in_shapes, in_dtypes):
        return DivMon(self._stop)
#============== FLOPs reduction monitor =============#

#============== pixel level gate with gluon block =============#
class pSGate(mx.operator.CustomOp):
    def __init__(self, gamma, f, depth, target):
        self._gamma  = float(gamma)
        self._f = float(f)
        self._depth = float(depth)
        self._target = float(target)
    
    def forward(self, is_train, req, in_data, out_data, aux):
        # get the conv results from both parts
        x0 = in_data[0] # Wp*Xp
        _x0 = in_data[1] # BN(Wp*Xp)
        x1 = in_data[2] # Wp*Xp+Wr*Xr
        th = in_data[4]
        stop = in_data[3]

        th = th+mx.nd.ones(th.shape)*self._target
        th = mx.nd.broadcast_to(th.reshape([th.shape[0],1,1]), shape=(x0.shape[1], x0.shape[2], x0.shape[3]))
        # using step function for the forward path
        mask1 = mx.nd.broadcast_greater(_x0, th)
        mask0 = mx.nd.ones(mask1.shape) - mask1
        out = x1*mask1+x0*mask0
        sps = mx.nd.mean(mask1, axis=(0,2,3)) 
        # num of pixels need to be computed over all input channels
        r = mx.nd.mean(mask1)
        rr = 1-r
        # mac = f^2*d*out.size
        mac = self._f*self._f*self._depth*out.shape[1]*out.shape[2]*out.shape[3]
        # the remaining computation cost in MAC operations
        rmcc = mac-rr*(1-stop)*mac
        self.assign(out_data[0], req[0], out)
        self.assign(out_data[1], req[0], rmcc)
        self.assign(out_data[2], req[0], r)
        self.assign(out_data[3], req[0], mac)
        
    def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
        x0 = in_data[0]
        _x0 = in_data[1]
        x1 = in_data[2]
        th = in_data[4]
        dy0 = out_grad[0]
        
        th = th+mx.nd.ones(th.shape)*self._target
        th = mx.nd.broadcast_to(th.reshape([th.shape[0],1,1]), shape=(_x0.shape[1], _x0.shape[2], _x0.shape[3]))
        g_x0 = dy0*(x0-x1)

        # Based on Stochastic Neurons (Yoshua 2013), it better to ignore the 
        # derivative of the sigmoid
        power = self._gamma*mx.nd.broadcast_sub(th, _x0)
       
        # s(x, Delta)
        mask0 = mx.nd.Activation(data=power, act_type='sigmoid')
        # J - mask0
        mask1 = mx.nd.ones(mask0.shape) - mask0
        # gradient of the soft threshold function - h 
        sh = self._gamma*mask0*mask1
        _gx0 = -g_x0*sh
        # gradient of threashold
        gh = mx.nd.sum(-_gx0, axis=(0,2,3))
        mmask1 = mx.nd.broadcast_greater(_x0, th)
        mmask0 = mx.nd.ones(mmask1.shape) - mmask1
        # gradient of x1
        gx1 = dy0*mmask1
        # gradient of x0
        gx0 = dy0*mmask0
        
        self.assign(in_grad[0], req[0], gx0)
        self.assign(in_grad[1], req[0], _gx0)
        self.assign(in_grad[2], req[0], gx1)
        self.assign(in_grad[4], req[0], gh)

@mx.operator.register("pSGate") #register with name "threhsold"
class pSGateProp(mx.operator.CustomOpProp):
    def __init__(self, gamma, f, depth, target):
        super(pSGateProp, self).__init__(True)
        self._gamma  = float(gamma)
        self._f = float(f)
        self._depth = float(depth)
        self._target = float(target)

    def list_arguments(self):
        return ['data0', 'data1', 'data2', 'stop', 'th']

    def list_outputs(self):
        return ['output', 'rmcc', 'r', 'mac']
    
    def infer_shape(self, in_shapes):
        data_shape = in_shapes[0]
        th_shape = (in_shapes[0][1],)
        output_shape = data_shape
        p_shape = (in_shapes[0][1],)
        c_shape = [1]
        return [data_shape, data_shape, data_shape, c_shape, th_shape], [output_shape, c_shape, c_shape, c_shape]

    def create_operator(self, ctx, in_shapes, in_dtypes):
        return pSGate(self._gamma, self._f, self._depth, self._target)

class pSGateBlock(mx.gluon.HybridBlock):
    def __init__(self, nsigma, gamma, c, f, depth, target, **kwargs):
        super(pSGateBlock, self).__init__(**kwargs)
        with self.name_scope():
            self._nsigma = float(nsigma)
            self._gamma = float(gamma)
            self._c = c
            self._f = f
            self._depth = depth
            self._target = target
            self.h = self.params.get('h', shape=(self._c,), 
                    init=mx.init.Constant(self._nsigma), wd_mult=1.0, 
                    lr_mult=1.0)

    def hybrid_forward(self, F, x0, _x0, x1, s, h):
        return F.Custom(x0, _x0, x1, s, h, gamma=self._gamma, f=self._f, 
                depth=self._depth, target=self._target, op_type='pSGate') 
#============== pixel level gate with gluon block =============#

#============== channel level gate with gluon block =============#
class cSGate(mx.operator.CustomOp):
    def __init__(self, gamma, f, depth, target):
        self._gamma  = float(gamma)
        self._f = float(f)
        self._depth = float(depth)
        self._target = float(target)
    
    def forward(self, is_train, req, in_data, out_data, aux):
        # get the conv results from both parts
        x0 = in_data[0]
        _x0 = in_data[1]
        x1 = in_data[2]
        th = in_data[4]
        stop = in_data[3]
        th = th+mx.nd.ones(th.shape)*self._target
        
        # using step function for the forward path
        mask1 = mx.nd.broadcast_greater(_x0, th)
        mask1 = mx.nd.broadcast_to(mask1.reshape([mask1.shape[0],mask1.shape[1],1,1]), shape=x0.shape)
        mask0 = mx.nd.ones(mask1.shape) - mask1
        out = x1*mask1+x0*mask0
        
        # num of pixels need to be computed over all input channels
        r = mx.nd.mean(mask1)
        rr = 1-r
        # mac = f^2*d*out.size
        mac = self._f*self._f*self._depth*out.shape[1]*out.shape[2]*out.shape[3]
        # the remaining computation cost in MAC operations
        rmcc = mac-rr*(1-stop)*mac
        
        self.assign(out_data[0], req[0], out)
        self.assign(out_data[1], req[0], rmcc)
        self.assign(out_data[2], req[0], r)
        self.assign(out_data[3], req[0], mac)
        self.assign(out_data[4], req[0], mx.nd.mean(th))
    
    def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
        # get the conv results from both parts
        x0 = in_data[0]
        _x0 = in_data[1]
        x1 = in_data[2]
        th = in_data[4]
        dy0 = out_grad[0]
        
        th = th+mx.nd.ones(th.shape)*self._target
        #th = mx.nd.broadcast_to(th.reshape([th.shape[0],1,1]), shape=(_x0.shape[1], _x0.shape[2], _x0.shape[3]))
        power = self._gamma*mx.nd.broadcast_sub(th, _x0)
       
        # s(x, Delta)
        mask0 = mx.nd.Activation(data=power, act_type='sigmoid')
        mask0 = mx.nd.broadcast_to(mask0.reshape([mask0.shape[0],mask0.shape[1],1,1]),shape=x0.shape)
        # J - mask0
        mask1 = mx.nd.ones(mask0.shape) - mask0
        g_x0 = dy0*(x0-x1)
        # gradient of the soft threshold function - h 
        sh = self._gamma*mask0*mask1
        _gx0 = -g_x0*sh
        _gx0 = mx.nd.sum(_gx0, axis=(2,3))
        # gradient of threashold
        gh = mx.nd.sum(-_gx0, axis=0)
        # gradient of x1
        gx1 = dy0*mask1
        # gradient of x0
        gx0 = dy0*mask0
        
        self.assign(in_grad[0], req[0], gx0)
        self.assign(in_grad[1], req[0], _gx0)
        self.assign(in_grad[2], req[0], gx1)
        self.assign(in_grad[4], req[0], gh)

@mx.operator.register("cSGate") #register with name "threhsold"
class cSGateProp(mx.operator.CustomOpProp):
    def __init__(self, gamma, f, depth, target):
        super(cSGateProp, self).__init__(True)
        self._gamma  = float(gamma)
        self._f = float(f)
        self._depth = float(depth)
        self._target = float(target)

    def list_arguments(self):
        return ['data0', 'data1', 'data2', 'stop', 'th']

    def list_outputs(self):
        return ['output', 'rmcc', 'r', 'mac', 'h']
    
    def infer_shape(self, in_shapes):
        data_shape = in_shapes[0]
        actv_shape = (in_shapes[0][0], in_shapes[0][1],) #batch and output channels
        th_shape = (in_shapes[0][1],)
        output_shape = data_shape
        p_shape = (in_shapes[0][1],)
        c_shape = [1]
        return [data_shape, actv_shape, data_shape, c_shape, th_shape], [output_shape, 
                c_shape, c_shape, c_shape, c_shape]

    def create_operator(self, ctx, in_shapes, in_dtypes):
        return cSGate(self._gamma, self._f, self._depth, self._target)

class cSGateBlock(mx.gluon.HybridBlock):
    def __init__(self, nsigma, gamma, c, f, depth, target, **kwargs):
        super(cSGateBlock, self).__init__(**kwargs)
        with self.name_scope():
            self._nsigma = float(nsigma)
            self._gamma = float(gamma)
            self._c = c
            self._f = f
            self._depth = depth
            self._target = target
            self.h = self.params.get('h', shape=(self._c,), 
                    init=mx.init.Constant(self._nsigma), wd_mult=1.0, 
                    lr_mult=1.0)

    def hybrid_forward(self, F, x0, _x0, x1, s, h):
        return F.Custom(x0, _x0, x1, s, h, gamma=self._gamma, f=self._f, 
                depth=self._depth, target=self._target, op_type='cSGate') 
#============== channel level gate with gluon block =============#

#============== layer level gate with gluon block =============#
class lSGate(mx.operator.CustomOp):
    def __init__(self, gamma, f, depth, target):
        self._gamma  = float(gamma)
        self.shape=(self._c,), _f = float(f)
        self._depth = float(depth)
        self._target = float(target)
    
    def forward(self, is_train, req, in_data, out_data, aux):
        # get the conv results from both parts
        x0 = in_data[0]
        _x0 = in_data[1]
        x1 = in_data[2]
        th = in_data[4]
        stop = in_data[3]
        th = th+self._target
        
        # using step function for the forward path
        mask1 = mx.nd.broadcast_greater(_x0, th)
        mask1 = mx.nd.broadcast_to(mask1.reshape([mask1.shape[0],1,1,1]), shape=x0.shape)
        mask0 = mx.nd.ones(mask1.shape) - mask1
        out = x1*mask1+x0*mask0
        
        # num of pixels need to be computed over all input channels
        p = mx.nd.sum(mask1, axis=(0,2,3))
        r = mx.nd.sum(p)/mask1.size
        
        #diff = mx.nd.exp(-mx.nd.mean(mx.nd.abs(x1-out)))
        # the cost in MAC operations for each output pixel
        c = self._f*self._f*(1-stop)*self._depth
        # mac = f^2*d*out.size
        mac = self._f*self._f*self._depth*out.shape[1]*out.shape[2]*out.shape[3]

        self.assign(out_data[0], req[0], out)
        self.assign(out_data[1], req[0], p)
        self.assign(out_data[2], req[0], c)
        self.assign(out_data[3], req[0], r)
        self.assign(out_data[4], req[0], mac)
        self.assign(out_data[5], req[0], mx.nd.mean(th))

        
    def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
        x0 = in_data[0]
        _x0 = in_data[1]
        x1 = in_data[2]
        th = in_data[4]
        dy0 = out_grad[0]
        
        th = th+self._target
        #th = mx.nd.broadcast_to(th.reshape([th.shape[0],1,1]), shape=(_x0.shape[1], _x0.shape[2], _x0.shape[3]))
        power = self._gamma*mx.nd.broadcast_sub(th, _x0)
       
        # s(x, Delta)
        mask0 = mx.nd.Activation(data=power, act_type='sigmoid')
        mask0 = mx.nd.broadcast_to(mask0.reshape([mask0.shape[0],1,1,1]),shape=x0.shape)
        # J - mask0
        mask1 = mx.nd.ones(mask0.shape) - mask0
        g_x0 = dy0*(x0-x1)
        # gradient of the soft threshold function - h 
        sh = self._gamma*mask0*mask1
        _gx0 = -g_x0*sh
        _gx0 = mx.nd.sum(_gx0, axis=(1,2,3))
        # gradient of threashold
        gh = mx.nd.sum(-_gx0, axis=0)
        # gradient of x1
        gx1 = dy0*mask1
        # gradient of x0
        gx0 = dy0*mask0
        
        self.assign(in_grad[0], req[0], gx0)
        self.assign(in_grad[1], req[0], _gx0)
        self.assign(in_grad[2], req[0], gx1)
        self.assign(in_grad[4], req[0], gh)


@mx.operator.register("lSGate") #register with name "threhsold"
class lSGateProp(mx.operator.CustomOpProp):
    def __init__(self, gamma, f, depth, target):
        super(lSGateProp, self).__init__(True)
        self._gamma  = float(gamma)
        self._f = float(f)
        self._depth = float(depth)
        self._target = float(target)

    def list_arguments(self):
        return ['data0', 'data1', 'data2', 'stop', 'th']

    def list_outputs(self):
        return ['output', 'p', 'c', 'r', 'mac', 'h']
    
    def infer_shape(self, in_shapes):
        data_shape = in_shapes[0]
        actv_shape = (in_shapes[0][0],)
        th_shape = [1]
        output_shape = data_shape
        p_shape = (in_shapes[0][1],)
        c_shape = [1]
        return [data_shape, actv_shape, data_shape, c_shape, th_shape], [output_shape, 
                p_shape, c_shape, c_shape, c_shape, c_shape]

    def create_operator(self, ctx, in_shapes, in_dtypes):
        return lSGate(self._gamma, self._f, self._depth, self._target)

class lSGateBlock(mx.gluon.HybridBlock):
    def __init__(self, nsigma, gamma, c, f, depth, target, **kwargs):
        super(lSGateBlock, self).__init__(**kwargs)
        with self.name_scope():
            self._nsigma = float(nsigma)
            self._gamma = float(gamma)
            self._c = c
            self._f = f
            self._depth = depth
            self._target = target
            self.h = self.params.get('h', shape=(1,), 
                    init=mx.init.Constant(self._nsigma), wd_mult=1.0, 
                    lr_mult=1.0)

    def hybrid_forward(self, F, x0, _x0, x1, s, h):
        return F.Custom(x0, _x0, x1, s, h, gamma=self._gamma, f=self._f, 
                depth=self._depth, target=self._target, op_type='lSGate') 
#============== layer level gate with gluon block =============#

#============== helper function for choosing gate =============#
def lvl(level, nsigma, gamma, nfltr, kernel, depth, target):
    return {
            'pixel': pSGateBlock(nsigma, gamma, nfltr, kernel, depth, target),
            'channel': cSGateBlock(nsigma, gamma, nfltr, kernel, depth, target),
            'layer': lSGateBlock(nsigma, gamma, nfltr, kernel, depth, target),
            }[level]
#============== helper function for choosing gate =============#
