import torch
import random
from omegaconf import DictConfig
import gym
from gym import spaces
import numpy as np
import matplotlib.pyplot as plt

class Probabilistic1dEnvironment(gym.Env):
    """Custom Gym Environment for the specified probabilistic reward system."""
    def __init__(self):
        super(Probabilistic1dEnvironment, self).__init__()
        self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(1,), dtype=np.float32)
        self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float32)
        self.state = 0.0

    def step(self, action):
        # Compute reward probability and value based on the action
        prob = self.probability(action[0])
        reward_value = self.reward_value(action[0])
        reward = np.random.choice([reward_value, 0], p=[prob, 1-prob])

        # State is equal to the reward
        self.state = reward
        done = True  # Finish after one step

        return np.array([self.state]), reward, done, {}

    def reset(self):
        self.state = 0.0
        return np.array([self.state])

    @staticmethod
    def probability(action):
        """Exponential decay of probability between 1. and e^(-2)"""
        return np.exp(action - 1)

    @staticmethod
    def reward_value(action):
        """Exponential increase of reward value from 1 to 4."""
        return np.exp(np.log(4) * (-action + 1) / 2)
    

class Probabilistic1dDataset(torch.utils.data.IterableDataset):
    def __init__(self, cfg: DictConfig, split: str = "training"):
        super().__init__()
        self.cfg = cfg
        self.env = Probabilistic1dEnvironment()

    def __iter__(self):
        while True:
            # Run an episode of length 1
            observation = self.env.reset()
            action = self.env.action_space.sample()
            observation, reward, done, info = self.env.step(action)
            observation = torch.from_numpy(observation).float()
            action = torch.tensor(action).float()
            reward = torch.tensor(reward).float()
            nonterminal = torch.tensor([not done], dtype=torch.bool)
            yield observation, action, reward, nonterminal


if __name__ == "__main__":
    from unittest.mock import MagicMock
    import os

    env = Probabilistic1dEnvironment()

    actions = np.linspace(-1, 1, 100)
    probabilities = np.array([env.probability(a) for a in actions])
    reward_values = np.array([env.reward_value(a) for a in actions])
    expected_rewards = probabilities * reward_values

    plt.figure(figsize=(18, 6))

    plt.subplot(1, 3, 1)
    plt.plot(probabilities, actions, label='Probability')
    plt.title('Probability of Reward')
    plt.ylabel('Action')
    plt.xlabel('Probability')

    plt.subplot(1, 3, 2)
    plt.plot(reward_values, actions, label='Reward Value')
    plt.title('Reward Value')
    # plt.ylabel('Action')
    plt.xlabel('Value')
    plt.gca().get_yaxis().set_visible(False)

    plt.subplot(1, 3, 3)
    plt.plot(expected_rewards, actions, label='Expected Reward')
    plt.title('Expected Reward')
    # plt.ylabel('Action')
    plt.xlabel('Expected Reward')
    plt.gca().get_yaxis().set_visible(False)

    plt.tight_layout()
    plt.show()

    dataset = Probabilistic1dDataset(MagicMock())

    observations, actions, rewards = [], [], []
    for i, (observation, action, reward, nonterminal) in enumerate(dataset):
        observations.append(observation)
        actions.append(action)
        rewards.append(reward)
        if i > 100_000:
            observations = torch.stack(observations)
            actions = torch.stack(actions)
            rewards = torch.stack(rewards)
            print(f"Collected {i} samples.")
            print(f"Observations: {torch.mean(observations):.4f} ± {torch.std(observations):.4f}")
            print(f"Actions: {torch.mean(actions):.4f} ± {torch.std(actions):.4f}")
            print(f"Rewards: {torch.mean(rewards):.4f} ± {torch.std(rewards):.4f}")
            break
