Remove unused AudioDataset originally fetched from an upstream GigaAM/utils
This commit is contained in:
@@ -14,44 +14,6 @@ DTYPE = np.float32
|
||||
MAX_LETTERS_PER_FRAME = 3
|
||||
|
||||
|
||||
class AudioDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
Helper class for creating batched inputs
|
||||
"""
|
||||
|
||||
def __init__(self, lst: list[str | np.ndarray | torch.Tensor]):
|
||||
if len(lst) == 0:
|
||||
raise ValueError("AudioDataset cannot be initialized with an empty list")
|
||||
assert isinstance(
|
||||
lst[0], (str, np.ndarray, torch.Tensor)
|
||||
), f"Unexpected dtype: {type(lst[0])}"
|
||||
self.lst = lst
|
||||
|
||||
def __len__(self):
|
||||
return len(self.lst)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.lst[idx]
|
||||
if isinstance(item, str):
|
||||
wav_tns = load_audio(item)
|
||||
elif isinstance(item, np.ndarray):
|
||||
wav_tns = torch.from_numpy(item)
|
||||
elif isinstance(item, torch.Tensor):
|
||||
wav_tns = item
|
||||
else:
|
||||
raise RuntimeError(f"Unexpected sample type: {type(item)} at idx={idx}")
|
||||
return wav_tns
|
||||
|
||||
@staticmethod
|
||||
def collate(wavs):
|
||||
lengths = torch.tensor([len(wav) for wav in wavs])
|
||||
max_len = lengths.max().item()
|
||||
wav_tns = torch.zeros(len(wavs), max_len, dtype=wavs[0].dtype)
|
||||
for idx, wav in enumerate(wavs):
|
||||
wav_tns[idx, : wav.shape[-1]] = wav.squeeze()
|
||||
return wav_tns, lengths
|
||||
|
||||
|
||||
class ASRABCModel(abc.ABC):
|
||||
encoder: ort.InferenceSession
|
||||
preprocessor: FeatureExtractor
|
||||
|
||||
Reference in New Issue
Block a user