Fix batch transcribe
This commit is contained in:
@@ -4,7 +4,7 @@ authors = [
|
||||
{ name = "nikto_b", email = "niktob560@yandex.ru" }
|
||||
]
|
||||
license = "MIT"
|
||||
version = "0.1.2"
|
||||
version = "0.1.3"
|
||||
description = "ONNX wrapper for a GigaAMASR models"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
||||
@@ -5,6 +5,7 @@ import numpy as np
|
||||
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
from pyexpat import features
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from .preprocess import FeatureExtractor, load_audio
|
||||
@@ -68,8 +69,17 @@ class ASRABCModel(abc.ABC):
|
||||
raise NotImplementedError()
|
||||
|
||||
def transcribe_batch(self,
|
||||
wavs: list[np.ndarray],
|
||||
all_wavs: list[np.ndarray],
|
||||
join_batches: list[int] | None = None) -> list[tuple[str, list[int]]]:
|
||||
if len(all_wavs) == 0:
|
||||
return []
|
||||
features = None
|
||||
input_lengths_array = None
|
||||
full_batch_size = len(all_wavs)
|
||||
while len(all_wavs) > 0:
|
||||
wavs = all_wavs[:128]
|
||||
all_wavs = all_wavs[128:]
|
||||
|
||||
input_lengths = []
|
||||
processed_wavs = []
|
||||
for wav in wavs:
|
||||
@@ -95,15 +105,33 @@ class ASRABCModel(abc.ABC):
|
||||
length = audio.shape[2]
|
||||
padded_wavs[i, :, :length] = audio
|
||||
|
||||
input_lengths_array = np.array(input_lengths, dtype=np.int64)
|
||||
it_input_lengths_array = np.array(input_lengths, dtype=np.int64)
|
||||
|
||||
features = self._transcribe_encode_batch(
|
||||
it_features = self._transcribe_encode_batch(
|
||||
padded_wavs,
|
||||
input_lengths_array
|
||||
it_input_lengths_array
|
||||
)
|
||||
if features is None:
|
||||
features = it_features
|
||||
else:
|
||||
|
||||
if it_features.shape[2] > features.shape[2]:
|
||||
features = np.concatenate([features, np.zeros(
|
||||
(features.shape[0], features.shape[1], it_features.shape[2] - features.shape[2]))], -1,
|
||||
dtype=DTYPE)
|
||||
elif it_features.shape[2] < features.shape[2]:
|
||||
it_features = np.concatenate([it_features, np.zeros(
|
||||
(it_features.shape[0], it_features.shape[1], features.shape[2] - it_features.shape[2]))], -1,
|
||||
dtype=DTYPE)
|
||||
|
||||
features = np.concatenate([features, it_features], 0)
|
||||
if input_lengths_array is None:
|
||||
input_lengths_array = it_input_lengths_array
|
||||
else:
|
||||
input_lengths_array = np.concatenate([input_lengths_array, it_input_lengths_array], 0)
|
||||
|
||||
if join_batches is None:
|
||||
batch_token_ids = [self._transcribe_decode(features[i]) for i in range(batch_size)]
|
||||
batch_token_ids = [self._transcribe_decode(features[i]) for i in range(full_batch_size)]
|
||||
return [self.tokenizer.decode(ids) for ids in batch_token_ids]
|
||||
else:
|
||||
ret = []
|
||||
|
||||
Reference in New Issue
Block a user