18 Star 0 Fork 2

openKylin/kylin-baidu-nlp-engine

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
baidunlpengine_p.cpp 10.95 KB
一键复制 编辑 原始数据 按行查看 历史
/*
* Copyright 2024 KylinSoft Co., Ltd.
*
* This program is free software: you can redistribute it and/or modify it under
* the terms of the GNU General Public License as published by the Free Software
* Foundation, either version 3 of the License, or (at your option) any later
* version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
* FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
* details.
*
* You should have received a copy of the GNU General Public License along with
* this program. If not, see <https://www.gnu.org/licenses/>.
*/
#include "baidunlpengine_p.h"
#include <cstdio>
#include "servererror.h"
#include "token.h"
#include "util.h"
std::string BaiduNlpEnginePrivate::systemRole_{
"我是您的 AI 助手。"
"您可以通过文字和语音对话与我进行互动。我具备跨应用文字处理和系统控制功能,"
"无论是工作学习还是日常生活,我都能为您提供帮助。"};
BaiduNlpEnginePrivate::BaiduNlpEnginePrivate() {}
BaiduNlpEnginePrivate::~BaiduNlpEnginePrivate() {
clearContext();
streamChatData_.clear();
}
std::string BaiduNlpEnginePrivate::engineName() const { return "baidu"; }
int BaiduNlpEnginePrivate::maxConcurrentChatTasks() const { return 6; }
void BaiduNlpEnginePrivate::setConfig(const std::string &config) {
Json::Value configJson = baidu_nlp_util::formatJsonFromString(config);
if (configJson.isNull() || !configJson.isMember("apiKey") ||
!configJson.isMember("secretKey") || !configJson["apiKey"].isString() ||
!configJson["secretKey"].isString()) {
std::fprintf(stderr, "Invalid config for baidu nlp engine: %s\n",
config.c_str());
return;
}
apiKey_ = configJson["apiKey"].asString();
secretKey_ = configJson["secretKey"].asString();
}
std::string BaiduNlpEnginePrivate::modelInfo() const {
const char *info = R"(
{
"vendor": "百度",
"name": "语言大模型",
"subConfig": [
{
"name": "文心ERNIE 4.0",
"config": [
{
"displayName": "APPID",
"configName": "appId"
},
{
"displayName": "APIKey",
"configName": "apiKey"
},
{
"displayName": "SecretKey",
"configName": "secretKey"
}
]
}
]
}
)";
return info;
}
std::string &BaiduNlpEnginePrivate::getAccessToken_() {
if (accessToken_.empty()) {
generateAccessToken_();
}
return accessToken_;
}
bool BaiduNlpEnginePrivate::generateAccessToken_() {
bool hasNetError = false;
std::string errorMessage;
accessToken_ = baidu_nlp_token::getBaiduToken(apiKey_, secretKey_,
hasNetError, errorMessage);
if (accessToken_.empty()) {
fprintf(stderr, "get baidu token failed: %s\n", errorMessage.c_str());
if (hasNetError) {
currentError_ = ai_engine::lm::EngineError(
ai_engine::lm::AiCapability::Nlp,
ai_engine::lm::EngineErrorCategory::Initialization,
(int)ai_engine::lm::NlpEngineErrorCode::NetworkDisconnected,
errorMessage);
} else {
currentError_ = ai_engine::lm::EngineError(
ai_engine::lm::AiCapability::Nlp,
ai_engine::lm::EngineErrorCategory::Initialization,
(int)ai_engine::lm::NlpEngineErrorCode::Unauthorized,
errorMessage);
}
return false;
}
return true;
}
void BaiduNlpEnginePrivate::setChatResultCallback(
ai_engine::lm::nlp::ChatResultCallback callback) {
chatResultCallback_ = callback;
}
void BaiduNlpEnginePrivate::setContextSize(int size) { contextSize_ = size; }
void BaiduNlpEnginePrivate::clearContext() { context_.clear(); }
namespace {
std::vector<std::string> splitDataString(const std::string &dataStr) {
std::vector<std::string> datas;
datas.clear();
int oldIndex = 0;
int newIndex = 0;
while ((newIndex = dataStr.find_first_of("\n\n", newIndex)) !=
std::string::npos) {
datas.push_back(dataStr.substr(oldIndex, newIndex - oldIndex));
newIndex += 2;
oldIndex = newIndex;
}
return datas;
}
} // namespace
bool writeChatData(std::string data, intptr_t userData) {
BaiduNlpEnginePrivate *engine = (BaiduNlpEnginePrivate *)userData;
ai_engine::lm::nlp::ChatResultCallback callback =
engine->chatResultCallback_;
if (engine->isStopped_) {
return false;
}
Json::Value resultJson = baidu_nlp_util::formatJsonFromString(data);
if (baidu_nlp_token::isBaiduTokenExpiredByResult(resultJson)) {
engine->needRefreshAccessToken_ = true;
return false;
}
if (int errorCode = baidu_nlp_server_error::parseErrorCode(data)) {
fprintf(stderr, "baidu chat failed: %s\n", data.c_str());
auto errorTuple =
baidu_nlp_server_error::errorCode2nlpResult(errorCode);
engine->currentError_ = ai_engine::lm::EngineError(
ai_engine::lm::AiCapability::Nlp, std::get<0>(errorTuple),
(int)std::get<1>(errorTuple), data);
return false;
}
std::string resultString = engine->tempStreamChatData_.data() + data;
if (resultString.substr(resultString.length() - 2) != "\n\n") {
engine->tempStreamChatData_ += data;
return true;
}
engine->tempStreamChatData_.clear();
std::vector<std::string> dataStrings = splitDataString(resultString);
for (std::string data : dataStrings) {
data = data.substr(6, data.length() - 1);
Json::Value obj = baidu_nlp_util::formatJsonFromString(data);
Json::Value callbackObj;
callbackObj["sentence_id"] = obj["sentence_id"];
callbackObj["is_end"] = obj["is_end"];
callbackObj["result"] = obj["result"];
engine->streamChatData_.push_back(obj["result"]);
ai_engine::lm::nlp::ChatResult result{callbackObj.toStyledString(),
ai_engine::lm::EngineError()};
callback(result);
}
return true;
}
void BaiduNlpEnginePrivate::addContext_(const std::string &role,
const std::string &message) {
if (role == "system") {
context_["system"] = message;
return;
}
int size = context_["messages"].size();
context_["messages"][size]["role"] = role;
context_["messages"][size]["content"] = message;
}
void BaiduNlpEnginePrivate::removeLastContext_() {
int size = context_["messages"].size();
if (size - 1 >= 0 &&
context_["messages"][size - 1]["role"].asString() == "user") {
Json::Value removeValue;
context_["messages"].removeIndex(size - 1, &removeValue);
return;
}
if ((size - 1) >= 0 && (size - 2) >= 0 &&
context_["messages"][size - 1]["role"].asString() == "assistant" &&
context_["messages"][size - 2]["role"].asString() == "user") {
Json::Value removeValue;
context_["messages"].removeIndex(size - 1, &removeValue);
context_["messages"].removeIndex(size - 2, &removeValue);
return;
}
}
bool BaiduNlpEnginePrivate::initChatModule(ai_engine::lm::EngineError &error) {
currentError_ = error;
if (!generateAccessToken_()) {
error = currentError_;
return false;
}
return true;
}
bool BaiduNlpEnginePrivate::chat(const std::string &message,
ai_engine::lm::EngineError &error) {
tempStreamChatData_.clear();
isStopped_ = false;
currentError_ = error;
if (!chatResultCallback_) {
return false;
}
if (message.empty()) {
error = ai_engine::lm::EngineError(
ai_engine::lm::AiCapability::Nlp,
ai_engine::lm::EngineErrorCategory::Initialization,
(int)ai_engine::lm::NlpEngineErrorCode::TooLittleData, "文本为空");
ai_engine::lm::nlp::ChatResult result{std::string{}, currentError_};
chatResultCallback_(result);
return false;
}
if (contextSize_ == 0) {
clearContext();
}
addContext_("user", message);
addContext_("system", systemRole_);
context_["stream"] = true;
std::string postString = context_.toStyledString();
cpr::Url url{"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/"
"wenxinworkshop/chat/completions_pro"};
cpr::Parameters parameters{
cpr::Parameters{{"access_token", getAccessToken_()}}};
cpr::Header header{cpr::Header{{"Content-Type", "application/json"}}};
cpr::Body body{cpr::Body{{postString}}};
cpr::Response response =
cpr::Post(url, parameters, header, body,
cpr::WriteCallback{writeChatData, (intptr_t)this});
if (needRefreshAccessToken_) {
removeLastContext_();
streamChatData_.clear();
generateAccessToken_();
needRefreshAccessToken_ = false;
return chat(message, error);
}
return processChatResponse_(response, error);
}
void BaiduNlpEnginePrivate::stopChat() { isStopped_ = true; }
bool BaiduNlpEnginePrivate::destroyChatModule(
ai_engine::lm::EngineError &error) {
stopChat();
return true;
}
bool BaiduNlpEnginePrivate::processChatResponse_(
cpr::Response &response, ai_engine::lm::EngineError &error) {
if (isStopped_) {
removeLastContext_();
streamChatData_.clear();
ai_engine::lm::nlp::ChatResult result{
std::string{"{\"result\":\"\",\"is_end\":true}"}, currentError_};
chatResultCallback_(result);
return true;
}
if (response.error.code == cpr::ErrorCode::REQUEST_CANCELLED) {
removeLastContext_();
streamChatData_.clear();
ai_engine::lm::nlp::ChatResult result{std::string{}, currentError_};
chatResultCallback_(result);
error = currentError_;
return false;
}
if (response.error) {
removeLastContext_();
streamChatData_.clear();
fprintf(stderr, "net error: %s\n", response.error.message.c_str());
currentError_ = ai_engine::lm::EngineError(
ai_engine::lm::AiCapability::Nlp,
ai_engine::lm::EngineErrorCategory::Initialization,
(int)ai_engine::lm::NlpEngineErrorCode::NetworkDisconnected,
"网络错误");
ai_engine::lm::nlp::ChatResult result{std::string{}, currentError_};
chatResultCallback_(result);
error = currentError_;
return false;
}
std::string assistantcontent;
for (Json::Value &value : streamChatData_) {
assistantcontent.append(value.asString());
}
addContext_("assistant", assistantcontent);
streamChatData_.clear();
error = currentError_;
return true;
}
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/openkylin/kylin-baidu-nlp-engine.git
git@gitee.com:openkylin/kylin-baidu-nlp-engine.git
openkylin
kylin-baidu-nlp-engine
kylin-baidu-nlp-engine
upstream

搜索帮助