Remove unused AudioDataset originally fetched from an upstream GigaAM/utils
This commit is contained in:
@@ -14,44 +14,6 @@ DTYPE = np.float32
|
|||||||
MAX_LETTERS_PER_FRAME = 3
|
MAX_LETTERS_PER_FRAME = 3
|
||||||
|
|
||||||
|
|
||||||
class AudioDataset(torch.utils.data.Dataset):
|
|
||||||
"""
|
|
||||||
Helper class for creating batched inputs
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, lst: list[str | np.ndarray | torch.Tensor]):
|
|
||||||
if len(lst) == 0:
|
|
||||||
raise ValueError("AudioDataset cannot be initialized with an empty list")
|
|
||||||
assert isinstance(
|
|
||||||
lst[0], (str, np.ndarray, torch.Tensor)
|
|
||||||
), f"Unexpected dtype: {type(lst[0])}"
|
|
||||||
self.lst = lst
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.lst)
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
item = self.lst[idx]
|
|
||||||
if isinstance(item, str):
|
|
||||||
wav_tns = load_audio(item)
|
|
||||||
elif isinstance(item, np.ndarray):
|
|
||||||
wav_tns = torch.from_numpy(item)
|
|
||||||
elif isinstance(item, torch.Tensor):
|
|
||||||
wav_tns = item
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"Unexpected sample type: {type(item)} at idx={idx}")
|
|
||||||
return wav_tns
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def collate(wavs):
|
|
||||||
lengths = torch.tensor([len(wav) for wav in wavs])
|
|
||||||
max_len = lengths.max().item()
|
|
||||||
wav_tns = torch.zeros(len(wavs), max_len, dtype=wavs[0].dtype)
|
|
||||||
for idx, wav in enumerate(wavs):
|
|
||||||
wav_tns[idx, : wav.shape[-1]] = wav.squeeze()
|
|
||||||
return wav_tns, lengths
|
|
||||||
|
|
||||||
|
|
||||||
class ASRABCModel(abc.ABC):
|
class ASRABCModel(abc.ABC):
|
||||||
encoder: ort.InferenceSession
|
encoder: ort.InferenceSession
|
||||||
preprocessor: FeatureExtractor
|
preprocessor: FeatureExtractor
|
||||||
|
|||||||
Reference in New Issue
Block a user