import numpy as np
import matplotlib.pyplot as plt
from IPython import embed


class Estimator():
    def __init__(self, learner, all_phis_valid):
        self.learner = learner
        self.all_phis_valid = all_phis_valid
        Phi = self.learner.Phi
        self.Sig = Phi.T.dot(Phi) + learner.d * np.eye(learner.d)
        self.inv = np.linalg.inv(self.Sig)


    def stats(self, learner_eval):
        lower, upper = [], []
        preds = []
        bonuses = []


        for i in range(self.all_phis_valid.shape[1]):
            phis_learner = self.all_phis_valid[:, i, :self.learner.d]
            phis_learner_eval = self.all_phis_valid[:, i, :learner_eval.d]

            variances = []
            for a in range(self.learner.K):
                v = phis_learner[a]
                variance = v.dot(self.inv).dot(v)
                bonus_a = .1 * np.sqrt(variance * self.learner.d )  # hyperparameter
                variances.append(bonus_a)
            bonus = np.max(variances)
            bonuses.append(bonus)

            a = learner_eval.action(phis_learner_eval)
            v = phis_learner[a]

            pred = self.learner.theta.dot(v)
            preds.append(pred)

        mean = np.mean(preds)
        mean_bonus = np.mean(bonuses)

        return mean - 2 * mean_bonus, mean + 2 * mean_bonus



class Slope():
    def __init__(self, learners, all_phis_valid):

        self.learners = learners
        ests = []
        for learner in learners:
            est = Estimator(learner, all_phis_valid)
            ests.append(est)
        self.ests = ests

    def estimate(self, learner_eval):
        lowers, uppers = [], []
        ests = self.ests
        for i, est in enumerate(ests):
            lower, upper = est.stats(learner_eval)
            lowers.append(lower)
            uppers.append(upper)




        for i in range(len(ests)):
            good = True
            for j in range(i + 1, len(ests)):
                
                if ( uppers[i] <= uppers[j] and uppers[i] >= lowers[j] )  \
                    or ( lowers[i] >= lowers[j] and lowers[i] <= uppers[j]  ):
                    pass
                else:
                    good = False

            if good:
                print "Using i: " + str(i)
                return np.mean([ lowers[i], uppers[i] ])

        print "Using i2: " + str(i)
        return np.mean([ lowers[i], uppers[i] ])


class LossEstimator():

    def __init__(self, phis_valid_mu, r_valid_mu):
        self.phis_valid_mu = phis_valid_mu
        self.r_valid_mu = r_valid_mu

    def estimate(self, learner):
        Phi = self.phis_valid_mu
        preds = Phi[:, :learner.d].dot(learner.theta)
        diff = preds - self.r_valid_mu
        err = np.mean(np.square(diff))
        return err




class LossTestEstimator():

    def __init__(self, n, learners, phis_valid_mu, r_valid_mu):
        self.phis_valid_mu = phis_valid_mu
        self.r_valid_mu = r_valid_mu
        self.learners = learners
        self.n = n

        errs = []
        for learner in learners:
            Phi = self.phis_valid_mu
            preds = Phi[:, :learner.d].dot(learner.theta)
            diff = preds - self.r_valid_mu
            err = np.mean(np.square(diff))
            errs.append(err)

        self.errs = np.array(errs)
        ds = [ learner.d for learner in self.learners ]
        self.pens = np.array([ float(d) / float(self.n) for d in ds ])

        print "errors: " + str(self.errs)
        print "penalties: " + str(self.pens)

    def select(self):
        cont = True
        khat = 0
        while cont:
            cont = False

            for k in range(khat + 1, len(self.learners)):
                if self.errs[k] < self.errs[khat] - 2 * self.pens[k]:
                    cont = True
                    khat = khat + 1
                    break

            print("Updating: " + str(khat))

        return khat



