1 Star 0 Fork 3

handsomestWei/CoHelperGLM

forked from ColdCurlyFu/CoHelperGLM 
加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
web_demo2.py 4.34 KB
一键复制 编辑 原始数据 按行查看 历史
ColdCurlyFu 提交于 2023-08-19 15:49 . 更新readme,支持in4模型
from transformers import AutoModel, AutoTokenizer
import streamlit as st
st.set_page_config(
page_title="CoHelperGLM",
page_icon=":robot:",
layout='wide'
)
@st.cache_resource
def get_model():
tokenizer = AutoTokenizer.from_pretrained("CODEGEEX2", trust_remote_code=True)
# CodeGeeX2-6B模型
model = AutoModel.from_pretrained("CODEGEEX2", trust_remote_code=True, device='cuda')
# CodeGeeX2-6B-int4模型
# model = AutoModel.from_pretrained("CODEGEEX2", trust_remote_code=True).quantize(4).to("cuda")
model = model.eval()
return tokenizer, model
tokenizer, model = get_model()
st.title("Prompts in, Bugs out")
max_length = st.sidebar.slider(
'max_length', 0, 2560, 512, step=1
)
top_k = st.sidebar.slider(
'top_k', 1, 10, 1, step=1
)
if int(top_k) == 1:
top_p = st.sidebar.slider(
'top_p', 0.0, 1.0, 0.8, step=0.01, disabled=True
)
temperature = st.sidebar.slider(
'temperature', 0.0, 1.0, 0.95, step=0.01, disabled=True
)
else:
top_p = st.sidebar.slider(
'top_p', 0.0, 1.0, 0.8, step=0.01
)
temperature = st.sidebar.slider(
'temperature', 0.0, 1.0, 0.95, step=0.01
)
genre = st.sidebar.radio(
"语言选择:",
('Python', 'HTML', 'Shell', "Go", "C++", "Java", "JavaScript"))
if genre == 'Shell':
lan = "shell"
p = "# language: Shell"
prefix = "#"
elif genre == 'Python':
lan = "Python"
p = "# language: Python"
prefix = "#"
elif genre == 'HTML':
lan = "HTML"
p = "<!--language: HTML-->"
elif genre == 'C++':
lan = "C++"
p = "// language: C++"
prefix = "//"
elif genre == 'Go':
lan = "Go"
p = "// language: Go"
prefix = "//"
elif genre == 'Java':
lan = "Java"
p = "// language: Java"
prefix = "//"
elif genre == 'JavaScript':
lan = "JavaScript"
p = "// language: JavaScript"
prefix = "//"
else:
lan = "Python"
p = "# language: Python"
prefix = "#"
if 'history' not in st.session_state:
st.session_state.history = []
if 'past_key_values' not in st.session_state:
st.session_state.past_key_values = None
for i, (query, response) in enumerate(st.session_state.history):
with st.chat_message(name="user", avatar="user"):
st.text(query)
with st.chat_message(name="assistant", avatar="assistant"):
st.markdown("```{}\n".format(lan.lower())+response+"\n```")
with st.chat_message(name="user", avatar="user"):
input_placeholder = st.empty()
with st.chat_message(name="assistant", avatar="assistant"):
message_placeholder = st.empty()
prompt_text = st.text_area(label="用户命令输入",
height=100,
placeholder="用户命令输入,例如:写一个冒泡排序算法")
button = st.button("发送", key="predict")
def pre_text(s, prefix):
res = []
s = s.split('\n')
for line in s:
if line:
if not line[0].startswith(prefix):
res.append(prefix + ' ' + line)
else:
res.append(line)
return '\n'.join(res)
def html_text(s):
res = []
s = s.split('\n')
for line in s:
if line:
if len(line) >= 4:
if line[:4]!="<!--" or line[-3:]!="-->":
res.append('<!--' + line + '-->')
else:
res.append(line)
else:
res.append(line)
return '\n'.join(res)
if button:
if lan.lower() == "html":
prompt_t = html_text(prompt_text)
else:
prompt_t = pre_text(prompt_text, prefix)
prompt_t = "{}\n{}\n".format(p, prompt_t)
input_placeholder.text(prompt_text)
history, past_key_values = st.session_state.history, st.session_state.past_key_values
for response, history in model.stream_chat(tokenizer, prompt_t, [],
past_key_values=None,
max_length=max_length, top_p=top_p, top_k=top_k,
temperature=temperature,
return_past_key_values=False):
message_placeholder.markdown("```{}\n".format(lan.lower())+response+"\n```")
st.session_state.history.append((prompt_text, response))
st.session_state.past_key_values = past_key_values
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/handsomestwei/CoHelperGLM.git
git@gitee.com:handsomestwei/CoHelperGLM.git
handsomestwei
CoHelperGLM
CoHelperGLM
main

搜索帮助