import numpy

class Opt:
    SGD = 'sgd'
    SAGA = 'saga'
    SSVRG = 'ssvrg'

class Model:
    # returns number of gradients computed
    def update_step(self, training_point, step_size, mu):
        if self.opt == Opt.SGD:
            self.sgd_step(training_point, step_size, mu)
            return 1
        elif self.opt == Opt.SAGA:
            self.saga_step(training_point, step_size, mu)
            return 1
        elif self.opt == Opt.SSVRG:
            return self.ssvrg_step(training_point, step_size, mu)
    
    def sgd_step(self, training_point, step_size, mu):
        raise NotImplementedError()
        
    def saga_step(self, training_point, step_size, mu):
        raise NotImplementedError()
        
    def ssvrg_step(self, training_point, step_size, mu):
        raise NotImplementedError()
        
    def loss(self, data):
        raise NotImplementedError()
        
    def reg_loss(self, data, mu):
        raise NotImplementedError()
        

class MatrixFactorization(Model):
    def __init__(self, init_L, init_R, opt, SSVRG_M = 1000):
        self.L = numpy.copy(init_L)
        self.R = numpy.copy(init_R)
        self.opt = opt
        
        self.tableL = {}
        self.tableR = {}
        self.table_sumL = numpy.zeros(self.L.shape)
        self.table_sumR = numpy.zeros(self.R.shape)
        
        self.b = 3
        self.k_size = 32
        self.flag_full = True    # in phase 1 computing the full (stable) gradient, o.w. in phase 2 doing update steps
        self.fullL = numpy.zeros(self.L.shape)
        self.fullR = numpy.zeros(self.R.shape)
        self.staleL = numpy.zeros(self.L.shape)
        self.staleR = numpy.zeros(self.R.shape) 
        self.t1 = 0              # progress in phase 1
        self.t2 = 0              # progress in phase 2
        self.M = SSVRG_M
        
    def sgd_step(self, training_point, step_size, mu):
        (i, j, val) = training_point
        pred_err = numpy.dot(self.L[i,:], self.R[:,j]) - val
        L_temp      = (1-step_size*mu)*self.L[i,:] - step_size*(pred_err*self.R[:,j])
        self.R[:,j] = (1-step_size*mu)*self.R[:,j] - step_size*(pred_err*self.L[i,:])
        self.L[i,:] = L_temp
     
    def saga_step(self, training_point, step_size, mu):
        (i, j, val) = training_point
        pred_err = numpy.dot(self.L[i,:], self.R[:,j]) - val
        gL = pred_err*self.R[:,j]
        gR = pred_err*self.L[i,:]
        m = len(self.tableL) if len(self.tableL)!=0 else 1
        
        if (i, j) in self.tableL:
            alphaL = self.tableL[(i, j)]
            alphaR = self.tableR[(i, j)]
        else:
            alphaL = numpy.zeros(gL.shape)
            alphaR = numpy.zeros(gR.shape)
            
        self.L[i,:] = (1-step_size*mu)*self.L[i,:] - step_size*(gL - alphaL)
        self.R[:,j] = (1-step_size*mu)*self.R[:,j] - step_size*(gR - alphaR)
        
        self.L -= step_size*(1./m)*self.table_sumL
        self.R -= step_size*(1./m)*self.table_sumR
        
        self.tableL[(i, j)] = gL
        self.tableR[(i, j)] = gR
        self.table_sumL[i,:] += gL - alphaL
        self.table_sumR[:,j] += gR - alphaR
        
    def ssvrg_step(self, training_point, step_size, mu):
        (i, j, val) = training_point
        
        if (self.flag_full):
            pred_err_stale = numpy.dot(self.staleL[i,:], self.staleR[:,j]) - val 
            self.fullL[i,:] = self.fullL[i,:] + (1./self.k_size)*pred_err_stale*self.staleR[:,j]
            self.fullR[:,j] = self.fullR[:,j] + (1./self.k_size)*pred_err_stale*self.staleL[i,:]
            
            self.t1 += 1
            if (self.t1 == self.k_size):
                self.t1 = 0
                self.k_size = self.b*self.k_size
                self.flag_full = False
            
            return 1
        else:
            pred_err = numpy.dot(self.L[i,:], self.R[:,j]) - val 
            pred_err_stale = numpy.dot(self.staleL[i,:], self.staleR[:,j]) - val 
            gL = pred_err*self.R[:,j]
            gR = pred_err*self.L[i,:]
            oL = pred_err_stale*self.staleR[:,j]
            oR = pred_err_stale*self.staleL[i,:]
            
            self.L[i,:] = (1-step_size*mu)*self.L[i,:] - step_size*(gL - oL)
            self.R[:,j] = (1-step_size*mu)*self.R[:,j] - step_size*(gR - oR)
            
            self.L -= step_size*self.fullL
            self.R -= step_size*self.fullR
            
            self.t2 += 1
            if (self.t2 == self.M):
                self.t2 = 0
                self.flag_full = True
                self.fullL = numpy.zeros(self.L.shape)
                self.fullR = numpy.zeros(self.R.shape)
                self.staleL = numpy.copy(self.L)
                self.staleR = numpy.copy(self.R)
            
            return 2
        
    def loss(self, data):
        return sum( (val-numpy.dot(self.L[i,:], self.R[:,j]))**2 for (i, j, val) in data )/len(data)
        
    def reg_loss(self, data, mu):
        if len(data) == 0:
            return 0
        return sum( (val-numpy.dot(self.L[i,:], self.R[:,j]))**2 + mu*(numpy.dot(self.L[i,:], self.L[i,:]) + numpy.dot(self.R[:,j], self.R[:,j])) for (i, j, val) in data )/len(data)
        
