import numpy as np
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import maximum_flow
from enum import Enum
import copy
from algorithms.util import one_hot

CHANCE_PLAYER_ID = -1

class OhHellGame(object):
	def __init__(self, num_players=3, num_suits=3, num_ranks=5, bid_made_bonus=10):
		self.num_players = num_players
		self.num_suits = num_suits
		self.num_ranks = num_ranks
		self.bid_made_bonus = bid_made_bonus

	def new_initial_state(self):
		return OhHellState(self)

	def max_tricks(self):
		return (self.num_suits * self.num_ranks - 1) // self.num_players

	def num_cards(self):
		return self.num_suits * self.num_ranks

	def num_actions(self):
		return self.num_cards()

	def num_features(self):
		return len(Phase) + self.num_suits + self.num_cards() + self.num_players * 2 + \
				self.num_players * (self.max_tricks() + 1) + self.max_tricks() * self.num_cards()

	def max_reward(self):
		return self.max_tricks() + self.bid_made_bonus

	def encode_hand(self, hand):
		ret = np.zeros(self.num_cards())
		for c in hand:
			ret[c] += 1.
		return ret


class Card(object):
	def __init__(self, suit, rank):
		self._suit = suit
		self._rank = rank

	def rank(self):
		return self._rank

	def suit(self):
		return self._suit

	def from_int(card_int, num_suits, num_ranks):
		cs = card_int // num_ranks
		cr = card_int % num_ranks
		return Card(cs, cr)

	def to_int(self, num_ranks):
		return self._suit * num_ranks + self._rank

	def __str__(self):
		return chr(ord('A') + self._suit) + str(self._rank)


class Trick(object):
	def __init__(self, leader, num_players):
		self._leader = leader
		self.cards = [None] * num_players

	def add(self, player, card):
		self.cards[player] = card

	def complete(self):
		return np.all(self.cards)

	def lead_suit(self):
		s = None
		if self.cards[self._leader]:
			s = self.cards[self._leader].suit()
		return s

	def winner(self, trump_suit):
		winner = None
		if self.complete():
			lead_suit = self.cards[self._leader].suit()
			winner = -1
			best_card = None
			for i, c in enumerate(self.cards):
				if c.suit() == trump_suit:
					if best_card is None or best_card.suit() != trump_suit or best_card.rank() < c.rank():
						winner = i
						best_card = c
				elif c.suit() == lead_suit and (best_card is None or best_card.suit() != trump_suit):
					if best_card is None or best_card.rank() < c.rank():
						winner = i
						best_card = c
		return winner

	def empty(self):
		return not any(self.cards)

	def num_cards(self):
		n = 0
		for c in self.cards:
			if c:
				n += 1
		return n

	def __str__(self):
		s = ''
		num_players = len(self.cards)
		for i in range(num_players):
			p = (self._leader + i) % num_players
			s += f' {p}: {str(self.cards[p])}'
		return s



Phase = Enum('Phase', ['SELECT_NUM_TRICKS', 'DEAL', 'BID', 'CARDPLAY', 'GAME_OVER'])


