import logging# , sys

_LOGGER_RT = logging.getLogger(name='optim.sampling_rate')
_LOGGER_SM = logging.getLogger(name='optim.sampler')

from util_py.arg_parsing  import ArgumentRegistry

##### -------------------------------- #####
#    These classes provide samplers, and require
#    that someone else decide the RATE of sampling.
#    We refer to the sampling rate below to get the 
#    rates.
##### -------------------------------- #####


class AbsSampler():
    r'''
    This only establishes the interface
    '''
        
    def initialize(self):
        pass
    
    def next_sample(self, siz):
        r'''
        Expected return = <sample set>
        '''
        raise NotImplementedError("this is an abstract class defining the interface.")
    

#class SetRepeater(AbsSampler):
#    r'''
#    This only establishes the interface
#    '''
#        
#    def __init__(self, full_set):
#        self.ret_set = full_set
#    
#    def next_sample(self, siz):
#        r'''
#        Expected return = <sample set>
#        '''
#        
#        return self.ret_set
        

class PermutationSampler(AbsSampler):
    '''
    This Sampler provides a sample of size ssiz from the permutation of size
    self.max_size. If the caches empties out before ssiz is reached then we
    refill the cache. 
    '''
    
    def __init__(self, maxsz):
        # using permutations
        self.curr_perm = np.random.permutation(maxsz)
        self.max_size = maxsz        

        self.sample_array = np.zeros(maxsz, dtype=int)
        
        # this vector stores the current samples of size M_t <= len_samples_all. 
        # Each half of the set will be the "odd/even" subsets

        # this points to the current pointer in the permutation of the sample set
        self.curr_pntr = 0 # if curr_pntr + M_t > size of permutation, then need to resample!
        self.cum_samples_accessed= 0 
    
    
    def next_sample(self, ssiz):
        
        if ssiz > self.max_size:
            errm = "The desired size {} should NOT be larger than the max-size {} allotted at initiation.".format(
                    ssiz,self.max_size)
            _LOGGER_SM.error(errm)
            raise ValueError(errm)
        
        if ssiz == self.max_size:
            self.sample_array[:] = range(self.max_size)
            self.cum_samples_accessed += ssiz
            return self.sample_array
        
        
        # sample from permutation, first check if remaining permutation has enuf samples left
        len_lft=min(self.max_size-(self.curr_pntr+1), ssiz)
        self.sample_array[0:len_lft] = self.curr_perm[self.curr_pntr: (self.curr_pntr+len_lft)]

        # maintain the current permutation and its starting point
        self.curr_pntr += len_lft
        if(self.curr_pntr == self.max_size -1) :
            # need a fresh permutation, so shuffle the current permutation
            np.random.shuffle(self.curr_perm)
            self.curr_pntr = 0
                
        if (len_lft < ssiz) :
            # needed more samples, just ran out of permutation
            self.sample_array[ len_lft:ssiz ] = self.curr_perm[ self.curr_pntr : (self.curr_pntr+(ssiz-len_lft))]
            self.curr_pntr += (ssiz - len_lft)
        
        self.cum_samples_accessed += ssiz
        
        return self.sample_array[0:ssiz]
    
    
    
class SetPermutationSampler(AbsSampler):
    '''
    This Sampler provides a sample of size ssiz from the permutation of size
    self.max_size. If the caches empties out before ssiz is reached then we
    refill the cache. 
    '''
    
    def __init__(self, fset):
        self.full_set = fset
        # using permutations
        if type(self.full_set) == np.ndarray:
            self.set_size = self.full_set.shape[0]
        else:
            # try len, but it will fail with exception if not list!
            self.set_size = len(self.full_set)

        self.perm_samp = PermutationSampler(self.set_size) 
        
    
    def next_sample(self, ssiz):        
        return self.full_set[self.perm_samp.next_sample(ssiz)]
    
##### -------------------------------- #####
#    These classes provide the RATE of sampling.
#    We refer to the samplers above to get the 
#    sample itself.
##### -------------------------------- #####

        

