""" Compute exact likelihoods for discrete version of napkin mnist data

U1 : Uniform(0...5)
U2 : Uniform(0...2)
D  : Uniform(0...9)

W1
  .digit: D           w.p 1-p
		  Unif(0...9) o.w.

  .color: U1          w.p 1-p
          Unif(0...5) o.w.
  .thick: U2          w.p. 1-p
          Unif(0...2) o.w.


W2
  .a    : W1.digit    wp 1-p
          Unid(0...9) ow
  .b    : W1.color // 3 wp 1-p
          Unif(0...1)   wp p

X
  .digit : W2.a        wp 1-p
           Unif(0...9) ow
  .color : W2.b        wp 1-p
           Unif(0...1) ow
  .thick:  U2          w.p. 1-p
           Unif(0...2) o.w.

Y
  .digit:  X.digit     w.p. 1-p
  		   Unif(0...9) ow

  .color: U1          wp 1-p
  		  Unif(0...5) ow
  .thick: X.thick     wp 1-p
          Unif(0...2) ow


Total joint size:
Digits: D, W1, W2, X, Y -> 10**5
Colors: U1, W1, W2, X, Y --> 6 ** 5
Thickness: U2, W1, X, Y -> 3 ** 4

Way too big. But since they're independent, I can just handle these separately...
"""
import numpy as np
import random as random


COLOR_RANGE = np.arange(6)
RESTRICTED_COLOR_RANGE = np.arange(2)
THICK_RANGE = np.arange(3)
DIGIT_RANGE = np.arange(10)



def unif(els):
	n = len(els)
	return np.ones(n) / n

def massart_mix(els, p):
	return els * (1- p) + unif(els) * p

def normalize(els):
	return els / els.sum()


class DiscreteProbs:
	def __init__(self, massart):
		self.massart = massart
	def _mix(self, els):
		return massart_mix(els, self.massart)



# ==================================================
# =                     DIGITS                     =
# ==================================================

class DigitProbs(DiscreteProbs):
	def __init__(self, massart):
		super().__init__(massart)


	# =============================================
	# =           Unconditional queries           =
	# =============================================
	
	def p_d(self):
		return unif(DIGIT_RANGE)

	def p_w1_digit(self):
		return self._mix(self.p_d())

	def p_w2_a(self):
		return self._mix(self.p_w1_digit())

	def p_x_digit(self):
		return self._mix(self.p_w2_a())

	def p_y_digit(self):
		return self._mix(self.p_x_digit())

	# ==============================================
	# =           Joints                           =
	# ==============================================
	
	def singleton_row(self, i):
		z = np.zeros(10)
		z[i] = 1
		return self._mix(z)


	def p_w1_given_d(self, d_dist):
		# d_dist is [d] tensor, should return [d,w1.digit] tensor

		out = np.zeros((10, 10))
		for i in range(10):
			out[i,:] = self.singleton_row(i) * d_dist[i]
		return out

	def p_w2_given_dw1(self, dw1_dist):
		out = np.zeros((10, 10, 10))
		for i in range(10):
			for j in range(10):
				out[i,j,:,] = self.singleton_row(j) * dw1_dist[i][j]
		return out

	def p_x_given_dw1w2(self, dw1w2_dist):
		out = np.zeros((10,10,10,10))
		for i in range(10):
			for j in range(10):
				for k in range(10):
					out[i,j,k,:] = self.singleton_row(k) * dw1w2_dist[i][j][k]
		return out

	def p_y_given_dw1w2x(self, dw1w2x_dist):
		out = np.zeros((10,10,10,10,10))
		for i in range(10):
			for j in range(10):
				for k in range(10):
					for l in range(10):
						out[i,j,k,l,:] = self.singleton_row(l) * dw1w2x_dist[i][j][k][l]
		return out

	def p_dw1w2xy(self):
		return self.p_y_given_dw1w2x(
				 self.p_x_given_dw1w2(
				   self.p_w2_given_dw1(
				     self.p_w1_given_d(
				     	self.p_d()))))

	def p_xy_given_w1w2(self, w1, w2):
		full_joint = self.p_dw1w2xy()
		p_dxy_given_w1w2 = normalize(full_joint[:, w1,w2,:,:]) # d,w1,w2
		return np.einsum('ijk->jk', p_dxy_given_w1w2)

	def p_x_given_w1w2(self, w1, w2):
		return np.einsum('jk -> j', self.p_xy_given_w1w2(w1,w2))


# =========================================
# =           COLORS                      =
# =========================================

