import numpy as np

class Agent(object):
	def __init__(self, game, epsilon=0.05, learning_rate=0.1):
		self._num_players = game.num_players
		self.epsilon = epsilon
		self._learning_rate = learning_rate
		self._max_reward = game.max_reward()
		self._td_errs = [[] for _ in range(self._num_players)]

	def act(self, infostate, player, legal_actions):
		# eps-greedy
		if np.random.rand() < self.epsilon:
			return np.random.choice(legal_actions)
		qs = self._q_values(infostate, player, legal_actions)
		max_indices = [i for i in range(len(qs)) if qs[i] == max(qs)]
		return legal_actions[np.random.choice(max_indices)]

	def action_probabilities(self, infostate, player, legal_actions):
		probs = np.zeros(len(legal_actions))
		qs = self._q_values(infostate, player, legal_actions)
		max_indices = [i for i in range(len(qs)) if qs[i] == max(qs)]
		probs += self.epsilon / max(1.,(len(legal_actions) - len(max_indices)))
		probs[max_indices] = (1. - self.epsilon) / len(max_indices)
		probs = probs / np.sum(probs)
		return probs

	def mean_td_err(self, player):
		return np.mean(self._td_errs[player])

	def _q_values(self, infostate, player, legal_actions):
		raise NotImplementedError("Base class")

	def feedback(self, episode, rewards):
		raise NotImplementedError("Base class")


class TabularQAgent(Agent):
	def __init__(self, game, epsilon=0.05, learning_rate=0.1):
		super().__init__(game, epsilon, learning_rate)
		self._qs = [{} for _ in range(game.num_players)]

	def _q_values(self, infostate, player, legal_actions):
		if str(infostate) not in self._qs[player]:
			self._qs[player][str(infostate)] = np.zeros(len(legal_actions))
		return self._qs[player][str(infostate)]

	def feedback(self, episode, rewards):
		for player in range(self._num_players):
			ep = episode[player]
			i = len(ep) - 1
			while i > 0:
				state, action = ep[i]
				if i == len(ep) - 1:
					target = rewards[player]
				else:
					next_state, _ = ep[i+1]
					target = max(self._q_values(next_state.get_infostate(player), player,
												next_state.get_legal_actions(player)))
				infostate = str(state.get_infostate(player))
				legal_actions = state.get_legal_actions(player)
				action_index = legal_actions.index(action)
				td_err = target - self._q_values(infostate, player, legal_actions)[action_index]
				self._td_errs[player].append(td_err)
				self._qs[player][infostate][action_index] += self._learning_rate * td_err
				i -= 1


class LinearQAgent(Agent):
	def __init__(self, game, epsilon=0.05, learning_rate=0.1):
		super().__init__(game, epsilon, learning_rate)
		self._weights = [np.zeros((game.num_features(), game.num_actions())) for _ in range(self._num_players)]

	def _q_values(self, infostate, player, legal_actions):
		qs = np.matmul(self._weights[player].T, infostate)
		return qs[legal_actions]

	def feedback(self, episode, rewards, scale=True):
		if scale:
			rewards /= 1. * self._max_reward
		for player in range(self._num_players):
			ep = episode[player]
			i = len(ep) - 1
			while i > 0:
				state, action = ep[i]
				if i == len(ep) - 1:
					target = rewards[player]
				else:
					next_state, _ = ep[i+1]
					target = max(self._q_values(next_state.get_infostate(player), player,
												next_state.get_legal_actions(player)))
				infostate = state.get_infostate(player)
				legal_actions = state.get_legal_actions(player)
				action_index = legal_actions.index(action)
				q = self._q_values(infostate, player, legal_actions)[action_index]
				td_err = target - q
				self._td_errs[player].append(td_err)
				self._weights[player][:,action] += self._learning_rate * td_err * np.array(infostate)
				i -= 1