class AbsSamplingRate(object):
    
    @staticmethod
    def get_name():
        raise NotImplementedError("this is an abstract class defining the interface.")
        
#    def __init__(self):
        
            
    def next_sample_size(self, algodata):
        raise NotImplementedError("this is an abstract class defining the interface.")
        
    def initialize(self):
        pass
        
class FixedSamplingRate(AbsSamplingRate):
    
    _name = 'fixed'
    def get_name():
        return FixedSamplingRate._name

    _argname_fixed, _argdefval_fixed= _name+'_size', 1
    
    def fill_args_registry(arg_reg):
        arg_reg.register_float_arg( FixedSamplingRate._argname_fixed,
                                   'sample size',FixedSamplingRate._argdefval_fixed)
        
    def __init__(self,arg_reg=None, arg_dict=None, siz=1):

        if (arg_dict is None ) or (arg_reg is None):
            self.size = siz
        else:
            self.size = arg_dict[arg_reg.get_arg_fullname(FixedSamplingRate._argname_fixed)]
        
    def next_sample_size(self,algodata):
        return self.size
    
    
from math import floor, inf
    
class GeometricRate(AbsSamplingRate):
    
    
    _name = 'geometric'
    def get_name():
        return GeometricRate._name

    _formula = 'min ( int(initval * geomfactor ^(\lfloor n_iter/stride \rfloor ) ), maxsize)'
    _argname_initial, _argdefval_initial = _name+'_initial', 1.
    _argname_geomfac, _argdefval_geomfac = _name+'_geomfactor', 1.01
    _argname_stride , _argdefval_stride  = _name+'_stride' , 1.0 
    _argname_max    , _argdefval_max     = _name+'_max' , inf
    
    def fill_args_registry(arg_reg):
        arg_reg.register_float_arg( GeometricRate._argname_initial,
                                   'initial size in \n\t{}'.format(GeometricRate._formula),
                                   GeometricRate._argdefval_initial)
        arg_reg.register_float_arg( GeometricRate._argname_geomfac,
                                   'geomfactor in \n\t{}'.format(GeometricRate._formula),
                                   GeometricRate._argdefval_geomfac)
        arg_reg.register_float_arg(GeometricRate._argname_stride, 
                                   'stride in \n\t{}'.format(GeometricRate._formula),
                                   GeometricRate._argdefval_stride)        
        arg_reg.register_float_arg(GeometricRate._argname_max, 
                                   'max in \n\t{}, "inf" for infinity'.format(GeometricRate._formula),
                                   GeometricRate._argdefval_max)        

    def __init__(self,arg_reg=None, arg_dict=None, init=1., gfac=1.01, strd=1.0, mx=inf):

        if (arg_dict is None ) or (arg_reg is None):
            self.initsize = init
            self.geomfactor = gfac
            self.stride = strd
            self.maxsiz = mx
        else:
            self.initsize = arg_dict[arg_reg.get_arg_fullname(GeometricRate._argname_initial)]
            self.geomfactor = arg_dict[arg_reg.get_arg_fullname(GeometricRate._argname_geomfac)]
            self.stride = arg_dict[arg_reg.get_arg_fullname(GeometricRate._argname_stride)]
            self.maxsiz = arg_dict[arg_reg.get_arg_fullname(GeometricRate._argname_max)]


    def initialize(self):
        self.size = self.initsize
        self.count = 1

    def next_sample_size(self, algodata):
        
        t_cnt = max(1, floor( algodata.n_itr.value / self.stride) )
        if t_cnt > self.count:
            self.size *= self.geomfactor
            self.count = t_cnt
        
        return min( self.maxsiz , int(self.size) )


from math import sqrt
import numpy as np

_giles_argdefval_geompar=1.-1/(2.*sqrt(2))
class GilesRate(AbsSamplingRate):
   
    _name = 'giles'
    def get_name():
        return GilesRate._name

    _formula = 'min ( 2 ** max ( min_expon, Geometric( geomparam ) ) + 1, maxsize)'
    _argname_minexp, _argdefval_minexp   = _name+'_minexpon', 0
    _argname_geompar, _argdefval_geompar = _name+'_geomparam', _giles_argdefval_geompar
