import os
import copy
import torch
import wandb

from models.iql.value import StateActionValueNetwork, StateValueNetwork


def loss(diff, expectile=0.8):
    weight = torch.where(diff > 0, expectile, (1 - expectile))
    return weight * (diff**2)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class IQL(object):
    def __init__(
        self,
        state_dim,
        action_dim,
        expectile=0.7,
        discount=0.99,
        tau=0.005,
        hidden_dim=256,
        q_hiddens=2,
        v_hiddens=2,
        layernorm=False,
    ):
        self.sa_vf = StateActionValueNetwork(state_dim, action_dim, hidden_dim, q_hiddens, layernorm).to(device)
        self.sa_vf_target = copy.deepcopy(self.sa_vf)
        self.sa_vf_optimizer = torch.optim.Adam(self.sa_vf.parameters(), lr=3e-4)

        self.s_vf = StateValueNetwork(state_dim, hidden_dim, v_hiddens, layernorm).to(device)
        self.s_vf_optimizer = torch.optim.Adam(self.s_vf.parameters(), lr=3e-4)

        self.discount = discount
        self.tau = tau

        self.total_it = 0
        self.expectile = expectile

    def update_v(self, states, actions, log_to_wb=False):
        with torch.no_grad():
            q1, q2 = self.sa_vf_target(states, actions)
            q = torch.minimum(q1, q2).detach()

        v = self.s_vf(states)
        state_value_loss = loss(q - v, self.expectile).mean()

        self.s_vf_optimizer.zero_grad()
        state_value_loss.backward()
        self.s_vf_optimizer.step()

        if log_to_wb:
            logs = dict()
            logs["IQL training/state_value_loss"] = state_value_loss
            logs["IQL training/state_value"] = v.mean()
            wandb.log(logs, step=self.total_it)

    def update_q(self, states, actions, rewards, next_states, not_dones, log_to_wb=False):
        with torch.no_grad():
            next_v = self.s_vf(next_states)
            target_q = (rewards + self.discount * not_dones * next_v).detach()

        q1, q2 = self.sa_vf(states, actions)
        state_action_value_loss = ((q1 - target_q)**2 + (q2 - target_q)**2).mean()

        self.sa_vf_optimizer.zero_grad()
        state_action_value_loss.backward()
        self.sa_vf_optimizer.step()

        if log_to_wb:
            logs = dict()
            logs["IQL training/state_action_value_loss"] = state_action_value_loss
            logs["IQL training/state_action_value1"] = q1.mean()
            logs["IQL training/state_action_value2"] = q2.mean()
            wandb.log(logs, step=self.total_it)

    def update_target(self):
        for param, target_param in zip(self.sa_vf.parameters(), self.sa_vf_target.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

    def train(self, replay_buffer, batch_size=256, log_to_wb=False):
        self.total_it += 1

        # Sample replay buffer
        state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)

        # Update
        self.update_v(state, action, log_to_wb)
        self.update_q(state, action, reward, next_state, not_done, log_to_wb)
        self.update_target()

    def save(self, model_dir):
        torch.save(self.sa_vf.state_dict(), os.path.join(model_dir, f"sa_vf_{str(self.total_it)}.pth"))
        torch.save(self.sa_vf_target.state_dict(), os.path.join(model_dir, f"sa_vf_target_{str(self.total_it)}.pth"))
        torch.save(self.sa_vf_optimizer.state_dict(), os.path.join(
            model_dir, f"sa_vf_optimizer_{str(self.total_it)}.pth"))

        torch.save(self.s_vf.state_dict(), os.path.join(model_dir, f"s_vf_{str(self.total_it)}.pth"))
        torch.save(self.s_vf_optimizer.state_dict(), os.path.join(
            model_dir, f"s_vf_optimizer_{str(self.total_it)}.pth"))