import numpy as np import onnxruntime as ort from .asr_abc import ASRABCModel, DTYPE, MAX_LETTERS_PER_FRAME from .decoding import RNNTHead, Tokenizer from .preprocess import FeatureExtractor class RNNTASR(ASRABCModel): head: RNNTHead predictor: ort.InferenceSession jointer: ort.InferenceSession def __init__( self, preprocessor: FeatureExtractor, tokenizer: Tokenizer, head: RNNTHead, encoder: ort.InferenceSession, predictor: ort.InferenceSession, jointer: ort.InferenceSession, ): self.head = head self.predictor = predictor self.jointer = jointer super().__init__(encoder, preprocessor, tokenizer) def _transcribe_decode(self, features) -> tuple[list[int], list[int]]: token_ids = [] timings = [] prev_token = self.blank_idx pred_states = [ np.zeros(shape=(1, 1, self.head.decoder.pred_hidden), dtype=DTYPE), np.zeros(shape=(1, 1, self.head.decoder.pred_hidden), dtype=DTYPE), ] for j in range(features.shape[-1]): emitted_letters = 0 while emitted_letters < MAX_LETTERS_PER_FRAME: pred_inputs = { node.name: data for (node, data) in zip( self.predictor.get_inputs(), [np.array([[prev_token]])] + pred_states ) } pred_outputs = self.predictor.run( [node.name for node in self.predictor.get_outputs()], pred_inputs ) joint_inputs = { node.name: data for node, data in zip( self.jointer.get_inputs(), [features[:, :, [j]], pred_outputs[0].swapaxes(1, 2)], ) } log_probs = self.jointer.run( [node.name for node in self.jointer.get_outputs()], joint_inputs ) token = log_probs[0].argmax(-1)[0][0] if token != self.blank_idx: prev_token = int(token) pred_states = pred_outputs[1:] token_ids.append(int(token)) timings.append(j) emitted_letters += 1 else: break return token_ids, timings