183 lines
5.9 KiB
Python
183 lines
5.9 KiB
Python
import abc
|
|
import math
|
|
|
|
import numpy as np
|
|
|
|
import onnxruntime as ort
|
|
import torch
|
|
from torch.nn.utils.rnn import pad_sequence
|
|
|
|
from .preprocess import FeatureExtractor, load_audio
|
|
from .decoding import Tokenizer
|
|
|
|
DTYPE = np.float32
|
|
MAX_LETTERS_PER_FRAME = 3
|
|
|
|
|
|
class AudioDataset(torch.utils.data.Dataset):
|
|
"""
|
|
Helper class for creating batched inputs
|
|
"""
|
|
|
|
def __init__(self, lst: list[str | np.ndarray | torch.Tensor]):
|
|
if len(lst) == 0:
|
|
raise ValueError("AudioDataset cannot be initialized with an empty list")
|
|
assert isinstance(
|
|
lst[0], (str, np.ndarray, torch.Tensor)
|
|
), f"Unexpected dtype: {type(lst[0])}"
|
|
self.lst = lst
|
|
|
|
def __len__(self):
|
|
return len(self.lst)
|
|
|
|
def __getitem__(self, idx):
|
|
item = self.lst[idx]
|
|
if isinstance(item, str):
|
|
wav_tns = load_audio(item)
|
|
elif isinstance(item, np.ndarray):
|
|
wav_tns = torch.from_numpy(item)
|
|
elif isinstance(item, torch.Tensor):
|
|
wav_tns = item
|
|
else:
|
|
raise RuntimeError(f"Unexpected sample type: {type(item)} at idx={idx}")
|
|
return wav_tns
|
|
|
|
@staticmethod
|
|
def collate(wavs):
|
|
lengths = torch.tensor([len(wav) for wav in wavs])
|
|
max_len = lengths.max().item()
|
|
wav_tns = torch.zeros(len(wavs), max_len, dtype=wavs[0].dtype)
|
|
for idx, wav in enumerate(wavs):
|
|
wav_tns[idx, : wav.shape[-1]] = wav.squeeze()
|
|
return wav_tns, lengths
|
|
|
|
|
|
class ASRABCModel(abc.ABC):
|
|
encoder: ort.InferenceSession
|
|
preprocessor: FeatureExtractor
|
|
tokenizer: Tokenizer
|
|
blank_idx: int
|
|
|
|
def __init__(self, encoder: ort.InferenceSession, preprocessor: FeatureExtractor, tokenizer: Tokenizer):
|
|
self.encoder = encoder
|
|
self.preprocessor = preprocessor
|
|
self.tokenizer = tokenizer
|
|
self.blank_idx = len(self.tokenizer)
|
|
|
|
def transcribe(self, wav: np.ndarray) -> tuple[str, list[int]]:
|
|
return self.transcribe_batch([wav])[0]
|
|
|
|
def _transcribe_encode(self, input_signal: np.ndarray):
|
|
enc_inputs = {
|
|
node.name: data
|
|
for (node, data) in zip(
|
|
self.encoder.get_inputs(),
|
|
[input_signal.astype(DTYPE), [input_signal.shape[-1]]],
|
|
)
|
|
}
|
|
enc_features = self.encoder.run(
|
|
[node.name for node in self.encoder.get_outputs()], enc_inputs
|
|
)[0]
|
|
|
|
return enc_features
|
|
|
|
def _transcribe_encode_batch(self,
|
|
input_signals: np.ndarray,
|
|
input_lengths: np.ndarray) -> np.ndarray:
|
|
|
|
enc_inputs = {
|
|
node.name: data for (node, data) in zip(
|
|
self.encoder.get_inputs(),
|
|
[
|
|
input_signals.astype(DTYPE),
|
|
input_lengths
|
|
]
|
|
)
|
|
}
|
|
|
|
outputs = self.encoder.run(
|
|
[node.name for node in self.encoder.get_outputs()],
|
|
enc_inputs
|
|
)[0]
|
|
|
|
return outputs
|
|
|
|
def _transcribe_decode(self, features) -> tuple[list[int], list[int]]:
|
|
raise NotImplementedError()
|
|
|
|
def transcribe_batch(self,
|
|
wavs: list[np.ndarray],
|
|
join_batches: list[int] | None = None) -> list[tuple[str, list[int]]]:
|
|
input_lengths = []
|
|
processed_wavs = []
|
|
for wav in wavs:
|
|
audio_tensor = load_audio(wav)
|
|
processed = self.preprocessor(
|
|
audio_tensor.unsqueeze(0),
|
|
torch.tensor([audio_tensor.shape[-1]])
|
|
)[0]
|
|
|
|
if isinstance(processed, torch.Tensor):
|
|
processed = processed.cpu().numpy()
|
|
|
|
processed_wavs.append(processed)
|
|
input_lengths.append(processed.shape[2])
|
|
|
|
max_length = max(input_lengths)
|
|
batch_size = len(wavs)
|
|
features_dim = processed_wavs[0].shape[1]
|
|
|
|
padded_wavs = np.zeros((batch_size, features_dim, max_length), dtype=DTYPE)
|
|
|
|
for i, audio in enumerate(processed_wavs):
|
|
length = audio.shape[2]
|
|
padded_wavs[i, :, :length] = audio
|
|
|
|
input_lengths_array = np.array(input_lengths, dtype=np.int64)
|
|
|
|
features = self._transcribe_encode_batch(
|
|
padded_wavs,
|
|
input_lengths_array
|
|
)
|
|
|
|
if join_batches is None:
|
|
batch_token_ids = [self._transcribe_decode(features[i]) for i in range(batch_size)]
|
|
return [self.tokenizer.decode(ids) for ids in batch_token_ids]
|
|
else:
|
|
ret = []
|
|
start_idx = 0
|
|
|
|
for batch_len in join_batches:
|
|
end_idx = start_idx + batch_len
|
|
batch_features_list = []
|
|
batch_lengths = input_lengths_array[start_idx:end_idx]
|
|
|
|
for i in range(batch_len):
|
|
idx = start_idx + i
|
|
real_length = batch_lengths[i]
|
|
real_features = features[idx, :, :real_length]
|
|
batch_features_list.append(real_features)
|
|
|
|
concatenated_features = []
|
|
total_time = 0
|
|
|
|
for i, real_features in enumerate(batch_features_list):
|
|
concatenated_features.append(real_features)
|
|
total_time += real_features.shape[1]
|
|
|
|
seg_features_2d = np.concatenate(concatenated_features, axis=1)
|
|
|
|
seg_features = seg_features_2d[np.newaxis, :, :]
|
|
|
|
token_ids, token_timings = self._transcribe_decode(seg_features)
|
|
|
|
rate = (sum(list(map(len, wavs[start_idx:end_idx]))) / 16000) / max(token_timings)
|
|
|
|
result_text, out_timings = self.tokenizer.decode((token_ids, token_timings))
|
|
norm_out_timings = list(map(lambda x: x * rate, out_timings))
|
|
ret.append((result_text, norm_out_timings))
|
|
|
|
start_idx = end_idx
|
|
|
|
return ret
|