#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# Modified base on code on: https://github.com/discovershu/SoRR

import numpy as np
from datetime import datetime
import math

class Logistic_DC_Change_lr:
    def __init__(self, lr_0=0.01, num_iter=100, T0=1000, batch=100, fit_intercept=True, verbose=True, k_value=5, k2_value = 4, seed =1234, dataname='real', Model_name='LogisticRegression'):
        self.lr_0 = lr_0
        self.lr = 0
        self.num_iter = num_iter
        self.T0 = T0
        self.fit_intercept = fit_intercept
        self.verbose = verbose
        self.k_value = k_value
        self.k2_value = k2_value
        self.seed = seed
        self.dataname = dataname
        self.Model_name = Model_name
        self.batch = batch
        self.time_list = []
        self.loss_list = []
        self.datapass = []

    def __add_intercept(self, X):
        intercept = np.ones((X.shape[0], 1))
        return np.concatenate((intercept, X), axis=1)

    def __sigmoid(self, z):
        return 1 / (1 + np.exp(-z))

    def __individualloss(self, h, y):
        h = np.clip(h, 1e-7, 1.0 - 1e-7)  # clip h, incase h == 0
        return -y * np.log(h) - (1 - y) * np.log(1 - h)

    def fit(self, X, y,X_val,y_val,X_test, y_test):
        np.random.seed(self.seed)
        print('X and y shape:', X.shape, y.shape)

        if self.fit_intercept:
            X = self.__add_intercept(X)
            X_val = self.__add_intercept(X_val)
            X_test = self.__add_intercept(X_test)

        # weights initialization
        self.theta = np.random.rand(X.shape[1])
        time_spent = 0
        data_pass = 0
        n = y.size
        index = np.arange(X.shape[0]) # number of data points
        lamb_2 = 0
        num_bat_per_epoch = int(math.floor(n / self.batch))
        
        z_last = np.dot(X, self.theta)
        h_last = self.__sigmoid(z_last)
        loss_last = self.__individualloss(h_last, y)
        sorted_loss_last = np.sort(loss_last)[::-1]
        diff_loss = np.sum(sorted_loss_last[self.k2_value:self.k_value]) / (self.k_value-self.k2_value)
        self.loss_list.append(diff_loss)
        self.datapass.append(0)
        
        for i in range(self.num_iter):
            start_time = datetime.now()
            z = np.dot(X, self.theta)
            h = self.__sigmoid(z)
            loss_1 = self.__individualloss(h, y)
            sorted_loss = np.sort(loss_1)[::-1]
            lamb_1 = sorted_loss[self.k2_value-1]
            hinge_1 = loss_1 - lamb_1
            loss_1[hinge_1 < 0] = 0
            u = (h - y)
            u[loss_1==0]=0

            subgradient = np.dot(X.T, u) / y.size
            
            # update learning rate
            self.lr = self.lr_0 / (i+1)
            
            inner_iter = (i+1)**2 * self.T0
            
            for j in range(inner_iter):
                # sample
                m = np.mod(j,num_bat_per_epoch)
                if m == 0:
                    np.random.shuffle(index)
                index_batch = index[m * self.batch:(m+1) * self.batch]
                X_batch = X[index_batch,:]
                y_batch = y[index_batch]
                
                h_2 = self.__sigmoid(np.dot(X_batch, self.theta))
                loss_2 = self.__individualloss(h_2, y_batch)
                # sorted_loss_2 = np.sort(loss_2)[::-1]
                # lamb_2 = sorted_loss_2[self.k_value-1]
                hinge_2 = loss_2 - lamb_2
                loss_2[hinge_2 < 0] = 0
                v = (h_2 - y_batch)
                v[loss_2==0]=0
                
                gradient = np.dot(X_batch.T, v) / y_batch.size - subgradient
                gradient_lam_2 = self.k_value / n - np.count_nonzero(loss_2) / self.batch
                
                lamb_2 -= self.lr * gradient_lam_2
                self.theta -= self.lr * gradient
            
            end_time = datetime.now()
            time_spent += (end_time - start_time).total_seconds()
            self.time_list.append(time_spent)

            # calculate loss
            z_last = np.dot(X, self.theta)
            h_last = self.__sigmoid(z_last)
            loss_last = self.__individualloss(h_last, y)
            sorted_loss_last = np.sort(loss_last)[::-1]
            diff_loss = np.sum(sorted_loss_last[self.k2_value:self.k_value]) / (self.k_value-self.k2_value)
            self.loss_list.append(diff_loss)
            data_pass += inner_iter * self.batch
            self.datapass.append(data_pass/n)