Fix batch transcribe
This commit is contained in:
@@ -4,7 +4,7 @@ authors = [
|
|||||||
{ name = "nikto_b", email = "niktob560@yandex.ru" }
|
{ name = "nikto_b", email = "niktob560@yandex.ru" }
|
||||||
]
|
]
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
version = "0.1.2"
|
version = "0.1.3"
|
||||||
description = "ONNX wrapper for a GigaAMASR models"
|
description = "ONNX wrapper for a GigaAMASR models"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import numpy as np
|
|||||||
|
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
import torch
|
import torch
|
||||||
|
from pyexpat import features
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
from .preprocess import FeatureExtractor, load_audio
|
from .preprocess import FeatureExtractor, load_audio
|
||||||
@@ -68,42 +69,69 @@ class ASRABCModel(abc.ABC):
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def transcribe_batch(self,
|
def transcribe_batch(self,
|
||||||
wavs: list[np.ndarray],
|
all_wavs: list[np.ndarray],
|
||||||
join_batches: list[int] | None = None) -> list[tuple[str, list[int]]]:
|
join_batches: list[int] | None = None) -> list[tuple[str, list[int]]]:
|
||||||
input_lengths = []
|
if len(all_wavs) == 0:
|
||||||
processed_wavs = []
|
return []
|
||||||
for wav in wavs:
|
features = None
|
||||||
audio_tensor = load_audio(wav)
|
input_lengths_array = None
|
||||||
processed = self.preprocessor(
|
full_batch_size = len(all_wavs)
|
||||||
audio_tensor.unsqueeze(0),
|
while len(all_wavs) > 0:
|
||||||
torch.tensor([audio_tensor.shape[-1]])
|
wavs = all_wavs[:128]
|
||||||
)[0]
|
all_wavs = all_wavs[128:]
|
||||||
|
|
||||||
if isinstance(processed, torch.Tensor):
|
input_lengths = []
|
||||||
processed = processed.cpu().numpy()
|
processed_wavs = []
|
||||||
|
for wav in wavs:
|
||||||
|
audio_tensor = load_audio(wav)
|
||||||
|
processed = self.preprocessor(
|
||||||
|
audio_tensor.unsqueeze(0),
|
||||||
|
torch.tensor([audio_tensor.shape[-1]])
|
||||||
|
)[0]
|
||||||
|
|
||||||
processed_wavs.append(processed)
|
if isinstance(processed, torch.Tensor):
|
||||||
input_lengths.append(processed.shape[2])
|
processed = processed.cpu().numpy()
|
||||||
|
|
||||||
max_length = max(input_lengths)
|
processed_wavs.append(processed)
|
||||||
batch_size = len(wavs)
|
input_lengths.append(processed.shape[2])
|
||||||
features_dim = processed_wavs[0].shape[1]
|
|
||||||
|
|
||||||
padded_wavs = np.zeros((batch_size, features_dim, max_length), dtype=DTYPE)
|
max_length = max(input_lengths)
|
||||||
|
batch_size = len(wavs)
|
||||||
|
features_dim = processed_wavs[0].shape[1]
|
||||||
|
|
||||||
for i, audio in enumerate(processed_wavs):
|
padded_wavs = np.zeros((batch_size, features_dim, max_length), dtype=DTYPE)
|
||||||
length = audio.shape[2]
|
|
||||||
padded_wavs[i, :, :length] = audio
|
|
||||||
|
|
||||||
input_lengths_array = np.array(input_lengths, dtype=np.int64)
|
for i, audio in enumerate(processed_wavs):
|
||||||
|
length = audio.shape[2]
|
||||||
|
padded_wavs[i, :, :length] = audio
|
||||||
|
|
||||||
features = self._transcribe_encode_batch(
|
it_input_lengths_array = np.array(input_lengths, dtype=np.int64)
|
||||||
padded_wavs,
|
|
||||||
input_lengths_array
|
it_features = self._transcribe_encode_batch(
|
||||||
)
|
padded_wavs,
|
||||||
|
it_input_lengths_array
|
||||||
|
)
|
||||||
|
if features is None:
|
||||||
|
features = it_features
|
||||||
|
else:
|
||||||
|
|
||||||
|
if it_features.shape[2] > features.shape[2]:
|
||||||
|
features = np.concatenate([features, np.zeros(
|
||||||
|
(features.shape[0], features.shape[1], it_features.shape[2] - features.shape[2]))], -1,
|
||||||
|
dtype=DTYPE)
|
||||||
|
elif it_features.shape[2] < features.shape[2]:
|
||||||
|
it_features = np.concatenate([it_features, np.zeros(
|
||||||
|
(it_features.shape[0], it_features.shape[1], features.shape[2] - it_features.shape[2]))], -1,
|
||||||
|
dtype=DTYPE)
|
||||||
|
|
||||||
|
features = np.concatenate([features, it_features], 0)
|
||||||
|
if input_lengths_array is None:
|
||||||
|
input_lengths_array = it_input_lengths_array
|
||||||
|
else:
|
||||||
|
input_lengths_array = np.concatenate([input_lengths_array, it_input_lengths_array], 0)
|
||||||
|
|
||||||
if join_batches is None:
|
if join_batches is None:
|
||||||
batch_token_ids = [self._transcribe_decode(features[i]) for i in range(batch_size)]
|
batch_token_ids = [self._transcribe_decode(features[i]) for i in range(full_batch_size)]
|
||||||
return [self.tokenizer.decode(ids) for ids in batch_token_ids]
|
return [self.tokenizer.decode(ids) for ids in batch_token_ids]
|
||||||
else:
|
else:
|
||||||
ret = []
|
ret = []
|
||||||
|
|||||||
Reference in New Issue
Block a user