import utils
import numpy as np


def get_mean_std(test_reward_per_frame, reward_per_frame):
    min_ = 99999999
    for row in test_reward_per_frame:
        if len(row) < min_:
            min_ = len(row)
    for i in range(len(test_reward_per_frame)):
        test_reward_per_frame[i] = test_reward_per_frame[i][:min_]

    min_ = 99999999
    for row in reward_per_frame:
        if len(row) < min_:
            min_ = len(row)
    for i in range(len(reward_per_frame)):
        reward_per_frame[i] = reward_per_frame[i][:min_]

    test_reward_per_frame = np.array(test_reward_per_frame)
    _test_reward_mean_over_seeds = np.average(test_reward_per_frame, axis=0)
    _test_reward_std_over_seeds = np.std(test_reward_per_frame, axis=0)

    reward_per_frame = np.array(reward_per_frame)
    _reward_mean_over_seeds = np.average(reward_per_frame, axis=0)
    _reward_std_over_seeds = np.std(reward_per_frame, axis=0)
    return (
        _test_reward_mean_over_seeds,
        _test_reward_std_over_seeds,
        _reward_mean_over_seeds,
        _reward_std_over_seeds,
    )


def explort_data(
    return_per_frame, test_return_per_frame, result_dir, exploration_type, seed_type
):
    (
        test_reward_mean_over_seeds,
        test_reward_std_over_seeds,
        reward_mean_over_seeds,
        reward_std_over_seeds,
    ) = get_mean_std(test_return_per_frame, return_per_frame)
    result_csv_file, result_csv_logger = utils.get_csv_logger(
        result_dir, f"{exploration_type}_seed_type_{seed_type}.csv"
    )
    result_csv_logger.writerow(test_reward_mean_over_seeds)
    result_csv_logger.writerow(test_reward_std_over_seeds)
    result_csv_logger.writerow(reward_mean_over_seeds)
    result_csv_logger.writerow(reward_std_over_seeds)
    result_csv_file.flush()


def explort_data_one_seed(
    return_per_frame, test_return_per_frame, result_dir, exploration_type, seed
):
    result_csv_file, result_csv_logger = utils.get_csv_logger(
        result_dir, f"{exploration_type}_seed_{seed}.csv"
    )
    result_csv_logger.writerow(return_per_frame)
    result_csv_logger.writerow(test_return_per_frame)
    result_csv_file.flush()
