1 Star 0 Fork 245

邓康林/KNN_faiss_dog_cat_question

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
train_ensemble.py 6.35 KB
一键复制 编辑 原始数据 按行查看 历史
邓康林 提交于 2024-11-06 11:08 . 111
import numpy as np
import time
import faiss
from util import createXY
from sklearn.model_selection import train_test_split
import logging
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
from xgboost import XGBClassifier
from sklearn.ensemble import (
BaggingClassifier,
AdaBoostClassifier,
RandomForestClassifier,
StackingClassifier,
VotingClassifier,
GradientBoostingClassifier
)
from tabulate import tabulate
import pickle
# 配置 logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s.%(msecs)03d - INFO - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
print("\n" + "="*50)
print("开始机器学习模型训练和评估流程")
print("="*50 + "\n")
# 数据加载
logging.info("数据集正在加载,请稍候......")
X, y = createXY(train_folder="../data/train", dest_folder=".")
X = np.array(X).astype('float32')
faiss.normalize_L2(X) # 对数据进行 L2 归一化
y = np.array(y)
logging.info("数据加载和预处理完成。")
# 记录数据信息
logging.info(f"X.shape: {X.shape}") # 新增:记录X的形状
logging.info(f"y.shape: ({len(y)},)") # 新增:记录y的形状
logging.info(f"X文件大小: {X.nbytes / (1024*1024):.12f} MB") # 新增:记录X的文件大小
logging.info(f"y文件大小: {y.nbytes / (1024*1024):.12f} MB") # 新增:记录y的文件大小
logging.info("数据集划分为训练集和测试集。") # 新增:记录数据集划分
# 数据集分割
logging.info("正在划分训练集和测试集,请稍候......")
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 确保分割后的数据类型一致性
X_train = np.array(X_train, dtype=np.float32)
X_test = np.array(X_test, dtype=np.float32)
y_train = np.array(y_train, dtype=np.int64)
y_test = np.array(y_test, dtype=np.int64)
logging.info(f"数据集划分完成!训练集: {X_train.shape[0]}样本, 测试集: {X_test.shape[0]}样本")
# 修改分类器配置
classifiers = {
"logistic_regression": LogisticRegression(
C=1.0,
max_iter=2000,
class_weight='balanced',
solver='lbfgs',
multi_class='multinomial'
),
"random_forest": RandomForestClassifier(
n_estimators=50, # 减少树的数量以加快测试
max_depth=10, # 限制树的深度
min_samples_split=2,
min_samples_leaf=1,
max_features='sqrt',
bootstrap=True,
random_state=42,
n_jobs=-1,
class_weight='balanced'
),
"svm": SVC(
C=10.0,
kernel='rbf',
gamma='scale',
probability=True,
class_weight='balanced',
random_state=42,
cache_size=2000
),
"hard_voting": VotingClassifier( # 新增:添加hard_voting分类器
estimators=[
('rf', RandomForestClassifier(n_estimators=10, random_state=42)),
('svc', SVC(kernel='linear', probability=True, random_state=42))
],
voting='hard'
),
"soft_voting": VotingClassifier( # 新增:添加soft_voting分类器
estimators=[
('rf', RandomForestClassifier(n_estimators=10, random_state=42)),
('svc', SVC(kernel='linear', probability=True, random_state=42))
],
voting='soft'
),
"bagging": BaggingClassifier(
estimator=DecisionTreeClassifier(max_depth=10, class_weight='balanced'),
n_estimators=50, # 减少估计器数量
max_samples=0.8,
max_features=0.8,
bootstrap=True,
n_jobs=-1,
random_state=42
),
"adaboost": AdaBoostClassifier(
estimator=DecisionTreeClassifier(max_depth=3, class_weight='balanced'),
n_estimators=50, # 减少估计器数量
learning_rate=0.1,
algorithm="SAMME.R",
random_state=42
),
"gradient_boosting": GradientBoostingClassifier( # 新增:添加gradient_boosting分类器
n_estimators=100,
learning_rate=0.1,
max_depth=3,
random_state=42
),
"xgboost": XGBClassifier( # 新增:添加xgboost分类器
n_estimators=100,
learning_rate=0.1,
max_depth=3,
use_label_encoder=False,
eval_metric='mlogloss',
random_state=42
)
}
results = []
best_accuracy = 0
best_model = None
print("\n" + "="*50)
print("开始模型训练和评估")
print("="*50 + "\n")
# 训练和评估分类器
for name, clf in classifiers.items():
try:
# 训练阶段
logging.info(f"正在进行{name}模型的训练,请耐心等待......")
start_time = time.time()
clf.fit(X_train, y_train)
fit_time = time.time() - start_time
logging.info(f"{name}模型训练完成,用时{fit_time:.4f}秒。")
# 评估阶段
logging.info(f"正在进行{name}模型的评估,请耐心等待......")
start_time = time.time()
accuracy = clf.score(X_test, y_test)
score_time = time.time() - start_time
logging.info(f"{name}模型评估完成,用时{score_time:.4f}秒。")
results.append([name, fit_time, score_time, accuracy])
if accuracy > best_accuracy:
best_accuracy = accuracy
best_model = clf
logging.info(f"发现新的最佳模型!{name}的准确率为{accuracy:.4f}")
except Exception as e:
logging.error(f"{name}模型训练或评估时出错: {str(e)}")
continue
# 保存最佳模型
if best_model is not None:
logging.info(f"最终最佳模型是 {best_model.__class__.__name__},准确率为 {best_accuracy:.4f}")
logging.info("正在保存最佳模型......")
with open("best_model.pkl", "wb") as f:
pickle.dump(best_model, f)
logging.info("最佳模型已成功保存到 best_model.pkl 文件中。")
# 打印结果表格
print("\n" + "="*70)
print("模型性能比较")
print("="*70)
headers = ["Classifier", "Training Time (s)", "Prediction Time (s)", "Accuracy"]
for row in results:
name, train_time, pred_time, acc = row
print(f"{name:<20} {train_time:>15.5f} {pred_time:>15.5f} {acc:>15.6f}")
print("="*70)
print("\n训练评估流程已完成!")
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/deng-kanglin/KNN_faiss_dog_cat_question.git
git@gitee.com:deng-kanglin/KNN_faiss_dog_cat_question.git
deng-kanglin
KNN_faiss_dog_cat_question
KNN_faiss_dog_cat_question
main

搜索帮助