#include "umoss.h"

#include <cfloat>
#include <cmath>
#include <algorithm>
#include <string>
#include <iostream>

using namespace std;


void UMOSS::setup(uint64_t n, int K) {
  this->n = n;
  this->K = K;

  /* initialise the arms data structure */
  while (arms.size() > 0) {
    arms.pop();
  }
  for (int i = 0;i != K;++i) {
    arms.push(Arm(i, DBL_MAX));
  }
}
  

void UMOSS::give_reward(double r) {
  Arm a = arms.top();
  arms.pop();

  a.T+=1;
  a.reward+=r;
  
  double ni = n * n / (bias[a.i] * bias[a.i]);
  a.idx = a.reward / a.T + sqrt(alpha / a.T * log(max(1.0, ni / a.T))) - sqrt(1.0 / ni);
  arms.push(a);
}

int UMOSS::get_arm() {
  /* return the arm with the highest index */
  return arms.top().i;
}



