1 Star 3 Fork 2

szw/LittleRAG

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
chat_model.py 1.76 KB
一键复制 编辑 原始数据 按行查看 历史
szw 提交于 2024-04-05 08:33 . 第一次提交
import base_util
from modelscope import snapshot_download
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# 默认的聊天模型加载基类
class BaseChatModel():
def __init__(self, model_name):
#
self.model_download_path = f"{base_util.current_project_dir}\\\models\ChatModels\\"
self.model_name = model_name
self.model = None
def load_model(self,model_name):
pass
def chat(self,msg,history):
pass
# 视情况扩展方法 后期可能有一个超级通用类。 把所有的流程整合到一起。 然后直接 .emb .rerank . search 都在一个class里面就搞定了
class internlm(BaseChatModel):
# 加载书生模型
def __init__(self, model_name='Shanghai_AI_Laboratory/internlm-chat-20b'):
# super().__init__('BAAI/bge-large-zh-v1.5')
super().__init__(model_name)
self.load_model(model_name)
def load_model(self,model_name):
self.model_path=snapshot_download(model_name,cache_dir=self.model_download_path)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
# `torch_dtype=torch.float16` 可以令模型以 float16 精度加载,否则 transformers 会将模型加载为 float32,导致显存不足
self.model = AutoModelForCausalLM.from_pretrained(self.model_path, device_map="auto", torch_dtype=torch.bfloat16,
trust_remote_code=True).eval()
return self.model
def chat(self,msg,history):
output, history = self.model.chat(self.tokenizer, msg,history)
return output,history
class chatGLM(BaseChatModel):
# 加载chatglm聊天模型
pass
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/szw1259577135/little-rag.git
git@gitee.com:szw1259577135/little-rag.git
szw1259577135
little-rag
LittleRAG
master

搜索帮助