Initial commit

This commit is contained in:
2025-12-03 16:44:29 +03:00
commit 4f7f22e95f
15 changed files with 1693 additions and 0 deletions

View File

182
src/gigaam_onnx/asr_abc.py Normal file
View File

@@ -0,0 +1,182 @@
import abc
import math
import numpy as np
import onnxruntime as ort
import torch
from torch.nn.utils.rnn import pad_sequence
from .preprocess import FeatureExtractor, load_audio
from .decoding import Tokenizer
DTYPE = np.float32
MAX_LETTERS_PER_FRAME = 3
class AudioDataset(torch.utils.data.Dataset):
"""
Helper class for creating batched inputs
"""
def __init__(self, lst: list[str | np.ndarray | torch.Tensor]):
if len(lst) == 0:
raise ValueError("AudioDataset cannot be initialized with an empty list")
assert isinstance(
lst[0], (str, np.ndarray, torch.Tensor)
), f"Unexpected dtype: {type(lst[0])}"
self.lst = lst
def __len__(self):
return len(self.lst)
def __getitem__(self, idx):
item = self.lst[idx]
if isinstance(item, str):
wav_tns = load_audio(item)
elif isinstance(item, np.ndarray):
wav_tns = torch.from_numpy(item)
elif isinstance(item, torch.Tensor):
wav_tns = item
else:
raise RuntimeError(f"Unexpected sample type: {type(item)} at idx={idx}")
return wav_tns
@staticmethod
def collate(wavs):
lengths = torch.tensor([len(wav) for wav in wavs])
max_len = lengths.max().item()
wav_tns = torch.zeros(len(wavs), max_len, dtype=wavs[0].dtype)
for idx, wav in enumerate(wavs):
wav_tns[idx, : wav.shape[-1]] = wav.squeeze()
return wav_tns, lengths
class ASRABCModel(abc.ABC):
encoder: ort.InferenceSession
preprocessor: FeatureExtractor
tokenizer: Tokenizer
blank_idx: int
def __init__(self, encoder: ort.InferenceSession, preprocessor: FeatureExtractor, tokenizer: Tokenizer):
self.encoder = encoder
self.preprocessor = preprocessor
self.tokenizer = tokenizer
self.blank_idx = len(self.tokenizer)
def transcribe(self, wav: np.ndarray) -> tuple[str, list[int]]:
return self.transcribe_batch([wav])[0]
def _transcribe_encode(self, input_signal: np.ndarray):
enc_inputs = {
node.name: data
for (node, data) in zip(
self.encoder.get_inputs(),
[input_signal.astype(DTYPE), [input_signal.shape[-1]]],
)
}
enc_features = self.encoder.run(
[node.name for node in self.encoder.get_outputs()], enc_inputs
)[0]
return enc_features
def _transcribe_encode_batch(self,
input_signals: np.ndarray,
input_lengths: np.ndarray) -> np.ndarray:
enc_inputs = {
node.name: data for (node, data) in zip(
self.encoder.get_inputs(),
[
input_signals.astype(DTYPE),
input_lengths
]
)
}
outputs = self.encoder.run(
[node.name for node in self.encoder.get_outputs()],
enc_inputs
)[0]
return outputs
def _transcribe_decode(self, features) -> tuple[list[int], list[int]]:
raise NotImplementedError()
def transcribe_batch(self,
wavs: list[np.ndarray],
join_batches: list[int] | None = None) -> list[tuple[str, list[int]]]:
input_lengths = []
processed_wavs = []
for wav in wavs:
audio_tensor = load_audio(wav)
processed = self.preprocessor(
audio_tensor.unsqueeze(0),
torch.tensor([audio_tensor.shape[-1]])
)[0]
if isinstance(processed, torch.Tensor):
processed = processed.cpu().numpy()
processed_wavs.append(processed)
input_lengths.append(processed.shape[2])
max_length = max(input_lengths)
batch_size = len(wavs)
features_dim = processed_wavs[0].shape[1]
padded_wavs = np.zeros((batch_size, features_dim, max_length), dtype=DTYPE)
for i, audio in enumerate(processed_wavs):
length = audio.shape[2]
padded_wavs[i, :, :length] = audio
input_lengths_array = np.array(input_lengths, dtype=np.int64)
features = self._transcribe_encode_batch(
padded_wavs,
input_lengths_array
)
if join_batches is None:
batch_token_ids = [self._transcribe_decode(features[i]) for i in range(batch_size)]
return [self.tokenizer.decode(ids) for ids in batch_token_ids]
else:
ret = []
start_idx = 0
for batch_len in join_batches:
end_idx = start_idx + batch_len
batch_features_list = []
batch_lengths = input_lengths_array[start_idx:end_idx]
for i in range(batch_len):
idx = start_idx + i
real_length = batch_lengths[i]
real_features = features[idx, :, :real_length]
batch_features_list.append(real_features)
concatenated_features = []
total_time = 0
for i, real_features in enumerate(batch_features_list):
concatenated_features.append(real_features)
total_time += real_features.shape[1]
seg_features_2d = np.concatenate(concatenated_features, axis=1)
seg_features = seg_features_2d[np.newaxis, :, :]
token_ids, token_timings = self._transcribe_decode(seg_features)
rate = (sum(list(map(len, wavs[start_idx:end_idx]))) / 16000) / max(token_timings)
result_text, out_timings = self.tokenizer.decode((token_ids, token_timings))
norm_out_timings = list(map(lambda x: x * rate, out_timings))
ret.append((result_text, norm_out_timings))
start_idx = end_idx
return ret