class LogisticRegression(Model):
    def __init__(self, init_w, opt, SSVRG_M = 1000):
        self.w = numpy.copy(init_w)
        self.opt = opt
        
        self.table = {}
        self.table_sum = numpy.zeros(self.w.shape)
        
        self.b = 3
        self.k_size = 32
        self.flag_full = True    # in phase 1 computing the full (stable) gradient, o.w. in phase 2 doing update steps
        self.full = numpy.zeros(self.w.shape)
        self.stale = numpy.zeros(self.w.shape)
        self.t1 = 0              # progress in phase 1
        self.t2 = 0              # progress in phase 2
        self.M = SSVRG_M
        
    # dot product with a sparse feature vector
    def dot_product(self, x):
        return sum(self.w[k]*v for (k,v) in x.iteritems())
        
    def dot_product_stale(self, x):
        return sum(self.stale[k]*v for (k,v) in x.iteritems())
 
    # apply regularization penalty to only sparse coordinates, better performance
    def sgd_step(self, training_point, step_size, mu):
        (i, x, y) = training_point
        p = 1./(1 + numpy.exp(y*self.dot_product(x)))
        
        for (k, v) in x.iteritems():
            self.w[k] = (1-step_size*mu)*self.w[k] - step_size*(-1*p*y*v)
        
    def saga_step(self, training_point, step_size, mu):
        (i, x, y) = training_point
        p = 1./(1 + numpy.exp(y*self.dot_product(x)))
        g = -1*p*y
        alpha = self.table[i] if i in self.table else 0
        m = len(self.table) if len(self.table)!= 0 else 1
        
        for (k, v) in x.iteritems():
            self.w[k] = (1-step_size*mu)*self.w[k] - step_size*(g-alpha)*v
            
        self.w -= step_size*(1./m)*self.table_sum
          
        self.table[i] = g
        for (k, v) in x.iteritems():
            self.table_sum[k] += (g-alpha)*v

    def ssvrg_step(self, training_point, step_size, mu):
        (i, x, y) = training_point
        
        if (self.flag_full):
            p_stale = 1./(1 + numpy.exp(y*self.dot_product_stale(x)))
            
            for (k, v) in x.iteritems():
                self.full[k] += -1*p_stale*y*v * (1./self.k_size)
                
            self.t1 += 1
            if (self.t1 == self.k_size):
                self.t1 = 0
                self.k_size = self.b*self.k_size
                self.flag_full = False
            
            return 1
        else:
            p = 1./(1 + numpy.exp(y*self.dot_product(x)))
            p_stale = 1./(1 + numpy.exp(y*self.dot_product_stale(x)))
            
            for (k, v) in x.iteritems():
                self.w[k] = (1-step_size*mu)*self.w[k] - step_size*(-1*y*v)*(p-p_stale)
            
            self.w -= step_size*self.full
            
            self.t2 += 1
            if (self.t2 == self.M):
                self.t2 = 0
                self.flag_full = True
                self.full = numpy.zeros(self.w.shape)
                self.stale = numpy.copy(self.w)
            
            return 2
        
    def loss(self, data):
        if len(data) == 0:
            return 0
            
        return sum( numpy.log(1+numpy.exp(-1*y*self.dot_product(x))) for (i,x,y) in data )/len(data)
        
    def reg_loss(self, data, mu):
        if len(data) == 0:
            return 0
    
        return sum( numpy.log(1+numpy.exp(-1*y*self.dot_product(x))) + 0.5*mu*sum(self.w[k]**2 for (k,v) in x.iteritems()) for (i,x,y) in data )/len(data)
        
    def zero_one_loss(self, data):
        return sum( (1 if self.dot_product(x) > 0 else -1)!= y for (i,x,y) in data )*1.0/len(data)
