from dataclasses import dataclass
from typing import Callable, NewType
import numpy.typing as npt
import numpy as np

RewardFun = Callable[[npt.NDArray, int], float] 

@dataclass
class Task:
    command: str
    base: RewardFun
    human: RewardFun

SLIGHTLY_MODIFIER = 2.0
HEAVILY_MODIFIER = 2.0

 
def if_(x: bool):
    return 1.0 if x else 0.0

task_commands = [
    Task(
        "While still prioritizing all, slightly focus on the oldest by age distribution.",
        lambda feature, state: state * 0.1 + if_(state) * SLIGHTLY_MODIFIER * feature[11],
        lambda feature, state: state**2 * 0.1 + if_(state) * SLIGHTLY_MODIFIER * feature[11],
    ),
    Task(
        "While still prioritizing all, slightly focus on the bottom 12.5% of the income_bracket distribution.",
        lambda feature, state: state * 0.1 + if_(state) * SLIGHTLY_MODIFIER * feature[36],
        lambda feature, state: state**2 * 0.1 + if_(state) * SLIGHTLY_MODIFIER * feature[36],
    ),
    Task(
        "While still prioritizing all, slightly focus on those who speak Hindi.",
        lambda feature, state: state * 0.1 + if_(state) * SLIGHTLY_MODIFIER * if_(feature[12]),
        lambda feature, state: state**2 * 0.1 + if_(state) * SLIGHTLY_MODIFIER * if_(feature[12]),
    ),
    Task(
        "While still prioritizing all, slightly weight those who have had low education.",
        lambda feature, state: state * 0.1 + if_(state) * SLIGHTLY_MODIFIER * if_(feature[16]),
        lambda feature, state: state**2 * 0.1 + if_(state) * SLIGHTLY_MODIFIER * if_(feature[16]),
    ),
    Task(
        "While still prioritizing all, slightly focus on both the youngest and oldest by age.",
        lambda feature, state: state * 0.1 + if_(state) * SLIGHTLY_MODIFIER * if_(feature[11] or feature[7]),
        lambda feature, state: state**2 * 0.1 + if_(state) * SLIGHTLY_MODIFIER * if_(feature[11] or feature[7]),
    ),
    Task(
        "While still prioritizing all, slightly prefer the income bracket bounds for the middle 40% of the population.",
        lambda feature, state: state * 0.1 + if_(state) * SLIGHTLY_MODIFIER * if_(feature[38] or feature[39] or feature[40]),
        lambda feature, state: state**2 * 0.1 + if_(state) * SLIGHTLY_MODIFIER * if_(feature[38] or feature[39] or feature[40]),
    ),
    Task(
        "While still prioritizing all, slightly favor those women who do not own their own phone.",
        lambda feature, state: state * 0.1 + if_(state) * SLIGHTLY_MODIFIER * if_(feature[24] or feature[25]),
        lambda feature, state: state**2 * 0.1 + if_(state) * SLIGHTLY_MODIFIER * if_(feature[24] or feature[25]),
    ),
    Task(
        "While still prioritizing all, slightly prioritize impoverished younger mothers by combining the distributions of 'age' and 'education'.",
        lambda feature, state: state * 0.1 + if_(state) * SLIGHTLY_MODIFIER * if_(feature[7] and feature[16]),
        lambda feature, state: state**2 * 0.1 + if_(state) * SLIGHTLY_MODIFIER * if_(feature[7] and feature[16]),
    ),
    Task(
        "While still prioritizing all, slightly advantage those who prefer being called after 7PM 'slot' registered at an NGO.",
        lambda feature, state: state * 0.1 + if_(state) * SLIGHTLY_MODIFIER * if_(feature[31] and feature[32]),
        lambda feature, state: state**2 * 0.1 + if_(state) * SLIGHTLY_MODIFIER * if_(feature[31] and feature[32]),
    ),
    Task(
        "While still prioritizing all, slightly focus on those Marathi-speakers with middle-aged mothers.",
        lambda feature, state: state * 0.1 + if_(state) * SLIGHTLY_MODIFIER * if_(feature[13] and (feature[9] or feature[10])),
        lambda feature, state: state**2 * 0.1 + if_(state) * SLIGHTLY_MODIFIER * if_(feature[13] and (feature[9] or feature[10])),
    ),
    Task(
        "While still prioritizing all, slightly emphasize beneficiaries who likely work early in the morning and late at night.",
        lambda feature, state: state * 0.1 + if_(state) * SLIGHTLY_MODIFIER * if_(feature[26] or feature[28]),
        lambda feature, state: state**2 * 0.1 + if_(state) * SLIGHTLY_MODIFIER * if_(feature[26] or feature[28]),
    ),
    Task(
        "While still prioritizing all, slightly weight the lowest income_bracket groups, the absolute lowest earners in the population.",
        lambda feature, state: state * 0.1 + if_(state) * SLIGHTLY_MODIFIER * if_(feature[35] or feature[36] or feature[37]),
        lambda feature, state: state**2 * 0.1 + if_(state) * SLIGHTLY_MODIFIER * if_(feature[35] or feature[36] or feature[37]),
    ),
    Task(
        "While still prioritizing all, slightly advantage those who prefer being called before 10:30am 'slot' and are registered at an NGO.",
        lambda feature, state: state * 0.1 + if_(state) * SLIGHTLY_MODIFIER * if_(feature[26] and feature[32]),
        lambda feature, state: state**2 * 0.1 + if_(state) * SLIGHTLY_MODIFIER * if_(feature[26] and feature[32]),
    ),
    Task(
        "While still prioritizing all, slightly advantage those who prefer being called between 10:30am-12:30pm and are registered at an NGO.",
        lambda feature, state: state * 0.1 + if_(state) * SLIGHTLY_MODIFIER * if_(feature[27] and feature[32]),
        lambda feature, state: state**2 * 0.1 + if_(state) * SLIGHTLY_MODIFIER * if_(feature[27] and feature[32]),
    ),
    Task(
        "While still prioritizing all, slightly advantage those who prefer being called between 12:30pm-3:30pm and are registered at an NGO.",
        lambda feature, state: state * 0.1 + if_(state) * SLIGHTLY_MODIFIER * if_(feature[28] and feature[32]),
        lambda feature, state: state**2 * 0.1 + if_(state) * SLIGHTLY_MODIFIER * if_(feature[28] and feature[32]),
    ),
]

def shaped_wrapper(reward_fun: RewardFun) -> RewardFun:
    def shaped_reward(feature: npt.NDArray, state: int) -> float:
        return reward_fun(feature, state) ** 2

    return shaped_reward


TASKS = []
for command in task_commands:
    command.human = shaped_wrapper(command.human)
    TASKS.append(command)