1 Star 0 Fork 245

胡洋/faiss_dog_cat_question

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
train_ensemble.py 5.31 KB
一键复制 编辑 原始数据 按行查看 历史
胡洋 提交于 2024-11-07 11:09 . PY
import time
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier, VotingClassifier, BaggingClassifier, AdaBoostClassifier
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
from xgboost import XGBClassifier
from sklearn.ensemble import StackingClassifier
from tabulate import tabulate
import logging
import numpy as np
import gradio as gr
import cv2
import faiss
from util import createXY
import pickle
print("模块载入完成,开始运行主代码")
# 配置logging, 确保能够打印正在运行的函数名
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# 数据加载和预处理
X, y = createXY(train_folder="../data/train", dest_folder=".", method='flat')
X = np.array(X).astype('float32')
# 假设faiss.normalize_L2存在于Faiss库中
faiss.normalize_L2(X)
y = np.array(y)
logging.info("数据加载和预处理完成。")
# 数据集分割
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=2023)
logging.info("数据集划分为训练集和测试集。")
# 创建逻辑回归分类器
logistic_clf = LogisticRegression(max_iter=1000)
# 创建随机森林分类器
random_forest_clf = RandomForestClassifier(random_state=42)
# 创建支持向量机分类器
# probability=True表示需要计算概率值,这样才能使用软投票
svm_clf = SVC(probability=True)
# 创建硬投票分类器,包含逻辑回归、随机森林和支持向量机三个分类器
voting_clf_hard = VotingClassifier(
estimators=[('lr', logistic_clf), ('rf', random_forest_clf), ('svc', svm_clf)],
voting='hard' # 采用硬投票, 即选择票数最多的类别作为预测结果
)
# 创建软投票分类器,包含逻辑回归、随机森林和支持向量机三个分类器
voting_clf_soft = VotingClassifier(
estimators=[('lr', logistic_clf), ('rf', random_forest_clf), ('svc', svm_clf)],
voting='soft' # 采用软投票, 即选择概率之和最大的类别作为预测结果
)
# 创建Bagging分类器
bag_clf = BaggingClassifier(
DecisionTreeClassifier(), # 基分类器选用决策树分类器
n_estimators=500, # 500个分类器
max_samples=100, # 每个分类器的训练集包含100个样本
bootstrap=True, # 有放回的采样
n_jobs=-1, # 使用所有CPU核
)
# 创建Pasting分类器
paste_clf = BaggingClassifier(
DecisionTreeClassifier(), # 基分类器选用决策树分类器
n_estimators=500, # 500个分类器
max_samples=100, # 每个分类器的训练集包含100个样本
bootstrap=False, # 无放回的采样
n_jobs=-1, # 使用所有CPU核
)
# 创建AdaBoost分类器
ada_clf = AdaBoostClassifier(
DecisionTreeClassifier(max_depth=1), # 基分类器选用决策树分类器
n_estimators=200, # 200个分类器
algorithm="SAMME.R", # 使用SAMME.R算法
learning_rate=0.5 # 学习率为0.5, 即每个分类器的权重缩减系数为0.5
)
# 创建一个梯度提升分类器
xgb_clf = XGBClassifier(
n_estimators=200, # 200个分类器
max_depth=2, # 每个分类器的最大深度为2
learning_rate=0.5 # 学习率为0.5, 即每个分类器的权重缩减系数为0.5
)
# 创建一个堆叠分类器
stacking_clf = StackingClassifier(
estimators=[('lr', logistic_clf), ('rf', random_forest_clf), ('svc', svm_clf)],
final_estimator=LogisticRegression() # 最终分类器选用逻辑回归分类器
)
clfs = {
"logistic_regression": logistic_clf,
"random_forest": random_forest_clf,
"svm": svm_clf,
"hard_voting": voting_clf_hard,
"soft_voting": voting_clf_soft,
"bagging": bag_clf,
"pasting": paste_clf,
"adaboost": ada_clf,
"gradient_boosting": xgb_clf,
"stacking": stacking_clf
}
results = []
# 训练和评估分类器
for name, clf in clfs.items():# 遍历字典中的每一个键值对:(模型名称, 模型对象)
start_time = time.time()
clf.fit(X_train, y_train)
train_time = time.time() - start_time
logging.info(f"{name}模型训练完成,用时:{train_time:.4f}秒。")
start_time = time.time()
accuracy = clf.score(X_test, y_test)
pred_time = time.time() - start_time
logging.info(f"{name}模型评估完成,用时:{pred_time:.4f}秒。")
with open(f"data/{name}.pkl",'wb') as f:
pickle.dump(clf,f)
results.append([name, train_time, pred_time, accuracy])
# 找到准确率最高的模型
best_model_info = max(results, key=lambda x: x[3])
best_model_name, _, _, best_accuracy = best_model_info
# 保存准确率最高的模型到 best_model.pkl
with open("data/best_model.pkl", 'wb') as f:
best_clf = clfs[best_model_name]
pickle.dump(best_clf, f)
print(f"准确率最高的模型({best_accuracy:.4f})已保存为 best_model.pkl")
# 打印结果表格
headers = ["Classifier", "Training Time (s)", "Prediction Time (s)", "Accuracy"]
print(tabulate(results, headers=headers, tablefmt="simple"))
print("程序运行完毕")
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/hu-yang-li/faiss_dog_cat_question.git
git@gitee.com:hu-yang-li/faiss_dog_cat_question.git
hu-yang-li
faiss_dog_cat_question
faiss_dog_cat_question
main

搜索帮助