/*
 * Copyright (c) 2017-present, XXX, Inc.
 *
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 */

#pragma once

#include <autogradpp/autograd.h>

namespace cpid {

class SyncTrainer;
class Checkpointer {
  using hires_clock = std::chrono::steady_clock;

 public:
  Checkpointer(SyncTrainer* trainer);

  /// This is the entry point to be called by trainers
  void updateDone(int updateCount);

  /// Returns the path where the latest model would be saved. (It's not guaranteed that one have been saved yet)
  std::string getModelPath() const;

  /// Epoch length (in number of updates)
  TORCH_ARG(int, epochLength) = 500;


  /// Where to save everything
  TORCH_ARG(std::string, checkpointPath);

  /// If true, print the mean of the metrics at each epoch
  TORCH_ARG(bool, printMetricsSummary) = true;

  /// If true, the metrics are aggregated over all workers
  TORCH_ARG(bool, aggregateMetrics) = true;

  /// If true, we clear the metrics at the end of the epoch
  TORCH_ARG(bool, flushMetrics) = true;

  /// If true, we dump the json of the metrics at each epoch
  TORCH_ARG(bool, dumpMetrics) = true;

  /// If true, we reduce accross nodes using the max operator instead
  TORCH_ARG(bool, reduceMax) = true;

 protected:
  void printSummary(
      std::unordered_map<std::string, float> means,
      std::unordered_map<std::string, float> mins,
      std::unordered_map<std::string, float> maxs);
  void reduceMetrics(std::vector<float>& values);
  SyncTrainer* trainer_;

  hires_clock::time_point lastEpochStamp_;
};
} // namespace cpid
