class BaseLinearPredictor:
    samp_rate = 2048

    pos_col = 'pos'
    diff_col = 'word_diff'

    def __init__(self, args, data_loader, pred_lst):
        self.duration = args.window_length
        self.delta = args.window_delta
        self.reload = args.reload
        self.anno_source = args.anno_source
        self.time_back = args.time_back
        self.k_back = args.k_back
        self.pred_lst = pred_lst
        self.data_loader = data_loader
        self.unit_variance = args.unit_variance
        self.sentence_position = args.sentence_position
        self.X_df, self.samp_idxs = self.get_predictors_dataframe(pred_lst)
        self.permute_targets = args.permute_targets
        self.verbose = False #NOTE hardcode
        self.clear_cache = args.clear_cache
        self.out_dir = data_loader.out_dir

    @staticmethod
    def get_model_str(col_lst, y_str):
        mod_str = y_str + ' ~ '
        col_lst.remove(y_str)
        for col_str in col_lst:
            mod_str += col_str + ' + '
        mod_str = mod_str[:-3]
        return mod_str

    def get_r_model_str(self, col_lst, y_str):
        return self.get_model_str(col_lst, y_str)[:-len('trial')] + '(1|trial)'

    def remove_prev_features_by_time(self, df):
        tmp_df = df.copy(deep=True)
        if self.k_back > 0:
            prev_1_cols = [c for c in tmp_df.columns if 'prev_1_' in c]
            tmp_df.loc[tmp_df[self.diff_col] > self.time_back, prev_1_cols] = np.nan
            for i in range(1, self.k_back):
                tmp_df[[c for c in tmp_df.columns if 'prev_{}_'.format(i+1) in c]] = \
                    tmp_df[[c for c in tmp_df.columns if 'prev_{}_'.format(i) in c]].shift(i+1)
        return tmp_df.dropna()

    def get_predictors_dataframe(self, pred_lst):
        pred_df = self.data_loader.get_predictors()
        pred_df = pred_df.copy() #To avoid warnings about setting on a copy of a slice
        offset_list = pred_df.is_onset.shift(-1).to_list()
        offset_list[-1] = 1
        pred_df["is_offset"] = offset_list
        if self.sentence_position=="off":
            pred_df = pred_df[pred_df["is_offset"]==1]
        elif self.sentence_position=="on":
            pred_df = pred_df[pred_df["is_onset"]==1]
        elif self.sentence_position=="mid":
            pred_df = pred_df[(pred_df["is_onset"]!=1) & (pred_df["is_offset"]!=1)]
        else:
            assert self.sentence_position=="all"
        pred_lst = np.array(pred_lst)
        pos_feature_idx = np.where(['pos-' in p for p in pred_lst])[0]
        if len(pos_feature_idx) > 0:
            pos_feature_lst = np.array([pred_lst[i].split('-')[1] for i in pos_feature_idx])
            pred_lst = np.delete(pred_lst, pos_feature_idx)
            pred_df = pred_df.loc[pred_df[self.pos_col].str.lower().isin(pos_feature_lst)]
            pred_df = self.remove_prev_features_by_time(pred_df)
            assert not pred_df[pred_lst].mean().isnull().any()
            #X_df = (pred_df[pred_lst] - np.nanmean(pred_df[pred_lst])).dropna()
            X_df = (pred_df[pred_lst] - pred_df[pred_lst].mean()).dropna()
            if self.unit_variance:
                X_df = X_df/(X_df.std() + EPSILON)
            X_df[self.pos_col] = pred_df[self.pos_col]
            #for i in range(self.k_back):
            #    X_df['prev_{}_'.format(i+1) + self.pos_col] = pred_df['prev_{}_'.format(i+1) + self.pos_col]
        else:
            pred_df = self.remove_prev_features_by_time(pred_df)
            X_df = (pred_df[pred_lst] - np.nanmean(pred_df[pred_lst])).dropna()
        samp_idxs = X_df.index
        X_df = X_df.reset_index(drop=True)
        #self.full_event_df = pred_df.loc[self.samp_idxs].reset_index(drop=True)
        return X_df, samp_idxs

    def integrate_voltage(self, elec, method='mean'):
        triggers = self.event_df.est_trig.values.astype(int)
        deltas = ((self.event_df.end.values - self.event_df.start.values) * self.samp_rate).astype(int)
        voltage_arr = np.zeros((self.neural_data.shape[0], len(triggers)))
        for i, (t, d) in enumerate(zip(triggers, deltas)):
            if method == 'mean':
                voltage_arr[:, i] = self.neural_data[:, t:t + d].mean(axis=1)
            elif method == 'sum':
                voltage_arr[:, i] = self.neural_data[:, t:t + d].sum(axis=1)
            else:
                raise Exception('Method {} not implemanted'.format(method))
        if len(elec) > 1:
            voltage_arr = voltage_arr.mean(axis=0)
        return voltage_arr

    def compute_electrode_signals(self, elec):
        elec_y = self.sig_mat[elec]
        if len(elec) > 1:
            elec_y = elec_y.mean(axis=0)
        return np.squeeze(elec_y)

    def peak_time_target(self, elec, peak_thresh=0):
        if len(elec) > 1:
            window_data = self.data_loader.get_signal_windows(self.duration, self.delta, self.reload)[elec, self.samp_idxs]
            peak_idx = window_data.mean(axis=0).argmax(axis=1)
        else:
            peak_idx = np.load(target_time_data_file)
            peak_idx = self.data_loader.get_sample_peak_times(self.duration, self.delta, self.reload)[elec, self.samp_idxs]
        idxs = np.where(peak_idx > peak_thresh * self.samp_rate // 1000)
        return peak_idx[idxs], idxs[0]

    def peak_amplitude_target(self, elec, peak_thresh=0):
        if len(elec) > 1:
            window_data = self.data_loader.get_signal_windows(self.duration, self.delta, self.reload)[elec, self.samp_idxs]
            elec_sig = window_data.mean(axis=0)
            peak_idx = elec_sig.argmax(axis=1)
            peak_amp = elec_sig.max(axis=1)
        else:
            peak_idx = self.data_loader.get_sample_peak_times(self.duration, self.delta, self.reload)[elec, self.samp_idxs]
            peak_amp = self.data_loader.get_sample_peak_amplitudes(self.duration, self.delta, self.reload)[elec, self.samp_idxs]
        idxs = np.where(peak_idx > peak_thresh * self.samp_rate // 1000)
        return peak_amp[idxs], idxs[0]

    def get_target(self, elec, target, peak_thresh):
        if target == 'time':
            target, idxs = self.peak_time_target(elec, peak_thresh)
        elif target == 'amp':
            target, idxs = self.peak_amplitude_target(elec, peak_thresh)
        elif target == 'integrate':
            target, idxs = self.integrate_voltage(elec)
        else:
            raise Exception('Target named \'{}\' isn\'t implemanted'.format(target))
        return target, idxs


class GLMPredictorR(BaseLinearPredictor):
    def __init__(self, args, data_loader, pred_lst):
        super(GLMPredictorR, self).__init__(args, data_loader, pred_lst)

    def run(self, elec, target_type, peak_thresh=0):
        X_df = self.X_df.copy()
        target, idxs = self.get_target(elec, target_type, peak_thresh)
        X_df = X_df.loc[idxs]
        X_df['target'] = target

        if self.permute_targets:
            random.shuffle(target)
            X_df.target = target

        model_str = self.get_model_str(X_df.columns.tolist(), 'target')
        model = Lm(model_str, data=X_df)

        if self.verbose:
            print('\n', '#' * 35, self.data_loader.get_channel_label_by_index(elec), '#' * 35)
        model.fit(verbose=self.verbose, summarize=self.verbose, no_warnings=(not self.verbose))
        return model

class GLMRunner():
    def __init__(self, cfg, task, model, criterion):
        self.cfg = cfg
        self.model = model#NOTE: not used
        self.task = task
        self.criterion = criterion#NOTE: not used
        self.exp_dir = os.getcwd()