class OhHellState(object):
	def __init__(self, game):
		self._game = game
		self._player_to_move = CHANCE_PLAYER_ID
		self._terminal = False
		self._phase = Phase.SELECT_NUM_TRICKS
		self._cards = np.zeros(game.num_ranks * game.num_suits) + CHANCE_PLAYER_ID
		self._current_trick = None
		self._initial_deal = []
		self._to_deal = 0
		self._bids = np.zeros(game.num_players) - 1.
		self._infostate_history = []
		self._tricks = []
		self._trump = None
		self._num_tricks = -1
		self._tricks_won = np.zeros(game.num_players)
		self._num_cards_revealed = [0] * game.num_suits

	def get_game(self):
		return self._game

	def get_player_to_move(self):
		return self._player_to_move

	def get_phase(self):
		return self._phase

	def get_legal_actions(self, player):
		actions = []
		if player != self._player_to_move:
			return actions

		if self._phase == Phase.SELECT_NUM_TRICKS:
			actions = list(range(1, self._game.max_tricks() + 1))
		elif self._phase == Phase.DEAL:
			actions = [x for x in range(self._game.num_ranks * self._game.num_suits) if self._cards[x] == CHANCE_PLAYER_ID]
		elif self._phase == Phase.BID:
			if np.sum(self._bids < 0) == 1:
				bid_sum = np.sum(self._bids) + 1
				actions = [x for x in range(self._num_tricks + 1) if bid_sum + x != self._num_tricks]
			else:
				actions = [x for x in range(self._num_tricks + 1)]
		elif self._phase == Phase.CARDPLAY:
			actions = []
			lead = self._current_trick.lead_suit()
			if lead is not None:
				actions = [idx for idx, x in enumerate(self._cards)
				           if x == player and
				           Card.from_int(idx, self._game.num_suits, self._game.num_ranks).suit() == lead]
			if lead is None or len(actions) == 0:
				actions = [idx for idx, x in enumerate(self._cards) if x == player]

		else:
			raise ValueError("Game Over: no legal actions.")


		return actions


	def play(self, player, action):
		if action not in self.get_legal_actions(player):
			raise ValueError("Illegal Action")

		self._infostate_history.append((player, self.get_infostate(player), action, self.get_legal_actions(player), self._phase))
		if self._phase == Phase.SELECT_NUM_TRICKS:
			self._num_tricks = action
			self._phase = Phase.DEAL
		elif self._phase == Phase.DEAL:
			if not self._trump:
				self._cards[action] = self._game.num_players # discard the trump
				self._trump = Card.from_int(action, self._game.num_suits, self._game.num_ranks)
				self._num_cards_revealed[self._trump.suit()] += 1
			else:	
				self._cards[action] = self._to_deal
				self._to_deal = (self._to_deal + 1) % self._game.num_players
				if np.sum(self._cards >= 0) > self._game.num_players * self._num_tricks:
					self._player_to_move = 0
					self._phase = Phase.BID
			self._initial_deal.append(action)
		elif self._phase == Phase.BID:
			self._bids[player] = action
			if np.all(self._bids >= 0):
				self._phase = Phase.CARDPLAY
				self._current_trick = Trick(0, self._game.num_players)
				self._player_to_move = 0
			else:
				self._player_to_move = (self._player_to_move + 1) % self._game.num_players
		elif self._phase == Phase.CARDPLAY:
			c = Card.from_int(action, self._game.num_suits, self._game.num_ranks)
			self._current_trick.add(player, c)
			self._num_cards_revealed[c.suit()] += 1
			if self._current_trick.complete():
				winner = self._current_trick.winner(self._trump.suit())
				self._tricks_won[winner] += 1
				self._tricks.append(self._current_trick)
				if len(self._tricks) == self._num_tricks:
					self._terminal = True
					self._phase = Phase.GAME_OVER
				else:
					self._current_trick = Trick(winner, self._game.num_players)
					self._player_to_move = winner
			else:
				self._player_to_move = (self._player_to_move + 1) % self._game.num_players
			self._cards[action] = CHANCE_PLAYER_ID
		else:
			raise ValueError("Game Over: action cannot be played.")

	def terminal(self):
		return self._terminal

	def history(self):
		return self._infostate_history

	def trump(self):
		return self._trump

	def score(self):
		scores = np.zeros(self._game.num_players)
		scores += self._tricks_won
		for p in range(self._game.num_players):
			if self._bids[p] == self._tricks_won[p]:
				scores[p] += self._game.bid_made_bonus
		return scores

	def get_infostate_string(self, player):
		s = f'Trump: {str(self._trump)} \nHand {player}:'
		for idx, c in enumerate(self._cards):
				if c == player:
					s += f' {str(Card.from_int(idx, self._game.num_suits, self._game.num_ranks))}'
		s += '\nBids:'
		for p in range(self._game.num_players):
			s += f' {self._bids[p]}'
		s += '\nTricks:\n'
		for t in self._tricks:
			s += str(t) + '\n'
		if self._current_trick and not self._current_trick.empty() and not self._current_trick.complete():
			s += str(self._current_trick) + '\n'
		return s

	def get_infostate(self, player):
		features = []
		if self._trump:
			# phase
			features.extend(one_hot(self._phase.value, len(Phase))) # len(Phase)
			features.extend(one_hot(self._trump.suit(), self._game.num_suits)) # num_suits
			hand = [idx for idx, c in enumerate(self._cards) if c == player]
			features.extend(self._game.encode_hand(hand)) # num cards
			# position
			features.extend(one_hot(player, self._game.num_players)) # num_players
			# player to move
			features.extend(one_hot(self._player_to_move, self._game.num_players)) # num_players
			for p in range(self._game.num_players): # num_players * max_tricks
				features.extend(one_hot(self._bids[p], self._game.max_tricks() + 1))
			for i in range(self._game.max_tricks()): # max_tricks * num_cards
				if i < len(self._tricks):
					features.extend(self._game.encode_hand([Card.to_int(c, self._game.num_ranks) for c in self._tricks[i].cards]))
				elif i == len(self._tricks) and self._current_trick and not self._current_trick.complete():
					features.extend(self._game.encode_hand([Card.to_int(c, self._game.num_ranks) for c in self._current_trick.cards if c]))
				else:
					features.extend(np.zeros(self._game.num_cards()))
		return features

	def get_public_state_string(self):
		s = f'Trump: {str(self._trump)}'
		s += '\nBids:'
		for p in range(self._game.num_players):
			s += f' {self._bids[p]}'
		s += '\nTricks:\n'
		for t in self._tricks:
			s += str(t) + '\n'
		if self._current_trick and not self._current_trick.empty() and not self._current_trick.complete():
			s += str(self._current_trick) + '\n'
		return s

	def get_voids(self, player):
		voids = []
		tricks = self._tricks + [self._current_trick]
		for t in tricks:
			if t:
				c = t.cards[player]
				if c and c.suit() != t.lead_suit() and t.lead_suit() not in voids:
					voids.append(t.lead_suit())
		for s in range(self._game.num_suits):
			total_void = self._num_cards_revealed[s] >= self._game.num_ranks
			if total_void and s not in voids:
				voids.append(s)
		voids.sort()
		return voids

	def get_played_cards(self, player):
		played = []
		tricks = self._tricks + [self._current_trick]
		for t in tricks:
			if t:
				if t.cards[player] is not None:
					played.append(t.cards[player].to_int(self._game.num_ranks))
		return played


	def num_unknown_cards_in_suit(self, suit):
		played_cards = []
		for p in range(self._game.num_players):
			played_cards.extend(self.get_played_cards(p))
		count = self._game.num_ranks
		if self._trump.suit() == suit:
			count -= 1
		for c in played_cards:
			if Card.from_int(c, self._game.num_suits, self._game.num_ranks).suit() == suit:
				count -= 1
		return count

	def num_unknown_cards_in_hand(self, player):
		count = 0
		for c in self._cards:
			if c == player:
				count += 1
		return count

	# max flow construction implementation
	def generate_history_in_public_state(self):
		num_nodes = self._game.num_players + self._game.num_suits + 2
		source = num_nodes - 2
		sink = num_nodes - 1
		max_capacity = self._game.num_cards()
		m = np.zeros((num_nodes, num_nodes), dtype=np.int32)
		for player in range(self._game.num_players):
			# edges connecting source to player vertices
			m[source, player] = self.num_unknown_cards_in_hand(player)
			voids = self.get_voids(player)
			for suit in range(self._game.num_suits):
				suit_index = self._game.num_players + suit
				# num cards in suit not played capacity to sink
				m[suit_index, sink] = self.num_unknown_cards_in_suit(suit)
				if suit not in voids:
					# flow between player and suit is max if player can have suit
					m[player, suit_index] = max_capacity
		flow = maximum_flow(csr_matrix(m), source, sink).flow.toarray()
		# return history from the flow
		s = OhHellState(self._game)
		# play num tricks and trump as they're public and first 
		s.play(CHANCE_PLAYER_ID, self._infostate_history[0][2])
		s.play(CHANCE_PLAYER_ID, self._infostate_history[1][2])
		must_deal = [[] for _ in range(self._game.num_players)]
		tricks = self._tricks + [self._current_trick]
		for t in tricks:
			if t:
				for i, c in enumerate(t.cards):
					if c:
						must_deal[i].append(c.to_int(self._game.num_ranks))
		played_cards = [self._trump.to_int(self._game.num_ranks)]
		for p in range(self._game.num_players):
			played_cards.extend(self.get_played_cards(p))
		while len(s.history()) < len(self.history()):
			to_move = s.get_player_to_move()
			action = None
			if s.get_phase() == Phase.BID or s.get_phase() == Phase.CARDPLAY:
				action = self._infostate_history[len(s.history())][2]
			else:
				if len(must_deal[s._to_deal]) > 0:
					action = must_deal[s._to_deal].pop()
				else:
					for suit in range(self._game.num_suits):
						suit_index = self._game.num_players + suit
						if action is None and flow[s._to_deal, suit_index] > 0:
							for c in s.get_legal_actions(to_move):
								if c not in played_cards and Card.from_int(c, self._game.num_suits, self._game.num_ranks).suit() == suit:
									action = c
									flow[s._to_deal, suit_index] -= 1
									break
						if action:
							break
			s.play(to_move, action)
		return s

	def generate_histories_in_public_state(self):
		states = []
		s = OhHellState(self._game)
		# number of tricks is public
		s.play(CHANCE_PLAYER_ID, self._infostate_history[0][2])

		played_cards_to_player = {}
		must_deal = [[] for _ in range(self._game.num_players)]
		voids = [[] for _ in range(self._game.num_players)]
		tricks = self._tricks + [self._current_trick]
		for t in tricks:
			if t:
				for i, c in enumerate(t.cards):
					if c:
						played_cards_to_player[c.to_int(self._game.num_ranks)] = i
						must_deal[i].append(c.to_int(self._game.num_ranks))
						# now handle voids
						if c.suit() != t.lead_suit() and t.lead_suit() not in voids[i]:
							voids[i].append(t.lead_suit())

		def generate_histories(s, histories):
			if len(s.history()) >= len(self.history()):
				if s.get_public_state_string() == self.get_public_state_string():
					histories.append(s)
				return
			to_move = s.get_player_to_move()
			if s.get_phase() == Phase.BID or s.get_phase() == Phase.CARDPLAY:
				actions = [self._infostate_history[len(s.history())][2]]
			else:
				actions = s.get_legal_actions(to_move)

			for a in actions:
				if a in s.get_legal_actions(to_move):
					if s.get_phase() == Phase.DEAL and a in played_cards_to_player and played_cards_to_player[a] != s._to_deal:
						continue
					if (s.get_phase() == Phase.DEAL and
					  		Card.from_int(a, self._game.num_suits, self._game.num_ranks).suit() in voids[s._to_deal] and
							(a in played_cards_to_player and played_cards_to_player[a] != s._to_deal)):
						continue
					if s.get_phase() == Phase.DEAL and len(set([a] + must_deal[s._to_deal] + s.get_hand(s._to_deal))) > self._num_tricks:
						continue
					c = copy.deepcopy(s)
					c.play(to_move, a)
					generate_histories(c, histories)

		generate_histories(s, states)
		return states

	def get_reach_probability(self, joint_policy):
		log_reach = 0.
		for player, infostate, action, legal_actions, phase in self._infostate_history:
			# skip the chance player as all deal actions have the same weight
			if player >= 0:
				action_index = legal_actions.index(action)
				log_reach += np.log(joint_policy.get_action_probabilities(infostate, player, legal_actions)[action_index])
		p = np.exp(log_reach)
		return p

	def get_phase_actions(self, phase):
		deal = []
		for action_info in self._infostate_history:
			# phase is the last element of the action info, TODO fix, brittle
			if action_info[-1] == phase:
				# action is the third element
				deal.append(action_info[2])
		return deal

	def num_tricks_played(self):
		return len(self._tricks)

	def num_cards_played_current_trick(self):
		if not self._current_trick:
			return 0
		return self._current_trick.num_cards()

	def get_hand(self, player):
		hand = []
		for idx, c in enumerate(self._cards):
			if c == player:
				hand.append(idx)
		return hand

	def __str__(self):
		s = f'Initial deal: {[str(Card.from_int(c, self._game.num_suits, self._game.num_ranks)) for c in self._initial_deal]}\n'
		s += f'Trump: {str(self._trump)}\nnTricks: {self._num_tricks}\nHands:\n'
		for p in range(self._game.num_players):
			s += f'{p}: '
			for idx, c in enumerate(self._cards):
				if c == p:
					s += f' {str(Card.from_int(idx, self._game.num_suits, self._game.num_ranks))}'
			s += '\n'
		s += 'Bids:'
		for p in range(self._game.num_players):
			s += f' {self._bids[p]}'
		s += '\nTricks:\n'
		for t in self._tricks:
			s += str(t) + '\n'
		if self._current_trick and not self._current_trick.empty() and not self._current_trick.complete():
			s += str(self._current_trick) + '\n'
		if self.terminal():
			s += f'Score: {str(self.score())}\n'
		return s
