import os

import numpy as np
import pandas as pd

SD = 0.1


def mean_pivotal_ci(x):
    m_hat = np.mean(x)
    m_samples = bootstrap_mean(x)
    p_lo, p_hi = np.percentile(m_samples, [2.5, 97.5])
    ci_lo, ci_hi = 2*m_hat - p_hi, 2*m_hat - p_lo
    return m_hat, (ci_lo, ci_hi)


def bootstrap_mean(x, n_bootstrap=10000):
    n_data = len(x)
    samples = np.random.choice(x, size=(n_bootstrap, n_data))
    return np.mean(samples, axis=1)


def print_table(base_csv, prop_csv):
    results = 'prediction_output'
    # bins = [2.0, 4.0, 6.0, 8.0, 10.0]
    bins = [12.0, 16.0, 20.0, 24.0]

    base = pd.read_csv(os.path.join(results, base_csv))
    prop = pd.read_csv(os.path.join(results, prop_csv))

    base = base[base.observed == 0.0].sort_values(['sample_num', 't'])
    prop = prop[prop.observed == 0.0].sort_values(['sample_num', 't'])

    base['residual'] = np.abs(base.y - base.y_hat)
    prop['residual'] = np.abs(prop.y - prop.y_hat)

    base['bin'] = pd.cut(base.t, bins)
    prop['bin'] = pd.cut(prop.t, bins)

    base = base[['bin', 'residual']].dropna()
    prop = prop[['bin', 'residual']].dropna()

    residuals = pd.DataFrame({'bin': base.bin, 'base': base.residual, 'prop': prop.residual})
    residuals['diff'] = residuals.base - residuals.prop

    def summarize_inference(x):
        m, (lo, hi) = mean_pivotal_ci(x)
        s = '{:.2f} ({:.2f}, {:.2f})'.format(m / SD, lo / SD, hi / SD)
        return s

    base_table = residuals.groupby('bin').base.agg(summarize_inference)
    prop_table = residuals.groupby('bin').prop.agg(summarize_inference)
    diff_table = residuals.groupby('bin').diff.agg(summarize_inference)

    print(base_table)
    print(prop_table)
    print(diff_table)

    # for time_bin, frame in residuals.groupby('bin'):
    #     print(time_bin)
    #     m, (lo, hi) = mean_pivotal_ci(frame.base - frame.prop)
    #     print('{:.2f} ({:.2f}, {:.2f})'.format(m / SD, lo / SD, hi / SD))


print_table('base_predictions1.csv', 'prop_predictions1.csv')
print()
print()
print_table('base_predictions2.csv', 'prop_predictions2.csv')
print()
print()
print_table('base_predictions3.csv', 'prop_predictions3.csv')
