Files
GigaAM-onnx/src/gigaam_onnx/rnnt.py
2025-12-03 16:44:29 +03:00

71 lines
2.4 KiB
Python

import numpy as np
import onnxruntime as ort
from .asr_abc import ASRABCModel, DTYPE, MAX_LETTERS_PER_FRAME
from .decoding import RNNTHead, Tokenizer
from .preprocess import FeatureExtractor
class RNNTASR(ASRABCModel):
head: RNNTHead
predictor: ort.InferenceSession
jointer: ort.InferenceSession
def __init__(
self,
preprocessor: FeatureExtractor,
tokenizer: Tokenizer,
head: RNNTHead,
encoder: ort.InferenceSession,
predictor: ort.InferenceSession,
jointer: ort.InferenceSession,
):
self.head = head
self.predictor = predictor
self.jointer = jointer
super().__init__(encoder, preprocessor, tokenizer)
def _transcribe_decode(self, features) -> tuple[list[int], list[int]]:
token_ids = []
timings = []
prev_token = self.blank_idx
pred_states = [
np.zeros(shape=(1, 1, self.head.decoder.pred_hidden), dtype=DTYPE),
np.zeros(shape=(1, 1, self.head.decoder.pred_hidden), dtype=DTYPE),
]
for j in range(features.shape[-1]):
emitted_letters = 0
while emitted_letters < MAX_LETTERS_PER_FRAME:
pred_inputs = {
node.name: data
for (node, data) in zip(
self.predictor.get_inputs(), [np.array([[prev_token]])] + pred_states
)
}
pred_outputs = self.predictor.run(
[node.name for node in self.predictor.get_outputs()], pred_inputs
)
joint_inputs = {
node.name: data
for node, data in zip(
self.jointer.get_inputs(),
[features[:, :, [j]], pred_outputs[0].swapaxes(1, 2)],
)
}
log_probs = self.jointer.run(
[node.name for node in self.jointer.get_outputs()], joint_inputs
)
token = log_probs[0].argmax(-1)[0][0]
if token != self.blank_idx:
prev_token = int(token)
pred_states = pred_outputs[1:]
token_ids.append(int(token))
timings.append(j)
emitted_letters += 1
else:
break
return token_ids, timings