代码拉取完成,页面将自动刷新
同步操作将从 Gitee 极速下载/TigerBot 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
import os
import fire
import torch
import readline
from accelerate import infer_auto_device_map, dispatch_model
from accelerate.utils import get_balanced_memory
from transformers import AutoTokenizer, AutoModelForCausalLM
os.environ["TOKENIZERS_PARALLELISM"] = "false"
tok_ins = "\n\n### Instruction:\n"
tok_res = "\n\n### Response:\n"
prompt_input = tok_ins + "{instruction}" + tok_res
def get_model(model):
def skip(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
model = AutoModelForCausalLM.from_pretrained(model, torch_dtype=torch.float16)
return model
def main(
model_path: str,
max_input_length: int = 512,
max_generate_length: int = 1024,
):
print(f"loading model: {model_path}...")
model = get_model(model_path)
max_memory = get_balanced_memory(model)
device_map = infer_auto_device_map(model, max_memory=max_memory,
no_split_module_classes=["BloomBlock"])
print("Using the following device map for the model:", device_map)
model = dispatch_model(model, device_map=device_map, offload_buffers=True)
device = torch.cuda.current_device()
tokenizer = AutoTokenizer.from_pretrained(
model_path,
cache_dir=None,
model_max_length=max_generate_length,
padding_side="left",
truncation_side='left',
padding=True,
truncation=True
)
if tokenizer.model_max_length is None or tokenizer.model_max_length > max_generate_length:
tokenizer.model_max_length = max_generate_length
generation_kwargs = {
"top_p": 0.95,
"temperature": 0.8,
"max_length": max_generate_length,
"eos_token_id": tokenizer.eos_token_id,
"pad_token_id": tokenizer.pad_token_id,
"early_stopping": True,
"no_repeat_ngram_size": 4,
}
sess_text = ""
while True:
raw_text = input("prompt(\"exit\" to end, \"clear\" to clear session) >>> ")
if not raw_text:
print('prompt should not be empty!')
continue
if raw_text.strip() == "exit":
print('session ended.')
break
if raw_text.strip() == "clear":
print('session cleared.')
sess_text = ""
continue
query_text = raw_text.strip()
sess_text += tok_ins + query_text
input_text = prompt_input.format_map({'instruction': sess_text.split(tok_ins, 1)[1]})
inputs = tokenizer(input_text, return_tensors='pt', truncation=True, max_length=max_input_length)
inputs = {k: v.to(device) for k, v in inputs.items()}
output = model.generate(**inputs, **generation_kwargs)
answer = ''
for tok_id in output[0][inputs['input_ids'].shape[1]:]:
if tok_id != tokenizer.eos_token_id:
answer += tokenizer.decode(tok_id)
sess_text += tok_res + answer
print("=" * 100)
print(answer)
print("=" * 100)
if __name__ == "__main__":
fire.Fire(main)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。