import numpy as np
import torch
import tqdm

class ReplayBuffer(object):
	def __init__(self, state_dim, action_dim, max_size=int(2e6)):
		self.max_size = max_size
		self.ptr = 0
		self.size = 0

		self.state = np.zeros((max_size, state_dim))
		self.action = np.zeros((max_size, action_dim))
		self.next_state = np.zeros((max_size, state_dim))
		self.reward = np.zeros((max_size, 1))
		self.returns = np.zeros((max_size, 1))
		self.not_done = np.zeros((max_size, 1))
		self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

	def add(self, state, action, next_state, reward, done):
		self.state[self.ptr] = state.copy()
		self.action[self.ptr] = action.copy()
		self.next_state[self.ptr] = next_state.copy()
		self.reward[self.ptr] = reward
		self.not_done[self.ptr] = 1. - done

		self.ptr = (self.ptr + 1) % self.max_size
		self.size = min(self.size + 1, self.max_size)


	def sample(self, batch_size):
		ind = np.random.randint(0, self.size, size=batch_size)

		return (
			torch.FloatTensor(self.state[ind]).to(self.device),
			torch.FloatTensor(self.action[ind]).to(self.device),
			torch.FloatTensor(self.next_state[ind]).to(self.device),
			torch.FloatTensor(self.reward[ind]).to(self.device),
			torch.FloatTensor(self.not_done[ind]).to(self.device)
		)
	def sample_with_return(self, batch_size):
		ind = np.random.randint(0, self.size, size=batch_size)

		return (
			torch.FloatTensor(self.state[ind]).to(self.device),
			torch.FloatTensor(self.action[ind]).to(self.device),
			torch.FloatTensor(self.next_state[ind]).to(self.device),
			torch.FloatTensor(self.next_action[ind]).to(self.device),
			torch.FloatTensor(self.reward[ind]).to(self.device),
			torch.FloatTensor(self.not_done[ind]).to(self.device),
			torch.FloatTensor(self.returns[ind]).to(self.device)
		)
	def process_sample(self, step, batch_size):
		bb = int(self.trajectory_cnt/10)
		id = int(step/10) + 1
		tra_ind = np.random.randint(0, min((id+1)*bb, self.trajectory_cnt), size=batch_size)
		p_ind = np.array([ np.random.randint(0, self.trajectory_len[tra], size=1) for tra in tra_ind])
		p_ind = p_ind.reshape(p_ind.shape[0])
		return (
			torch.FloatTensor(self.trajectory_s[tra_ind, p_ind]).to(self.device),
			torch.FloatTensor(self.trajectory_a[tra_ind, p_ind]).to(self.device),
			torch.FloatTensor(self.trajectory_nexts[tra_ind, p_ind]).to(self.device),
			torch.FloatTensor(self.trajectory_r[tra_ind, p_ind]).to(self.device),
			torch.FloatTensor(self.trajectory_not_done[tra_ind, p_ind]).to(self.device)
		)

	def transform(self, list_np):
		max_len = np.max([data.shape[0] for data in list_np])
		dim = list_np[0].shape[1]
		new_array = np.zeros((len(list_np), max_len, dim))
		for index, data in enumerate(list_np):
			new_array[index][:len(data)] = data
		return new_array

	def convert_D4RL(self, dataset, reward_tune):
		self.state = dataset['observations']
		self.action = dataset['actions']
		self.next_state = dataset['next_observations']
		self.next_action = self.action[:, :]
		self.reward = dataset['rewards'].reshape(-1,1)
		if reward_tune:
			self.reward = (self.reward - 0.5) * 4.0
		# print(self.reward)
		self.not_done = 1. - dataset['terminals'].reshape(-1,1)
		self.size = self.state.shape[0]-1
		print("Load D4RL dataset finished!")
	def convert_NEORL(self, dataset, reward_tune):
		self.state = dataset['obs']
		self.action = dataset['action']
		self.next_state = dataset['next_obs']
		self.next_action = self.action[:, :]
		self.reward = dataset['reward'].reshape(-1,1)
		if reward_tune:
			self.reward = (self.reward - 0.5) * 4.0
		# print(self.reward)
		self.not_done = 1. - dataset['done'].reshape(-1,1)
		self.size = self.state.shape[0]-1
		print("Load NEORL dataset finished!")

	def normalize_states(self, eps = 1e-3):
		mean = self.state.mean(0,keepdims=True)
		std = self.state.std(0,keepdims=True) + eps
		self.state = (self.state - mean)/std
		self.next_state = (self.next_state - mean)/std
		return mean, std

	def get_return(self):
		pre_return=0
		for i in reversed(range(self.size)):
			self.returns[i] = self.reward[i] + 0.99 * pre_return * self.not_done[i]
			pre_return = self.returns[i]
