1 Star 0 Fork 0

苍漠潇潇/voice_assistant

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
web_demo_audio.py 6.71 KB
一键复制 编辑 原始数据 按行查看 历史
苍漠潇潇 提交于 2024-09-24 06:26 . yi
import gradio as gr
import modelscope_studio as mgr
import librosa
import prompt
import json
from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration
from argparse import ArgumentParser
DEFAULT_CKPT_PATH = 'Qwen/Qwen2-Audio-7B-Instruct'
def _get_args():
parser = ArgumentParser()
parser.add_argument("-c", "--checkpoint-path", type=str, default=DEFAULT_CKPT_PATH,
help="Checkpoint name or path, default to %(default)r")
parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only")
parser.add_argument("--inbrowser", action="store_true", default=False,
help="Automatically launch the interface in a new tab on the default browser.")
parser.add_argument("--server-port", type=int, default=8000,
help="Demo server port.")
parser.add_argument("--server-name", type=str, default="127.0.0.1",
help="Demo server name.")
args = parser.parse_args()
return args
# 发送文件
def add_text(chatbot, task_history, input):
text_content = input.text
content = []
if len(input.files) > 0:
for i in input.files:
content.append({'type': 'audio', 'audio_url': i.path})
if text_content:
content.append({'type': 'text', 'text': text_content})
task_history.append({"role": "user", "content": content})
chatbot.append([{
"text": input.text,
"files": input.files,
}, None])
return chatbot, task_history, None
# 上传语音文件
def add_file(chatbot, task_history, audio_file):
"""Add audio file to the chat history."""
task_history.append({"role": "user", "content": [{"audio": audio_file.name}]})
chatbot.append((f"[Audio file: {audio_file.name}]", None))
return chatbot, task_history
def reset_user_input():
"""Reset the user input field."""
return gr.Textbox.update(value='')
def reset_state(task_history):
"""Reset the chat history."""
return [], []
def regenerate(chatbot, task_history, survey_json):
"""Regenerate the last bot response."""
if task_history and task_history[-1]['role'] == 'assistant':
task_history.pop()
chatbot.pop()
if task_history:
chatbot, task_history = predict(chatbot, task_history, survey_json)
return chatbot, task_history
def predict(chatbot, task_history, survey_json):
# 如果没有聊天历史,提出问卷中的第一个问题
if not task_history:
first_question = survey_json['survey']['sections'][0]['questions'][0]['question']
task_history.append({'role': 'assistant', 'content': first_question, 'question_id': 1})
chatbot.append((None, first_question))
return chatbot, task_history
# 获取用户的最新输入
user_input = task_history[-1]['content']
# 如果有音频或文本输入,处理用户的输入
text = processor.apply_chat_template(task_history, add_generation_prompt=True, tokenize=False)
audios = []
# 处理聊天记录中的音频文件
for message in task_history:
if isinstance(message["content"], list):
for ele in message["content"]:
if ele["type"] == "audio":
audios.append(librosa.load(ele['audio_url'], sr=processor.feature_extractor.sampling_rate)[0])
if len(audios) == 0:
audios = None
print(f"{text=}")
print(f"{audios=}")
inputs = processor(text=text, audios=audios, return_tensors="pt", padding=True)
# 如果允许使用 GPU,则将数据移到 GPU 上
if not _get_args().cpu_only:
inputs["input_ids"] = inputs.input_ids.to("cuda")
# 调用模型生成响应
generate_ids = model.generate(**inputs, max_length=256)
generate_ids = generate_ids[:, inputs.input_ids.size(1):]
response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
print(f"{response=}" )
# 将生成的响应加入到任务历史记录中
task_history.append({'role': 'assistant', 'content': response})
chatbot.append((None, response))
# 获取下一个问题
next_question = get_next_question(task_history, survey_json)
if next_question:
task_history.append(
{'role': 'assistant', 'content': next_question['question'], 'question_id': next_question['id']})
chatbot.append((None, next_question['question']))
return chatbot, task_history
def get_next_question(task_history, survey_json):
"""Find the next unanswered question from the survey JSON."""
answered_ids = [q['question_id'] for q in task_history if 'question_id' in q]
# 遍历问卷,找到未回答的问题
for section in survey_json['survey']['sections']:
for question in section['questions']:
if question['id'] not in answered_ids:
return question
return None # 所有问题已回答完毕
def _launch_demo(args):
with gr.Blocks() as demo:
chatbot = mgr.Chatbot(label='Qwen2-Audio-7B-Instruct', elem_classes="control-height", height=750)
user_input = mgr.MultimodalInput(
interactive=True,
sources=['microphone', 'upload'],
submit_button_props=dict(value="发送"),
upload_button_props=dict(value="上传文件", show_progress=True),
)
task_history = gr.State([])
with gr.Row():
empty_bin = gr.Button("清除历史")
regen_btn = gr.Button("重试")
user_input.submit(fn=add_text,
inputs=[chatbot, task_history, user_input],
outputs=[chatbot, task_history, user_input]).then(
predict, [chatbot, task_history], [chatbot, task_history], show_progress=True
)
empty_bin.click(reset_state, outputs=[chatbot, task_history], show_progress=True)
regen_btn.click(regenerate, [chatbot, task_history, survey_json], [chatbot, task_history], show_progress=True)
demo.queue().launch(
share=False,
inbrowser=args.inbrowser,
server_port=args.server_port,
server_name=args.server_name,
)
if __name__ == "__main__":
with open('Q.json', 'r', encoding='utf-8') as f:
survey_json = json.load(f)
args = _get_args()
if args.cpu_only:
device_map = "cpu"
else:
device_map = "auto"
model = Qwen2AudioForConditionalGeneration.from_pretrained(
args.checkpoint_path,
torch_dtype="auto",
device_map=device_map,
resume_download=True,
).eval()
model.generation_config.max_new_tokens = 2048 # For chat.
print("generation_config", model.generation_config)
processor = AutoProcessor.from_pretrained(args.checkpoint_path, resume_download=True)
_launch_demo(args)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/green-desert-xiaoxiao/voice_assistant.git
git@gitee.com:green-desert-xiaoxiao/voice_assistant.git
green-desert-xiaoxiao
voice_assistant
voice_assistant
master

搜索帮助