import json
from pathlib import Path
from dataclasses import dataclass, field
from typing import Optional, Dict, Sequence

import numpy as np
import transformers
from rank_bm25 import BM25Plus
from datasets import load_dataset


PROMPT_TEMPLATE = """You are an exceptionally intelligent coding assistant that consistently delivers accurate and reliable responses to user instructions.

@@ Instruction
Write a solution to the following problem:
```python
{prompt}
```

@@ Response
```python
{prompt}"""


@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default="codellama/CodeLlama-7b-Python-hf")
    peft_model: Optional[str] = field(default="")


@dataclass
class DataArguments:
    data_path: str = field(default=None, metadata={"help": "Path to the training data."})


@dataclass
class TrainingArguments(transformers.TrainingArguments):
    cache_dir: Optional[str] = field(default=None)
    optim: str = field(default="adamw_torch")
    model_max_length: int = field(
        default=512,
        metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
    )
    lora_r: int = field(
        default=8,
        metadata={"help": "LoRA rank."},
    )
    training_split : int = field(
        default=0,
        metadata={"help": "training split."},
    )
    query_idx: int = field(
        default=0,
        metadata={"help": "which query example to be used."},
    )
    solution_idx: int = field(
        default=0,
        metadata={"help": "which solution to query."},
    )


def test_tokenize_function(examples, tokenizer, solution_idx=0):
    sources = [PROMPT_TEMPLATE.format(prompt=ex) for ex in examples["prompt"]]
    def select_solution(solution_list):
        if solution_idx < len(solution_list):
            return solution_list[solution_idx]
        else:
            return solution_list[0]
    #targets = [f"{ex}\n{EOT_TOKEN}" for ex in examples["canonical_solution"]]
    targets = [f"{select_solution(ex)}\n{EOT_TOKEN}" for ex in examples["llm_solutions"]]

    data_dict = preprocess(sources, targets, tokenizer)
    return data_dict


def read_jsonl_file(file_path):
    with open(file_path, 'r') as file:
        json_objects = list(map(json.loads, file))
    return json_objects


def main():
    parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    #training_datasets = load_dataset(
    #    "ise-uiuc/Magicoder-OSS-Instruct-75K",
    #    split="train"
    #)
    training_datasets = load_dataset(
        "garage-bAInd/Open-Platypus",
        split="train"
    )
    problem_indices = list(range(12))
    num_clusters = len(problem_indices)
    clusters = [[] for _ in range(num_clusters)]
    cluster_sizes = [0] * num_clusters

    # Usage
    #jsonl_file_path = 'influence/coding_queries.jsonl'
    jsonl_file_path = 'influence/text_queries.jsonl'
    test_queries = read_jsonl_file(jsonl_file_path)
    test_queries = [test_queries[idx] for idx in problem_indices]
    test_corpus = [data["query"] + "\n\n" + data["solution"] for data in test_queries]

    tokenized_test_corpus = [doc.split(" ") for doc in test_corpus]
    bm25 = BM25Plus(tokenized_test_corpus)
    split_labels = []
    bm25_data = []

    #for ex in training_datasets:
    #bm25_score_save_name = "influence/coding_bm25_scores.pkl"
    bm25_score_save_name = "influence/text_bm25_scores.pkl"
    import pandas as pd
    if Path(bm25_score_save_name).exists():
        bm25_data = pd.read_pickle(bm25_score_save_name)
    else:
        #for ex in training_datasets.select(range(500)):
        for ex in training_datasets:
            query = ex["input"] + "\n" + ex["instruction"] + "\n" + ex["output"]
            query = query.strip()
            tokenized_query = query.split(" ")
            scores = bm25.get_scores(tokenized_query)
            bm25_data.append(scores)
        bm25_data = pd.DataFrame(bm25_data)
        bm25_data.to_pickle(bm25_score_save_name)

    selected_queries = bm25_data.columns[:len(problem_indices)]
    bm25_data = bm25_data[selected_queries]
    bm25_data = (bm25_data - bm25_data.mean()) / bm25_data.std()
    split_labels = bm25_data.idxmax(axis=1).tolist()
    training_datasets = training_datasets.add_column("split_labels", split_labels)

    print(f"the problem indices is {problem_indices}")
    for i in range(len(problem_indices)):
        tmp = training_datasets.filter(lambda example: example['split_labels']==i)
        print(f"split label {i} has {len(tmp)} examples")

    save_name = f"{len(problem_indices)}normalizedbm25clusters"
    dataset_save_name = f"Open-Platypus-{save_name}"
    with Path(f"{dataset_save_name}.jsonl").open("w") as f:
        for data_point in training_datasets:
            f.write(f'{json.dumps(data_point)}\n')


if __name__ == "__main__":
    main()

