1 Star 0 Fork 2

顺其_自然/TigerBot

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
infer.py 3.06 KB
一键复制 编辑 原始数据 按行查看 历史
Vivicai1005 提交于 2023-06-08 15:43 . improve infer
import os
import fire
import torch
import readline
from accelerate import infer_auto_device_map, dispatch_model
from accelerate.utils import get_balanced_memory
from transformers import AutoTokenizer, AutoModelForCausalLM
os.environ["TOKENIZERS_PARALLELISM"] = "false"
tok_ins = "\n\n### Instruction:\n"
tok_res = "\n\n### Response:\n"
prompt_input = tok_ins + "{instruction}" + tok_res
def get_model(model):
def skip(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
model = AutoModelForCausalLM.from_pretrained(model, torch_dtype=torch.float16)
return model
def main(
model_path: str,
max_input_length: int = 512,
max_generate_length: int = 1024,
):
print(f"loading model: {model_path}...")
model = get_model(model_path)
max_memory = get_balanced_memory(model)
device_map = infer_auto_device_map(model, max_memory=max_memory,
no_split_module_classes=["BloomBlock"])
print("Using the following device map for the model:", device_map)
model = dispatch_model(model, device_map=device_map, offload_buffers=True)
device = torch.cuda.current_device()
tokenizer = AutoTokenizer.from_pretrained(
model_path,
cache_dir=None,
model_max_length=max_generate_length,
padding_side="left",
truncation_side='left',
padding=True,
truncation=True
)
if tokenizer.model_max_length is None or tokenizer.model_max_length > max_generate_length:
tokenizer.model_max_length = max_generate_length
generation_kwargs = {
"top_p": 0.95,
"temperature": 0.8,
"max_length": max_generate_length,
"eos_token_id": tokenizer.eos_token_id,
"pad_token_id": tokenizer.pad_token_id,
"early_stopping": True,
"no_repeat_ngram_size": 4,
}
sess_text = ""
while True:
raw_text = input("prompt(\"exit\" to end, \"clear\" to clear session) >>> ")
if not raw_text:
print('prompt should not be empty!')
continue
if raw_text.strip() == "exit":
print('session ended.')
break
if raw_text.strip() == "clear":
print('session cleared.')
sess_text = ""
continue
query_text = raw_text.strip()
sess_text += tok_ins + query_text
input_text = prompt_input.format_map({'instruction': sess_text.split(tok_ins, 1)[1]})
inputs = tokenizer(input_text, return_tensors='pt', truncation=True, max_length=max_input_length)
inputs = {k: v.to(device) for k, v in inputs.items()}
output = model.generate(**inputs, **generation_kwargs)
answer = ''
for tok_id in output[0][inputs['input_ids'].shape[1]:]:
if tok_id != tokenizer.eos_token_id:
answer += tokenizer.decode(tok_id)
sess_text += tok_res + answer
print("=" * 100)
print(answer)
print("=" * 100)
if __name__ == "__main__":
fire.Fire(main)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/wanghuan_git/TigerBot.git
git@gitee.com:wanghuan_git/TigerBot.git
wanghuan_git
TigerBot
TigerBot
main

搜索帮助