67 lines
1.2 KiB
Python
67 lines
1.2 KiB
Python
import numpy as np
|
||
|
||
from .preprocess import FeatureExtractor, load_audio
|
||
import onnxruntime as rt
|
||
|
||
from .decoding import CTCGreedyDecoding, Tokenizer
|
||
from .ctc import CTCASR
|
||
|
||
_CTC_VOCAB = [
|
||
' ',
|
||
'а',
|
||
'б',
|
||
'в',
|
||
'г',
|
||
'д',
|
||
'е',
|
||
'ж',
|
||
'з',
|
||
'и',
|
||
'й',
|
||
'к',
|
||
'л',
|
||
'м',
|
||
'н',
|
||
'о',
|
||
'п',
|
||
'р',
|
||
'с',
|
||
'т',
|
||
'у',
|
||
'ф',
|
||
'х',
|
||
'ц',
|
||
'ч',
|
||
'ш',
|
||
'щ',
|
||
'ъ',
|
||
'ы',
|
||
'ь',
|
||
'э',
|
||
'ю',
|
||
'я',
|
||
]
|
||
|
||
|
||
class GigaAMV3CTC(CTCASR):
|
||
preprocessor: FeatureExtractor
|
||
model_path: str
|
||
decoding: CTCGreedyDecoding
|
||
|
||
def __init__(self, model_path: str, provider: str, opts: rt.SessionOptions):
|
||
self.model_path = model_path
|
||
preprocessor = FeatureExtractor(
|
||
sample_rate=16000,
|
||
features=64,
|
||
win_length=320,
|
||
hop_length=160,
|
||
mel_scale='htk',
|
||
n_fft=320,
|
||
mel_norm=None,
|
||
center=False
|
||
)
|
||
|
||
tokenizer = Tokenizer(_CTC_VOCAB)
|
||
encoder = rt.InferenceSession(self.model_path, providers=[provider], sess_options=opts)
|
||
super().__init__(preprocessor, tokenizer, encoder)
|