library(data.table)
library(glue)
library(parallel)
library(doParallel)
library(Rfast) # provides a (very!) efficient implementation of the NxN distance calculation
library(xtable)

source('./helpers.R')
source('./logic.R')

use.saved.data <- TRUE

seed <- 1
set.seed(seed)

d2 <- read_csv('../data/gib_data_final.csv')
mean(d2$y)

acc.human <- printable_confusion_matrix(d2) %>% mutate(decision.rule = 'Physician Discretion')
raw.acc.gbs <- lapply(0:22, function(i) {
  calc_confusion_matrix(d2 %>% mutate(y_hat = total_score > i)) %>% mutate(threshold = i)
}) %>% rbindlist %>% as_tibble

acc.gbs <- lapply(0:22, function(i) {
  printable_confusion_matrix(d2 %>% mutate(y_hat = total_score > i)) %>% mutate(threshold = i)
}) %>% rbindlist %>% as_tibble


acc.gbs.filtered <- (acc.gbs %>% filter(threshold <= 2)) %>% rbind(acc.gbs[which.max(raw.acc.gbs$acc), ]) %>%
  mutate(decision.rule = map_chr(threshold, function(t) glue('Admit GBS > {t}'))) %>%
  select(-threshold)

table1 <- rbind(acc.human, acc.gbs.filtered)
table1 <- table1[, c('decision.rule', 'admitted.frac', 'acc', 'tpr', 'tnr')]

colnames(table1) <- c('Decision Rule', 'Fraction Hospitalized', 'Accuracy', 'Sensitivity', 'Specificity')

hline <- c(-1,0,nrow(table1))
htype <- c("\\toprule ", "\\midrule ","\\bottomrule ")


print(xtable(table1, 
             type = "latex",
             caption = "Physician and GBS performance",
             label = 'tab:physician and gbs perf',
             auto = TRUE,
             digits = 2),
      include.rownames = FALSE, table.placement = '!htbp', caption.placement = 'top', sanitize.text.function=function(x){x},
      add.to.row = list(pos = as.list(hline),
                        command = htype),  
      hline.after = NULL,
      file = '../figures/table1.tex')


max.L <- floor(nrow(d2) / 2)

if (!file.exists('../data/res.100.RData') || !use.saved.data) {
  res.100 <- generate_p_value(d2, L = 100, K = 1000, loss_type = '01', return.diagnostics = T)
  save(res.100, file = '../data/res.100.RData')
}

if (!file.exists('../data/res.250.RData') || !use.saved.data) {
  res.250 <- generate_p_value(d2, L = 250, K = 1000, loss_type = '01', return.diagnostics = T)
  save(res.250, file = '../data/res.250.RData')
}

if (!file.exists('../data/res.500.RData') || !use.saved.data) {
  res.500 <- generate_p_value(d2, L = 500, K = 1000, loss_type = '01', return.diagnostics = T)
  save(res.500, file = '../data/res.500.RData')
}

if (!file.exists('../data/res.1000.RData') || !use.saved.data) {
  res.1000 <- generate_p_value(d2, L = 1000, K = 1000, loss_type = '01', return.diagnostics = T)
  save(res.1000, file = '../data/res.1000.RData')
}

if (!file.exists('../data/res.max.RData') || !use.saved.data) {
  res.max <- generate_p_value(d2, L = max.L, K = 1000, loss_type = '01', return.diagnostics = T)
  save(res.max, file = '../data/res.max.RData')
}

load('../data/res.100.RData')
load('../data/res.250.RData')
load('../data/res.500.RData')
load('../data/res.1000.RData')
load('../data/res.max.RData')

print_table <- function(res, L) {
  nimp <- res$n_imperfect
  chg <- res$possible_changes_01 %>% nrow
  impr <- res$possible_improvements_01 %>% nrow
  decr <- chg - impr
  tau <- res$p_value
  
  data.frame(L = L, `mismatched pairs` = nimp, `swaps that increase loss` = decr, `swaps that decrease loss` = impr, tau = tau)
}


table2 <- print_table(res.100, 100) %>%
  rbind(print_table(res.250, 250)) %>%
  rbind(print_table(res.500, 500)) %>%
  rbind(print_table(res.1000, 1000)) %>%
  rbind(print_table(res.max, max.L)) %>%
  mutate(tau = map_chr(tau, function(t) if (t < .001) glue('<.001') else round(t,3) %>% as.character))

colnames(table2) <- c('L', 'mismatched pairs', 'swaps that increase loss', 'swaps that decrease loss', '$\\tau$')

hline <- c(-1,0,nrow(table2))
htype <- c("\\toprule ", "\\midrule ","\\bottomrule ")

print(xtable(table2, 
             type = "latex",
             caption = "Testing for physician expertise",
             label = 'tab:testing for expertise',
             auto = TRUE),
      include.rownames = FALSE, table.placement = '!htbp', caption.placement = 'top', sanitize.text.function=function(x){x},
      add.to.row = list(pos = as.list(hline),
                        command = htype),
      hline.after = NULL,
      file = '../figures/table2.tex')

