import os
import sklearn
import numpy as np
from sklearn.metrics import mean_squared_error


class VectorRegression(sklearn.base.BaseEstimator):
    """Class to perform regression on multiple outputs."""

    def __init__(self, estimator):
        self.estimator = estimator

    def fit(self, x, y):
        _, m = y.shape
        # Fit a separate regressor for each column of y
        self.estimators_ = []
        for i in range(m):
            idx = np.where(y[:, i] != -1)[0]
            x_idx = x[idx]
            y_idx = y[idx, i]
            self.estimators_.append(sklearn.base.clone(self.estimator).fit(x_idx, y_idx))
        return self

    def predict(self, x):
        # Join regressors' predictions
        res = [est.predict(x)[:, np.newaxis] for est in self.estimators_]
        return np.hstack(res)

    def score(self, x, y):
        # Join regressors' scores
        res = []
        for i, est in enumerate(self.estimators_):
            idx = np.where(y[:, i] != -1)[0]
            x_idx = x[idx]
            y_idx = y[idx, i]
            res.append(est.score(x_idx, y_idx))
        # print('score', res)
        return np.mean(res)


def load_embeds_and_labels(save_path, label_all=False):
    save_path = os.path.join(save_path, 'eval')
    name = 'labels_all.npy' if label_all else 'labels_new.npy'

    train_embs = np.load(f'{save_path}/train_embeds.npy')
    train_labels = np.load(f'{save_path}/train_{name}')
    has_labels_idx = np.where(train_labels != -1)[0]  # if labels = -1, haven't labeled these clips
    print(f'Valid train labels {len(has_labels_idx)} | All train data {len(train_labels)}')
    train_embs = train_embs[has_labels_idx]
    train_labels = train_labels[has_labels_idx]

    val_embs = np.load(f'{save_path}/val_embeds.npy')
    val_labels = np.load(f'{save_path}/val_{name}')
    has_labels_idx_val = np.where(val_labels != -1)[0]
    print(f'Valid val labels {len(has_labels_idx_val)} | All val data {len(val_labels)}')
    val_embs = val_embs[has_labels_idx_val]
    val_labels = val_labels[has_labels_idx_val]
    return train_embs, train_labels, val_embs, val_labels


def construct_embs_labels_list(embs, labels, video_len_list, modify_embeddings, return_list=False):
    cur_idx = 0
    embs_list, labels_list = [], []
    for i in range(len(video_len_list)):
        video_len = video_len_list[i]
        embs_tmp = embs[cur_idx: cur_idx + video_len, :]
        if modify_embeddings:
            col = np.arange(video_len).reshape(-1, 1) * 1e-3
            embs_tmp = np.concatenate((embs_tmp, col), axis=1)
        embs_list.append(embs_tmp)
        labels_list.append(labels[cur_idx: cur_idx + video_len])
        cur_idx = cur_idx + video_len
    if return_list:
        return embs_list, labels_list
    labels_list = get_targets_from_labels(labels_list, num_classes=int(max(labels)) + 1)
    embs = np.concatenate(embs_list, axis=0)
    labels = np.concatenate(labels_list, axis=0)
    return embs, labels


def regression_labels_for_class(labels, class_idx):
    # Assumes labels are ordered. Find the last occurrence of particular class.
    if class_idx not in labels:
        return -1 * np.ones(len(labels))
        # assert class_idx - 1 in labels
        # return regression_labels_for_class(labels, class_idx-1)  #-1 * np.ones(len(labels))
    transition_frame = np.argwhere(labels == class_idx)[-1, 0]
    return (np.arange(float(len(labels))) - transition_frame) / len(labels)


def get_regression_labels(class_labels, num_classes):
    regression_labels = []
    for i in range(num_classes):
        regression_labels.append(regression_labels_for_class(class_labels, i))
    return np.stack(regression_labels, axis=1)


def get_targets_from_labels(all_class_labels, num_classes):
    all_regression_labels = []
    for class_labels in all_class_labels:
        all_regression_labels.append(get_regression_labels(class_labels,
                                                           num_classes))
    return all_regression_labels


