import numpy as np

from utils import TargetFunction

class L1Norm(TargetFunction):
    
    def __init__(self, d : int = 1, seed: int = 1212) -> None:
        x_star = np.zeros(d)
        self.x0 = np.ones(d)
        self.name = "l1"
        self.title = "L1 Norm"
        super().__init__(x_star, seed)
        
    def __call__(self, x):
        return np.linalg.norm(x, ord=1)
    
    
class HuberLoss(TargetFunction):
    
    def __init__(self, d : int = 1, delta = 10.0, seed: int = 1212) -> None:
        x_star = np.zeros(d)
        self.x0 = np.ones(d)
        self.name = "huber_loss"
        self.title = "Huber Loss"
        self.delta = delta
        super().__init__(x_star, seed)
        
    def __call__(self, x):

        norm_x = np.linalg.norm(x, ord=2)
        if norm_x <= self.delta:
            return 0.5 * np.dot(x, x)
        else:
            return self.delta * norm_x - 0.5 * self.delta**2

class InfinityNorm(TargetFunction):
    
    def __init__(self, d : int = 1, seed: int = 1212) -> None:
        x_star = np.zeros(d)
        self.name = "inf_norm"
        self.title = "Infinity Norm"
        super().__init__(x_star, seed)
        self.x0 = self.rnd_state.rand(d) * (10.0 - 5) + 5
        
    def __call__(self, x):
        return np.max(np.abs(x))
    
    
class TotalVariation(TargetFunction):

    def __init__(self, d : int = 1, seed: int = 1212) -> None:
        x_star = np.zeros(d)
        self.name = "tw"
        self.title = "Total Variation"
        super().__init__(x_star, seed)
        self.x0 = self.rnd_state.rand(d) * (10.0 - 5) + 5
        
    def __call__(self, x):
        return np.sum(np.abs(np.diff(x)))
    
    
class LogSumExp(TargetFunction):

    def __init__(self, d : int = 1, seed: int = 1212) -> None:
        x_star = np.full(d, 1e-5)
        self.name = "lse"
        self.title = "Log-Sum-Exp Function"
        super().__init__(x_star, seed)
        self.x0 = self.rnd_state.rand(d) * (10.0 - 1) + 1
        
    def __call__(self, x):
        return np.log(np.sum(np.exp(x)))


class SparseGroupLasso(TargetFunction):
    def __init__(self, d : int = 1, seed: int = 1212) -> None:
        x_star = np.zeros(d)
        self.name = "sg_lasso"
        self.title = "Sparse Group Lasso"
        super().__init__(x_star, seed)
        self.x0 = self.rnd_state.rand(d) * (10.0 - 1) + 1
        self.idx_group = [[], [] ,[]]
        for i in range(d):
            g_idx = i % 3
            self.idx_group[g_idx].append(i)
                        
    def __call__(self, x):
        group_norms = [np.linalg.norm(x[group]) for group in self.idx_group]
        return np.sum(group_norms)
    
class ElasticNet(TargetFunction):
    def __init__(self, d : int = 1, seed: int = 1212) -> None:
        x_star = np.zeros(d)
        self.name = "enet"
        self.title = "Elastic Net"
        self.alpha = 0.5
        self.beta = 0.5
        self.x0 = np.full(d, 12.0)
        super().__init__(x_star, seed)
                        
    def __call__(self, x):
        l1_term = self.alpha * np.linalg.norm(x, ord=1)
        l2_term = 0.5 * self.beta * np.dot(x, x)
        return l1_term + l2_term