import gcsfs
import zstandard
import ujson as json
import time
import tarfile
import codecs
from functools import reduce
import jsonlines
import io
from zipfile import ZipFile
import gzip
from math import ceil
import mmap
import multiprocessing as mp
try:
   mp.set_start_method('spawn', force=True)
   print("spawned")
except RuntimeError:
   pass

from pathlib import Path

VALID_EXTENSIONS = ['openwebtext.tar.xz', '_data.xz', '.dat.zst', '.jsonl', '.jsonl.zst', '.jsonl.zst.tar', '.json.zst', '.txt', '.zip', '.tar.gz', '.json.gz', '.gz']

def has_valid_extension(file):
    return any([file.endswith(ext) for ext in VALID_EXTENSIONS])


def handle_jsonl(jsonl_reader, get_meta, autojoin_paragraphs, para_joiner, key='text'):
    for ob in jsonl_reader:
        # naive jsonl where each object is just the string itself, with no meta. For legacy compatibility.
        if isinstance(ob, str):
            assert not get_meta
            yield ob
            continue

        text = ob[key]

        if autojoin_paragraphs and isinstance(text, list):
            text = para_joiner.join(text)

        if get_meta:
            yield text, (ob['meta'] if 'meta' in ob else {})
        else:
            yield text


class GCReader:
    def __init__(self, in_path):
        self.in_path = in_path
        self.fs = gcsfs.GCSFileSystem(access='read_only')
        
    
    def stream_data(self, get_meta=False, threaded=False):
        if not threaded:
            yield from self._stream_data(get_meta)
            return
        
        q = mp.Queue(1000)
        p = mp.Process(target=self._stream_data_threaded, args=(q, get_meta))
        p.start()
        while p.is_alive():
            res = q.get()
            if res is None: break
            yield res
    
    def _stream_data_threaded(self, q, get_meta=False):
        for data in self._stream_data(get_meta):
            q.put(data)
        q.put(None)

    def _stream_data(self, get_meta=False, jsonl_key="text"):
        self.f_name = ""
        files = self.listdir_or_file(self.in_path)
        if not files:
            raise FileNotFoundError(f"No valid file(s) found in {self.in_path}")
        for f in files:
            self.f_name = f
            if f.endswith('.jsonl.zst'):
                yield from self.read_jsonl_zst(f, get_meta, key=jsonl_key)
            else:
                # shouldn't be reached
                print(f'Skipping {f} as streaming for that filetype is not implemented')
                
    def read_jsonl_zst(self, file, get_meta=False, autojoin_paragraphs=True, para_joiner='\n\n', key='text'):
        
        with self.fs.open(file, 'rb') as fh:
            cctx = zstandard.ZstdDecompressor()
            reader = io.BufferedReader(cctx.stream_reader(fh))
            rdr = jsonlines.Reader(reader)
            yield from handle_jsonl(rdr, get_meta, autojoin_paragraphs, para_joiner, key)

    def _listdir_or_file(self, x):
        
        if isinstance(x, list):
            return reduce(lambda x, y: x + y, map(self.listdir_or_file, sorted(x)))
        if self.fs.isfile(x):
            return [x]
        elif self.fs.isdir(x):
            return sorted(self.fs.ls(x))
        else:
            raise FileNotFoundError(f"{x} not found")

    def listdir_or_file(self, x):
        return list(filter(has_valid_extension, self._listdir_or_file(x)))



