import numpy as np
from abc import ABCMeta, abstractmethod

class LossFunction(metaclass=ABCMeta):
    @abstractmethod
    def loss(self, pred_y, true_y):
        pass
    @abstractmethod
    def grad(self, pred_y, true_y):
        pass
    @abstractmethod
    def hess(self, pred_y, true_y):
        pass
    @abstractmethod
    def eval(self, pred_y, true_y):
        pass
    
class SquareLoss(LossFunction):
    def loss(self, pred_y, true_y):
        return 0.5 * (pred_y - true_y)**2
    def grad(self, pred_y, true_y):
        return pred_y - true_y
    def hess(self, pred_y, true_y):
        return np.ones(pred_y.size)
    def eval(self, pred_y, true_y):
        return (pred_y - true_y)**2
    
class LogisticLoss(LossFunction):
    def loss(self, pred_y, true_y):
        q = 1 / (1 + np.exp(-pred_y))
        q = np.maximum(np.minimum(q, 1-1e-8), 1e-8)
        return - true_y * np.log(q) - (1 - true_y) * np.log(1 - q)
    def grad(self, pred_y, true_y):
        q = 1 / (1 + np.exp(-pred_y))
        return q - true_y
    def hess(self, pred_y, true_y):
        q = 1 / (1 + np.exp(-pred_y))
        return q * (1 - q)
    def eval(self, pred_y, true_y):
        return ((pred_y>0) - true_y)**2
    
def fit_with_sgd(x, y, x_val, y_val, loss_func, alpha=0.1, lr=0.1, decay=True, num_epoch=1, batch_size=10, seed=0, skip_index=[]):
    assert x.shape[0] == y.size
    (n, d) = x.shape
    
    # sgd
    a = np.zeros(d)
    info = []
    c = 1
    k = int(np.floor(n / batch_size))
    for epoch in range(num_epoch):
        np.random.seed(seed+epoch)
        idx = np.array_split(np.random.permutation(n), k)
        for i in idx:
            b = i.size
            err = np.mean(loss_func.eval(x_val.dot(a), y_val))
            if decay:
                lr *= np.sqrt(c / (c + 1))
                c += 1
            info.append({'index':i, 'alpha':alpha, 'lr':lr, 'params':[a.copy(),], 'err':err})
            if np.intersect1d(i, skip_index).size > 0:
                i = np.setdiff1d(i, skip_index)
            if i.size == 0:
                continue
            g = loss_func.grad(x[i, :].dot(a), y[i]).dot(x[i, :]) / b
            g += alpha * a
            a -= lr * g
    return a, info

def infer_linear_influence(x, y, u, info, loss_func, alpha=0.1):
    n = x.shape[0]
    inf_o = np.zeros(n)
    for i in range(len(info)):
        k = info[-i-1]['index']
        alpha = info[-i-1]['alpha']
        lr = info[-i-1]['lr']
        ai = info[-i-1]['params'][0]
        for j in k:
            g = loss_func.grad(x[j, :].dot(ai), y[j]) * x[j, :]    
            inf_o[j] += lr * (g.dot(u) + alpha * ai.dot(u)) / k.size
        h = loss_func.hess(x[k, :].dot(ai), y[k])
        Hu = np.sum((h * (x[k, :].dot(u)))[:, np.newaxis] * x[k, :], axis=0) / k.size
        u = (1 - alpha * lr) * u - lr * Hu
    return inf_o
