import abc import math import numpy as np import onnxruntime as ort import torch from torch.nn.utils.rnn import pad_sequence from .preprocess import FeatureExtractor, load_audio from .decoding import Tokenizer DTYPE = np.float32 MAX_LETTERS_PER_FRAME = 3 class AudioDataset(torch.utils.data.Dataset): """ Helper class for creating batched inputs """ def __init__(self, lst: list[str | np.ndarray | torch.Tensor]): if len(lst) == 0: raise ValueError("AudioDataset cannot be initialized with an empty list") assert isinstance( lst[0], (str, np.ndarray, torch.Tensor) ), f"Unexpected dtype: {type(lst[0])}" self.lst = lst def __len__(self): return len(self.lst) def __getitem__(self, idx): item = self.lst[idx] if isinstance(item, str): wav_tns = load_audio(item) elif isinstance(item, np.ndarray): wav_tns = torch.from_numpy(item) elif isinstance(item, torch.Tensor): wav_tns = item else: raise RuntimeError(f"Unexpected sample type: {type(item)} at idx={idx}") return wav_tns @staticmethod def collate(wavs): lengths = torch.tensor([len(wav) for wav in wavs]) max_len = lengths.max().item() wav_tns = torch.zeros(len(wavs), max_len, dtype=wavs[0].dtype) for idx, wav in enumerate(wavs): wav_tns[idx, : wav.shape[-1]] = wav.squeeze() return wav_tns, lengths class ASRABCModel(abc.ABC): encoder: ort.InferenceSession preprocessor: FeatureExtractor tokenizer: Tokenizer blank_idx: int def __init__(self, encoder: ort.InferenceSession, preprocessor: FeatureExtractor, tokenizer: Tokenizer): self.encoder = encoder self.preprocessor = preprocessor self.tokenizer = tokenizer self.blank_idx = len(self.tokenizer) def transcribe(self, wav: np.ndarray) -> tuple[str, list[int]]: return self.transcribe_batch([wav])[0] def _transcribe_encode(self, input_signal: np.ndarray): enc_inputs = { node.name: data for (node, data) in zip( self.encoder.get_inputs(), [input_signal.astype(DTYPE), [input_signal.shape[-1]]], ) } enc_features = self.encoder.run( [node.name for node in self.encoder.get_outputs()], enc_inputs )[0] return enc_features def _transcribe_encode_batch(self, input_signals: np.ndarray, input_lengths: np.ndarray) -> np.ndarray: enc_inputs = { node.name: data for (node, data) in zip( self.encoder.get_inputs(), [ input_signals.astype(DTYPE), input_lengths ] ) } outputs = self.encoder.run( [node.name for node in self.encoder.get_outputs()], enc_inputs )[0] return outputs def _transcribe_decode(self, features) -> tuple[list[int], list[int]]: raise NotImplementedError() def transcribe_batch(self, wavs: list[np.ndarray], join_batches: list[int] | None = None) -> list[tuple[str, list[int]]]: input_lengths = [] processed_wavs = [] for wav in wavs: audio_tensor = load_audio(wav) processed = self.preprocessor( audio_tensor.unsqueeze(0), torch.tensor([audio_tensor.shape[-1]]) )[0] if isinstance(processed, torch.Tensor): processed = processed.cpu().numpy() processed_wavs.append(processed) input_lengths.append(processed.shape[2]) max_length = max(input_lengths) batch_size = len(wavs) features_dim = processed_wavs[0].shape[1] padded_wavs = np.zeros((batch_size, features_dim, max_length), dtype=DTYPE) for i, audio in enumerate(processed_wavs): length = audio.shape[2] padded_wavs[i, :, :length] = audio input_lengths_array = np.array(input_lengths, dtype=np.int64) features = self._transcribe_encode_batch( padded_wavs, input_lengths_array ) if join_batches is None: batch_token_ids = [self._transcribe_decode(features[i]) for i in range(batch_size)] return [self.tokenizer.decode(ids) for ids in batch_token_ids] else: ret = [] start_idx = 0 for batch_len in join_batches: end_idx = start_idx + batch_len batch_features_list = [] batch_lengths = input_lengths_array[start_idx:end_idx] for i in range(batch_len): idx = start_idx + i real_length = batch_lengths[i] real_features = features[idx, :, :real_length] batch_features_list.append(real_features) concatenated_features = [] total_time = 0 for i, real_features in enumerate(batch_features_list): concatenated_features.append(real_features) total_time += real_features.shape[1] seg_features_2d = np.concatenate(concatenated_features, axis=1) seg_features = seg_features_2d[np.newaxis, :, :] token_ids, token_timings = self._transcribe_decode(seg_features) rate = (sum(list(map(len, wavs[start_idx:end_idx]))) / 16000) / max(token_timings) result_text, out_timings = self.tokenizer.decode((token_ids, token_timings)) norm_out_timings = list(map(lambda x: x * rate, out_timings)) ret.append((result_text, norm_out_timings)) start_idx = end_idx return ret