import os
import os.path as osp
import numpy as np
import csv
import argparse
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import GridSearchCV

def get_feature_dataset(subject, bottleneck, train=True, remove_person_cls=False, cls_balance=False, cls_balance_test=False, fmri_data=False):
    if fmri_data:
        # load the original brain_diffuser fmri data
        train_str = 'train' if train else 'test'
        subj_idx = subject[-1]
        X = np.load(f'/storage/user/BrainBitsWIP/data/processed_data/{subject}/nsd_{train_str}_fmriavg_nsdgeneral_sub{subj_idx}.npy')
        X = X / 300
    else:
        # load the bottleneck features
        X = np.load(f'brain_diffuser_feats/{subject}/train_single_{bottleneck}/{"train" if train else "test"}_intermediates.npy')
    Y = np.load(f'brain_diffuser_largest_object_labels/{subject}_{"tr" if train else "te"}.npy')
    if train:
        # labels not in test set preprocessed to be -1
        train_inds = Y >= 0
        X = X[train_inds]
        Y = Y[train_inds]
    if remove_person_cls:
        labs, counts = np.unique(Y, return_counts=True)
        non_person_inds = Y != labs[counts.argmax()]
        X = X[non_person_inds]
        Y = Y[non_person_inds]
    if cls_balance and train:
        vals_tr, counts_tr = np.unique(Y, return_counts=True)
        min_count = counts_tr.min()
        balanced_inds = []
        for lab_val in vals_tr:
            val_inds = np.where(Y == lab_val)
            val_inds = val_inds[0][:min_count]
            balanced_inds.append(val_inds)
        balanced_inds = np.concatenate(balanced_inds)
        X = X[balanced_inds]
        Y = Y[balanced_inds]
    if cls_balance_test and not train:
        vals_tr, counts_tr = np.unique(Y, return_counts=True)
        min_count = counts_tr.min()
        balanced_inds = []
        for lab_val in vals_tr:
            val_inds = np.where(Y == lab_val)
            val_inds = val_inds[0][:min_count]
            balanced_inds.append(val_inds)
        balanced_inds = np.concatenate(balanced_inds)
        X = X[balanced_inds]
        Y = Y[balanced_inds]
    return X, Y

def write_metrics(row, csv_fp):
    with open(csv_fp, mode='a', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(row)

def main():

    if opt.fmri_data:
        opt.fp_details += '_fmri_data'
    if opt.remove_person_cls:
        opt.fp_details += '_remove_person_cls'
    if opt.cls_balance:
        opt.fp_details += '_cls_balance'
    if opt.cls_balance_test:
        opt.fp_details += '_cls_balance_test'
    if opt.weighted_loss:
        opt.fp_details += '_weighted_loss'
    if opt.normalize:
        opt.fp_details += '_normalized'
    if opt.grid_search_l2_reg:
        opt.fp_details += '_grid_search_l2_reg'
    if opt.solver == 'saga':
        opt.fp_details += '_saga'
    if opt.max_iter != 100:
        opt.fp_details += f'_max_iter={opt.max_iter}'
    if opt.grid_search_max_iter:
        opt.fp_details += '_grid_search_max_iter'

    root_dir = osp.join('brain_diffuser_results', 
                        'logistic_regression'+opt.fp_details,
                        opt.subject)
    os.makedirs(root_dir, exist_ok=True)

    write_metrics(['bottleneck_dim', 'accuracy', 'max_iter'], osp.join(root_dir, 'test.csv'))

    for bottleneck_dim in opt.bottleneck_dims:
        # set up the dataset
        X_tr, Y_tr = get_feature_dataset(opt.subject, bottleneck_dim,
                                        remove_person_cls=opt.remove_person_cls,
                                        cls_balance=opt.cls_balance,
                                        fmri_data=opt.fmri_data)
        X_te, Y_te = get_feature_dataset(opt.subject, bottleneck_dim, 
                                        train=False,
                                        remove_person_cls=opt.remove_person_cls,
                                        cls_balance=opt.cls_balance,
                                        cls_balance_test=opt.cls_balance_test, 
                                        fmri_data=opt.fmri_data)

        if opt.normalize:
            scaler = StandardScaler()
            X_tr = scaler.fit_transform(X_tr)
            X_te = scaler.transform(X_te)

        if opt.weighted_loss:
            vals, counts = np.unique(Y_tr, return_counts=True)
            counts = counts / counts.sum()
            weights = 1 / counts
            weights_dict = {v: w for v, w in zip(vals, weights)}
            # model = LogisticRegression(penalty='l2', class_weight=weights_dict, max_iter=opt.max_iter, solver=opt.solver) # TODO uncomment
            model = LogisticRegression(penalty=None, class_weight=weights_dict, max_iter=opt.max_iter, solver=opt.solver)
        else:
            model = LogisticRegression(penalty='l2', max_iter=opt.max_iter, solver=opt.solver)

        if opt.grid_search_l2_reg:
            param_grid = {'C': [1e-3, 1e-2, 1e-1, 1e0, 1e1]}
            grid_search = GridSearchCV(model, param_grid, cv=5)
            grid_search.fit(X_tr, Y_tr)
            model = grid_search.best_estimator_
            print(f"Best C value: {grid_search.best_params_['C']}")
        elif opt.grid_search_max_iter:
            param_grid = {'max_iter': [10, 25, 50, 75, 100]}
            grid_search = GridSearchCV(model, param_grid, cv=5)
            grid_search.fit(X_tr, Y_tr)
            model = grid_search.best_estimator_
            max_iter_best = grid_search.best_params_['max_iter']
            print(f"Best max_iter value: {max_iter_best}")
        else:
            model.fit(X_tr, Y_tr)
        y_pred = model.predict(X_te)
        acc = np.mean(y_pred == Y_te)

        write_metrics([bottleneck_dim, acc, max_iter_best], osp.join(root_dir, 'test.csv'))
        print(f'subject={opt.subject} | bottleneck_dim={bottleneck_dim} | acc={acc}')

if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument("--subject", type=str, default=None,
                        help="subj01 or subj02  or subj05  or subj07")
    parser.add_argument("--bottleneck_dims", nargs='+', type=int, default=None,
                        help="1 5 10 etc.")
    parser.add_argument("--fp_details", type=str, default='',
                        help="notes for the file path")
    parser.add_argument("--remove_person_cls", type=bool, default=False,
                        help="balance dataset by removing most popular class")
    parser.add_argument("--cls_balance", type=bool, default=False,
                        help="select first k of each class, k=min(class_counts)")
    parser.add_argument("--normalize", type=bool, default=False,
                        help="normalize the bottleneck intermediates")
    parser.add_argument("--cls_balance_test", type=bool, default=False,
                        help="balance the test set")
    parser.add_argument("--weighted_loss", type=bool, default=False,
                        help="weight the loss of by class frequency")
    parser.add_argument("--grid_search_l2_reg", type=bool, default=False,
                        help="optimize l2 weight hyperparameter")
    parser.add_argument("--max_iter", type=int, default=100,
                        help="number of solver iterations")
    parser.add_argument("--fmri_data", type=bool, default=False,
                        help="load the brain data instead of the bottleneck")
    parser.add_argument("--solver", type=str, default='lbfgs',
                    help="lbfgs or saga")
    parser.add_argument("--grid_search_max_iter", type=bool, default=False,
                        help="optimize l2 weight hyperparameter")
    opt = parser.parse_args()

    main()