#    _argname_stride , _argdefval_stride  = _name+'_stride' , 1.0 
    _argname_max    , _argdefval_max     = _name+'_max' , inf

        
    def fill_args_registry(arg_reg):
        arg_reg.register_float_arg(GilesRate._argname_max, 
                                   'max in \n\t{}, "inf" for infinity'.format(GilesRate._formula),
                                   GilesRate._argdefval_max)        

        arg_reg.register_int_arg( GilesRate._argname_minexp,
                                   'min expon in \n\t{}'.format(GilesRate._formula),
                                   GilesRate._argdefval_minexp)
        
        arg_reg.register_float_arg( GilesRate._argname_geompar,
                                   'geomparam in \n\t{}'.format(GilesRate._formula),
                                   GilesRate._argdefval_geompar)


    def __init__(self,arg_reg=None, arg_dict=None, minex=1., 
                 gpar=_giles_argdefval_geompar, mx=inf):

        if (arg_dict is None ) or (arg_reg is None):
            self.minexpon = minex
            self.geomparam = gpar
            self.maxsiz = mx
        else:
            self.minexpon = arg_dict[arg_reg.get_arg_fullname(GilesRate._argname_minexp)]
            self.geomparam= arg_dict[arg_reg.get_arg_fullname(GilesRate._argname_geompar)]
            self.maxsiz = arg_dict[arg_reg.get_arg_fullname(GilesRate._argname_max)]
 
        _LOGGER_RT.info("Giles Sampling rate params: min( {:6d}, 2**({:4.1f} + Geom({:12.8f}) ) ) + 1 ".format(
                self.maxsiz, self.minexpon, self.geomparam))

    def initialize(self):
        self.prob = 0.0

    def next_sample_size(self, algodata):
        #get the number of samples needed
        N_t = np.random.geometric(self.geomparam)
        self.prob = self.geomparam *(1-self.geomparam)**(N_t-1)

        retval= int(min(self.maxsiz-1, 2**(self.minexpon + N_t)) + 1 )  
        
        if _LOGGER_RT.isEnabledFor(logging.DEBUG):
            _LOGGER_RT.debug("Sampled rate: {:6d} = min( {:6d}, 2**({:4.1f} + {:5.2f}) ) + 1 ".format(
                    retval, self.maxsiz, self.minexpon, N_t))


        return retval
    
    
SamplingClassList = [FixedSamplingRate, GeometricRate, GilesRate]

_arg_reg_base = 'samplingrate'
_argname_type='type'

def get_sampling_rate_arg_registry(extra_classes=None) :
    
    argdefval_type =[]  
    for c in SamplingClassList:
        argdefval_type.append(c.get_name())

    if extra_classes is not None:
        for c in extra_classes:
            argdefval_type.append(c.get_name())
            
    arg_reg = ArgumentRegistry(_arg_reg_base)

    arg_reg.register_str_arg(_argname_type,
                             'which learning rate schedule to use',
                             argdefval_type[0],
                             argdefval_type)

    for cl in SamplingClassList:          
        cl.fill_args_registry(arg_reg)

    if extra_classes is not None:
        for cl in extra_classes:
            cl.fill_args_registry(arg_reg)
    
    return arg_reg    


def instantiate_sampling_rate(arg_dict, addlcls=None):

    arg_registry = get_sampling_rate_arg_registry(addlcls)
    # read in the args
    lrnm = arg_dict[arg_registry.get_arg_fullname(_argname_type)]

    for cl in SamplingClassList:
        if lrnm == cl.get_name():
            return cl(arg_registry, arg_dict)
        
    if addlcls is not None:
        for cl in addlcls:
            if lrnm == cl.get_name():
                return cl(arg_registry, arg_dict)

    raise ValueError('have not implemented learning rate \'{}\' yet.'.format(lrnm))
