1 Star 0 Fork 0

zxwu/Huatuo-Llama-Med-Chinese

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
infer_literature.py 4.13 KB
一键复制 编辑 原始数据 按行查看 历史
s65b40 提交于 2023-08-07 21:46 . Update Huozi-based model
import sys
import json
import fire
import gradio as gr
import torch
import transformers
from peft import PeftModel
from transformers import GenerationConfig, AutoModelForCausalLM, AutoTokenizer
from utils.prompter import Prompter
if torch.cuda.is_available():
device = "cuda"
def load_instruction(instruct_dir):
input_data = []
with open(instruct_dir, "r") as f:
lines = f.readlines()
for line in lines:
line = line.strip()
d = json.loads(line)
input_data.append(d)
return input_data
def main(
load_8bit: bool = False,
base_model: str = "",
# the infer data, if not exists, infer the default instructions in code
single_or_multi: str = "",
use_lora: bool = True,
lora_weights: str = "tloen/alpaca-lora-7b",
# The prompt template to use, will default to med_template.
prompt_template: str = "med_template",
):
prompter = Prompter(prompt_template)
tokenizer = AutoTokenizer.from_pretrained(base_model)
model = AutoModelForCausalLM.from_pretrained(
base_model,
load_in_8bit=load_8bit,
torch_dtype=torch.float16,
device_map="auto",
)
if use_lora:
print(f"using lora {lora_weights}")
model = PeftModel.from_pretrained(
model,
lora_weights,
torch_dtype=torch.float16,
)
# unwind broken decapoda-research config
model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk
model.config.bos_token_id = 1
model.config.eos_token_id = 2
if not load_8bit:
model.half() # seems to fix bugs for some users.
model.eval()
if torch.__version__ >= "2" and sys.platform != "win32":
model = torch.compile(model)
def evaluate(
instruction,
input=None,
temperature=0.1,
top_p=0.75,
top_k=40,
num_beams=4,
max_new_tokens=256,
**kwargs,
):
prompt = prompter.generate_prompt(instruction, input)
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)
generation_config = GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
num_beams=num_beams,
**kwargs,
)
with torch.no_grad():
generation_output = model.generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=max_new_tokens,
)
s = generation_output.sequences[0]
output = tokenizer.decode(s)
return prompter.get_response(output)
if single_or_multi == "multi":
response=""
instruction=""
for _ in range(0,5):
inp=input("请输入:")
inp="<user>: " + inp
instruction=instruction+inp
response=evaluate(instruction)
response=response.replace('\n','')
print("Response:", response)
instruction= instruction + " <bot>: " + response
elif single_or_multi == "single":
for instruction in [
"肝癌是什么?有哪些症状和迹象?",
"肝癌是如何诊断的?有哪些检查和测试可以帮助诊断?",
"Sorafenib是一种口服的多靶点酪氨酸激酶抑制剂,它的作用机制是什么?",
"Regorafenib是一种口服的多靶点酪氨酸激酶抑制剂,它的作用机制是什么?它和Sorafenib有什么不同?",
"肝癌药物治疗的副作用有哪些?如何缓解这些副作用?",
"肝癌药物治疗的费用高昂,如何降低治疗的经济负担?",
"我想了解一下β-谷甾醇是否可作为肝癌的治疗药物",
"能介绍一下最近Hsa_circ_0008583在肝细胞癌治疗中的潜在应用的研究么?"
]:
print("instruction:",instruction)
instruction="<user>: "+instruction
print("Response:", evaluate(instruction))
if __name__ == "__main__":
fire.Fire(main)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/zxwu/Huatuo-Llama-Med-Chinese.git
git@gitee.com:zxwu/Huatuo-Llama-Med-Chinese.git
zxwu
Huatuo-Llama-Med-Chinese
Huatuo-Llama-Med-Chinese
main

搜索帮助