import random
import json
from dataclasses import dataclass
from typing import Optional
import re
import pandas as pd
import numpy as np

all_simple_objects = [
    "pencil",
    "notebook",
    "pen",
    "cup",
    "plate",
    "jug",
    "mug",
    "puzzle",
    "textbook",
    "leash",
    "necklace",
    "bracelet",
    "bottle",
    "ball",
    "envelope",
    "lighter",
    "bowl",
    'apple',
    'pear',
    'banana',
    'orange',
    'steak'
]

transitions = [' First,', '', ' When I go,', ' I think']
subject_starts = [" Today,", " Tonight,", " Tomorrow,", ""]

def generate_items_text(chosen_objects):
  items = []
  for obj in chosen_objects:
    item = "a{} {}".format("n" if obj[0] in "aeiou" else "", obj)
    items.append(item)

  if len(chosen_objects) == 2:
    items_text = " and ".join(items)
  else:
    items_text = ", ".join(items[:-1]) + ", and " + items[-1]
  return items_text

def rand_bool():
  return random.random() > 0.5

def gen_count_scores(true_count, max_number=6):
  scores = {}
  for i in range(max_number+1):
    scores[str(i)] = (1 if i == true_count else 0)
    scores[textify_number(i)] = (1 if i == true_count else 0)
  return scores

@dataclass
class Example:
    start: str
    transition: str
    objects: list
    query_idx: int
    invalid_obj: Optional[str] = None
        
    def copy(self, start, transition, objects, query_idx=0, invalid_obj=None):
        return Example(str(start), str(transition), list(objects), int(query_idx), invalid_obj)
    
    def self_copy(self):
        return self.copy(self.start, self.transition, self.objects, self.query_idx, self.invalid_obj)
        
    def __str__(self):
        query_object = self.objects[self.query_idx]
        items = generate_items_text(self.objects)
        first_items = [a for a in self.objects if a != query_object]
        random.shuffle(first_items)
        first_items = ', the '.join(first_items)
        temp = "{start} when I go to the store, I will buy {items}.{transition} I will get the {first_items}{oxford_comma} and then the"
        
        if self.invalid_obj is not None:
            query_object = self.invalid_obj
            
        return temp.format(start=self.start, items=items, transition=self.transition, first_items=first_items, oxford_comma=',' if len(self.objects) > 2 else '')
    
    def __repr__(self):
        return str(self)
    
    def get_label(self, uppercase=False):
        query_idx = self.query_idx
        label = self.objects[query_idx]
        label = label.lower()
        if uppercase:
            label = label.title()
        return " "+label
    
    def with_label(self, uppercase=False, override_idx=None):
        if override_idx == None and self.invalid_obj != None:
            label= random.choice( list(set(all_simple_colors).difference(self.colors)) )
        else:
            if override_idx != None:
                query_idx = override_index
            else:
                query_idx = self.query_idx
            label = self.colors[query_idx]
        if uppercase:
            label = " "+label.title()
        else:
            label = " "+label.lower()
        return str(self)+label
    
    def set_query(self, idx):
        assert idx < len(objects)
        self.query_idx = idx
        
    def min_pair_obj(self, manual_idx=None, manual_obj=None):
        #pass manual idx or obj but not both
        if manual_obj in self.objects:
            manual_idx = self.objects.index(manual_obj)
        idx = manual_idx
        assert idx !=self.query_idx
        if idx is None:
            avail_idxs = list(range(len(self.objects)))
            print(self.query_idx, avail_idxs, self.objects)
            avail_idxs.remove(self.query_idx)
            idx = random.choice(avail_idxs)
        
        return self.copy(self.start, self.transition, self.objects, idx)
    
    def n_objs(self, n=3):
        exs = []
        for i in range(n):
            if i == self.query_idx:
                continue
            exs.append(self.min_pair_obj(manual_idx=i))
                
        return exs
            
        
    def min_pair_color(self, manual_idx=None, color_choice=None):
        new_colors = list(self.colors)
        idx = manual_idx
        if idx is None:
            idx = random.choice(range(len(self.colors)))
        
        if color_choice is None:
            color_choice = random.choice( list(set(all_simple_colors).difference(new_colors)) )
            
        new_colors[idx] = color_choice
        return self.copy(self.surface, self.ss, self.objects, new_colors, self.query_idx)

    
    def min_pair_invalid(self, obj_choice=None):
        obj = obj_choice
        if obj is None:
            obj = random.choice( list(set(all_simple_objects).difference(self.objects)) )
        
        assert obj not in self.objects and obj is not None
        
        return self.copy(self.surface, self.ss, self.objects, self.colors, self.query_idx, obj)
    
    def n_invalid(self, n):
        exs = []
        objs = list(set(all_simple_objects).difference(self.objects))
        random.shuffle(objs)
        objs = objs[:n]
        for o in objs:
            exs.append(self.min_pair_invalid(o))
        return exs
    
    
    def get_viable_preds(self, uppercase=False):
        if uppercase:
            return [' '+label.title() for label in self.objects]
        return [' '+label for label in self.objects]
    
def parse_example_string(s):
    start = re.search(r"^(.*?)( when I go to the store, I will buy)", s).group(1)
    transition = s.split('.')[1].split(' I will get')[0]
    items = re.search(r"( I will buy )(.*?)(\.)", s).group(2)
    items = re.split(r', a[n]? ', items)
    items[0] = items[0].split("a ")[1]
    second_to_last, last = items[-1].split(", and a ")
    items[-1] = second_to_last
    items.append(last)
    print(start, 'trans:::', transition, 'items:::', items)

    second_list_items = s.split('.')[1].split(' I will get the ')[1].split(', and then the')[0].split(',')
    second_list_items = [item.replace(' the', '' ).strip() for item in second_list_items]
    print(second_list_items,'second list items')
    
    query_idx = [i for i, item in enumerate(items) if item not in second_list_items][0]
    print("query idx:::", query_idx)
    return start, transition, items, query_idx

if __name__ == "__main__":
    """
    ex_string = " Today, when I go to the store, I will buy a steak, a bracelet, a mug, a jug, a lighter, a ball, an apple, a puzzle, a pear, and a bowl. I think I will get the lighter, the apple, the mug, the puzzle, the pear, the bracelet, the jug, the steak, the ball, and then the"
    start, transition, items, query_idx = parse_example_string(ex_string)
    print(start in subject_starts, transition in transitions, all([item in all_simple_objects for item in items]), query_idx)
    ex = Example(start, transition, items, query_idx)
    print(ex)
    """

    n_objs = [12]#list(range(2,21))#[2,3,4,5,6,7,8,10,15,20]
    seeds = [1,2,3,4,5]
    data = []
    for seed in seeds:
        random.seed(seed)
        for _ in range(50):
            chosen_objects = random.sample(all_simple_objects, n_objs[-1]) #sample the max number of objects
            start, trans = random.choice(subject_starts), random.choice(transitions)
            for n_obj in n_objs:
                query_idx = random.choice(range(n_obj))
                ex = Example(start, trans, chosen_objects[:n_obj], query_idx)
                data.append({'text':str(ex), 'n_objs':n_obj, 'query_idx':query_idx, 'objects':ex.objects})
    write_path = f'datasets/laundry_list_250_{n_objs[0]}bjs.json'
    with open(write_path, 'w') as f:
        json.dump(data, f, indent=2)
    print(len(data))
    print(f"Data written to {write_path}")
    