using GerryChain
using Plots, Shapefile
using Printf
using JLD
using SparseArrays
using Random
using DataFrames
using CSV

### METADATA ###
SHAPEFILE_PATH     = "Shapefiles/WI/WI_dual_graph_2020.json"
POPULATION_COL     = "TOTPOP"
ENSEMBLE_FILENAME  = "../../Redistricting_via_Local_Fairness/Real_Ensembles/WICD"
NUM_DISTRICTS      = 8
BLUE_VOTES         = "PRES16D"
RED_VOTES          = "PRES16R"
NUM_MAPS           = 1
DISTRICTING_PATH   = "../../Redistricting_via_Local_Fairness/Real_Ensembles/WICD_"

# Specify parameters, defaulted to be the whole ensemble
START_MAP             = 1
END_MAP               = NUM_MAPS

# These are defaulted to be max # of districts; no need to modify if no performance issues
SAMPLE_SIZE           = NUM_DISTRICTS * (END_MAP - START_MAP + 1) # How many districts to sample before moving on
TERMINATION_THRESHOLD = NUM_DISTRICTS * (END_MAP - START_MAP + 1) # How many deviating group must be found before terminating a map

# Initialize graph
graph = BaseGraph(SHAPEFILE_PATH, POPULATION_COL)

println("Loading the ensemble from file...")
partitions = load(string(ENSEMBLE_FILENAME, ".jld"), "maps")
println("Successfully loaded the ensemble from file.")

# For every district in the current map, we want to decide which party wins, and the partisanship
for cur_map_index = START_MAP:END_MAP

    districting_table = DataFrame(
            "Precinct" => Int[],
            "District" => Int[],
            "Partisanship" => Float64[]
            )
    new_format = Int[]
    district_color = Array{String}(undef, NUM_DISTRICTS)
    district_partisanship = zeros(NUM_DISTRICTS)
    #total # of voters who voted democrat or republican
    red_blue_votes = 0
    total_population = 0
    for district_index = 1:NUM_DISTRICTS
        current_district = partitions[cur_map_index].dist_nodes[district_index]

        # Count the votes, and decide whether blue or red wins for that district
        democrat_votes = 0
        republic_votes = 0
        for precinct in current_district
            democrat_votes += graph.attributes[precinct][BLUE_VOTES]
            republic_votes += graph.attributes[precinct][RED_VOTES]
            red_blue_votes += graph.attributes[precinct][BLUE_VOTES] +  graph.attributes[precinct][RED_VOTES]
            total_population += graph.attributes[precinct][POPULATION_COL]
        end
        if (democrat_votes > republic_votes)
            district_color[district_index] = "Blue"
            district_partisanship[district_index] = -democrat_votes / (democrat_votes + republic_votes)
        else
            district_color[district_index] = "Red"
            district_partisanship[district_index] = republic_votes / (democrat_votes + republic_votes)
        end
        for precinct in current_district
            push!(districting_table, (precinct, district_index, district_partisanship[district_index]))
        end
    end
    sort!(districting_table, [:Precinct])
    CSV.write(string(DISTRICTING_PATH, cur_map_index, ".csv"), districting_table)
end
println("Finished outputing dataframes to CSV file.")
