/*
 * Copyright (c) 2017-present, XXX, Inc.
 *
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 */

#include "checkpointer.h"
#include "distributed.h"
#include "synctrainer.h"
#include <experimental/filesystem>

namespace cpid {
namespace dist = distributed;
Checkpointer::Checkpointer(SyncTrainer* trainer) : trainer_(trainer) {
  lastEpochStamp_ = hires_clock::now();
}

std::string Checkpointer::getModelPath() const {
  return checkpointPath_ + "trainer_latest.bin";
}

void Checkpointer::updateDone(int updateCount) {
  auto metrics = trainer_->metricsContext();

  if (updateCount % epochLength_ == 0) {
    std::vector<float> sampleCount;
    sampleCount.push_back(metrics->getCounter("sampleCount"));
    if (aggregateMetrics_) {
      dist::allreduce(sampleCount);
    } else {
      // if we don't reduce, use an estimate of the sampleCount
      sampleCount[0] *= dist::globalContext()->size;
    }
    if (dist::globalContext()->rank == 0) {
      LOG(INFO) << "EPOCH " << updateCount / epochLength_ << " done.";
      hires_clock::time_point now = hires_clock::now();
      std::chrono::duration<double, std::milli> dur = now - lastEpochStamp_;
      LOG(INFO) << "Speed: " << double(epochLength_) / (dur.count() / 1000.)
                << " updates/s    "
                << double(sampleCount[0]) / (dur.count() / 1000.)
                << " frames/s";
      lastEpochStamp_ = now;
    }

    auto means = metrics->getMeanEventValues();
    if (printMetricsSummary_) {
      if (printMetricsSummary_) {
        auto minRed = [](float a, float b) { return std::min(a, b); };
        auto maxRed = [](float a, float b) { return std::max(a, b); };
        auto mins = metrics->reduceEventValues(minRed, 1e20);
        auto maxs = metrics->reduceEventValues(maxRed, -1e20);
        if (dist::globalContext()->rank == 0) {
          LOG(INFO) << "Metrics summary:";
        }
        printSummary(means, mins, maxs);
        auto means_inter = metrics->getMeanIntervals();
        auto mins_inter = metrics->reduceIntervals(minRed, 1e20);
        auto maxs_inter = metrics->reduceIntervals(maxRed, -1e20);
        if (dist::globalContext()->rank == 0) {
          LOG(INFO) << "Timings summary:";
        }
        printSummary(means_inter, mins_inter, maxs_inter);
      }
    }
    if (dist::globalContext()->rank == 0) {
      LOG(INFO) << "";
    }
    if (dumpMetrics_) {
      metrics->dumpJson(
          checkpointPath_ + std::to_string(dist::globalContext()->rank) +
          "-epoch_" + std::to_string(updateCount / epochLength_) +
          "-metrics.json");
    }
    if (flushMetrics_) {
      metrics->clear();
    }
    if (dist::globalContext()->rank == 0) {
      ag::save(checkpointPath_ + "trainer_latest.bin", trainer_);
      bool should_save = false;
      double new_perf = means["winrate"];
      if (std::experimental::filesystem::exists(checkpointPath_ + "perf.txt")) {
        std::ifstream old_perf_f(checkpointPath_ + "perf.txt");
        float old_perf;
        old_perf_f >> old_perf;
        should_save = old_perf < new_perf;
      } else {
        should_save = true;
      }
      if (should_save) {
        std::string target =
            checkpointPath_ + "trainer_" + std::to_string(new_perf);
        while (std::experimental::filesystem::exists(target + ".bin")) {
          target += "_" + std::to_string(updateCount / epochLength_);
        }
        std::experimental::filesystem::copy(
            checkpointPath_ + "trainer_latest.bin", target + ".bin");
        if (std::experimental::filesystem::exists(
                checkpointPath_ + "trainer_best.bin")) {
          std::experimental::filesystem::remove(
              checkpointPath_ + "trainer_best.bin");
        }
        std::experimental::filesystem::copy(
            checkpointPath_ + "trainer_latest.bin",
            checkpointPath_ + "trainer_best.bin");

        std::ofstream perf_f(checkpointPath_ + "perf.txt");
        perf_f << new_perf << std::endl;
        ;
      }

    }
  }
}


void Checkpointer::printSummary(
    std::unordered_map<std::string, float> means,
    std::unordered_map<std::string, float> mins,
    std::unordered_map<std::string, float> maxs) {
  std::vector<std::pair<std::string, float>> sortedMeans(
      means.begin(), means.end()),
      sortedMins(mins.begin(), mins.end()),
      sortedMaxs(maxs.begin(), maxs.end());
  std::sort(sortedMeans.begin(), sortedMeans.end());
  std::sort(sortedMins.begin(), sortedMins.end());
  std::sort(sortedMaxs.begin(), sortedMaxs.end());

  std::vector<float> values(sortedMeans.size()), valuesMin(sortedMeans.size()),
      valuesMax(sortedMeans.size());
  auto valueGet = [](const auto& pair) { return pair.second; };
  std::transform(
      sortedMeans.begin(), sortedMeans.end(), values.begin(), valueGet);
  std::transform(
      sortedMins.begin(), sortedMins.end(), valuesMin.begin(), valueGet);
  std::transform(
      sortedMaxs.begin(), sortedMaxs.end(), valuesMax.begin(), valueGet);

  if (aggregateMetrics_) {
    reduceMetrics(values);
    dist::allreduce(valuesMin, dist::ReduceOp::MIN);
    dist::allreduce(valuesMax, dist::ReduceOp::MAX);
  }
  if (dist::globalContext()->rank == 0) {
    for (size_t i = 0; i < values.size(); ++i) {
      LOG(INFO) << sortedMeans[i].first << " " << values[i]
                << " (min: " << valuesMin[i] << " max: " << valuesMax[i] << ")";
    }
  }
}

void Checkpointer::reduceMetrics(std::vector<float>& values) {
  if (reduceMax_) {
    dist::allreduce(values, dist::ReduceOp::MAX);
  } else {
    distributed::allreduce(values);
    for (float& v : values) {
      v /= dist::globalContext()->size;
    }
  }
};
} // namespace cpid
