代码拉取完成,页面将自动刷新
同步操作将从 ColdCurlyFu/CoHelperGLM 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
from transformers import AutoModel, AutoTokenizer
import streamlit as st
st.set_page_config(
page_title="CoHelperGLM",
page_icon=":robot:",
layout='wide'
)
@st.cache_resource
def get_model():
tokenizer = AutoTokenizer.from_pretrained("CODEGEEX2", trust_remote_code=True)
# CodeGeeX2-6B模型
model = AutoModel.from_pretrained("CODEGEEX2", trust_remote_code=True, device='cuda')
# CodeGeeX2-6B-int4模型
# model = AutoModel.from_pretrained("CODEGEEX2", trust_remote_code=True).quantize(4).to("cuda")
model = model.eval()
return tokenizer, model
tokenizer, model = get_model()
st.title("Prompts in, Bugs out")
max_length = st.sidebar.slider(
'max_length', 0, 2560, 512, step=1
)
top_k = st.sidebar.slider(
'top_k', 1, 10, 1, step=1
)
if int(top_k) == 1:
top_p = st.sidebar.slider(
'top_p', 0.0, 1.0, 0.8, step=0.01, disabled=True
)
temperature = st.sidebar.slider(
'temperature', 0.0, 1.0, 0.95, step=0.01, disabled=True
)
else:
top_p = st.sidebar.slider(
'top_p', 0.0, 1.0, 0.8, step=0.01
)
temperature = st.sidebar.slider(
'temperature', 0.0, 1.0, 0.95, step=0.01
)
genre = st.sidebar.radio(
"语言选择:",
('Python', 'HTML', 'Shell', "Go", "C++", "Java", "JavaScript"))
if genre == 'Shell':
lan = "shell"
p = "# language: Shell"
prefix = "#"
elif genre == 'Python':
lan = "Python"
p = "# language: Python"
prefix = "#"
elif genre == 'HTML':
lan = "HTML"
p = "<!--language: HTML-->"
elif genre == 'C++':
lan = "C++"
p = "// language: C++"
prefix = "//"
elif genre == 'Go':
lan = "Go"
p = "// language: Go"
prefix = "//"
elif genre == 'Java':
lan = "Java"
p = "// language: Java"
prefix = "//"
elif genre == 'JavaScript':
lan = "JavaScript"
p = "// language: JavaScript"
prefix = "//"
else:
lan = "Python"
p = "# language: Python"
prefix = "#"
if 'history' not in st.session_state:
st.session_state.history = []
if 'past_key_values' not in st.session_state:
st.session_state.past_key_values = None
for i, (query, response) in enumerate(st.session_state.history):
with st.chat_message(name="user", avatar="user"):
st.text(query)
with st.chat_message(name="assistant", avatar="assistant"):
st.markdown("```{}\n".format(lan.lower())+response+"\n```")
with st.chat_message(name="user", avatar="user"):
input_placeholder = st.empty()
with st.chat_message(name="assistant", avatar="assistant"):
message_placeholder = st.empty()
prompt_text = st.text_area(label="用户命令输入",
height=100,
placeholder="用户命令输入,例如:写一个冒泡排序算法")
button = st.button("发送", key="predict")
def pre_text(s, prefix):
res = []
s = s.split('\n')
for line in s:
if line:
if not line[0].startswith(prefix):
res.append(prefix + ' ' + line)
else:
res.append(line)
return '\n'.join(res)
def html_text(s):
res = []
s = s.split('\n')
for line in s:
if line:
if len(line) >= 4:
if line[:4]!="<!--" or line[-3:]!="-->":
res.append('<!--' + line + '-->')
else:
res.append(line)
else:
res.append(line)
return '\n'.join(res)
if button:
if lan.lower() == "html":
prompt_t = html_text(prompt_text)
else:
prompt_t = pre_text(prompt_text, prefix)
prompt_t = "{}\n{}\n".format(p, prompt_t)
input_placeholder.text(prompt_text)
history, past_key_values = st.session_state.history, st.session_state.past_key_values
for response, history in model.stream_chat(tokenizer, prompt_t, [],
past_key_values=None,
max_length=max_length, top_p=top_p, top_k=top_k,
temperature=temperature,
return_past_key_values=False):
message_placeholder.markdown("```{}\n".format(lan.lower())+response+"\n```")
st.session_state.history.append((prompt_text, response))
st.session_state.past_key_values = past_key_values
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。