def compute_progression_value(train_embs, train_labels, val_embs, val_labels, train_video_len_list, val_video_len_list, modify_embeddings):
    train_embs, train_labels = construct_embs_labels_list(train_embs, train_labels, train_video_len_list, modify_embeddings)
    val_embs, val_labels = construct_embs_labels_list(val_embs, val_labels, val_video_len_list, modify_embeddings)

    lin_model = VectorRegression(sklearn.linear_model.LinearRegression())
    lin_model.fit(train_embs, train_labels)
    train_score = lin_model.score(train_embs, train_labels)
    val_score = lin_model.score(val_embs, val_labels)
    # y_pred = lin_model.predict(val_embs)
    # rmse = np.sqrt(mean_squared_error(val_labels, y_pred))
    print(f'Phase progression score: train = {train_score:.4f}, val = {val_score:.8f}')
    return train_score, val_score


def progression(save_path, train_video_len_list, val_video_len_list, modify_labels=False, modify_embeddings=False):
    train_embs, train_labels, val_embs, val_labels = load_embeds_and_labels(save_path, label_all=False)

    if modify_labels:
        idx_label0 = np.where(train_labels != 0)[0]
        idx_label2 = np.where(train_labels != 2)[0]  # do not consider frames with label = 2
        val_idx_label0 = np.where(val_labels != 0)[0]
        val_idx_label2 = np.where(val_labels != 2)[0]

        print(f'Train {len(train_labels)}, {len(idx_label0)} frames label != 0, {len(idx_label2)} frames label != 2')
        print(f'Val {len(val_labels)}, {len(val_idx_label0)} frames label != 0, {len(val_idx_label2)} frames label != 2')

        train_score1, val_score1 = compute_progression_value(train_embs[idx_label2], train_labels[idx_label2],
                                                             val_embs[val_idx_label2], val_labels[val_idx_label2],
                                                             train_video_len_list, val_video_len_list, modify_embeddings=False)
        train_score2, val_score2 = compute_progression_value(train_embs[idx_label0], train_labels[idx_label0] - 1,
                                                             val_embs[val_idx_label0], val_labels[val_idx_label0] - 1,
                                                             train_video_len_list, val_video_len_list, modify_embeddings=False)
        train_score = 0.5 * train_score1 + 0.5 * train_score2
        val_score = 0.5 * val_score1 + 0.5 * val_score2
        return train_score, val_score

    else:
        return compute_progression_value(train_embs, train_labels,
                                         val_embs, val_labels,
                                         train_video_len_list, val_video_len_list,
                                         modify_embeddings)


def progression_semi(save_path, train_video_len_list):
    train_embs, train_labels, val_embs, val_labels = load_embeds_and_labels(save_path, label_all=False)
    train_embs_list, train_labels_list = construct_embs_labels_list(train_embs, train_labels, train_video_len_list,
                                                                    modify_embeddings=False, return_list=True)
    num_episodes = 10
    labeled_ratio_list = [0.5]
    num_samples = len(train_embs_list)
    results_labeled_ratio = []
    for labeled_ratio in labeled_ratio_list:
        num_labeled = int(labeled_ratio * num_samples)
        results = np.zeros((num_episodes), dtype=float)
        for i in range(num_episodes):
            tmp = np.random.permutation(num_samples)
            idx1 = tmp[:num_labeled]
            idx2 = tmp[num_labeled:]

            train_embs1 = np.concatenate([train_embs_list[i] for i in idx1], axis=0)
            train_labels1 = np.concatenate([train_labels_list[i] for i in idx1], axis=0)
            train_embs2 = np.concatenate([train_embs_list[i] for i in idx2], axis=0)
            train_labels2 = np.concatenate([train_labels_list[i] for i in idx2], axis=0)

            lin_model = VectorRegression(sklearn.linear_model.LinearRegression())
            lin_model.fit(train_embs1, train_labels1)
            train_score = lin_model.score(train_embs1, train_labels1)
            train_score_prop = lin_model.score(train_embs2, train_labels2)
            print(f'Run {i} train score {train_score:.4f}, train score propagated {train_score_prop:.4f}')
            results[i] = train_score_prop
        results_labeled_ratio.append(np.mean(results))
        print(f'{labeled_ratio} few shot progress score {np.mean(results):.4f} +- {np.std(results):.4f}')

    # print(f'{labeled_ratio_list}, results {results_labeled_ratio}')
    return results_labeled_ratio


