import numpy as np

import logging
_LOGGER = logging.getLogger(name='wdro.func_approx')


from wavelets import Marr,wavelet
import matplotlib.pyplot as plt


class FApprox:
    '''
    This class captures function approximations of the form:
        f(x) = \sum_{i=1}^I f_i(\theta^t_i x)
    where 
       f_i    are one-dimensional Lipschitz functions,
    \theta_i  are fixed a priori and provided to this class at 
              initiation, i=1,...,I
              
    This class models all f_i as wavelets, and provides methods 
    to query for function value as well as gradient at a given x.
    '''

    def __init__(self,thetas, rng_set, lvl):
        r'''
        Set of thetas, correpsonding range requried for  the wavelet functions
        and level of granulaity of wavelet approximation.
        '''        
        self.thetas = thetas
        self.dim = self.thetas.shape[1]
        self.n_thetas = self.thetas.shape[0]
        self.grain_lvl=lvl

        self.wavelets = []
        # self.wavelet_lims = np.zeros((self.n_thetas, 2))
        self.n_weights = 0

        # for hesian computation
        self.outer_thetas=np.zeros((self.n_thetas, self.dim,self.dim))
        
        for v in range(self.n_thetas):
            
            # get a fix on the desired range
            rngs = rng_set[v].copy() # [np.min(empvals) , np.max(empvals)]
            _LOGGER.info("for theta ndx {}, range is {} ".format(v,rngs))
            if rngs[0] > 0.: rngs[0] /=1.25
            else:            rngs[0] *=1.25
            if rngs[1] < 0.: rngs[1] /=1.25
            else:            rngs[1] *=1.25
            _LOGGER.info("modded range to {} ".format(rngs))

            # we now initialize the wavelets using the empirical sets
            self.wavelets.append (wavelet(rngs,Marr(),lvl) )        
            self.n_weights += self.wavelets[-1].n_weights
            #self.wavelet_lims[v][:] = self.wavelets[-1]._fn_lims #wvapx_tru_rng

            th2d=self.thetas[v].reshape((self.dim,1))
            self.outer_thetas[v,:,:] = th2d.dot(th2d.transpose())
            #print("value of theta: ", th2d)
            #print("shape of theta: ", th2d.shape)
            #print("shape of outerthetsa: ", self.outer_thetas[v].shape)
            #print("value of outerthetsa: ", self.outer_thetas[v])

        
    def randomize_weights(self, rng=1.):
        for v in range(self.n_thetas):
            self.wavelets[v].weights[:] = np.random.uniform(-rng,rng, 
                         size=(self.wavelets[v].n_weights,))            


    def initiate_weights_from_emprdistn(self, empr_set):
        r'''
        Set the weights of each of the wavelets using the empirical distribution
        provided by the empr_set.
        '''
        theta_xs=self.thetas @ (empr_set.transpose())
        for v in range(self.n_thetas):
            empvals = theta_xs[v]
            self.wavelets[v].init_weights_from_emprdistn(empvals)            

    def get_reasonable_wavelet_limits(self):
        retval = np.zeros((self.n_thetas, 2))
        
        for v in range(self.n_thetas):
            retval[v,:] = self.wavelets[v].guess_reasonable_funclimits()
            
        return retval
    
    def fill_value_thetax(self,theta_x, retval):
        for i in range(self.n_thetas):
            retval += self.wavelets[i].get_value(theta_x[i])
        return retval

    def get_value(self,x):
        '''
        x is asumed to be in the form x=[x1 ; x2 ;...; xn] where each ';' indicates 
        the start of another row.
        '''
        
        #_LOGGER.debug("f_getv x shape: {}".format(x.shape))
        retval = np.zeros((x.shape[0]))
        
        theta_x = self.thetas @ (x.transpose())

        self.fill_value_thetax(theta_x, retval)
        return retval
        
    def fill_gradient_wrt_x(self,x, retval):
        '''
        We assume here that x is ONE-DIMENSIONAL!! In other words, we seek gradient
        at ONE point x.
        '''
        assert(len(x.shape) ==1)
        
        theta_x = self.thetas @ (x.transpose())

        # print("theta_x: {}".format(theta_x))
        for i in range(self.n_thetas):
            wder = self.wavelets[i].get_deriv(theta_x[i])
            retval += wder * self.thetas[i]
        return retval

    
    def fill_hessian_wrt_x(self, x, retval):
        '''
        We assume here that x is ONE-DIMENSIONAL!! In other words, we seek gradient
        at ONE point x.
        '''
        assert(len(x.shape) ==1)
        
        theta_x = self.thetas @ (x.transpose())

        # print("theta_x: {}".format(theta_x))
        for i in range(self.n_thetas):
            wder = self.wavelets[i].get_second_deriv(np.array([theta_x[i]]))
            
            retval += wder * self.outer_thetas[i]
        return retval


    def fill_gradient_wrt_weights(self, x, retval, mult):

        theta_x = self.thetas @ (x.transpose())

        offset=0
        for v in range(self.n_thetas):
            self.wavelets[v].fill_basis(theta_x[v], retval, offset, mult)
            offset += self.wavelets[v].n_weights


    def copy_weights_to(self, retval):        
        offset=0
        for v in range(self.n_thetas):
            nd = offset+self.wavelets[v].n_weights
            retval[offset:nd] = self.wavelets[v].weights
            offset = nd
        
    def set_weights_from(self, retval):
        offset=0
        for v in range(self.n_thetas):
            nd = offset+self.wavelets[v].n_weights
            self.wavelets[v].weights[:] = retval[offset:nd] 
            offset = nd

    def plot_fcn(self, filepre, empr_sets):        
        theta_xses=[] # first ndx is theta, second is dataset
        for v in range(self.n_thetas):
            theta_xses.append([])
        
        for ecnt in range(len(empr_sets)):
            empr_set = empr_sets[ecnt]
            theta_x = self.thetas @ (empr_set.transpose())
            for v in range(self.n_thetas):
                theta_xses[v].append(theta_x[v])
                if _LOGGER.isEnabledFor(logging.DEBUG):
                    _LOGGER.debug("theta_xes[{}][{}] = {}".format(v,ecnt, theta_xses[v][-1]))
            
        for v in range(self.n_thetas):
            fmx,fg,(ax,ax2) = self.wavelets[v].plot_wavelet(theta_xses[v], False)
                          
            #print("Created fig, saving it...")
            # save fig
            if filepre is None:
                plt.show()
            else:
                fg.savefig(filepre+'-th[{}].png'.format(v),dpi=200)
            #_LOGGER.debug("Weights are: {}".format(self.wavelets[-1].weights))
            plt.close(fg)


