Initial commit
This commit is contained in:
70
src/gigaam_onnx/rnnt.py
Normal file
70
src/gigaam_onnx/rnnt.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
from .asr_abc import ASRABCModel, DTYPE, MAX_LETTERS_PER_FRAME
|
||||
from .decoding import RNNTHead, Tokenizer
|
||||
from .preprocess import FeatureExtractor
|
||||
|
||||
|
||||
class RNNTASR(ASRABCModel):
|
||||
head: RNNTHead
|
||||
predictor: ort.InferenceSession
|
||||
jointer: ort.InferenceSession
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
preprocessor: FeatureExtractor,
|
||||
tokenizer: Tokenizer,
|
||||
head: RNNTHead,
|
||||
encoder: ort.InferenceSession,
|
||||
predictor: ort.InferenceSession,
|
||||
jointer: ort.InferenceSession,
|
||||
):
|
||||
self.head = head
|
||||
self.predictor = predictor
|
||||
self.jointer = jointer
|
||||
super().__init__(encoder, preprocessor, tokenizer)
|
||||
|
||||
def _transcribe_decode(self, features) -> tuple[list[int], list[int]]:
|
||||
token_ids = []
|
||||
timings = []
|
||||
prev_token = self.blank_idx
|
||||
|
||||
pred_states = [
|
||||
np.zeros(shape=(1, 1, self.head.decoder.pred_hidden), dtype=DTYPE),
|
||||
np.zeros(shape=(1, 1, self.head.decoder.pred_hidden), dtype=DTYPE),
|
||||
]
|
||||
for j in range(features.shape[-1]):
|
||||
emitted_letters = 0
|
||||
while emitted_letters < MAX_LETTERS_PER_FRAME:
|
||||
pred_inputs = {
|
||||
node.name: data
|
||||
for (node, data) in zip(
|
||||
self.predictor.get_inputs(), [np.array([[prev_token]])] + pred_states
|
||||
)
|
||||
}
|
||||
pred_outputs = self.predictor.run(
|
||||
[node.name for node in self.predictor.get_outputs()], pred_inputs
|
||||
)
|
||||
|
||||
joint_inputs = {
|
||||
node.name: data
|
||||
for node, data in zip(
|
||||
self.jointer.get_inputs(),
|
||||
[features[:, :, [j]], pred_outputs[0].swapaxes(1, 2)],
|
||||
)
|
||||
}
|
||||
log_probs = self.jointer.run(
|
||||
[node.name for node in self.jointer.get_outputs()], joint_inputs
|
||||
)
|
||||
token = log_probs[0].argmax(-1)[0][0]
|
||||
|
||||
if token != self.blank_idx:
|
||||
prev_token = int(token)
|
||||
pred_states = pred_outputs[1:]
|
||||
token_ids.append(int(token))
|
||||
timings.append(j)
|
||||
emitted_letters += 1
|
||||
else:
|
||||
break
|
||||
return token_ids, timings
|
||||
Reference in New Issue
Block a user