Initial commit
This commit is contained in:
0
src/gigaam_onnx/__init__.py
Normal file
0
src/gigaam_onnx/__init__.py
Normal file
182
src/gigaam_onnx/asr_abc.py
Normal file
182
src/gigaam_onnx/asr_abc.py
Normal file
@@ -0,0 +1,182 @@
|
||||
import abc
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from .preprocess import FeatureExtractor, load_audio
|
||||
from .decoding import Tokenizer
|
||||
|
||||
DTYPE = np.float32
|
||||
MAX_LETTERS_PER_FRAME = 3
|
||||
|
||||
|
||||
class AudioDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
Helper class for creating batched inputs
|
||||
"""
|
||||
|
||||
def __init__(self, lst: list[str | np.ndarray | torch.Tensor]):
|
||||
if len(lst) == 0:
|
||||
raise ValueError("AudioDataset cannot be initialized with an empty list")
|
||||
assert isinstance(
|
||||
lst[0], (str, np.ndarray, torch.Tensor)
|
||||
), f"Unexpected dtype: {type(lst[0])}"
|
||||
self.lst = lst
|
||||
|
||||
def __len__(self):
|
||||
return len(self.lst)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.lst[idx]
|
||||
if isinstance(item, str):
|
||||
wav_tns = load_audio(item)
|
||||
elif isinstance(item, np.ndarray):
|
||||
wav_tns = torch.from_numpy(item)
|
||||
elif isinstance(item, torch.Tensor):
|
||||
wav_tns = item
|
||||
else:
|
||||
raise RuntimeError(f"Unexpected sample type: {type(item)} at idx={idx}")
|
||||
return wav_tns
|
||||
|
||||
@staticmethod
|
||||
def collate(wavs):
|
||||
lengths = torch.tensor([len(wav) for wav in wavs])
|
||||
max_len = lengths.max().item()
|
||||
wav_tns = torch.zeros(len(wavs), max_len, dtype=wavs[0].dtype)
|
||||
for idx, wav in enumerate(wavs):
|
||||
wav_tns[idx, : wav.shape[-1]] = wav.squeeze()
|
||||
return wav_tns, lengths
|
||||
|
||||
|
||||
class ASRABCModel(abc.ABC):
|
||||
encoder: ort.InferenceSession
|
||||
preprocessor: FeatureExtractor
|
||||
tokenizer: Tokenizer
|
||||
blank_idx: int
|
||||
|
||||
def __init__(self, encoder: ort.InferenceSession, preprocessor: FeatureExtractor, tokenizer: Tokenizer):
|
||||
self.encoder = encoder
|
||||
self.preprocessor = preprocessor
|
||||
self.tokenizer = tokenizer
|
||||
self.blank_idx = len(self.tokenizer)
|
||||
|
||||
def transcribe(self, wav: np.ndarray) -> tuple[str, list[int]]:
|
||||
return self.transcribe_batch([wav])[0]
|
||||
|
||||
def _transcribe_encode(self, input_signal: np.ndarray):
|
||||
enc_inputs = {
|
||||
node.name: data
|
||||
for (node, data) in zip(
|
||||
self.encoder.get_inputs(),
|
||||
[input_signal.astype(DTYPE), [input_signal.shape[-1]]],
|
||||
)
|
||||
}
|
||||
enc_features = self.encoder.run(
|
||||
[node.name for node in self.encoder.get_outputs()], enc_inputs
|
||||
)[0]
|
||||
|
||||
return enc_features
|
||||
|
||||
def _transcribe_encode_batch(self,
|
||||
input_signals: np.ndarray,
|
||||
input_lengths: np.ndarray) -> np.ndarray:
|
||||
|
||||
enc_inputs = {
|
||||
node.name: data for (node, data) in zip(
|
||||
self.encoder.get_inputs(),
|
||||
[
|
||||
input_signals.astype(DTYPE),
|
||||
input_lengths
|
||||
]
|
||||
)
|
||||
}
|
||||
|
||||
outputs = self.encoder.run(
|
||||
[node.name for node in self.encoder.get_outputs()],
|
||||
enc_inputs
|
||||
)[0]
|
||||
|
||||
return outputs
|
||||
|
||||
def _transcribe_decode(self, features) -> tuple[list[int], list[int]]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def transcribe_batch(self,
|
||||
wavs: list[np.ndarray],
|
||||
join_batches: list[int] | None = None) -> list[tuple[str, list[int]]]:
|
||||
input_lengths = []
|
||||
processed_wavs = []
|
||||
for wav in wavs:
|
||||
audio_tensor = load_audio(wav)
|
||||
processed = self.preprocessor(
|
||||
audio_tensor.unsqueeze(0),
|
||||
torch.tensor([audio_tensor.shape[-1]])
|
||||
)[0]
|
||||
|
||||
if isinstance(processed, torch.Tensor):
|
||||
processed = processed.cpu().numpy()
|
||||
|
||||
processed_wavs.append(processed)
|
||||
input_lengths.append(processed.shape[2])
|
||||
|
||||
max_length = max(input_lengths)
|
||||
batch_size = len(wavs)
|
||||
features_dim = processed_wavs[0].shape[1]
|
||||
|
||||
padded_wavs = np.zeros((batch_size, features_dim, max_length), dtype=DTYPE)
|
||||
|
||||
for i, audio in enumerate(processed_wavs):
|
||||
length = audio.shape[2]
|
||||
padded_wavs[i, :, :length] = audio
|
||||
|
||||
input_lengths_array = np.array(input_lengths, dtype=np.int64)
|
||||
|
||||
features = self._transcribe_encode_batch(
|
||||
padded_wavs,
|
||||
input_lengths_array
|
||||
)
|
||||
|
||||
if join_batches is None:
|
||||
batch_token_ids = [self._transcribe_decode(features[i]) for i in range(batch_size)]
|
||||
return [self.tokenizer.decode(ids) for ids in batch_token_ids]
|
||||
else:
|
||||
ret = []
|
||||
start_idx = 0
|
||||
|
||||
for batch_len in join_batches:
|
||||
end_idx = start_idx + batch_len
|
||||
batch_features_list = []
|
||||
batch_lengths = input_lengths_array[start_idx:end_idx]
|
||||
|
||||
for i in range(batch_len):
|
||||
idx = start_idx + i
|
||||
real_length = batch_lengths[i]
|
||||
real_features = features[idx, :, :real_length]
|
||||
batch_features_list.append(real_features)
|
||||
|
||||
concatenated_features = []
|
||||
total_time = 0
|
||||
|
||||
for i, real_features in enumerate(batch_features_list):
|
||||
concatenated_features.append(real_features)
|
||||
total_time += real_features.shape[1]
|
||||
|
||||
seg_features_2d = np.concatenate(concatenated_features, axis=1)
|
||||
|
||||
seg_features = seg_features_2d[np.newaxis, :, :]
|
||||
|
||||
token_ids, token_timings = self._transcribe_decode(seg_features)
|
||||
|
||||
rate = (sum(list(map(len, wavs[start_idx:end_idx]))) / 16000) / max(token_timings)
|
||||
|
||||
result_text, out_timings = self.tokenizer.decode((token_ids, token_timings))
|
||||
norm_out_timings = list(map(lambda x: x * rate, out_timings))
|
||||
ret.append((result_text, norm_out_timings))
|
||||
|
||||
start_idx = end_idx
|
||||
|
||||
return ret
|
||||
34
src/gigaam_onnx/ctc.py
Normal file
34
src/gigaam_onnx/ctc.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .preprocess import FeatureExtractor, load_audio
|
||||
from .decoding import CTCGreedyDecoding, Tokenizer
|
||||
import onnxruntime as ort
|
||||
from .asr_abc import ASRABCModel
|
||||
|
||||
|
||||
class CTCASR(ASRABCModel):
|
||||
preprocessor: FeatureExtractor
|
||||
encoder: ort.InferenceSession
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
preprocessor: FeatureExtractor,
|
||||
tokenizer: Tokenizer,
|
||||
encoder: ort.InferenceSession
|
||||
):
|
||||
super().__init__(encoder, preprocessor, tokenizer)
|
||||
|
||||
def _transcribe_decode(self, features) -> tuple[list[int], list[int]]:
|
||||
token_ids = []
|
||||
prev_token = self.blank_idx
|
||||
timings = []
|
||||
while len(features.shape) > 2:
|
||||
features = features[0]
|
||||
for i, tok in enumerate(features.argmax(0).squeeze().tolist()):
|
||||
if (tok != prev_token or prev_token == self.blank_idx) and tok < self.blank_idx:
|
||||
token_ids.append(tok)
|
||||
timings.append(i)
|
||||
prev_token = tok
|
||||
|
||||
return token_ids, timings
|
||||
133
src/gigaam_onnx/decoder.py
Normal file
133
src/gigaam_onnx/decoder.py
Normal file
@@ -0,0 +1,133 @@
|
||||
# Mostly based on https://github.com/salute-developers/GigaAM/blame/bd77657d48f73633ed1d237ce0d6f99108f3c875/gigaam/decoder.py
|
||||
# Original authors:
|
||||
# - https://github.com/georgygospodinov
|
||||
# - https://github.com/Alexander4127
|
||||
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
class CTCHead(nn.Module):
|
||||
"""
|
||||
CTC Head module for Connectionist Temporal Classification.
|
||||
"""
|
||||
|
||||
def __init__(self, feat_in: int, num_classes: int):
|
||||
super().__init__()
|
||||
self.decoder_layers = torch.nn.Sequential(
|
||||
torch.nn.Conv1d(feat_in, num_classes, kernel_size=1)
|
||||
)
|
||||
|
||||
def forward(self, encoder_output: Tensor) -> Tensor:
|
||||
return torch.nn.functional.log_softmax(
|
||||
self.decoder_layers(encoder_output).transpose(1, 2), dim=-1
|
||||
)
|
||||
|
||||
|
||||
class RNNTJoint(nn.Module):
|
||||
"""
|
||||
RNN-Transducer Joint Network Module.
|
||||
This module combines the outputs of the encoder and the prediction network using
|
||||
a linear transformation followed by ReLU activation and another linear projection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, enc_hidden: int, pred_hidden: int, joint_hidden: int, num_classes: int
|
||||
):
|
||||
super().__init__()
|
||||
self.enc_hidden = enc_hidden
|
||||
self.pred_hidden = pred_hidden
|
||||
self.pred = nn.Linear(pred_hidden, joint_hidden)
|
||||
self.enc = nn.Linear(enc_hidden, joint_hidden)
|
||||
self.joint_net = nn.Sequential(nn.ReLU(), nn.Linear(joint_hidden, num_classes))
|
||||
|
||||
def joint(self, encoder_out: Tensor, decoder_out: Tensor) -> Tensor:
|
||||
"""
|
||||
Combine the encoder and prediction network outputs into a joint representation.
|
||||
"""
|
||||
enc = self.enc(encoder_out).unsqueeze(2)
|
||||
pred = self.pred(decoder_out).unsqueeze(1)
|
||||
return self.joint_net(enc + pred).log_softmax(-1)
|
||||
|
||||
def input_example(self) -> Tuple[Tensor, Tensor]:
|
||||
device = next(self.parameters()).device
|
||||
enc = torch.zeros(1, self.enc_hidden, 1)
|
||||
dec = torch.zeros(1, self.pred_hidden, 1)
|
||||
return enc.float().to(device), dec.float().to(device)
|
||||
|
||||
def input_names(self) -> List[str]:
|
||||
return ["enc", "dec"]
|
||||
|
||||
def output_names(self) -> List[str]:
|
||||
return ["joint"]
|
||||
|
||||
def forward(self, enc: Tensor, dec: Tensor) -> Tensor:
|
||||
return self.joint(enc.transpose(1, 2), dec.transpose(1, 2))
|
||||
|
||||
|
||||
class RNNTDecoder(nn.Module):
|
||||
"""
|
||||
RNN-Transducer Decoder Module.
|
||||
This module handles the prediction network part of the RNN-Transducer architecture.
|
||||
"""
|
||||
|
||||
def __init__(self, pred_hidden: int, pred_rnn_layers: int, num_classes: int):
|
||||
super().__init__()
|
||||
self.blank_id = num_classes - 1
|
||||
self.pred_hidden = pred_hidden
|
||||
self.embed = nn.Embedding(num_classes, pred_hidden, padding_idx=self.blank_id)
|
||||
self.lstm = nn.LSTM(pred_hidden, pred_hidden, pred_rnn_layers)
|
||||
|
||||
def predict(
|
||||
self,
|
||||
x: Optional[Tensor],
|
||||
state: Optional[Tensor],
|
||||
batch_size: int = 1,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Make predictions based on the current input and previous states.
|
||||
If no input is provided, use zeros as the initial input.
|
||||
"""
|
||||
if x is not None:
|
||||
emb: Tensor = self.embed(x)
|
||||
else:
|
||||
emb = torch.zeros(
|
||||
(batch_size, 1, self.pred_hidden), device=next(self.parameters()).device
|
||||
)
|
||||
g, hid = self.lstm(emb.transpose(0, 1), state)
|
||||
return g.transpose(0, 1), hid
|
||||
|
||||
def input_example(self) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
device = next(self.parameters()).device
|
||||
label = torch.tensor([[0]]).to(device)
|
||||
hidden_h = torch.zeros(1, 1, self.pred_hidden).to(device)
|
||||
hidden_c = torch.zeros(1, 1, self.pred_hidden).to(device)
|
||||
return label, hidden_h, hidden_c
|
||||
|
||||
def input_names(self) -> List[str]:
|
||||
return ["x", "h", "c"]
|
||||
|
||||
def output_names(self) -> List[str]:
|
||||
return ["dec", "h", "c"]
|
||||
|
||||
def forward(self, x: Tensor, h: Tensor, c: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
"""
|
||||
ONNX-specific forward with x, state = (h, c) -> x, h, c.
|
||||
"""
|
||||
emb = self.embed(x)
|
||||
g, (h, c) = self.lstm(emb.transpose(0, 1), (h, c))
|
||||
return g.transpose(0, 1), h, c
|
||||
|
||||
|
||||
class RNNTHead(nn.Module):
|
||||
"""
|
||||
RNN-Transducer Head Module.
|
||||
This module combines the decoder and joint network components of the RNN-Transducer architecture.
|
||||
"""
|
||||
|
||||
def __init__(self, decoder: Dict[str, int], joint: Dict[str, int]):
|
||||
super().__init__()
|
||||
self.decoder = RNNTDecoder(**decoder)
|
||||
self.joint = RNNTJoint(**joint)
|
||||
181
src/gigaam_onnx/decoding.py
Normal file
181
src/gigaam_onnx/decoding.py
Normal file
@@ -0,0 +1,181 @@
|
||||
# Mostly based on https://github.com/salute-developers/GigaAM/blob/bd77657d48f73633ed1d237ce0d6f99108f3c875/gigaam/decoding.py
|
||||
# Original authors:
|
||||
# - https://github.com/georgygospodinov
|
||||
# - https://github.com/Alexander4127
|
||||
# - https://github.com/sverdoot
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
from torch import Tensor
|
||||
|
||||
from .decoder import CTCHead, RNNTHead
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
"""
|
||||
Tokenizer for converting between text and token IDs.
|
||||
The tokenizer can operate either character-wise or using a pre-trained SentencePiece model.
|
||||
"""
|
||||
|
||||
def __init__(self, vocab: List[str], model_path: Optional[str] = None):
|
||||
self.charwise = model_path is None
|
||||
if self.charwise:
|
||||
self.vocab = vocab
|
||||
else:
|
||||
self.model = SentencePieceProcessor()
|
||||
self.model.load(model_path)
|
||||
|
||||
def decode(self, tokens: tuple[list[int], list[int]]) -> tuple[str, list[int]]:
|
||||
"""
|
||||
Convert a list of token IDs back to a string.
|
||||
"""
|
||||
tokens, timings = tokens
|
||||
if self.charwise:
|
||||
return "".join(self.vocab[tok] for tok in tokens), timings
|
||||
pieces = self.model.id_to_piece(tokens)
|
||||
ret = ''
|
||||
out_timings = []
|
||||
for piece, time in zip(pieces, timings):
|
||||
space = ''
|
||||
while piece.startswith('▁'):
|
||||
piece = piece[1:]
|
||||
space = ' '
|
||||
ret += space + piece
|
||||
out_timings += [time] * (len(space) + len(piece))
|
||||
return ret, out_timings
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Get the total number of tokens in the vocabulary.
|
||||
"""
|
||||
return len(self.vocab) if self.charwise else len(self.model)
|
||||
|
||||
|
||||
class CTCGreedyDecoding:
|
||||
"""
|
||||
Class for performing greedy decoding of CTC outputs.
|
||||
"""
|
||||
|
||||
def __init__(self, vocabulary: List[str], model_path: Optional[str] = None):
|
||||
self.tokenizer = Tokenizer(vocabulary, model_path)
|
||||
self.blank_id = len(self.tokenizer)
|
||||
|
||||
@torch.inference_mode()
|
||||
def decode(self, head: CTCHead, encoded: Tensor, lengths: Tensor) -> List[str]:
|
||||
"""
|
||||
Decode the output of a CTC model into a list of hypotheses.
|
||||
"""
|
||||
log_probs = head(encoder_output=encoded)
|
||||
assert (
|
||||
len(log_probs.shape) == 3
|
||||
), f"Expected log_probs shape {log_probs.shape} == [B, T, C]"
|
||||
b, _, c = log_probs.shape
|
||||
assert (
|
||||
c == len(self.tokenizer) + 1
|
||||
), f"Num classes {c} != len(vocab) + 1 {len(self.tokenizer) + 1}"
|
||||
labels = log_probs.argmax(dim=-1, keepdim=False)
|
||||
|
||||
skip_mask = labels != self.blank_id
|
||||
skip_mask[:, 1:] = torch.logical_and(
|
||||
skip_mask[:, 1:], labels[:, 1:] != labels[:, :-1]
|
||||
)
|
||||
for length in lengths:
|
||||
skip_mask[length:] = 0
|
||||
|
||||
pred_texts: List[str] = []
|
||||
for i in range(b):
|
||||
pred_texts.append(
|
||||
"".join(self.tokenizer.decode(labels[i][skip_mask[i]].cpu().tolist()))
|
||||
)
|
||||
return pred_texts
|
||||
|
||||
|
||||
class RNNTGreedyDecoding:
|
||||
def __init__(
|
||||
self,
|
||||
vocabulary: List[str],
|
||||
model_path: Optional[str] = None,
|
||||
max_symbols_per_step: int = 10,
|
||||
):
|
||||
"""
|
||||
Class for performing greedy decoding of RNN-T outputs.
|
||||
"""
|
||||
self.tokenizer = Tokenizer(vocabulary, model_path)
|
||||
self.blank_id = len(self.tokenizer)
|
||||
self.max_symbols = max_symbols_per_step
|
||||
|
||||
def _greedy_decode(self, head: RNNTHead, x: Tensor, seqlen: Tensor) -> str:
|
||||
"""
|
||||
Internal helper function for performing greedy decoding on a single sequence.
|
||||
"""
|
||||
hyp: List[int] = []
|
||||
dec_state: Optional[Tensor] = None
|
||||
last_label: Optional[Tensor] = None
|
||||
for t in range(seqlen):
|
||||
f = x[t, :, :].unsqueeze(1)
|
||||
not_blank = True
|
||||
new_symbols = 0
|
||||
while not_blank and new_symbols < self.max_symbols:
|
||||
g, hidden = head.decoder.predict(last_label, dec_state)
|
||||
k = head.joint.joint(f, g)[0, 0, 0, :].argmax(0).item()
|
||||
if k == self.blank_id:
|
||||
not_blank = False
|
||||
else:
|
||||
hyp.append(int(k))
|
||||
dec_state = hidden
|
||||
last_label = torch.tensor([[hyp[-1]]]).to(x.device)
|
||||
new_symbols += 1
|
||||
|
||||
return self.tokenizer.decode(hyp)
|
||||
|
||||
@torch.inference_mode()
|
||||
def _greedy_decode(self,
|
||||
head: RNNTHead,
|
||||
x: Tensor,
|
||||
seqlen: Tensor,
|
||||
topk: int = 10) -> Tuple[str, List[float]]:
|
||||
"""
|
||||
Internal helper function for performing greedy decoding on a single sequence.
|
||||
"""
|
||||
hyp: List[int] = []
|
||||
confidences: List[float] = []
|
||||
dec_state: Optional[Tensor] = None
|
||||
last_label: Optional[Tensor] = None
|
||||
for t in range(seqlen):
|
||||
f = x[t, :, :].unsqueeze(1)
|
||||
is_blank = False
|
||||
new_symbols = 0
|
||||
while not is_blank and new_symbols < self.max_symbols:
|
||||
g, hidden = head.decoder.predict(last_label, dec_state)
|
||||
|
||||
logits = head.joint.joint(f, g)[0, 0, 0, :]
|
||||
probs = torch.softmax(logits, dim=0)
|
||||
# was: argmax for top prob token
|
||||
k = torch.argmax(probs).item()
|
||||
confidence = probs[k].item()
|
||||
# became: top k items, and top-1 extraction
|
||||
|
||||
if k == self.blank_id:
|
||||
is_blank = True
|
||||
else:
|
||||
confidences.append(probs[k].item())
|
||||
hyp.append(k)
|
||||
dec_state = hidden
|
||||
last_label = torch.tensor([[hyp[-1]]]).to(x.device)
|
||||
new_symbols += 1
|
||||
|
||||
tokenized = self.tokenizer.decode(hyp)
|
||||
return tokenized, confidences
|
||||
|
||||
def decode(self, head: RNNTHead, encoded: Tensor, enc_len: Tensor) -> List[Tuple[str, List[float]]]:
|
||||
"""
|
||||
Decode the output of an RNN-T model into a list of hypotheses.
|
||||
"""
|
||||
b = encoded.shape[0]
|
||||
pred_texts = []
|
||||
encoded = encoded.transpose(1, 2)
|
||||
for i in range(b):
|
||||
inseq = encoded[i, :, :].unsqueeze(1)
|
||||
pred_texts.append(self._greedy_decode(head, inseq, enc_len[i]))
|
||||
return pred_texts
|
||||
91
src/gigaam_onnx/preprocess.py
Normal file
91
src/gigaam_onnx/preprocess.py
Normal file
@@ -0,0 +1,91 @@
|
||||
# Mostly based on https://github.com/salute-developers/GigaAM/blame/bd77657d48f73633ed1d237ce0d6f99108f3c875/gigaam/preprocess.py
|
||||
# Original authors:
|
||||
# - https://github.com/georgygospodinov
|
||||
# - https://github.com/Alexander4127
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch import Tensor, nn
|
||||
|
||||
SAMPLE_RATE = 16000
|
||||
|
||||
|
||||
def load_audio(src: np.ndarray, sample_rate: int = SAMPLE_RATE) -> Tensor:
|
||||
"""
|
||||
Load an audio file and resample it to the specified sample rate.
|
||||
"""
|
||||
if len(src) <= 0:
|
||||
raise ValueError('Empty file provided')
|
||||
|
||||
audio_np = src
|
||||
if audio_np.dtype == np.float32:
|
||||
return torch.frombuffer(audio_np.tobytes(), dtype=torch.float32).float()
|
||||
elif audio_np.dtype == np.float16:
|
||||
return torch.frombuffer(audio_np.tobytes(), dtype=torch.float16).float()
|
||||
elif audio_np.dtype == np.float64:
|
||||
return torch.frombuffer(audio_np.tobytes(), dtype=torch.float64).float()
|
||||
else:
|
||||
audio_np = audio_np.astype(np.int16)
|
||||
|
||||
return torch.frombuffer(audio_np.tobytes(), dtype=torch.int16).float() / 32768.0
|
||||
|
||||
|
||||
class SpecScaler(nn.Module):
|
||||
"""
|
||||
Module that applies logarithmic scaling to spectrogram values.
|
||||
This module clamps the input values within a certain range and then applies a natural logarithm.
|
||||
"""
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return torch.log(x.clamp_(1e-9, 1e9))
|
||||
|
||||
|
||||
class FeatureExtractor(nn.Module):
|
||||
"""
|
||||
Module for extracting Log-mel spectrogram features from raw audio signals.
|
||||
This module uses Torchaudio's MelSpectrogram transform to extract features
|
||||
and applies logarithmic scaling.
|
||||
"""
|
||||
|
||||
def __init__(self, sample_rate: int, features: int, **kwargs):
|
||||
super().__init__()
|
||||
self.hop_length = kwargs.get("hop_length", sample_rate // 100)
|
||||
self.win_length = kwargs.get("win_length", sample_rate // 40)
|
||||
self.n_fft = kwargs.get("n_fft", sample_rate // 40)
|
||||
self.center = kwargs.get("center", True)
|
||||
self.featurizer = nn.Sequential(
|
||||
torchaudio.transforms.MelSpectrogram(
|
||||
sample_rate=sample_rate,
|
||||
n_mels=features,
|
||||
win_length=self.win_length,
|
||||
hop_length=self.hop_length,
|
||||
n_fft=self.n_fft,
|
||||
center=self.center,
|
||||
),
|
||||
SpecScaler(),
|
||||
)
|
||||
|
||||
def out_len(self, input_lengths: Tensor) -> Tensor:
|
||||
"""
|
||||
Calculates the output length after the feature extraction process.
|
||||
"""
|
||||
if self.center:
|
||||
return (
|
||||
input_lengths.div(self.hop_length, rounding_mode="floor").add(1).long()
|
||||
)
|
||||
else:
|
||||
return (
|
||||
(input_lengths - self.win_length)
|
||||
.div(self.hop_length, rounding_mode="floor")
|
||||
.add(1)
|
||||
.long()
|
||||
)
|
||||
|
||||
def forward(self, input_signal: Tensor, length: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Extract Log-mel spectrogram features from the input audio signal.
|
||||
"""
|
||||
return self.featurizer(input_signal), self.out_len(length)
|
||||
70
src/gigaam_onnx/rnnt.py
Normal file
70
src/gigaam_onnx/rnnt.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
from .asr_abc import ASRABCModel, DTYPE, MAX_LETTERS_PER_FRAME
|
||||
from .decoding import RNNTHead, Tokenizer
|
||||
from .preprocess import FeatureExtractor
|
||||
|
||||
|
||||
class RNNTASR(ASRABCModel):
|
||||
head: RNNTHead
|
||||
predictor: ort.InferenceSession
|
||||
jointer: ort.InferenceSession
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
preprocessor: FeatureExtractor,
|
||||
tokenizer: Tokenizer,
|
||||
head: RNNTHead,
|
||||
encoder: ort.InferenceSession,
|
||||
predictor: ort.InferenceSession,
|
||||
jointer: ort.InferenceSession,
|
||||
):
|
||||
self.head = head
|
||||
self.predictor = predictor
|
||||
self.jointer = jointer
|
||||
super().__init__(encoder, preprocessor, tokenizer)
|
||||
|
||||
def _transcribe_decode(self, features) -> tuple[list[int], list[int]]:
|
||||
token_ids = []
|
||||
timings = []
|
||||
prev_token = self.blank_idx
|
||||
|
||||
pred_states = [
|
||||
np.zeros(shape=(1, 1, self.head.decoder.pred_hidden), dtype=DTYPE),
|
||||
np.zeros(shape=(1, 1, self.head.decoder.pred_hidden), dtype=DTYPE),
|
||||
]
|
||||
for j in range(features.shape[-1]):
|
||||
emitted_letters = 0
|
||||
while emitted_letters < MAX_LETTERS_PER_FRAME:
|
||||
pred_inputs = {
|
||||
node.name: data
|
||||
for (node, data) in zip(
|
||||
self.predictor.get_inputs(), [np.array([[prev_token]])] + pred_states
|
||||
)
|
||||
}
|
||||
pred_outputs = self.predictor.run(
|
||||
[node.name for node in self.predictor.get_outputs()], pred_inputs
|
||||
)
|
||||
|
||||
joint_inputs = {
|
||||
node.name: data
|
||||
for node, data in zip(
|
||||
self.jointer.get_inputs(),
|
||||
[features[:, :, [j]], pred_outputs[0].swapaxes(1, 2)],
|
||||
)
|
||||
}
|
||||
log_probs = self.jointer.run(
|
||||
[node.name for node in self.jointer.get_outputs()], joint_inputs
|
||||
)
|
||||
token = log_probs[0].argmax(-1)[0][0]
|
||||
|
||||
if token != self.blank_idx:
|
||||
prev_token = int(token)
|
||||
pred_states = pred_outputs[1:]
|
||||
token_ids.append(int(token))
|
||||
timings.append(j)
|
||||
emitted_letters += 1
|
||||
else:
|
||||
break
|
||||
return token_ids, timings
|
||||
66
src/gigaam_onnx/v3_ctc.py
Normal file
66
src/gigaam_onnx/v3_ctc.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import numpy as np
|
||||
|
||||
from .preprocess import FeatureExtractor, load_audio
|
||||
import onnxruntime as rt
|
||||
|
||||
from .decoding import CTCGreedyDecoding, Tokenizer
|
||||
from .ctc import CTCASR
|
||||
|
||||
_CTC_VOCAB = [
|
||||
' ',
|
||||
'а',
|
||||
'б',
|
||||
'в',
|
||||
'г',
|
||||
'д',
|
||||
'е',
|
||||
'ж',
|
||||
'з',
|
||||
'и',
|
||||
'й',
|
||||
'к',
|
||||
'л',
|
||||
'м',
|
||||
'н',
|
||||
'о',
|
||||
'п',
|
||||
'р',
|
||||
'с',
|
||||
'т',
|
||||
'у',
|
||||
'ф',
|
||||
'х',
|
||||
'ц',
|
||||
'ч',
|
||||
'ш',
|
||||
'щ',
|
||||
'ъ',
|
||||
'ы',
|
||||
'ь',
|
||||
'э',
|
||||
'ю',
|
||||
'я',
|
||||
]
|
||||
|
||||
|
||||
class GigaAMV3CTC(CTCASR):
|
||||
preprocessor: FeatureExtractor
|
||||
model_path: str
|
||||
decoding: CTCGreedyDecoding
|
||||
|
||||
def __init__(self, model_path: str, provider: str, opts: rt.SessionOptions):
|
||||
self.model_path = model_path
|
||||
preprocessor = FeatureExtractor(
|
||||
sample_rate=16000,
|
||||
features=64,
|
||||
win_length=320,
|
||||
hop_length=160,
|
||||
mel_scale='htk',
|
||||
n_fft=320,
|
||||
mel_norm=None,
|
||||
center=False
|
||||
)
|
||||
|
||||
tokenizer = Tokenizer(_CTC_VOCAB)
|
||||
encoder = rt.InferenceSession(self.model_path, providers=[provider], sess_options=opts)
|
||||
super().__init__(preprocessor, tokenizer, encoder)
|
||||
27
src/gigaam_onnx/v3_e2e_ctc.py
Normal file
27
src/gigaam_onnx/v3_e2e_ctc.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import onnxruntime as rt
|
||||
|
||||
from .ctc import CTCASR
|
||||
from .preprocess import FeatureExtractor
|
||||
from .decoding import Tokenizer
|
||||
|
||||
|
||||
class GigaAMV3E2ECTC(CTCASR):
|
||||
model_path: str
|
||||
tokenizer_path: str
|
||||
|
||||
def __init__(self, model_path: str, tokenizer_path: str, provider: str, opts: rt.SessionOptions):
|
||||
self.model_path = model_path
|
||||
self.tokenizer_path = tokenizer_path
|
||||
preprocessor = FeatureExtractor(
|
||||
sample_rate=16000,
|
||||
features=64,
|
||||
win_length=320,
|
||||
hop_length=160,
|
||||
mel_scale='htk',
|
||||
n_fft=320,
|
||||
mel_norm=None,
|
||||
center=False
|
||||
)
|
||||
tokenizer = Tokenizer([], self.tokenizer_path)
|
||||
encoder = rt.InferenceSession(self.model_path, providers=[provider], sess_options=opts)
|
||||
super().__init__(preprocessor, tokenizer, encoder)
|
||||
57
src/gigaam_onnx/v3_e2e_rnnt.py
Normal file
57
src/gigaam_onnx/v3_e2e_rnnt.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import onnxruntime as rt
|
||||
|
||||
from .decoder import RNNTHead
|
||||
from .decoding import Tokenizer
|
||||
from .preprocess import FeatureExtractor
|
||||
from .rnnt import RNNTASR
|
||||
|
||||
|
||||
class GigaAMV3E2ERNNT(RNNTASR):
|
||||
model_decoder_path: str
|
||||
model_encoder_path: str
|
||||
model_joint_path: str
|
||||
tokenizer_path: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_decoder_path: str,
|
||||
model_encoder_path: str,
|
||||
model_joint_path: str,
|
||||
tokenizer_path: str,
|
||||
provider: str,
|
||||
opts: rt.SessionOptions
|
||||
):
|
||||
self.model_decoder_path = model_decoder_path
|
||||
self.model_encoder_path = model_encoder_path
|
||||
self.model_joint_path = model_joint_path
|
||||
|
||||
self.tokenizer_path = tokenizer_path
|
||||
preprocessor = FeatureExtractor(
|
||||
sample_rate=16000,
|
||||
features=64,
|
||||
win_length=320,
|
||||
hop_length=160,
|
||||
mel_scale='htk',
|
||||
n_fft=320,
|
||||
mel_norm=None,
|
||||
center=False
|
||||
)
|
||||
tokenizer = Tokenizer([], self.tokenizer_path)
|
||||
encoder = rt.InferenceSession(self.model_encoder_path, providers=[provider], sess_options=opts)
|
||||
predictor = rt.InferenceSession(self.model_decoder_path, providers=[provider], sess_options=opts)
|
||||
jointer = rt.InferenceSession(self.model_joint_path, providers=[provider], sess_options=opts)
|
||||
head = RNNTHead(
|
||||
{
|
||||
'pred_hidden': 320,
|
||||
'pred_rnn_layers': 1,
|
||||
'num_classes': 1025,
|
||||
},
|
||||
{
|
||||
'enc_hidden': 768,
|
||||
'pred_hidden': 320,
|
||||
'joint_hidden': 320,
|
||||
'num_classes': 1025,
|
||||
}
|
||||
)
|
||||
|
||||
super().__init__(preprocessor, tokenizer, head, encoder, predictor, jointer)
|
||||
91
src/gigaam_onnx/v3_rnnt.py
Normal file
91
src/gigaam_onnx/v3_rnnt.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import numpy as np
|
||||
import onnxruntime as rt
|
||||
|
||||
from .decoder import RNNTHead
|
||||
from .decoding import Tokenizer
|
||||
from .preprocess import FeatureExtractor
|
||||
from .rnnt import RNNTASR
|
||||
|
||||
_RNNT_VOCAB = [
|
||||
' ',
|
||||
'а',
|
||||
'б',
|
||||
'в',
|
||||
'г',
|
||||
'д',
|
||||
'е',
|
||||
'ж',
|
||||
'з',
|
||||
'и',
|
||||
'й',
|
||||
'к',
|
||||
'л',
|
||||
'м',
|
||||
'н',
|
||||
'о',
|
||||
'п',
|
||||
'р',
|
||||
'с',
|
||||
'т',
|
||||
'у',
|
||||
'ф',
|
||||
'х',
|
||||
'ц',
|
||||
'ч',
|
||||
'ш',
|
||||
'щ',
|
||||
'ъ',
|
||||
'ы',
|
||||
'ь',
|
||||
'э',
|
||||
'ю',
|
||||
'я',
|
||||
]
|
||||
|
||||
|
||||
class GigaAMV3RNNT(RNNTASR):
|
||||
model_decoder_path: str
|
||||
model_encoder_path: str
|
||||
model_joint_path: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_decoder_path: str,
|
||||
model_encoder_path: str,
|
||||
model_joint_path: str,
|
||||
provider: str,
|
||||
opts: rt.SessionOptions
|
||||
):
|
||||
self.model_decoder_path = model_decoder_path
|
||||
self.model_encoder_path = model_encoder_path
|
||||
self.model_joint_path = model_joint_path
|
||||
|
||||
preprocessor = FeatureExtractor(
|
||||
sample_rate=16000,
|
||||
features=64,
|
||||
win_length=320,
|
||||
hop_length=160,
|
||||
mel_scale='htk',
|
||||
n_fft=320,
|
||||
mel_norm=None,
|
||||
center=False
|
||||
)
|
||||
tokenizer = Tokenizer(_RNNT_VOCAB)
|
||||
encoder = rt.InferenceSession(self.model_encoder_path, providers=[provider], sess_options=opts)
|
||||
predictor = rt.InferenceSession(self.model_decoder_path, providers=[provider], sess_options=opts)
|
||||
jointer = rt.InferenceSession(self.model_joint_path, providers=[provider], sess_options=opts)
|
||||
head = RNNTHead(
|
||||
{
|
||||
'pred_hidden': 320,
|
||||
'pred_rnn_layers': 1,
|
||||
'num_classes': 34,
|
||||
},
|
||||
{
|
||||
'enc_hidden': 768,
|
||||
'pred_hidden': 320,
|
||||
'joint_hidden': 320,
|
||||
'num_classes': 34,
|
||||
}
|
||||
)
|
||||
|
||||
super().__init__(preprocessor, tokenizer, head, encoder, predictor, jointer)
|
||||
Reference in New Issue
Block a user