Fix batch transcribe

This commit is contained in:
2025-12-22 17:00:58 +03:00
parent a4071a4bdc
commit 44ee142946
2 changed files with 55 additions and 27 deletions

View File

@@ -4,7 +4,7 @@ authors = [
{ name = "nikto_b", email = "niktob560@yandex.ru" } { name = "nikto_b", email = "niktob560@yandex.ru" }
] ]
license = "MIT" license = "MIT"
version = "0.1.2" version = "0.1.3"
description = "ONNX wrapper for a GigaAMASR models" description = "ONNX wrapper for a GigaAMASR models"
readme = "README.md" readme = "README.md"
requires-python = ">=3.10" requires-python = ">=3.10"

View File

@@ -5,6 +5,7 @@ import numpy as np
import onnxruntime as ort import onnxruntime as ort
import torch import torch
from pyexpat import features
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from .preprocess import FeatureExtractor, load_audio from .preprocess import FeatureExtractor, load_audio
@@ -68,42 +69,69 @@ class ASRABCModel(abc.ABC):
raise NotImplementedError() raise NotImplementedError()
def transcribe_batch(self, def transcribe_batch(self,
wavs: list[np.ndarray], all_wavs: list[np.ndarray],
join_batches: list[int] | None = None) -> list[tuple[str, list[int]]]: join_batches: list[int] | None = None) -> list[tuple[str, list[int]]]:
input_lengths = [] if len(all_wavs) == 0:
processed_wavs = [] return []
for wav in wavs: features = None
audio_tensor = load_audio(wav) input_lengths_array = None
processed = self.preprocessor( full_batch_size = len(all_wavs)
audio_tensor.unsqueeze(0), while len(all_wavs) > 0:
torch.tensor([audio_tensor.shape[-1]]) wavs = all_wavs[:128]
)[0] all_wavs = all_wavs[128:]
if isinstance(processed, torch.Tensor): input_lengths = []
processed = processed.cpu().numpy() processed_wavs = []
for wav in wavs:
audio_tensor = load_audio(wav)
processed = self.preprocessor(
audio_tensor.unsqueeze(0),
torch.tensor([audio_tensor.shape[-1]])
)[0]
processed_wavs.append(processed) if isinstance(processed, torch.Tensor):
input_lengths.append(processed.shape[2]) processed = processed.cpu().numpy()
max_length = max(input_lengths) processed_wavs.append(processed)
batch_size = len(wavs) input_lengths.append(processed.shape[2])
features_dim = processed_wavs[0].shape[1]
padded_wavs = np.zeros((batch_size, features_dim, max_length), dtype=DTYPE) max_length = max(input_lengths)
batch_size = len(wavs)
features_dim = processed_wavs[0].shape[1]
for i, audio in enumerate(processed_wavs): padded_wavs = np.zeros((batch_size, features_dim, max_length), dtype=DTYPE)
length = audio.shape[2]
padded_wavs[i, :, :length] = audio
input_lengths_array = np.array(input_lengths, dtype=np.int64) for i, audio in enumerate(processed_wavs):
length = audio.shape[2]
padded_wavs[i, :, :length] = audio
features = self._transcribe_encode_batch( it_input_lengths_array = np.array(input_lengths, dtype=np.int64)
padded_wavs,
input_lengths_array it_features = self._transcribe_encode_batch(
) padded_wavs,
it_input_lengths_array
)
if features is None:
features = it_features
else:
if it_features.shape[2] > features.shape[2]:
features = np.concatenate([features, np.zeros(
(features.shape[0], features.shape[1], it_features.shape[2] - features.shape[2]))], -1,
dtype=DTYPE)
elif it_features.shape[2] < features.shape[2]:
it_features = np.concatenate([it_features, np.zeros(
(it_features.shape[0], it_features.shape[1], features.shape[2] - it_features.shape[2]))], -1,
dtype=DTYPE)
features = np.concatenate([features, it_features], 0)
if input_lengths_array is None:
input_lengths_array = it_input_lengths_array
else:
input_lengths_array = np.concatenate([input_lengths_array, it_input_lengths_array], 0)
if join_batches is None: if join_batches is None:
batch_token_ids = [self._transcribe_decode(features[i]) for i in range(batch_size)] batch_token_ids = [self._transcribe_decode(features[i]) for i in range(full_batch_size)]
return [self.tokenizer.decode(ids) for ids in batch_token_ids] return [self.tokenizer.decode(ids) for ids in batch_token_ids]
else: else:
ret = [] ret = []