--[[
	experiments for nips 2017 paper @ lognormal @ maze domain
]]--

---------------------------------------------

-- Choose one of the three modes below
--local mode = "manually interact with maze"
--local mode = "compute optimal policy"
local mode = "get distribution of estimated values"

local params = {
	pForward = 0.6,
	
	--typeFail = "still", 	-- type of failure: stand still with 1-pForward, or
	typeFail = "random", 	-- type of failure: move randomly with 1-pForward

	rGoal = 1,
	rateFlag = 100, -- goal reward = rGoal * rateFlag^(#obtainedFlags)
	gamma = 0.9,
	
	numRuns = 1000,
	errVI = 1e-8,

	pathMaze,
	pathPi,
	numStates,
	numActions = 4,	
}

-- number of samples. Note: NOT for each state, but for each set of runs.
local numSamples = {}
for j = 1, 20 do
	numSamples[j] = 10 * j
end

---------------------------------------------

local currentNumSample
local maze = {
	map,	-- map[y][x]
	xMax, yMax,
	yxToPosId,
	posIdToX,
	posIdToY,
	startX,
	startY,
};


local function LoadMaze(pathMaze)
	local rawMap = dofile(pathMaze)
	maze.map = {}
	maze.posIdToX = {}
	maze.posIdToY = {}
	maze.yxToPosId = {}
	maze.yMax = #(rawMap)
	maze.xMax = #(rawMap)
	
	-- load map & count numStates
	local n = 0
	for y = 1, maze.yMax do
		maze.map[y] = {}
		maze.yxToPosId[y] = {}
		for x = 1, maze.xMax do
			maze.map[y][x] = rawMap[y]:sub(x,x)
			if maze.map[y][x] ~= "X" then 
				n = n + 1
				maze.posIdToX[n] = x
				maze.posIdToY[n] = y
				maze.yxToPosId[y][x] = n
				if maze.map[y][x] == "S" then 
					maze.startX = x
					maze.startY = y
				end
			end
		end
	end
	params.numStates = n * 8 	-- n positions, 3 flags, so n * 2^3 states in total
end

local function StateIdToPosFlag(stateId)
	local f = (stateId - 1) % 8
	local pos = (stateId - 1 - f) / 8 + 1
	return pos, f
end

local function PosFlagToStateId(pos, flag)
	return (pos - 1)*8 + flag + 1
end

local dirX = {0, -1, 1, 0, 0}
local dirY = {-1, 0, 0, 1, 0}
local dirStr = {"^", "<", ">", "v", "o"}
--[[
	   1
	2 (5) 3
	   4
]]--

local function MazeFunction(x, y, f, xNew, yNew)
	local r = 0
	local fNew = f
	
	if maze.map[yNew][xNew] == "X" then
		xNew = x
		yNew = y
	end
	
	if maze.map[y][x] == "G" then
		local rate = 1
		if bit.band(f, 4) > 0 then rate = rate * params.rateFlag end
		if bit.band(f, 2) > 0 then rate = rate * params.rateFlag end
		if bit.band(f, 1) > 0 then rate = rate * params.rateFlag end
		r = params.rGoal * rate
		fNew = 0
		xNew = maze.startX
		yNew = maze.startY
		
	elseif maze.map[y][x] == "T" then
		fNew = 0
		xNew = maze.startX
		yNew = maze.startY	
		
	elseif maze.map[y][x] == "1" then
		fNew = bit.bor(f, 4)
		
	elseif maze.map[y][x] == "2" then
		fNew = bit.bor(f, 2)
		
	elseif maze.map[y][x] == "3" then
		fNew = bit.bor(f, 1)
	end

	return PosFlagToStateId(maze.yxToPosId[yNew][xNew], fNew), r
end

local function Step(state, action)

	local pos, f = StateIdToPosFlag(state)
	local x, y = maze.posIdToX[pos], maze.posIdToY[pos]
	
	local xNew, yNew
	
	if params.typeFail == "still" then
		if math.random() > params.pForward then
			xNew, yNew = x, y
		else
			xNew = math.min(math.max(x + dirX[action], 1), maze.xMax)
			yNew = math.min(math.max(y + dirY[action], 1), maze.yMax)
		end
		
	elseif params.typeFail == "random" then
		if math.random() > params.pForward then
			action = math.random(4)
		end
		xNew = math.min(math.max(x + dirX[action], 1), maze.xMax)
		yNew = math.min(math.max(y + dirY[action], 1), maze.yMax)
	end
	
	return MazeFunction(x, y, f, xNew, yNew)
end

local function GetTrueTransition(state, action)

	local trans, rewards, adj = {}, {}, {n = 0}
	
	local pos, f = StateIdToPosFlag(state)
	local x, y = maze.posIdToX[pos], maze.posIdToY[pos]
	
	for actionOutcome = 1, 5 do
		local xNew, yNew
		local fNew = f
		
		local pr = 0
		if params.typeFail == "still" then
			if actionOutcome == 5 then
				pr = 1 - params.pForward
			elseif actionOutcome == action then
				pr = params.pForward
			end
			
		elseif params.typeFail == "random" then
			if actionOutcome == action then
				pr = 0.25*(1 - params.pForward) + params.pForward
			elseif actionOutcome ~= 5 then
				pr = 0.25*(1 - params.pForward)
			end
		end
		
		if pr > 0 then
			xNew = math.min(math.max(x + dirX[actionOutcome], 1), maze.xMax)
			yNew = math.min(math.max(y + dirY[actionOutcome], 1), maze.yMax)
			
			local arriveAt, r = MazeFunction(x, y, f, xNew, yNew)
			if not trans[arriveAt] then 
				adj.n = adj.n + 1
				adj[adj.n] = arriveAt
				trans[arriveAt] = pr
				rewards[arriveAt] = r
			else
				trans[arriveAt] = trans[arriveAt] + pr
			end
		end
	end
	
	return trans, rewards, adj
end

local function FlagToBits(flag)
	return (bit.band(flag, 4)>0 and 1 or 0)
		.. (bit.band(flag, 2)>0 and 1 or 0) 
		.. (bit.band(flag, 1)>0 and 1 or 0)
end

local function MainTestMaze()
	params.pathMaze = "mazes/1.lua"
	LoadMaze(params.pathMaze)
	
	local state = PosFlagToStateId(maze.yxToPosId[maze.startY][maze.startX], 0)
	local action = 0
	local reward = 0
	
	while true do
		local posNow, flagNow = StateIdToPosFlag(state)
		local xNow = maze.posIdToX[posNow]
		local yNow = maze.posIdToY[posNow]
		for y = 1, maze.yMax do
			for x = 1, maze.xMax do
				if x == xNow and y == yNow then
					io.write("@")
				else
					io.write(maze.map[y][x])
				end
			end
			io.write("\n")
		end
		io.write("flag (000~111): "..FlagToBits(flagNow))
		
		action = 0
		while (action < 1 or action > params.numActions) do
			io.write("\nSelect action (1~4) > ")
			action = io.read("*number")
		end
		
		state, reward = Step(state, action)
		io.write("reward: ", reward, "\n--------\n")
	end
end

local function ValueIteration(T, R, V, Q, adjTable)
--
	while true do
		local delta = 0
		for s = 1, params.numStates do
			local maxQ = -1e20
			for a = 1, params.numActions do
				if adjTable[s][a].n > 0 then
					local sum = 0
					for i = 1, adjTable[s][a].n do
						sum = sum + T[s][a][ adjTable[s][a][i] ]*(
								R[s][a][ adjTable[s][a][i] ] + 
								params.gamma*V[ adjTable[s][a][i] ]
							)
					end
					
					delta = math.max(delta, math.abs(Q[s][a]-sum))
					Q[s][a] = sum
					maxQ = math.max(maxQ, sum)
				end
			end
			if maxQ > -1e20 then
				V[s] = maxQ
			end	
		end
		if delta < params.errVI then
			break
		end
	end
end

-- the method below seems more beautiful (no awkward adjTable) but is super slow
local function ValueIterationSuperSlow(T, R, V, Q)
--
	while true do
		local delta = 0
		for s = 1, params.numStates do
			local maxQ = -1e20
			for a = 1, params.numActions do
				local sum = 0
				for s2,_ in pairs(T[s][a]) do
					sum = sum + T[s][a][s2]*(
							R[s][a][s2] + 
							params.gamma*V[s2]
						)
				end
				delta = math.max(delta, math.abs(Q[s][a]-sum))
				Q[s][a] = sum
				maxQ = math.max(maxQ, sum)
			end
			if maxQ > -1e20 then
				V[s] = maxQ
			end
		end
		if delta < params.errVI then
			break
		end
	end
end

local function ValueEstimateGivenPolicy(T, R, V, adjTable, pi) -- does NOT compute Q!
--
	while true do
		local delta = 0
		for s = 1, params.numStates do
			local a = pi[s]
			if adjTable[s][a].n > 0 then
				local sum = 0
				for i = 1, adjTable[s][a].n do
					sum = sum + T[s][a][ adjTable[s][a][i] ]*(
							R[s][a][ adjTable[s][a][i] ] + 
							params.gamma*V[ adjTable[s][a][i] ]
						)
				end
				
				delta = math.max(delta, math.abs(V[s]-sum))
				V[s] = sum
			end
		end
		if delta < params.errVI then
			break
		end
	end
end

local function WriteAndPrintPolicy(fOptimal, Q)
	local pi = {}
	fOptimal:write("local pi={")
	for s = 1, params.numStates do
		local a0 = 1
		for a = 2, params.numActions do
			if Q[s][a] > Q[s][a0] then
				a0 = a
			end
		end
		pi[s] = a0
		fOptimal:write(a0..',')
	end
	fOptimal:write("}\nreturn pi\n")
	
	local piMap = {}
	for y = 1, maze.yMax do
		piMap[y] = {}
		for x = 1, maze.xMax do
			piMap[y][x] = maze.map[y][x]
		end
	end
	for f = 0, 7 do
		io.write("flag: "..FlagToBits(f).."\n")
		for s = 1, params.numStates do
			local pos, flag = StateIdToPosFlag(s)
			if f == flag then
				piMap[maze.posIdToY[pos]][maze.posIdToX[pos]] = dirStr[pi[s]]
			end
		end
		
		for y = 1, maze.yMax do
			for x = 1, maze.xMax do
				io.write(piMap[y][x])
			end
			io.write("\n")
		end
		io.write("\n---\n")
	end
	io.write("===\n")
	
	return pi
end

local function CompareT(T1, T2) -- for debugging
	print("Comparing T ...", T1, T2)
	local count = 0
	for s = 1, params.numStates do
		for a = 1, params.numActions do
			for s2, t in pairs(T1[s][a]) do
				count = count + 1
				if t ~= T2[s][a][s2] then
					local p, f = StateIdToPosFlag(s)
					local p2, f2 = StateIdToPosFlag(s2)
					print(s.."("..p..","..f..")", a, s2.."("..p2..","..f2..")", t, T2[s][a][s2])
				end
			end
		end
	end
	print("Compared " .. count .. " pairs of T(s,a,s')")
end

local function Run(fOut, fOptimal, flagUseTrueTransition)
	local Q = {}
	local V = {}
	local T = {}
	local R = {}
	local adjTable = {}
	local Nsa = {}
	local Nsas = {}
	for i = 1, params.numStates do
		V[i] = 0
		Q[i] = {}
		T[i] = {}
		R[i] = {}
		adjTable[i] = {}
		Nsa[i] = {}
		Nsas[i] = {}
		for j = 1, params.numActions do
			Q[i][j] = 0
			T[i][j] = {}
			R[i][j] = {}
			Nsas[i][j] = {}
			adjTable[i][j] = {n = 0}
		end
	end
	
	for s = 1, params.numStates do
		for a = 1, params.numActions do
			if flagUseTrueTransition then
				T[s][a], R[s][a], adjTable[s][a] = GetTrueTransition(s,a)
			else
				Nsa[s][a] = currentNumSample
				
				for t = 1, currentNumSample do
					local s2, r = Step(s, a)
					if not Nsas[s][a][s2] then 
						Nsas[s][a][s2] = 1
						adjTable[s][a].n = adjTable[s][a].n + 1
						adjTable[s][a][adjTable[s][a].n] = s2
						R[s][a][s2] = r
					else
						Nsas[s][a][s2] = Nsas[s][a][s2] + 1
						if R[s][a][s2] ~= r then
							local p, f = StateIdToPosFlag(s)
							local p2, f2 = StateIdToPosFlag(s2)
							print(s.."("..p..","..f..")", a, s2.."("..p2..","..f2..")", r, R[s][a][s2])
						end
					end
				end
				for k = 1, adjTable[s][a].n do
					T[s][a][ adjTable[s][a][k] ] = Nsas[s][a][ adjTable[s][a][k] ] / Nsa[s][a]
				end
			end
		end
	end
	
	if fOptimal then -- for obtaining optimal policy
		ValueIteration(T, R, V, Q, adjTable)
		local pi = WriteAndPrintPolicy(fOptimal, Q)
		if fOut then
			local s = PosFlagToStateId(
				maze.yxToPosId[maze.startY][maze.startX], 
				0
			)
			fOut:write(V[s] .. '\n')
		end
	end
	
	if fOut and not fOptimal then -- for obtaining distribution of estimated values
		local pi = dofile(params.pathPi)
		ValueEstimateGivenPolicy(T, R, V, adjTable, pi)
		local s = PosFlagToStateId(
			maze.yxToPosId[maze.startY][maze.startX], 
			0
		)
		fOut:write(V[s] .. '\n')
	end
end

local function MainGetOptimals()
	local pathPrefix = "mazes/"

	for mapID = 1, 20 do
		params.pathMaze = pathPrefix..mapID..".lua"
		params.pathPi = pathPrefix..mapID..".pi.lua"
		LoadMaze(params.pathMaze)
		local fOptimal = io.open(params.pathPi, "w")
		local fOut = io.open(pathPrefix..mapID..".V.txt", "w")
		
		for i = 1, 1 do
			Run(fOut, fOptimal, true)
		end
		fOptimal:close()
		fOut:close()
	end
end

local function MainGetDistribution()
	local pathPrefix = "mazes/"
	local dateStr = os.date("%y%m%d_%H%M%S")
	
	for mapID = 1, 20 do
		params.pathMaze = pathPrefix..mapID..".lua"
		params.pathPi = pathPrefix..mapID..".pi.lua"
		LoadMaze(params.pathMaze)

		local pathLogsPrefix = "logsMaze/maze"..mapID.."."..dateStr
		-- note: this line will NOT create a folder if it does not exist --

		local fLog = io.open(pathLogsPrefix..".params.txt", "w")
		for k, v in pairs(params) do
			fLog:write(k .. " = " .. tostring(v) .. "\n")
		end
		fLog:close()

		for _, v in ipairs(numSamples) do
			currentNumSample = v
			local fOut = io.open(pathLogsPrefix.."_"..currentNumSample..".v.txt", "w")
			for i = 1, params.numRuns do
				Run(fOut)
			end
			fOut:close()
		end
	end
end

math.randomseed(os.time())
math.random()
math.random()
math.random()
math.random()

if mode == "manually interact with maze" then
	MainTestMaze()
elseif mode == "compute optimal policy" then
	MainGetOptimals()
elseif mode == "get distribution of estimated values" then
	MainGetDistribution()
end


