
import math
import torch

p=2147483647
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'


def check_overflow(x):
    overflow = x[x>=p]
    num = torch.numel(overflow)
    #if num>0:
    #    import pdb; pdb.set_trace()
    overflow = torch.mul(torch.rand(num), p).to(device)
    x[x>=p] = overflow
    return x
    

def arelu(x, tr, def_pos):
	pos = torch.sum(x>=0)
	neg = torch.sum(x<0)

	if tr==-1:
		badsign_pos=0
		badsign_neg=0
		relu = torch.nn.functional.relu(x, inplace=True)
		return relu, badsign_pos, badsign_neg, pos, neg 

	xl = x.long()
	rs = torch.floor(torch.mul(torch.rand(xl.shape), p)).to(device).long()
	px = torch.add(p,xl)
	xf = torch.where(xl<0, px, xl)    #field
	xr = torch.add(xf, rs)
	x0 = torch.where(xr<p, xr, xr-p)
	nr = torch.sub(p, rs)
	x1 = torch.where(nr<p, nr, nr-p)

	#xreconst = torch.where(x0+x1<p, x0+x1, x0+x1-p)

	t = int(math.pow(2, tr))
	x0t = x0//t
	rst = rs//t
	rp0 = torch.sub(x0t, rst)
	rsp0 = torch.where(rp0<0, p//t+rp0, rp0)

	x0t1 = torch.where(x0<p/2, x0//t, p-(p-x0)//t)
	rst1 = torch.where(rs<p/2, rs//t, p-(p-rs)//t)
	rp1 = torch.sub(x0t1, rst1)
	rsp1 = torch.where(rp1<0, p+rp1, rp1)

	if not def_pos:
		# zero is neg and 1 is pos
		# if rt==x0t sign is assigned negaitve
		sign = torch.where(rst<x0t, torch.ones(x.shape).to(device), torch.zeros(x.shape).to(device))
		sign0 = torch.where(rsp0<(p/2)//t, torch.ones(x.shape).to(device), torch.zeros(x.shape).to(device))
		sign1 = torch.where(rsp1<p/2, torch.ones(x.shape).to(device), torch.zeros(x.shape).to(device))

	if def_pos:
		# if rt==x0t sign is assigned positive
		sign = torch.where(rst>x0t, torch.zeros(x.shape).to(device), torch.ones(x.shape).to(device))
		sign0 = torch.where(rsp0<(p/2)//t, torch.ones(x.shape).to(device), torch.zeros(x.shape).to(device))
		sign1 = torch.where(rsp1<p/2, torch.ones(x.shape).to(device), torch.zeros(x.shape).to(device))

	relu = x*sign

	truesign = torch.where(x<0, torch.zeros(x.shape).to(device), torch.ones(x.shape).to(device))
	badsign = torch.sum(torch.logical_and(truesign!=sign, x!=0))

	#badsign_pos = torch.logical_and((truesign==1), (sign==0)).sum()
	#badsign_neg = torch.logical_and((truesign==0), (sign==1)).sum()
	badsign_pos = torch.logical_and(torch.logical_and((truesign==1), (sign==0)), x!=0).sum()
	badsign_neg = torch.logical_and(torch.logical_and((truesign==0), (sign==1)), x!=0).sum()

	import pdb; pdb.set_trace()
	
	return relu, badsign_pos, badsign_neg, pos, neg 

# Stochastic relu implmentation


def srelu(x, tr, alpha, beta):
    pos = torch.sum(x>=0)
    neg = torch.sum(x<0)

    x = torch.mul(x, math.pow(2, 2*14-alpha-beta))
    if (torch.max(x)>p).item():
        print("Ooops exceeded p!")
        exit()

    ## ----- For default negative case -----------##

    # Collecting all the activations in the truncation range
    pos_trunc_range = x[(x>0) & (x<2**tr)]
    #pos_trunc_range = torch.where((x>=0) & (x<2**tr), x, -1)

    # Error probaility for the activations in the truncation range
    error_prob =  torch.div(torch.add(torch.mul(pos_trunc_range,-1),2**tr), 2**tr).to(device)
    # uniformly sample probabilities between 0 and 1

    prob = torch.rand(torch.numel(pos_trunc_range)).to(device)

    #Assigning zero if its randomly sampled probability is less than error rate
    pos_trunc_range[prob < error_prob] = 0

    #set values in original tensor
    x[(x>0) & (x<2**tr)] = pos_trunc_range
        
    #implementation of ReLU function
    relu = torch.where(x >= 0.0, x, torch.zeros(x.shape).to(device))
    relu = torch.div(relu, math.pow(2, 2*14-alpha-beta))
    #relu = relu.float() 

    badsign_pos = torch.sum(prob<error_prob) 
    badsign_neg = torch.tensor(0).to(device)

    
    return relu, badsign_pos, badsign_neg, pos, neg


