import logging
from .line_search import get_linesearch_args_registry
# from util_py.arg_parsing  import ArgumentRegistry
from .abs_optim import AbsOptim

_LOGGER = logging.getLogger(name='optim.detgd')

class DeterministicGradientDescent(AbsOptim):
    
    
    @staticmethod
    def get_args_registries(stct_xtra=None, ls_xtra = None):
        arg_list=super(DeterministicGradientDescent, DeterministicGradientDescent).get_arg_registries(stct_xtra)
        
                 
        #unified arg reg for linesearch
        arg_list.append(get_linesearch_args_registry(ls_xtra))

        return arg_list


    def __init__(self, objective, stopcrit, algstate, lnsrch, dirnmod, arg_dict=None,
                 ismx=False):

        super(DeterministicGradientDescent,self).__init__(
                objective, stopcrit, algstate, arg_dict, ismx)

        self.__line_search = lnsrch 
        self.__dirn_modifier = dirnmod
        
        self.descent_dirn = self.objective.get_blank_of_itersize()    

        self.reported_n_samp=1
        
    def set_dirn_modifier(self, dm):
        self.__dirn_modifier = dm
        
    def step(self):
        '''
        loss is calculated here and then gradient calculated.
        
        return value: it is traditional to return the calculated training loss
        '''
        
        # back prop to make autograd compute the gradient
        # evaluate higher derivatives, if available. Else just silent
        objval = self.objective.evaluate_fn_and_derivatives()
        
        # curr dirn copy along with umltiple -1.0
        self.objective.copy_gradient_to(self.descent_dirn ,self.step_dirn_mult)

        # modify the direction using Newton method, or Quasi Newton etc.
        self.__dirn_modifier.modify_direction(self.descent_dirn)
        
        # calculate the next steplength / learning rate 
        steplength, nfunc= self.__line_search.find_steplength(self.descent_dirn)

        # take step along dirn for this length
        self.objective.add_step_along_dirn(steplength, self.descent_dirn)

        self.algo_state.output_current(self.reported_n_samp, objval, steplength,
                                       self.descent_dirn)
        return objval
