Initial commit
This commit is contained in:
182
src/gigaam_onnx/asr_abc.py
Normal file
182
src/gigaam_onnx/asr_abc.py
Normal file
@@ -0,0 +1,182 @@
|
||||
import abc
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from .preprocess import FeatureExtractor, load_audio
|
||||
from .decoding import Tokenizer
|
||||
|
||||
DTYPE = np.float32
|
||||
MAX_LETTERS_PER_FRAME = 3
|
||||
|
||||
|
||||
class AudioDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
Helper class for creating batched inputs
|
||||
"""
|
||||
|
||||
def __init__(self, lst: list[str | np.ndarray | torch.Tensor]):
|
||||
if len(lst) == 0:
|
||||
raise ValueError("AudioDataset cannot be initialized with an empty list")
|
||||
assert isinstance(
|
||||
lst[0], (str, np.ndarray, torch.Tensor)
|
||||
), f"Unexpected dtype: {type(lst[0])}"
|
||||
self.lst = lst
|
||||
|
||||
def __len__(self):
|
||||
return len(self.lst)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.lst[idx]
|
||||
if isinstance(item, str):
|
||||
wav_tns = load_audio(item)
|
||||
elif isinstance(item, np.ndarray):
|
||||
wav_tns = torch.from_numpy(item)
|
||||
elif isinstance(item, torch.Tensor):
|
||||
wav_tns = item
|
||||
else:
|
||||
raise RuntimeError(f"Unexpected sample type: {type(item)} at idx={idx}")
|
||||
return wav_tns
|
||||
|
||||
@staticmethod
|
||||
def collate(wavs):
|
||||
lengths = torch.tensor([len(wav) for wav in wavs])
|
||||
max_len = lengths.max().item()
|
||||
wav_tns = torch.zeros(len(wavs), max_len, dtype=wavs[0].dtype)
|
||||
for idx, wav in enumerate(wavs):
|
||||
wav_tns[idx, : wav.shape[-1]] = wav.squeeze()
|
||||
return wav_tns, lengths
|
||||
|
||||
|
||||
class ASRABCModel(abc.ABC):
|
||||
encoder: ort.InferenceSession
|
||||
preprocessor: FeatureExtractor
|
||||
tokenizer: Tokenizer
|
||||
blank_idx: int
|
||||
|
||||
def __init__(self, encoder: ort.InferenceSession, preprocessor: FeatureExtractor, tokenizer: Tokenizer):
|
||||
self.encoder = encoder
|
||||
self.preprocessor = preprocessor
|
||||
self.tokenizer = tokenizer
|
||||
self.blank_idx = len(self.tokenizer)
|
||||
|
||||
def transcribe(self, wav: np.ndarray) -> tuple[str, list[int]]:
|
||||
return self.transcribe_batch([wav])[0]
|
||||
|
||||
def _transcribe_encode(self, input_signal: np.ndarray):
|
||||
enc_inputs = {
|
||||
node.name: data
|
||||
for (node, data) in zip(
|
||||
self.encoder.get_inputs(),
|
||||
[input_signal.astype(DTYPE), [input_signal.shape[-1]]],
|
||||
)
|
||||
}
|
||||
enc_features = self.encoder.run(
|
||||
[node.name for node in self.encoder.get_outputs()], enc_inputs
|
||||
)[0]
|
||||
|
||||
return enc_features
|
||||
|
||||
def _transcribe_encode_batch(self,
|
||||
input_signals: np.ndarray,
|
||||
input_lengths: np.ndarray) -> np.ndarray:
|
||||
|
||||
enc_inputs = {
|
||||
node.name: data for (node, data) in zip(
|
||||
self.encoder.get_inputs(),
|
||||
[
|
||||
input_signals.astype(DTYPE),
|
||||
input_lengths
|
||||
]
|
||||
)
|
||||
}
|
||||
|
||||
outputs = self.encoder.run(
|
||||
[node.name for node in self.encoder.get_outputs()],
|
||||
enc_inputs
|
||||
)[0]
|
||||
|
||||
return outputs
|
||||
|
||||
def _transcribe_decode(self, features) -> tuple[list[int], list[int]]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def transcribe_batch(self,
|
||||
wavs: list[np.ndarray],
|
||||
join_batches: list[int] | None = None) -> list[tuple[str, list[int]]]:
|
||||
input_lengths = []
|
||||
processed_wavs = []
|
||||
for wav in wavs:
|
||||
audio_tensor = load_audio(wav)
|
||||
processed = self.preprocessor(
|
||||
audio_tensor.unsqueeze(0),
|
||||
torch.tensor([audio_tensor.shape[-1]])
|
||||
)[0]
|
||||
|
||||
if isinstance(processed, torch.Tensor):
|
||||
processed = processed.cpu().numpy()
|
||||
|
||||
processed_wavs.append(processed)
|
||||
input_lengths.append(processed.shape[2])
|
||||
|
||||
max_length = max(input_lengths)
|
||||
batch_size = len(wavs)
|
||||
features_dim = processed_wavs[0].shape[1]
|
||||
|
||||
padded_wavs = np.zeros((batch_size, features_dim, max_length), dtype=DTYPE)
|
||||
|
||||
for i, audio in enumerate(processed_wavs):
|
||||
length = audio.shape[2]
|
||||
padded_wavs[i, :, :length] = audio
|
||||
|
||||
input_lengths_array = np.array(input_lengths, dtype=np.int64)
|
||||
|
||||
features = self._transcribe_encode_batch(
|
||||
padded_wavs,
|
||||
input_lengths_array
|
||||
)
|
||||
|
||||
if join_batches is None:
|
||||
batch_token_ids = [self._transcribe_decode(features[i]) for i in range(batch_size)]
|
||||
return [self.tokenizer.decode(ids) for ids in batch_token_ids]
|
||||
else:
|
||||
ret = []
|
||||
start_idx = 0
|
||||
|
||||
for batch_len in join_batches:
|
||||
end_idx = start_idx + batch_len
|
||||
batch_features_list = []
|
||||
batch_lengths = input_lengths_array[start_idx:end_idx]
|
||||
|
||||
for i in range(batch_len):
|
||||
idx = start_idx + i
|
||||
real_length = batch_lengths[i]
|
||||
real_features = features[idx, :, :real_length]
|
||||
batch_features_list.append(real_features)
|
||||
|
||||
concatenated_features = []
|
||||
total_time = 0
|
||||
|
||||
for i, real_features in enumerate(batch_features_list):
|
||||
concatenated_features.append(real_features)
|
||||
total_time += real_features.shape[1]
|
||||
|
||||
seg_features_2d = np.concatenate(concatenated_features, axis=1)
|
||||
|
||||
seg_features = seg_features_2d[np.newaxis, :, :]
|
||||
|
||||
token_ids, token_timings = self._transcribe_decode(seg_features)
|
||||
|
||||
rate = (sum(list(map(len, wavs[start_idx:end_idx]))) / 16000) / max(token_timings)
|
||||
|
||||
result_text, out_timings = self.tokenizer.decode((token_ids, token_timings))
|
||||
norm_out_timings = list(map(lambda x: x * rate, out_timings))
|
||||
ret.append((result_text, norm_out_timings))
|
||||
|
||||
start_idx = end_idx
|
||||
|
||||
return ret
|
||||
Reference in New Issue
Block a user