71 lines
2.4 KiB
Python
71 lines
2.4 KiB
Python
import numpy as np
|
|
import onnxruntime as ort
|
|
|
|
from .asr_abc import ASRABCModel, DTYPE, MAX_LETTERS_PER_FRAME
|
|
from .decoding import RNNTHead, Tokenizer
|
|
from .preprocess import FeatureExtractor
|
|
|
|
|
|
class RNNTASR(ASRABCModel):
|
|
head: RNNTHead
|
|
predictor: ort.InferenceSession
|
|
jointer: ort.InferenceSession
|
|
|
|
def __init__(
|
|
self,
|
|
preprocessor: FeatureExtractor,
|
|
tokenizer: Tokenizer,
|
|
head: RNNTHead,
|
|
encoder: ort.InferenceSession,
|
|
predictor: ort.InferenceSession,
|
|
jointer: ort.InferenceSession,
|
|
):
|
|
self.head = head
|
|
self.predictor = predictor
|
|
self.jointer = jointer
|
|
super().__init__(encoder, preprocessor, tokenizer)
|
|
|
|
def _transcribe_decode(self, features) -> tuple[list[int], list[int]]:
|
|
token_ids = []
|
|
timings = []
|
|
prev_token = self.blank_idx
|
|
|
|
pred_states = [
|
|
np.zeros(shape=(1, 1, self.head.decoder.pred_hidden), dtype=DTYPE),
|
|
np.zeros(shape=(1, 1, self.head.decoder.pred_hidden), dtype=DTYPE),
|
|
]
|
|
for j in range(features.shape[-1]):
|
|
emitted_letters = 0
|
|
while emitted_letters < MAX_LETTERS_PER_FRAME:
|
|
pred_inputs = {
|
|
node.name: data
|
|
for (node, data) in zip(
|
|
self.predictor.get_inputs(), [np.array([[prev_token]])] + pred_states
|
|
)
|
|
}
|
|
pred_outputs = self.predictor.run(
|
|
[node.name for node in self.predictor.get_outputs()], pred_inputs
|
|
)
|
|
|
|
joint_inputs = {
|
|
node.name: data
|
|
for node, data in zip(
|
|
self.jointer.get_inputs(),
|
|
[features[:, :, [j]], pred_outputs[0].swapaxes(1, 2)],
|
|
)
|
|
}
|
|
log_probs = self.jointer.run(
|
|
[node.name for node in self.jointer.get_outputs()], joint_inputs
|
|
)
|
|
token = log_probs[0].argmax(-1)[0][0]
|
|
|
|
if token != self.blank_idx:
|
|
prev_token = int(token)
|
|
pred_states = pred_outputs[1:]
|
|
token_ids.append(int(token))
|
|
timings.append(j)
|
|
emitted_letters += 1
|
|
else:
|
|
break
|
|
return token_ids, timings
|