34
src/gigaam_onnx/ctc.py Normal file
View File

@@ -0,0 +1,34 @@
import numpy as np
import torch
from .preprocess import FeatureExtractor, load_audio
from .decoding import CTCGreedyDecoding, Tokenizer
import onnxruntime as ort
from .asr_abc import ASRABCModel
class CTCASR(ASRABCModel):
preprocessor: FeatureExtractor
encoder: ort.InferenceSession
def __init__(
self,
preprocessor: FeatureExtractor,
tokenizer: Tokenizer,
encoder: ort.InferenceSession
):
super().__init__(encoder, preprocessor, tokenizer)
def _transcribe_decode(self, features) -> tuple[list[int], list[int]]:
token_ids = []
prev_token = self.blank_idx
timings = []
while len(features.shape) > 2:
features = features[0]
for i, tok in enumerate(features.argmax(0).squeeze().tolist()):
if (tok != prev_token or prev_token == self.blank_idx) and tok < self.blank_idx:
token_ids.append(tok)
timings.append(i)
prev_token = tok
return token_ids, timings

133
src/gigaam_onnx/decoder.py Normal file
View File

@@ -0,0 +1,133 @@
# Mostly based on https://github.com/salute-developers/GigaAM/blame/bd77657d48f73633ed1d237ce0d6f99108f3c875/gigaam/decoder.py
# Original authors:
# - https://github.com/georgygospodinov
# - https://github.com/Alexander4127
from typing import Dict, List, Optional, Tuple
import torch
from torch import Tensor, nn
class CTCHead(nn.Module):
"""
CTC Head module for Connectionist Temporal Classification.
"""
def __init__(self, feat_in: int, num_classes: int):
super().__init__()
self.decoder_layers = torch.nn.Sequential(
torch.nn.Conv1d(feat_in, num_classes, kernel_size=1)
)
def forward(self, encoder_output: Tensor) -> Tensor:
return torch.nn.functional.log_softmax(
self.decoder_layers(encoder_output).transpose(1, 2), dim=-1
)
class RNNTJoint(nn.Module):
"""
RNN-Transducer Joint Network Module.
This module combines the outputs of the encoder and the prediction network using
a linear transformation followed by ReLU activation and another linear projection.
"""
def __init__(
self, enc_hidden: int, pred_hidden: int, joint_hidden: int, num_classes: int
):
super().__init__()
self.enc_hidden = enc_hidden
self.pred_hidden = pred_hidden
self.pred = nn.Linear(pred_hidden, joint_hidden)
self.enc = nn.Linear(enc_hidden, joint_hidden)
self.joint_net = nn.Sequential(nn.ReLU(), nn.Linear(joint_hidden, num_classes))
def joint(self, encoder_out: Tensor, decoder_out: Tensor) -> Tensor:
"""
Combine the encoder and prediction network outputs into a joint representation.
"""
enc = self.enc(encoder_out).unsqueeze(2)
pred = self.pred(decoder_out).unsqueeze(1)
return self.joint_net(enc + pred).log_softmax(-1)
def input_example(self) -> Tuple[Tensor, Tensor]:
device = next(self.parameters()).device
enc = torch.zeros(1, self.enc_hidden, 1)
dec = torch.zeros(1, self.pred_hidden, 1)
return enc.float().to(device), dec.float().to(device)
def input_names(self) -> List[str]:
return ["enc", "dec"]
def output_names(self) -> List[str]:
return ["joint"]
def forward(self, enc: Tensor, dec: Tensor) -> Tensor:
return self.joint(enc.transpose(1, 2), dec.transpose(1, 2))
class RNNTDecoder(nn.Module):
"""
RNN-Transducer Decoder Module.
This module handles the prediction network part of the RNN-Transducer architecture.
"""
def __init__(self, pred_hidden: int, pred_rnn_layers: int, num_classes: int):
super().__init__()
self.blank_id = num_classes - 1
self.pred_hidden = pred_hidden
self.embed = nn.Embedding(num_classes, pred_hidden, padding_idx=self.blank_id)
self.lstm = nn.LSTM(pred_hidden, pred_hidden, pred_rnn_layers)
def predict(
self,
x: Optional[Tensor],
state: Optional[Tensor],
batch_size: int = 1,
) -> Tuple[Tensor, Tensor]:
"""
Make predictions based on the current input and previous states.
If no input is provided, use zeros as the initial input.
"""
if x is not None:
emb: Tensor = self.embed(x)
else:
emb = torch.zeros(
(batch_size, 1, self.pred_hidden), device=next(self.parameters()).device
)
g, hid = self.lstm(emb.transpose(0, 1), state)
return g.transpose(0, 1), hid
def input_example(self) -> Tuple[Tensor, Tensor, Tensor]:
device = next(self.parameters()).device
label = torch.tensor([[0]]).to(device)
hidden_h = torch.zeros(1, 1, self.pred_hidden).to(device)
hidden_c = torch.zeros(1, 1, self.pred_hidden).to(device)
return label, hidden_h, hidden_c
def input_names(self) -> List[str]:
return ["x", "h", "c"]
def output_names(self) -> List[str]:
return ["dec", "h", "c"]
def forward(self, x: Tensor, h: Tensor, c: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
"""
ONNX-specific forward with x, state = (h, c) -> x, h, c.
"""
emb = self.embed(x)
g, (h, c) = self.lstm(emb.transpose(0, 1), (h, c))
return g.transpose(0, 1), h, c
class RNNTHead(nn.Module):
"""
RNN-Transducer Head Module.
This module combines the decoder and joint network components of the RNN-Transducer architecture.
"""
def __init__(self, decoder: Dict[str, int], joint: Dict[str, int]):
super().__init__()
self.decoder = RNNTDecoder(**decoder)
self.joint = RNNTJoint(**joint)

181
src/gigaam_onnx/decoding.py Normal file
View File

@@ -0,0 +1,181 @@
# Mostly based on https://github.com/salute-developers/GigaAM/blob/bd77657d48f73633ed1d237ce0d6f99108f3c875/gigaam/decoding.py
# Original authors:
# - https://github.com/georgygospodinov
# - https://github.com/Alexander4127
# - https://github.com/sverdoot
from typing import List, Optional, Tuple
import torch
from sentencepiece import SentencePieceProcessor
from torch import Tensor
from .decoder import CTCHead, RNNTHead
class Tokenizer:
"""
Tokenizer for converting between text and token IDs.
The tokenizer can operate either character-wise or using a pre-trained SentencePiece model.
"""
def __init__(self, vocab: List[str], model_path: Optional[str] = None):
self.charwise = model_path is None
if self.charwise:
self.vocab = vocab
else:
self.model = SentencePieceProcessor()
self.model.load(model_path)
def decode(self, tokens: tuple[list[int], list[int]]) -> tuple[str, list[int]]:
"""
Convert a list of token IDs back to a string.
"""
tokens, timings = tokens
if self.charwise:
return "".join(self.vocab[tok] for tok in tokens), timings
pieces = self.model.id_to_piece(tokens)
ret = ''
out_timings = []
for piece, time in zip(pieces, timings):
space = ''
while piece.startswith(''):
piece = piece[1:]
space = ' '
ret += space + piece
out_timings += [time] * (len(space) + len(piece))
return ret, out_timings
def __len__(self):
"""
Get the total number of tokens in the vocabulary.
"""
return len(self.vocab) if self.charwise else len(self.model)
class CTCGreedyDecoding:
"""
Class for performing greedy decoding of CTC outputs.
"""
def __init__(self, vocabulary: List[str], model_path: Optional[str] = None):
self.tokenizer = Tokenizer(vocabulary, model_path)
self.blank_id = len(self.tokenizer)
@torch.inference_mode()
def decode(self, head: CTCHead, encoded: Tensor, lengths: Tensor) -> List[str]:
"""
Decode the output of a CTC model into a list of hypotheses.
"""
log_probs = head(encoder_output=encoded)
assert (
len(log_probs.shape) == 3
), f"Expected log_probs shape {log_probs.shape} == [B, T, C]"
b, _, c = log_probs.shape
assert (
c == len(self.tokenizer) + 1
), f"Num classes {c} != len(vocab) + 1 {len(self.tokenizer) + 1}"
labels = log_probs.argmax(dim=-1, keepdim=False)
skip_mask = labels != self.blank_id
skip_mask[:, 1:] = torch.logical_and(
skip_mask[:, 1:], labels[:, 1:] != labels[:, :-1]
)
for length in lengths:
skip_mask[length:] = 0
pred_texts: List[str] = []
for i in range(b):
pred_texts.append(
"".join(self.tokenizer.decode(labels[i][skip_mask[i]].cpu().tolist()))
)
return pred_texts
class RNNTGreedyDecoding:
def __init__(
self,
vocabulary: List[str],
model_path: Optional[str] = None,
max_symbols_per_step: int = 10,
):
"""
Class for performing greedy decoding of RNN-T outputs.
"""
self.tokenizer = Tokenizer(vocabulary, model_path)
self.blank_id = len(self.tokenizer)
self.max_symbols = max_symbols_per_step
def _greedy_decode(self, head: RNNTHead, x: Tensor, seqlen: Tensor) -> str:
"""
Internal helper function for performing greedy decoding on a single sequence.
"""
hyp: List[int] = []
dec_state: Optional[Tensor] = None
last_label: Optional[Tensor] = None
for t in range(seqlen):
f = x[t, :, :].unsqueeze(1)
not_blank = True
new_symbols = 0
while not_blank and new_symbols < self.max_symbols:
g, hidden = head.decoder.predict(last_label, dec_state)
k = head.joint.joint(f, g)[0, 0, 0, :].argmax(0).item()
if k == self.blank_id:
not_blank = False
else:
hyp.append(int(k))
dec_state = hidden
last_label = torch.tensor([[hyp[-1]]]).to(x.device)
new_symbols += 1
return self.tokenizer.decode(hyp)
@torch.inference_mode()
def _greedy_decode(self,
head: RNNTHead,
x: Tensor,
seqlen: Tensor,
topk: int = 10) -> Tuple[str, List[float]]:
"""
Internal helper function for performing greedy decoding on a single sequence.
"""
hyp: List[int] = []
confidences: List[float] = []
dec_state: Optional[Tensor] = None
last_label: Optional[Tensor] = None
for t in range(seqlen):
f = x[t, :, :].unsqueeze(1)
is_blank = False
new_symbols = 0
while not is_blank and new_symbols < self.max_symbols:
g, hidden = head.decoder.predict(last_label, dec_state)
logits = head.joint.joint(f, g)[0, 0, 0, :]
probs = torch.softmax(logits, dim=0)
# was: argmax for top prob token
k = torch.argmax(probs).item()
confidence = probs[k].item()
# became: top k items, and top-1 extraction
if k == self.blank_id:
is_blank = True
else:
confidences.append(probs[k].item())
hyp.append(k)
dec_state = hidden
last_label = torch.tensor([[hyp[-1]]]).to(x.device)
new_symbols += 1
tokenized = self.tokenizer.decode(hyp)
return tokenized, confidences
def decode(self, head: RNNTHead, encoded: Tensor, enc_len: Tensor) -> List[Tuple[str, List[float]]]:
"""
Decode the output of an RNN-T model into a list of hypotheses.
"""
b = encoded.shape[0]
pred_texts = []
encoded = encoded.transpose(1, 2)
for i in range(b):
inseq = encoded[i, :, :].unsqueeze(1)
pred_texts.append(self._greedy_decode(head, inseq, enc_len[i]))
return pred_texts

View File

@@ -0,0 +1,91 @@
# Mostly based on https://github.com/salute-developers/GigaAM/blame/bd77657d48f73633ed1d237ce0d6f99108f3c875/gigaam/preprocess.py
# Original authors:
# - https://github.com/georgygospodinov
# - https://github.com/Alexander4127
from typing import Tuple
import numpy as np
import torch
import torchaudio
from torch import Tensor, nn
SAMPLE_RATE = 16000
def load_audio(src: np.ndarray, sample_rate: int = SAMPLE_RATE) -> Tensor:
"""
Load an audio file and resample it to the specified sample rate.
"""
if len(src) <= 0:
raise ValueError('Empty file provided')
audio_np = src
if audio_np.dtype == np.float32:
return torch.frombuffer(audio_np.tobytes(), dtype=torch.float32).float()
elif audio_np.dtype == np.float16:
return torch.frombuffer(audio_np.tobytes(), dtype=torch.float16).float()
elif audio_np.dtype == np.float64:
return torch.frombuffer(audio_np.tobytes(), dtype=torch.float64).float()
else:
audio_np = audio_np.astype(np.int16)
return torch.frombuffer(audio_np.tobytes(), dtype=torch.int16).float() / 32768.0
class SpecScaler(nn.Module):
"""
Module that applies logarithmic scaling to spectrogram values.
This module clamps the input values within a certain range and then applies a natural logarithm.
"""
def forward(self, x: Tensor) -> Tensor:
return torch.log(x.clamp_(1e-9, 1e9))
class FeatureExtractor(nn.Module):
"""
Module for extracting Log-mel spectrogram features from raw audio signals.
This module uses Torchaudio's MelSpectrogram transform to extract features
and applies logarithmic scaling.
"""
def __init__(self, sample_rate: int, features: int, **kwargs):
super().__init__()
self.hop_length = kwargs.get("hop_length", sample_rate // 100)
self.win_length = kwargs.get("win_length", sample_rate // 40)
self.n_fft = kwargs.get("n_fft", sample_rate // 40)
self.center = kwargs.get("center", True)
self.featurizer = nn.Sequential(
torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate,
n_mels=features,
win_length=self.win_length,
hop_length=self.hop_length,
n_fft=self.n_fft,
center=self.center,
),
SpecScaler(),
)
def out_len(self, input_lengths: Tensor) -> Tensor:
"""
Calculates the output length after the feature extraction process.
"""
if self.center:
return (
input_lengths.div(self.hop_length, rounding_mode="floor").add(1).long()
)
else:
return (
(input_lengths - self.win_length)
.div(self.hop_length, rounding_mode="floor")
.add(1)
.long()
)
def forward(self, input_signal: Tensor, length: Tensor) -> Tuple[Tensor, Tensor]:
"""
Extract Log-mel spectrogram features from the input audio signal.
"""
return self.featurizer(input_signal), self.out_len(length)

70
src/gigaam_onnx/rnnt.py Normal file
View File

@@ -0,0 +1,70 @@
import numpy as np
import onnxruntime as ort
from .asr_abc import ASRABCModel, DTYPE, MAX_LETTERS_PER_FRAME
from .decoding import RNNTHead, Tokenizer
from .preprocess import FeatureExtractor
class RNNTASR(ASRABCModel):
head: RNNTHead
predictor: ort.InferenceSession
jointer: ort.InferenceSession
def __init__(
self,
preprocessor: FeatureExtractor,
tokenizer: Tokenizer,
head: RNNTHead,
encoder: ort.InferenceSession,
predictor: ort.InferenceSession,
jointer: ort.InferenceSession,
):
self.head = head
self.predictor = predictor
self.jointer = jointer
super().__init__(encoder, preprocessor, tokenizer)
def _transcribe_decode(self, features) -> tuple[list[int], list[int]]:
token_ids = []
timings = []
prev_token = self.blank_idx
pred_states = [
np.zeros(shape=(1, 1, self.head.decoder.pred_hidden), dtype=DTYPE),
np.zeros(shape=(1, 1, self.head.decoder.pred_hidden), dtype=DTYPE),
]
for j in range(features.shape[-1]):
emitted_letters = 0
while emitted_letters < MAX_LETTERS_PER_FRAME:
pred_inputs = {
node.name: data
for (node, data) in zip(
self.predictor.get_inputs(), [np.array([[prev_token]])] + pred_states
)
}
pred_outputs = self.predictor.run(
[node.name for node in self.predictor.get_outputs()], pred_inputs
)
joint_inputs = {
node.name: data
for node, data in zip(
self.jointer.get_inputs(),
[features[:, :, [j]], pred_outputs[0].swapaxes(1, 2)],
)
}
log_probs = self.jointer.run(
[node.name for node in self.jointer.get_outputs()], joint_inputs
)
token = log_probs[0].argmax(-1)[0][0]
if token != self.blank_idx:
prev_token = int(token)
pred_states = pred_outputs[1:]
token_ids.append(int(token))
timings.append(j)
emitted_letters += 1
else:
break
return token_ids, timings

66
src/gigaam_onnx/v3_ctc.py Normal file
View File

@@ -0,0 +1,66 @@
import numpy as np
from .preprocess import FeatureExtractor, load_audio
import onnxruntime as rt
from .decoding import CTCGreedyDecoding, Tokenizer
from .ctc import CTCASR
_CTC_VOCAB = [
' ',
'а',
'б',
'в',
'г',
'д',
'е',
'ж',
'з',
'и',
'й',
'к',
'л',
'м',
'н',
'о',
'п',
'р',
'с',
'т',
'у',
'ф',
'х',
'ц',
'ч',
'ш',
'щ',
'ъ',
'ы',
'ь',
'э',
'ю',
'я',
]
class GigaAMV3CTC(CTCASR):
preprocessor: FeatureExtractor
model_path: str
decoding: CTCGreedyDecoding
def __init__(self, model_path: str, provider: str, opts: rt.SessionOptions):
self.model_path = model_path
preprocessor = FeatureExtractor(
sample_rate=16000,
features=64,
win_length=320,
hop_length=160,
mel_scale='htk',
n_fft=320,
mel_norm=None,
center=False
)
tokenizer = Tokenizer(_CTC_VOCAB)
encoder = rt.InferenceSession(self.model_path, providers=[provider], sess_options=opts)
super().__init__(preprocessor, tokenizer, encoder)

View File

@@ -0,0 +1,27 @@
import onnxruntime as rt
from .ctc import CTCASR
from .preprocess import FeatureExtractor
from .decoding import Tokenizer
class GigaAMV3E2ECTC(CTCASR):
model_path: str
tokenizer_path: str
def __init__(self, model_path: str, tokenizer_path: str, provider: str, opts: rt.SessionOptions):
self.model_path = model_path
self.tokenizer_path = tokenizer_path
preprocessor = FeatureExtractor(
sample_rate=16000,
features=64,
win_length=320,
hop_length=160,
mel_scale='htk',
n_fft=320,
mel_norm=None,
center=False
)
tokenizer = Tokenizer([], self.tokenizer_path)
encoder = rt.InferenceSession(self.model_path, providers=[provider], sess_options=opts)
super().__init__(preprocessor, tokenizer, encoder)

View File

@@ -0,0 +1,57 @@
import onnxruntime as rt
from .decoder import RNNTHead
from .decoding import Tokenizer
from .preprocess import FeatureExtractor
from .rnnt import RNNTASR
class GigaAMV3E2ERNNT(RNNTASR):
model_decoder_path: str
model_encoder_path: str
model_joint_path: str
tokenizer_path: str
def __init__(
self,
model_decoder_path: str,
model_encoder_path: str,
model_joint_path: str,
tokenizer_path: str,
provider: str,
opts: rt.SessionOptions
):
self.model_decoder_path = model_decoder_path
self.model_encoder_path = model_encoder_path
self.model_joint_path = model_joint_path
self.tokenizer_path = tokenizer_path
preprocessor = FeatureExtractor(
sample_rate=16000,
features=64,
win_length=320,
hop_length=160,
mel_scale='htk',
n_fft=320,
mel_norm=None,
center=False
)
tokenizer = Tokenizer([], self.tokenizer_path)
encoder = rt.InferenceSession(self.model_encoder_path, providers=[provider], sess_options=opts)
predictor = rt.InferenceSession(self.model_decoder_path, providers=[provider], sess_options=opts)
jointer = rt.InferenceSession(self.model_joint_path, providers=[provider], sess_options=opts)
head = RNNTHead(
{
'pred_hidden': 320,
'pred_rnn_layers': 1,
'num_classes': 1025,
},
{
'enc_hidden': 768,
'pred_hidden': 320,
'joint_hidden': 320,
'num_classes': 1025,
}
)
super().__init__(preprocessor, tokenizer, head, encoder, predictor, jointer)

View File

@@ -0,0 +1,91 @@
import numpy as np
import onnxruntime as rt
from .decoder import RNNTHead
from .decoding import Tokenizer
from .preprocess import FeatureExtractor
from .rnnt import RNNTASR
_RNNT_VOCAB = [
' ',
'а',
'б',
'в',
'г',
'д',
'е',
'ж',
'з',
'и',
'й',
'к',
'л',
'м',
'н',
'о',
'п',
'р',
'с',
'т',
'у',
'ф',
'х',
'ц',
'ч',
'ш',
'щ',
'ъ',
'ы',
'ь',
'э',
'ю',
'я',
]
class GigaAMV3RNNT(RNNTASR):
model_decoder_path: str
model_encoder_path: str
model_joint_path: str
def __init__(
self,
model_decoder_path: str,
model_encoder_path: str,
model_joint_path: str,
provider: str,
opts: rt.SessionOptions
):
self.model_decoder_path = model_decoder_path
self.model_encoder_path = model_encoder_path
self.model_joint_path = model_joint_path
preprocessor = FeatureExtractor(
sample_rate=16000,
features=64,
win_length=320,
hop_length=160,
mel_scale='htk',
n_fft=320,
mel_norm=None,
center=False
)
tokenizer = Tokenizer(_RNNT_VOCAB)
encoder = rt.InferenceSession(self.model_encoder_path, providers=[provider], sess_options=opts)
predictor = rt.InferenceSession(self.model_decoder_path, providers=[provider], sess_options=opts)
jointer = rt.InferenceSession(self.model_joint_path, providers=[provider], sess_options=opts)
head = RNNTHead(
{
'pred_hidden': 320,
'pred_rnn_layers': 1,
'num_classes': 34,
},
{
'enc_hidden': 768,
'pred_hidden': 320,
'joint_hidden': 320,
'num_classes': 34,
}
)
super().__init__(preprocessor, tokenizer, head, encoder, predictor, jointer)