/*
 * Copyright (c) 2017-present, XXX, Inc.
 *
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 */

#include "sampler.h"
#include "common/autograd.h"
#include "common/rand.h"
#include "flags.h"
#include <autogradpp/autograd.h>

using namespace common;

CustomGaussianSampler::CustomGaussianSampler(
    const std::string& policyKey,
    const std::string& policyPlayKey,
    const std::string& stdKey,
    const std::string& actionKey,
    const std::string& pActionKey)
    : cpid::ContinuousGaussianSampler(policyKey, stdKey, actionKey, pActionKey),
      policyPlayKey_(policyPlayKey) {}

ag::Variant CustomGaussianSampler::sample(ag::Variant in) {
  torch::NoGradGuard g_;
  auto& dict = in.getDict();
  if (dict.count(policyKey_) == 0) {
    throw std::runtime_error("Policy key not found while sampling action");
  }
  if (dict.count(stdKey_) == 0) {
    throw std::runtime_error(
        "Standard deviation key not found while sampling continuous action");
  }
  auto pi = in[policyKey_];
  auto piPlay = in[policyPlayKey_];
  if (pi.dim() > 2) {
    LOG(FATAL) << "Expected at most 2 dimensions, but found " << pi.dim()
               << " in " << common::tensorInfo(pi);
    throw std::runtime_error("Policy doesn't have expected shape");
  }
  if (pi.dim() == 1) {
    pi = pi.unsqueeze(0);
    piPlay = piPlay.unsqueeze(0);
  }

  auto device = pi.options().device();
  ag::Variant& stdVar = dict[stdKey_];
  if (stdVar.isDouble() || stdVar.isFloat()) {
    double dev = stdVar.isDouble() ? stdVar.getDouble() : stdVar.getFloat();
    // we do sampling on cpu for now
    dict[actionKey_] =
        ag::Variant(at::normal(
                        piPlay.to(at::kCPU),
                        dev / float(FLAGS_correlated_steps),
                        Rand::gen())
                        .to(device));
    dict[pActionKey_] =
        ag::Variant(common::normalPDF(dict[actionKey_].get(), pi, dev));
  } else {
    torch::Tensor dev = in[stdKey_];
    // we do sampling on cpu for now
    dict[actionKey_] = ag::Variant(
        at::normal(
            piPlay.to(at::kCPU),
            dev.to(at::kCPU) / float(FLAGS_correlated_steps),
            Rand::gen())
            .to(device));
    dict[pActionKey_] =
        ag::Variant(common::normalPDF(dict[actionKey_].get(), pi, dev));
  }

  return in;
}
