
import numpy as np


def sample_fidelity(mf_type, query_proportion):
    if mf_type == 'multiplier':
        return sample_multiplier_dist(query_proportion)
    elif mf_type == 'equal':
        return sample_equal(query_proportion)
    elif mf_type == 'half_108':
        return sample_simple_dist(query_proportion)
    elif mf_type == 'simple_half':
        return simple_half(query_proportion)
    else:
        return 108

def simple_half(query_proportion=1):
    if query_proportion <= .5:
        epochs = 36
    else:
        epochs = 108
    return epochs


def sample_multiplier_dist(query_proportion=1):
    # multiplier < 1 gives weight to lower fidelities
    # for now, hard-coded for four fidelity levels

    multiplier = 2 ** (2 * query_proportion - 1) # between 0.5 and 2
    ranges = [multiplier ** i for i in range(4)]
    maximum = sum(ranges)
    draw = np.random.rand() * maximum
    for i in range(4):
        if draw <= sum(ranges[:i+1]):
            break
    epochs = 4 * 3 ** i # hard-coded for the four fidelities in nasbench
    return epochs


def sample_simple_dist(query_proportion):
    # Query proportion should be in [0,1]
    # Current hardcoded for nasbench.
    if query_proportion <= .25:
        epochs = 12
    elif query_proportion >.25 and query_proportion <= .5:
        epochs = 36
    else:
        epochs = 108
    return epochs


def sample_equal(query_proportion):
    if query_proportion <= .25:
        epochs = 4
    elif query_proportion >.25 and query_proportion <= .5:
        epochs = 12
    elif query_proportion >.5 and query_proportion <= .75:
        epochs = 36
    else:
        epochs = 108
    return epochs