Fix batch transcribe

This commit is contained in:
2025-12-22 17:00:58 +03:00
parent a4071a4bdc
commit 44ee142946
2 changed files with 55 additions and 27 deletions

View File

@@ -4,7 +4,7 @@ authors = [
{ name = "nikto_b", email = "niktob560@yandex.ru" }
]
license = "MIT"
version = "0.1.2"
version = "0.1.3"
description = "ONNX wrapper for a GigaAMASR models"
readme = "README.md"
requires-python = ">=3.10"

View File

@@ -5,6 +5,7 @@ import numpy as np
import onnxruntime as ort
import torch
from pyexpat import features
from torch.nn.utils.rnn import pad_sequence
from .preprocess import FeatureExtractor, load_audio
@@ -68,42 +69,69 @@ class ASRABCModel(abc.ABC):
raise NotImplementedError()
def transcribe_batch(self,
wavs: list[np.ndarray],
all_wavs: list[np.ndarray],
join_batches: list[int] | None = None) -> list[tuple[str, list[int]]]:
input_lengths = []
processed_wavs = []
for wav in wavs:
audio_tensor = load_audio(wav)
processed = self.preprocessor(
audio_tensor.unsqueeze(0),
torch.tensor([audio_tensor.shape[-1]])
)[0]
if len(all_wavs) == 0:
return []
features = None
input_lengths_array = None
full_batch_size = len(all_wavs)
while len(all_wavs) > 0:
wavs = all_wavs[:128]
all_wavs = all_wavs[128:]
if isinstance(processed, torch.Tensor):
processed = processed.cpu().numpy()
input_lengths = []
processed_wavs = []
for wav in wavs:
audio_tensor = load_audio(wav)
processed = self.preprocessor(
audio_tensor.unsqueeze(0),
torch.tensor([audio_tensor.shape[-1]])
)[0]
processed_wavs.append(processed)
input_lengths.append(processed.shape[2])
if isinstance(processed, torch.Tensor):
processed = processed.cpu().numpy()
max_length = max(input_lengths)
batch_size = len(wavs)
features_dim = processed_wavs[0].shape[1]
processed_wavs.append(processed)
input_lengths.append(processed.shape[2])
padded_wavs = np.zeros((batch_size, features_dim, max_length), dtype=DTYPE)
max_length = max(input_lengths)
batch_size = len(wavs)
features_dim = processed_wavs[0].shape[1]
for i, audio in enumerate(processed_wavs):
length = audio.shape[2]
padded_wavs[i, :, :length] = audio
padded_wavs = np.zeros((batch_size, features_dim, max_length), dtype=DTYPE)
input_lengths_array = np.array(input_lengths, dtype=np.int64)
for i, audio in enumerate(processed_wavs):
length = audio.shape[2]
padded_wavs[i, :, :length] = audio
features = self._transcribe_encode_batch(
padded_wavs,
input_lengths_array
)
it_input_lengths_array = np.array(input_lengths, dtype=np.int64)
it_features = self._transcribe_encode_batch(
padded_wavs,
it_input_lengths_array
)
if features is None:
features = it_features
else:
if it_features.shape[2] > features.shape[2]:
features = np.concatenate([features, np.zeros(
(features.shape[0], features.shape[1], it_features.shape[2] - features.shape[2]))], -1,
dtype=DTYPE)
elif it_features.shape[2] < features.shape[2]:
it_features = np.concatenate([it_features, np.zeros(
(it_features.shape[0], it_features.shape[1], features.shape[2] - it_features.shape[2]))], -1,
dtype=DTYPE)
features = np.concatenate([features, it_features], 0)
if input_lengths_array is None:
input_lengths_array = it_input_lengths_array
else:
input_lengths_array = np.concatenate([input_lengths_array, it_input_lengths_array], 0)
if join_batches is None:
batch_token_ids = [self._transcribe_decode(features[i]) for i in range(batch_size)]
batch_token_ids = [self._transcribe_decode(features[i]) for i in range(full_batch_size)]
return [self.tokenizer.decode(ids) for ids in batch_token_ids]
else:
ret = []