1 Star 0 Fork 0

Mortal Fu、/minimind

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
my_openai_api.py 14.85 KB
一键复制 编辑 原始数据 按行查看 历史
gongjingyao 提交于 2024-08-28 16:41 . MiniMind first open source
# encoding: utf-8
import json
import re
import time
import uuid
import warnings
import tiktoken
import torch
import numpy as np
from typing import List
from flask import Flask, current_app, request, Blueprint, stream_with_context
from flask_cors import CORS
from sentence_transformers import SentenceTransformer
from sklearn.preprocessing import PolynomialFeatures
from transformers import AutoTokenizer, AutoModelForCausalLM
from marshmallow import validate, Schema, fields
from pydantic import BaseModel
warnings.filterwarnings('ignore', category=UserWarning)
# ------------------------------------------------------------------------------------------------------------------
DEVICE_NAME = "cuda:0" if torch.cuda.is_available() else "cpu"
DEVICE = torch.device(DEVICE_NAME)
MODEL_PATH = "./minimind-small-T"
TOKENIZE_PATH = MODEL_PATH
max_new_tokens = 2048
temperature = 0.7
top_k = 8
# ------------------------------------------------------------------------------------------------------------------
class Transformers():
def __init__(self, app=None, tokenizer=None, model=None):
# self.chat = None
if app is not None:
self.init_app(app, tokenizer, model)
def init_app(self, app, tokenizer=None, model=None, chat=None):
self.tokenizer = tokenizer
self.model = model
# if chat is None:
# # self.chat = model.chat
# self.chat = self.chat
# gpt2's
def build_chat_input(self, tokenizer, messages: List[dict]):
new_prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)[-(max_new_tokens - 1):]
inputs_ids = tokenizer(new_prompt).data['input_ids']
inputs_ids = (torch.tensor(inputs_ids, dtype=torch.long, device=DEVICE)[None, ...])
return inputs_ids, tokenizer.eos_token_id, new_prompt
def chat_stream(self, tokenizer, messages: List[dict], stream=True):
input_ids, eos_token_id, new_prompt = self.build_chat_input(tokenizer, messages)
if stream:
res_y = self.model.generate(input_ids, tokenizer.eos_token_id, max_new_tokens=max_new_tokens,
temperature=temperature, top_k=top_k, stream=True)
y = next(res_y)
history_idx = 0
while y != None:
answer = tokenizer.decode(y[0].tolist())
if answer and answer[-1] == '�':
try:
y = next(res_y)
except:
break
continue
# print(answer)
if not len(answer):
try:
y = next(res_y)
except:
break
continue
yield answer[history_idx:]
try:
y = next(res_y)
except:
break
history_idx = len(answer)
if not stream:
break
def chat_no_stream(self, tokenizer, messages: List[dict]):
input_ids, eos_token_id, new_prompt = self.build_chat_input(tokenizer, messages)
res_y = self.model.generate(input_ids, tokenizer.eos_token_id, max_new_tokens=max_new_tokens,
temperature=temperature, top_k=top_k, stream=False)
y = next(res_y)
answer = tokenizer.decode(y[0].tolist())
return answer
tfs = Transformers()
base_tfs = Transformers()
models_bp = Blueprint('Models', __name__, url_prefix='/v1/models')
chat_bp = Blueprint('Chat', __name__, url_prefix='/v1/chat')
completions_bp = Blueprint('Completions', __name__, url_prefix='/v1/completions')
embedding_bp = Blueprint('Embeddings', __name__, url_prefix='/v1')
def sse(line, field="data"):
return "{}: {}\n\n".format(
field, json.dumps(line, ensure_ascii=False) if isinstance(line, dict) else line)
def empty_cache():
if torch.backends.mps.is_available():
torch.mps.empty_cache()
def create_app():
app = Flask(__name__)
CORS(app)
app.register_blueprint(models_bp)
app.register_blueprint(chat_bp)
app.register_blueprint(completions_bp)
app.register_blueprint(embedding_bp)
@app.after_request
def after_request(resp):
empty_cache()
return resp
tokenizer = AutoTokenizer.from_pretrained(
TOKENIZE_PATH, trust_remote_code=True, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH, trust_remote_code=True).to(DEVICE)
# model.generation_config = GenerationConfig.from_pretrained(model_name)
tfs.init_app(app, tokenizer, model)
base_tfs.init_app(app, tokenizer, model)
return app
class ModelSchema(Schema):
id = fields.Str()
object = fields.Str(dump_default="model", metadata={"example": "model"})
created = fields.Int(dump_default=lambda: int(time.time()), metadata={"example": 1695402567})
owned_by = fields.Str(dump_default="owner", metadata={"example": "owner"})
class ModelListSchema(Schema):
object = fields.Str(dump_default="list", metadata={"example": "list"})
data = fields.List(fields.Nested(ModelSchema), dump_default=[])
class ChatMessageSchema(Schema):
role = fields.Str(required=True, metadata={"example": "system"})
content = fields.Str(required=True, metadata={"example": "You are a helpful assistant."})
class CreateChatCompletionSchema(Schema):
model = fields.Str(required=True, metadata={"example": "minimind"})
messages = fields.List(
fields.Nested(ChatMessageSchema), required=True,
metadata={"example": [
ChatMessageSchema().dump({"role": "system", "content": "You are a helpful assistant."}),
ChatMessageSchema().dump({"role": "user", "content": "Hello!"})
]}
)
temperature = fields.Float(load_default=1.0, metadata={"example": 1.0})
top_p = fields.Float(load_default=1.0, metadata={"example": 1.0})
n = fields.Int(load_default=1, metadata={"example": 1})
max_tokens = fields.Int(load_default=None, metadata={"example": None})
stream = fields.Bool(load_default=False, example=False)
presence_penalty = fields.Float(load_default=0.0, example=0.0)
frequency_penalty = fields.Float(load_default=0.0, example=0.0)
class ChatCompletionChoiceSchema(Schema):
index = fields.Int(metadata={"example": 0})
message = fields.Nested(ChatMessageSchema, metadata={
"example": ChatMessageSchema().dump(
{"role": "assistant", "content": "\n\nHello there, how may I assist you today?"}
)})
finish_reason = fields.Str(
validate=validate.OneOf(["stop", "length", "content_filter", "function_call"]),
metadata={"example": "stop"})
class ChatCompletionSchema(Schema):
id = fields.Str(
dump_default=lambda: uuid.uuid4().hex,
metadata={"example": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7"})
object = fields.Constant("chat.completion")
created = fields.Int(dump_default=lambda: int(time.time()), metadata={"example": 1695402567})
model = fields.Str(metadata={"example": "minimind"})
choices = fields.List(fields.Nested(ChatCompletionChoiceSchema))
class ChatDeltaSchema(Schema):
role = fields.Str(metadata={"example": "assistant"})
content = fields.Str(required=True, metadata={"example": "Hello"})
class ChatCompletionChunkChoiceSchema(Schema):
index = fields.Int(metadata={"example": 0})
delta = fields.Nested(ChatDeltaSchema, metadata={"example": ChatDeltaSchema().dump(
{"role": "assistant", "example": "Hello"})})
finish_reason = fields.Str(
validate=validate.OneOf(["stop", "length", "content_filter", "function_call"]),
metadata={"example": "stop"})
class ChatCompletionChunkShema(Schema):
id = fields.Str(
dump_default=lambda: uuid.uuid4().hex,
metadata={"example": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7"})
object = fields.Constant("chat.completion.chunk")
created = fields.Int(dump_default=lambda: int(time.time()), metadata={"example": 1695402567})
model = fields.Str(metadata={"example": "minimind"})
choices = fields.List(fields.Nested(ChatCompletionChunkChoiceSchema))
class CreateCompletionSchema(Schema):
model = fields.Str(required=True, metadata={"example": "minimind"})
prompt = fields.Raw(metadata={"example": "Say this is a test"})
max_tokens = fields.Int(load_default=16, metadata={"example": 256})
temperature = fields.Float(load_default=1.0, metadata={"example": 1.0})
top_p = fields.Float(load_default=1.0, metadata={"example": 1.0})
n = fields.Int(load_default=1, metadata={"example": 1})
stream = fields.Bool(load_default=False, example=False)
logit_bias = fields.Dict(load_default=None, example={})
presence_penalty = fields.Float(load_default=0.0, example=0.0)
frequency_penalty = fields.Float(load_default=0.0, example=0.0)
class CompletionChoiceSchema(Schema):
index = fields.Int(load_default=0, metadata={"example": 0})
text = fields.Str(required=True, metadata={"example": "登鹳雀楼->王之涣\n夜雨寄北->"})
logprobs = fields.Dict(load_default=None, metadata={"example": {}})
finish_reason = fields.Str(
validate=validate.OneOf(["stop", "length", "content_filter", "function_call"]),
metadata={"example": "stop"})
class CompletionUsageSchema(Schema):
prompt_tokens = fields.Int(metadata={"example": 5})
completion_tokens = fields.Int(metadata={"example": 7})
total_tokens = fields.Int(metadata={"example": 12})
class CompletionSchema(Schema):
id = fields.Str(
dump_default=lambda: uuid.uuid4().hex,
metadata={"example": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7"})
object = fields.Constant("text_completion")
created = fields.Int(dump_default=lambda: int(time.time()), metadata={"example": 1695402567})
model = fields.Str(metadata={"example": "minimind"})
choices = fields.List(fields.Nested(CompletionChoiceSchema))
usage = fields.Nested(CompletionUsageSchema)
@stream_with_context
def stream_chat_generate(messages):
delta = ChatDeltaSchema().dump(
{"role": "assistant"})
choice = ChatCompletionChunkChoiceSchema().dump(
{"index": 0, "delta": delta, "finish_reason": None})
yield sse(
ChatCompletionChunkShema().dump({
"model": "minimind",
"choices": [choice]})
)
# 调用 chat 方法并遍历其返回的生成器
for response in tfs.chat_stream(tfs.tokenizer, messages):
delta = ChatDeltaSchema().dump(
{"content": response})
choice = ChatCompletionChunkChoiceSchema().dump(
{"index": 0, "delta": delta, "finish_reason": None})
yield sse(
ChatCompletionChunkShema().dump({
"model": "minimind",
"choices": [choice]})
)
yield sse('[DONE]')
@chat_bp.route("/completions", methods=['POST'])
def create_chat_completion():
create_chat_completion = CreateChatCompletionSchema().load(request.json)
if create_chat_completion["stream"]:
return current_app.response_class(
stream_chat_generate(create_chat_completion["messages"]),
mimetype="text/event-stream"
)
else:
response = tfs.chat_no_stream(tfs.tokenizer, create_chat_completion["messages"])
message = ChatMessageSchema().dump(
{"role": "assistant", "content": response})
choice = ChatCompletionChoiceSchema().dump(
{"index": 0, "message": message, "finish_reason": "stop"})
return ChatCompletionSchema().dump({
"model": "minimind",
"choices": [choice]})
class EmbeddingRequest(BaseModel):
input: List[str]
model: str
@embedding_bp.route("/embeddings", methods=['POST'])
def get_embeddings():
request_data = request.get_json() # 获取 POST 请求体中的 JSON 数据
request_params = EmbeddingRequest(**request_data) # 将 JSON 数据转换为 EmbeddingRequest 对象
def expand_features(embedding, target_length):
poly = PolynomialFeatures(degree=2)
expanded_embedding = poly.fit_transform(embedding.reshape(1, -1))
expanded_embedding = expanded_embedding.flatten()
if len(expanded_embedding) > target_length:
# 如果扩展后的特征超过目标长度,可以通过截断或其他方法来减少维度
expanded_embedding = expanded_embedding[:target_length]
elif len(expanded_embedding) < target_length:
# 如果扩展后的特征少于目标长度,可以通过填充或其他方法来增加维度
expanded_embedding = np.pad(
expanded_embedding, (0, target_length - len(expanded_embedding))
)
return expanded_embedding
def num_tokens_from_string(string: str) -> int:
"""Returns the number of tokens in a text string."""
encoding = tiktoken.get_encoding('cl100k_base')
num_tokens = len(encoding.encode(string))
return num_tokens
def has_chinese_char(s):
pattern = re.compile(r'[\u4e00-\u9fa5]')
# if bool(pattern.search(s)):
# print('m3e编码')
# else:
# print('bge编码')
return bool(pattern.search(s))
# 计算嵌入向量和tokens数量
embeddings = [embeddings_model_m3e.encode(text)
if has_chinese_char(text)
else embeddings_model_bge.encode(text)
for text in request_params.input]
# 如果嵌入向量的维度不为1536,则使用插值法扩展至1536维度
embeddings = [
expand_features(embedding, 768) if len(embedding) < 768 else embedding
for embedding in embeddings
]
# Min-Max normalization 归一化
embeddings = [embedding / np.linalg.norm(embedding) for embedding in embeddings]
# 将numpy数组转换为列表
embeddings = [embedding.tolist() for embedding in embeddings]
prompt_tokens = sum(len(text.split()) for text in request_params.input)
total_tokens = sum(num_tokens_from_string(text) for text in request_params.input)
response = {
"data": [
{"embedding": embedding, "index": index, "object": "embedding"}
for index, embedding in enumerate(embeddings)
],
"model": request_params.model,
"object": "list",
"usage": {
"prompt_tokens": prompt_tokens,
"total_tokens": total_tokens,
},
}
# print(response)
return response
app = create_app()
if __name__ == '__main__':
use_emb = False
try:
import ngrok
import logging
logging.basicConfig(level=logging.INFO)
listener = ngrok.werkzeug_develop()
except Exception:
pass
embeddings_model_m3e = SentenceTransformer('.\\m3e-base', device='cpu') if use_emb else None
embeddings_model_bge = SentenceTransformer('.\\bge-base-en-v1.5', device='cpu') if use_emb else None
app.run(debug=False, host="0.0.0.0", port=8000)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/lhwz666/minimind.git
git@gitee.com:lhwz666/minimind.git
lhwz666
minimind
minimind
master

搜索帮助