1 Star 0 Fork 0

杨显杰/asr_server_cpu

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
asr_vad_detect.py 5.81 KB
一键复制 编辑 原始数据 按行查看 历史
xianjie.yang 提交于 2023-12-07 11:21 . timeout set to 10s
import os
import uuid
from concurrent import futures
import queue
import threading
import copy
import numpy as np
import sys
import threading
import logging
import time
VAD_BASE_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, VAD_BASE_DIR)
from AsrStDecoder import warp_asr_alloc_stVAD,warp_asr_free_stVAD,warp_asr_vad_input,warp_asr_vad_reset
"""初始化全局变量"""
g_thread_pools = None
g_max_concurrent = 20
g_session_map_lock = None
g_session_map = dict()
g_vad_model = None
MODEL_PATH = os.path.join(VAD_BASE_DIR, "model_vad")
def set_log():
logger = logging.getLogger()
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s | (%(module)s:%(lineno)d)| %(levelname)s | %(message)s')
stdout_handler = logging.StreamHandler(sys.stdout)
stdout_handler.setLevel(logging.INFO)
stdout_handler.setFormatter(formatter)
file_handler = logging.FileHandler('vad.log')
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.addHandler(stdout_handler)
class VAD_MODEL:
def __init__(self, modelpath, modelnum):
self.max_num = modelnum
self.model_list = []
for i in range(self.max_num):
stVad = np.zeros((1), dtype=np.int64)
warp_asr_alloc_stVAD(stVad, modelpath)
self.model_list.append({'sid':None,'vad':stVad})
self._model_lock = threading.Lock()
def alloc_model(self,sid):
for i in range(self.max_num):
if self.model_list[i]['sid'] is None:
self.model_list[i]['sid'] = sid
return self.model_list[i]['vad']
logging.warning("error: reach maximum concurrent session {}".format(self.max_num))
return None
def free_model(self,sid):
for i in range(self.max_num):
if self.model_list[i]['sid'] is not None and self.model_list[i]['sid'] == sid:
self.model_list[i]['sid'] = None
warp_asr_vad_reset(self.model_list[i]['vad'])
def vad_init_model(device_id=0, max_concurrent_numbers=40):
"""初始化VAD检测模型参数"""
global g_session_map
global g_session_map_lock
global g_thread_pools
global g_max_concurrent
global g_vad_model
global g_model_lock
global g_device_id
set_log()
g_thread_pools = futures.ThreadPoolExecutor(max_concurrent_numbers)
g_max_concurrent = max_concurrent_numbers
g_session_map_lock = threading.Lock()
g_vad_model = VAD_MODEL(bytes(MODEL_PATH,'utf-8'),max_concurrent_numbers)
def vad_predict_call(session_id, vad_detect_object,paras):
"""推断模块
特征进来后,判断是否为final。都送入模型,得到输出。"""
result_map = {"session_id": session_id}
wavs = np.frombuffer(paras["data"], dtype=np.int8)
endflag = paras["is_final"]
vadflag = warp_asr_vad_input(vad_detect_object,wavs,wavs.shape[0],endflag)
result_map['audio_slice'] = paras["audio_slice"]
result_map['is_final'] = paras["is_final"]
if vadflag == 1:
result_map['audio_label'] = 2
result_map['result'] = [True,True]
elif vadflag == 0:
result_map['audio_label'] = 1
result_map['result'] = [False,False]
else:
result_map["audio_label"] = 0
result_map['result'] = "insufficient audio"
return result_map
def exception_callback(future):
"""
异常抛出
:param worker:
:return:
"""
logger = logging.getLogger(__name__)
try:
future.result()
except Exception:
logger.exception("Executor Exception")
def vad_decode_timeout_update(sessionid, time_wait=10):
while True:
with g_session_map_lock:
if sessionid not in g_session_map:
#logging.warning('timeout session not exist {}'.format(sessionid))
break
else:
session_obj = g_session_map[sessionid]
if session_obj['silent'] == 1:
logging.warning('timeout del session {}'.format(sessionid))
vad_api({"session_id": sessionid, "data": bytes([]), "is_final": True, "audio_slice": 1})
break
else:
session_obj['silent'] = 1
time.sleep(time_wait)
def vad_api(params):
"""
VAD在线检测
:param params: type: dict 是传递过来的音频数据相关信息
:return: type: dict
"""
global g_vad_model
global g_thread_pools
global g_session_map
global g_max_concurrent
with g_session_map_lock:
logging.info('g_max_concurrent {} g_session_map {}'.format(g_max_concurrent,len(g_session_map)))
session_id = params['session_id']
if session_id not in g_session_map:
vadinfo = g_vad_model.alloc_model(session_id)
if vadinfo is None:
logging.warning('reach max cconcurrent {}'.format(g_max_concurrent))
return {'audio_label':0, 'result':[False],'is_final':params['is_final'],'audio_slice': 0}
g_session_map[session_id] = {'VADINFO': vadinfo, 'TIMER': None, 'silent':0}
if params["is_final"] == False:
g_thread_pools.submit(vad_decode_timeout_update, session_id).add_done_callback(exception_callback)
else:
g_session_map[session_id]['silent'] = 0
vad_detect_object = g_session_map[session_id]['VADINFO']
# 数据大于100ms
logging.info('recv params {} {} {}'.format(params['session_id'],params['audio_slice'],len(params['data'])))
detect_result = vad_predict_call(session_id,vad_detect_object, params)
logging.info('send result {} {} {}'.format(detect_result['session_id'],detect_result['result'],len(params['data'])))
if detect_result.get("is_final"):
with g_session_map_lock:
g_vad_model.free_model(session_id)
del g_session_map[session_id]
return detect_result
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/yang-xianjie/asr_server_cpu.git
git@gitee.com:yang-xianjie/asr_server_cpu.git
yang-xianjie
asr_server_cpu
asr_server_cpu
dev

搜索帮助

0d507c66 1850385 C8b1a773 1850385