# -*- coding: utf-8 -*-

import numpy as np
#%%

class Gaptron: 
    def __init__(self, gapmap, algW, gamma, beta, outcomes, domset = [], reveal = [], domsetdict = []):
        
        self.name = "Gaptron" + "-" + algW.name
        self.gapmap = gapmap 
        self.algW = algW
        self.gamma = gamma
        self.beta = beta
        self.inversegamma = 0
        self.outcomes = outcomes # The outcomes start with category 0 and end with category K-1 
        self.K = len(outcomes)
        if not domset: # if domset is an empty list the bandit scenario is assumed
            self.domset = outcomes 
            self.domsetvec = np.ones(self.K)/self.K
            self.domsetdict = []
        else:
            self.domset = domset
            self.domsetvec = np.zeros(self.K)
            self.domsetvec[domset] = 1.0
            self.domsetvec = self.domsetvec/np.sum(self.domsetvec)
            self.domsetdict = domsetdict # should be dict of form: {'0': revealed by, '1': revealed by .... etc}
        self.reveal = reveal # list of revealing actions
    
            
    

    def predict(self, x):
        Wt = self.algW.Wmat() # class vectors in columns
        ystart = np.argmax(np.dot(x, Wt))
        at = self.gapmap(Wt, x)
        
        zetat = 1.0
        gammat = self.gamma
        eystart = np.zeros(self.K)
        eystart[ystart] = 1
        self.ptprime = (1 - np.max([at, gammat])) * eystart + np.max([at, gammat]) * 1/self.K * np.ones(self.K) + (1 - zetat) * gammat * self.domsetvec
        self.ytprime = np.random.choice(self.outcomes, 1, p = self.ptprime)[0]
        return(self.ytprime)
        
    
    
    def computePt(self, y):
        y = int(y)
        if self.reveal == self.outcomes:
            return(1.0)
        if not self.domsetdict:
            return(self.ptprime[y])
        else:
            return(np.sum(self.ptprime[self.domsetdict[str(y)]]))
        
        
    def update(self, y, x, loss):
        if y != (self.K + 1): # use self.K + 1 as y to signal that you do not know what is the right answer
            Pt = self.computePt(y)
            scaling = 1/Pt
            self.algW.update(y, x, loss, scaling)

#%%



def logmap(Wt, x):
    Kvec = np.dot(x, Wt)
    ystart = np.argmax(Kvec)
    y = int(ystart)
    Kvecstable = Kvec - np.max(Kvec)
    pstart = np.exp(Kvecstable[int(y)])/np.sum(np.exp(Kvecstable))
    if pstart >= 0.5:
        return(1.0 - pstart)
    else:
        return(1.0)
    

    
