import glob
import random
import os

import torch
import librosa
from torch.utils.data import Dataset
import torch.nn.functional as F

import tools.Whisper as Whisper
from tqdm import tqdm


def get_all_file_paths(directory):
    file_paths = []
    for root, dirs, files in os.walk(directory):
        for file in files:
            file_paths.append(os.path.join(root, file))
    return file_paths

def load_commonvoice_list(path, split="train"):
    file_list_path = os.path.join(path, f"{split}.txt")
    lines = []
    line_count = {'train': 21101675, 'test': 8000, 'dev': 2000}
    line_count = line_count[split]
    with open(file_list_path, "r") as f:
        for line in tqdm(f, total=line_count, leave=False):
            lines.append(line.strip())
        # lines = f.readlines()
    # lines = [l.strip() for l in lines]
    return lines


class NSynthDataset(Dataset):
    """Dataset to load NSynth data."""

    def __init__(self, audio_dir, sample_rate=16000, split="train"):
        super().__init__()
        self.audio_dir = audio_dir
        self.filenames = load_commonvoice_list(audio_dir, split=split)
        print(len(self.filenames))
        self.sr = sample_rate
        self.max_len = sample_rate * 2

    def __len__(self):
        return len(self.filenames)

    def _load_audio(self, filename):
        filename = os.path.join(self.audio_dir, filename)
        audio, sr = librosa.load(filename, sr=self.sr)
        audio = torch.tensor(audio).flatten().unsqueeze(0)
        if audio.shape[1] > sr * 30:
            audio = audio[:, :sr * 30]
        return audio

    def _clip_audio(self, audio):
        if audio.shape[1] > self.max_len:
            st = random.randint(0, audio.shape[1] - self.max_len - 1)
            ed = st + self.max_len
            return audio[:, st:ed]
        else:
            ans = torch.zeros(1, self.max_len)
            ans[:, :audio.shape[1]] = audio
            return ans

    def __getitem__(self, index):
        ans = torch.zeros(1, self.max_len)
        audio = self._load_audio(self.filenames[index])
        if audio.shape[1] > self.max_len:
            st = random.randint(0, audio.shape[1] - self.max_len - 1)
            ed = st + self.max_len
            ans = audio[:, st:ed]
        else:
            ans[:, :audio.shape[1]] = audio
            st = 0
        ans_teacher = F.pad(ans, (50, 50), 'constant', 0).squeeze(0)
        return ans, ans_teacher
        # deviate = st % 320
        # full_audio = audio.squeeze(0)[deviate:]
        # full_audio = Whisper.pad_or_trim(full_audio)
        # st = st // 320
        # length = self.max_len // 320
        # mel = Whisper.log_mel_spectrogram(full_audio, n_mels=128)
        # return ans, mel, st, length