class ColorProbs(DiscreteProbs):
	def __init__(self, massart):
		super().__init__(massart)

	# =================================================
	# =           Unconditional queries               =
	# =================================================
	
	def p_u1(self):
		return unif(COLOR_RANGE)

	def p_w1_color(self):
		return self._mix(self.p_u1())		

	def p_w2_b(self):
		w1_color = self.p_w1_color()
		base_w2b = np.array([[1,1,1, 0, 0, 0],
			                 [0,0,0, 1, 1, 1]]) @ w1_color
		return self._mix(base_w2b)

	def p_x_color(self):
		return self._mix(self.p_w2_b())			
	
	def p_y_color(self):
		return self._mix(self.p_u1())


	# ================================================
	# =           Joints                             =
	# ================================================
	def singleton_row(self, i, n=6):
		z = np.zeros(n)
		z[i] = 1.0
		return self._mix(z)

	def p_w1_given_u1(self, u1_dist):
		out = np.zeros((6, 6))
		for i in range(6):
			out[i,:] = self.singleton_row(i) * u1_dist[i]
		return out

	def p_w2_given_u1w1(self, u1w1_dist):
		out = np.zeros((6,6,2))
		for i in range(6):
			for j in range(6):
				out[i,j,:] = self.singleton_row(j//3, n=2) * u1w1_dist[i][j]
		return out

	def p_x_given_u1w1w2(self, u1w1w2_dist):
		out = np.zeros((6,6,2,2))
		for i in range(6):
			for j in range(6):
				for k in range(2):
					out[i,j,k,:] = self.singleton_row(k, n=2) * u1w1w2_dist[i][j][k]
		return out

	def p_y_given_u1w1w2x(self, u1w1w2x_dist):
		out = np.zeros((6,6,2,2,6))
		for i in range(6):
			for j in range(6):
				for k in range(2):
					for l in range(2):
						out[i,j,k,l,:] = self.singleton_row(i) * u1w1w2x_dist[i][j][k][l]
		return out

	def p_u1w1w2xy(self):
		return self.p_y_given_u1w1w2x(
			     self.p_x_given_u1w1w2(
			       self.p_w2_given_u1w1(
			       	self.p_w1_given_u1(
			       		self.p_u1()))))

	def p_xy_given_w1w2(self, w1, w2):
		full_joint = self.p_u1w1w2xy()
		p_u1xy_given_w1w2 = normalize(full_joint[:,w1,w2,:,:])
		return np.einsum('ijk->jk', p_u1xy_given_w1w2)

	def p_x_given_w1w2(self, w1, w2):
		xy_given_w1w2 = self.p_xy_given_w1w2(w1,w2)
		return np.einsum('jk->j', xy_given_w1w2)


# =========================================
# =           THICKNESSES                 =
# =========================================

class ThickProbs(DiscreteProbs):
	def __init__(self, massart):
		super().__init__(massart)

	# ==============================================
	# =           Unconditional queries            =
	# ==============================================
		
	def p_u2(self):
		return unif(THICK_RANGE)

	def p_w1_thick(self):
		return self._mix(self.p_u2())		

	def p_x_thick(self):
		return self._mix(self.p_u2())

	def p_y_thick(self):
		return self._mix(self.p_x_thick())	
	

	# ========================================
	# =           Joints                     =
	# ========================================

	def singleton_row(self, i, n=3):
		z = np.zeros(n)
		z[i] = 1.0
		return self._mix(z)

	def p_w1_given_u2(self, u2_dist):
		out = np.zeros((3, 3))
		for i in range(3):
			out[i,:] = self.singleton_row(i) * u2_dist[i]
		return out

	def p_x_given_u2w1(self, u2w1_dist):
		out = np.zeros((3,3, 3))
		for i in range(3):
			for j in range(3):
				out[i,j,:] = self.singleton_row(i) * u2w1_dist[i][j]
		return out

	def p_y_given_u2w1x(self, u2w1x_dist):
		out = np.zeros((3,3,3,3))
		for i in range(3):
			for j in range(3):
				for k in range(3):
					out[i,j,k, :] = self.singleton_row(k) * u2w1x_dist[i][j][k]
		return out

	def p_u2w1xy(self):
		return self.p_y_given_u2w1x(
			     self.p_x_given_u2w1(
			       self.p_w1_given_u2(
			         self.p_u2())))

	def p_xy_given_w1(self, w1):
		full_joint = self.p_u2w1xy()
		p_u2xy_given_w1 = normalize(full_joint[:,w1,:,:])
		return np.einsum('ijk->jk', p_u2xy_given_w1)

	def p_x_given_w1(self, w1):
		xy_given_w1 = self.p_xy_given_w1(w1)
		return np.einsum('jk -> j', xy_given_w1)





# ============================================================
# =               Full conditional queries                   =
# ============================================================

class FullProbs(DiscreteProbs):
	def __init__(self, massart):
		super().__init__(massart)
		self.digit_probs = DigitProbs(massart)
		self.color_probs = ColorProbs(massart)
		self.thick_probs = ThickProbs(massart)



	# ====================================================
	# =           Interventional Distributions           =
	# ====================================================
	

	def p_y_dox_digit(self, x_digit, w2_digit=None):
		if w2_digit == None:
			w2_digit = random.randint(0,9)


		p_w1w2xy = np.einsum('abcde -> bcde', self.digit_probs.p_dw1w2xy())
		p_w1 = np.einsum('abcd->a', p_w1w2xy).reshape(10, 1, 1, 1)
		p_w1w2 = np.einsum('abcd -> ab', p_w1w2xy).reshape(10,10, 1, 1)

		prod = p_w1w2xy * p_w1 / p_w1w2 # (w1w2xy) : (10, 10, 10, 10)
		num = np.einsum('abcd -> bcd', prod) # (w2xy) : (10, 10, 10)
		denom = np.einsum('abcd -> bc', prod).reshape(10,10,1) # (w2x) : (10,10,1)

		return (num / denom)[w2_digit, x_digit, :]

	def p_y_dox_color(self, x_color, w2b=None):
		if w2b == None:
			w2b = random.randint(0,1)

		p_w1w2xy = np.einsum('abcde->bcde', self.color_probs.p_u1w1w2xy()) # (6, 2, 2, 6)
		p_w1 = np.einsum('abcd->a', p_w1w2xy).reshape(6,1,1,1) # (6, 1,1,1)
		p_w1w2 = np.einsum('abcd->ab', p_w1w2xy).reshape(6,2,1,1) 

		prod = p_w1w2xy * p_w1 / p_w1w2
		num = np.einsum('abcd->bcd', prod)
		denom = np.einsum('abcd->bc', prod).reshape(2,2,1)

		return (num / denom)[x_color, w2b, :]
	
	def p_y_dox_thickness(self, x_thickness):
		p_w1xy = np.einsum('abcd->bcd', self.thick_probs.p_u2w1xy())
		p_w1 = np.einsum('abc->a', p_w1xy).reshape(3,1,1)
		p_w1w2 = p_w1
		prod = p_w1xy * p_w1 / p_w1w2


		num = np.einsum('abc->bc', prod)
		denom = np.einsum('abc->b', prod).reshape(3,1)
		return (num / denom)[x_thickness, :]


	def p_y_dox(self, x_digit, x_color, x_thickness):
		digits = self.p_y_dox_digit(x_digit)
		colors = self.p_y_dox_color(x_color)
		thickness = self.p_y_dox_thickness(x_thickness)

		return np.einsum('i j k -> ijk', digits, colors, thickness)


	# =======================================================
	# =           Conditional Distributions                 =
	# =======================================================
	
	def p_y_givenx_digit(self, x_digit):
		p_dw1w2xy = self.digit_probs.p_dw1w2xy()
		p_xy = np.einsum('abcde->de', p_dw1w2xy)
		return normalize(p_xy[x_digit,:])


	def p_y_givenxw2_digit(self, x_digit, w2_digit):
		p_dw1w2xy = self.digit_probs.p_dw1w2xy()
		p_w2xy = np.einsum('abcde->cde', p_dw1w2xy)
		return normalize(p_xy[w2_digit, x_digit,:])


	def p_y_givenx_color(self, x_color):
		p_u1w1w2xy = self.color_probs.p_u1w1w2xy()
		p_xy = np.einsum('abcde->de', p_u1w1w2xy)
		return normalize(p_xy[x_color, :])

	def p_y_givenxw2_color(self, x_color, w2_color):
		p_u1w1w2xy = self.color_probs.p_u1w1w2xy()
		p_w2xy = np.einsum('abcde->cde', p_u1w1w2xy)
		return normalize(p_w2xy[w2_color, x_color, :])		

	def p_y_givenx_thickness(self, x_thickness):
		p_u2w1xy = self.thick_probs.p_u2w1xy()
		p_xy = np.einsum('abcd->cd', p_u2w1xy)
		return normalize(p_xy[x_thickness, :])

	def p_y_givenx(self, x_digit, x_color, x_thickness):
		digits = self.p_y_givenx_digit(x_digit)
		colors = self.p_y_given_x_color(x_color)
		thickness = self.p_y_givenx_thickness(x_thickness)
		return np.einsum('i j k -> ijk', digits, colors, thickness)

	
	


	
