Files
GigaAM-onnx/src/gigaam_onnx/asr_abc.py
2025-12-03 16:44:29 +03:00

183 lines
5.9 KiB
Python

import abc
import math
import numpy as np
import onnxruntime as ort
import torch
from torch.nn.utils.rnn import pad_sequence
from .preprocess import FeatureExtractor, load_audio
from .decoding import Tokenizer
DTYPE = np.float32
MAX_LETTERS_PER_FRAME = 3
class AudioDataset(torch.utils.data.Dataset):
"""
Helper class for creating batched inputs
"""
def __init__(self, lst: list[str | np.ndarray | torch.Tensor]):
if len(lst) == 0:
raise ValueError("AudioDataset cannot be initialized with an empty list")
assert isinstance(
lst[0], (str, np.ndarray, torch.Tensor)
), f"Unexpected dtype: {type(lst[0])}"
self.lst = lst
def __len__(self):
return len(self.lst)
def __getitem__(self, idx):
item = self.lst[idx]
if isinstance(item, str):
wav_tns = load_audio(item)
elif isinstance(item, np.ndarray):
wav_tns = torch.from_numpy(item)
elif isinstance(item, torch.Tensor):
wav_tns = item
else:
raise RuntimeError(f"Unexpected sample type: {type(item)} at idx={idx}")
return wav_tns
@staticmethod
def collate(wavs):
lengths = torch.tensor([len(wav) for wav in wavs])
max_len = lengths.max().item()
wav_tns = torch.zeros(len(wavs), max_len, dtype=wavs[0].dtype)
for idx, wav in enumerate(wavs):
wav_tns[idx, : wav.shape[-1]] = wav.squeeze()
return wav_tns, lengths
class ASRABCModel(abc.ABC):
encoder: ort.InferenceSession
preprocessor: FeatureExtractor
tokenizer: Tokenizer
blank_idx: int
def __init__(self, encoder: ort.InferenceSession, preprocessor: FeatureExtractor, tokenizer: Tokenizer):
self.encoder = encoder
self.preprocessor = preprocessor
self.tokenizer = tokenizer
self.blank_idx = len(self.tokenizer)
def transcribe(self, wav: np.ndarray) -> tuple[str, list[int]]:
return self.transcribe_batch([wav])[0]
def _transcribe_encode(self, input_signal: np.ndarray):
enc_inputs = {
node.name: data
for (node, data) in zip(
self.encoder.get_inputs(),
[input_signal.astype(DTYPE), [input_signal.shape[-1]]],
)
}
enc_features = self.encoder.run(
[node.name for node in self.encoder.get_outputs()], enc_inputs
)[0]
return enc_features
def _transcribe_encode_batch(self,
input_signals: np.ndarray,
input_lengths: np.ndarray) -> np.ndarray:
enc_inputs = {
node.name: data for (node, data) in zip(
self.encoder.get_inputs(),
[
input_signals.astype(DTYPE),
input_lengths
]
)
}
outputs = self.encoder.run(
[node.name for node in self.encoder.get_outputs()],
enc_inputs
)[0]
return outputs
def _transcribe_decode(self, features) -> tuple[list[int], list[int]]:
raise NotImplementedError()
def transcribe_batch(self,
wavs: list[np.ndarray],
join_batches: list[int] | None = None) -> list[tuple[str, list[int]]]:
input_lengths = []
processed_wavs = []
for wav in wavs:
audio_tensor = load_audio(wav)
processed = self.preprocessor(
audio_tensor.unsqueeze(0),
torch.tensor([audio_tensor.shape[-1]])
)[0]
if isinstance(processed, torch.Tensor):
processed = processed.cpu().numpy()
processed_wavs.append(processed)
input_lengths.append(processed.shape[2])
max_length = max(input_lengths)
batch_size = len(wavs)
features_dim = processed_wavs[0].shape[1]
padded_wavs = np.zeros((batch_size, features_dim, max_length), dtype=DTYPE)
for i, audio in enumerate(processed_wavs):
length = audio.shape[2]
padded_wavs[i, :, :length] = audio
input_lengths_array = np.array(input_lengths, dtype=np.int64)
features = self._transcribe_encode_batch(
padded_wavs,
input_lengths_array
)
if join_batches is None:
batch_token_ids = [self._transcribe_decode(features[i]) for i in range(batch_size)]
return [self.tokenizer.decode(ids) for ids in batch_token_ids]
else:
ret = []
start_idx = 0
for batch_len in join_batches:
end_idx = start_idx + batch_len
batch_features_list = []
batch_lengths = input_lengths_array[start_idx:end_idx]
for i in range(batch_len):
idx = start_idx + i
real_length = batch_lengths[i]
real_features = features[idx, :, :real_length]
batch_features_list.append(real_features)
concatenated_features = []
total_time = 0
for i, real_features in enumerate(batch_features_list):
concatenated_features.append(real_features)
total_time += real_features.shape[1]
seg_features_2d = np.concatenate(concatenated_features, axis=1)
seg_features = seg_features_2d[np.newaxis, :, :]
token_ids, token_timings = self._transcribe_decode(seg_features)
rate = (sum(list(map(len, wavs[start_idx:end_idx]))) / 16000) / max(token_timings)
result_text, out_timings = self.tokenizer.decode((token_ids, token_timings))
norm_out_timings = list(map(lambda x: x * rate, out_timings))
ret.append((result_text, norm_out_timings))
start_idx = end_idx
return ret