From 44ee14294602513e6bc942e14cdcd6d91c50ce5c Mon Sep 17 00:00:00 2001 From: nikto_b Date: Mon, 22 Dec 2025 17:00:58 +0300 Subject: [PATCH] Fix batch transcribe --- pyproject.toml | 2 +- src/gigaam_onnx/asr_abc.py | 80 +++++++++++++++++++++++++------------- 2 files changed, 55 insertions(+), 27 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 458c578..aa6dc21 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ authors = [ { name = "nikto_b", email = "niktob560@yandex.ru" } ] license = "MIT" -version = "0.1.2" +version = "0.1.3" description = "ONNX wrapper for a GigaAMASR models" readme = "README.md" requires-python = ">=3.10" diff --git a/src/gigaam_onnx/asr_abc.py b/src/gigaam_onnx/asr_abc.py index 9aff6b8..9281a87 100644 --- a/src/gigaam_onnx/asr_abc.py +++ b/src/gigaam_onnx/asr_abc.py @@ -5,6 +5,7 @@ import numpy as np import onnxruntime as ort import torch +from pyexpat import features from torch.nn.utils.rnn import pad_sequence from .preprocess import FeatureExtractor, load_audio @@ -68,42 +69,69 @@ class ASRABCModel(abc.ABC): raise NotImplementedError() def transcribe_batch(self, - wavs: list[np.ndarray], + all_wavs: list[np.ndarray], join_batches: list[int] | None = None) -> list[tuple[str, list[int]]]: - input_lengths = [] - processed_wavs = [] - for wav in wavs: - audio_tensor = load_audio(wav) - processed = self.preprocessor( - audio_tensor.unsqueeze(0), - torch.tensor([audio_tensor.shape[-1]]) - )[0] + if len(all_wavs) == 0: + return [] + features = None + input_lengths_array = None + full_batch_size = len(all_wavs) + while len(all_wavs) > 0: + wavs = all_wavs[:128] + all_wavs = all_wavs[128:] - if isinstance(processed, torch.Tensor): - processed = processed.cpu().numpy() + input_lengths = [] + processed_wavs = [] + for wav in wavs: + audio_tensor = load_audio(wav) + processed = self.preprocessor( + audio_tensor.unsqueeze(0), + torch.tensor([audio_tensor.shape[-1]]) + )[0] - processed_wavs.append(processed) - input_lengths.append(processed.shape[2]) + if isinstance(processed, torch.Tensor): + processed = processed.cpu().numpy() - max_length = max(input_lengths) - batch_size = len(wavs) - features_dim = processed_wavs[0].shape[1] + processed_wavs.append(processed) + input_lengths.append(processed.shape[2]) - padded_wavs = np.zeros((batch_size, features_dim, max_length), dtype=DTYPE) + max_length = max(input_lengths) + batch_size = len(wavs) + features_dim = processed_wavs[0].shape[1] - for i, audio in enumerate(processed_wavs): - length = audio.shape[2] - padded_wavs[i, :, :length] = audio + padded_wavs = np.zeros((batch_size, features_dim, max_length), dtype=DTYPE) - input_lengths_array = np.array(input_lengths, dtype=np.int64) + for i, audio in enumerate(processed_wavs): + length = audio.shape[2] + padded_wavs[i, :, :length] = audio - features = self._transcribe_encode_batch( - padded_wavs, - input_lengths_array - ) + it_input_lengths_array = np.array(input_lengths, dtype=np.int64) + + it_features = self._transcribe_encode_batch( + padded_wavs, + it_input_lengths_array + ) + if features is None: + features = it_features + else: + + if it_features.shape[2] > features.shape[2]: + features = np.concatenate([features, np.zeros( + (features.shape[0], features.shape[1], it_features.shape[2] - features.shape[2]))], -1, + dtype=DTYPE) + elif it_features.shape[2] < features.shape[2]: + it_features = np.concatenate([it_features, np.zeros( + (it_features.shape[0], it_features.shape[1], features.shape[2] - it_features.shape[2]))], -1, + dtype=DTYPE) + + features = np.concatenate([features, it_features], 0) + if input_lengths_array is None: + input_lengths_array = it_input_lengths_array + else: + input_lengths_array = np.concatenate([input_lengths_array, it_input_lengths_array], 0) if join_batches is None: - batch_token_ids = [self._transcribe_decode(features[i]) for i in range(batch_size)] + batch_token_ids = [self._transcribe_decode(features[i]) for i in range(full_batch_size)] return [self.tokenizer.decode(ids) for ids in batch_token_ids] else: ret = []