代码拉取完成,页面将自动刷新
同步操作将从 mynameisi/faiss_dog_cat_question 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
# 导入必要的库
import numpy as np # 用于处理多维数组和矩阵运算
import faiss # 用于高效相似性搜索和稠密向量聚类
from util import createXY # 用于创建数据集的特征和标签
from sklearn.model_selection import train_test_split # 用于拆分数据集为训练集和测试集
from sklearn.neighbors import KNeighborsClassifier # sklearn中的K近邻分类器
import argparse # 用于解析命令行参数
import logging # 用于记录日志
from tqdm import tqdm # 用于在循环中显示进度条
from FaissKNeighbors import FaissKNeighbors # 导入自定义的FaissKNeighbors类
# 配置logging, 确保能够打印正在运行的函数名
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# 获取命令行参数
def get_args():
parser = argparse.ArgumentParser(description='使用CPU或GPU训练模型。') # 创建命令行参数解析器对象
parser.add_argument('-m', '--mode', type=str, required=True, choices=['cpu', 'gpu'], help='选择训练模式:CPU或GPU。')
parser.add_argument('-f', '--feature', type=str, required=True, choices=['flat', 'vgg'], help='选择特征提取方法:flat或vgg。')
parser.add_argument('-l', '--library', type=str, required=True, choices=['sklearn', 'faiss'], help='选择使用的库:sklearn或faiss。')
args = parser.parse_args()
return args
# 主函数,运行训练过程
def main():
args = get_args()
# 根据mode初始化FAISS所需的资源
res = faiss.StandardGpuResources() if args.mode == 'gpu' else None
logging.info(f"选择模式是 {args.mode.upper()}")
logging.info(f"选择特征提取方法是 {args.feature.upper()}")
logging.info(f"选择使用的库是 {args.library.upper()}")
# 载入和预处理数据
train_folder ="D:\\wangguozheng\\333\\faiss_dog_cat_question\\cat_dog_data\\data\\train"
dest_folder = "."
X, y = createXY(train_folder=train_folder, dest_folder=dest_folder, method=args.feature)
# 检查X的形状和内容
print("Shape of X before conversion:", X.shape)
print("First few elements of X before conversion:", X[:5])
X = np.array(X).astype('float32')
# 确保X是一个二维数组
if len(X.shape) != 2:
raise ValueError(f"X must be a 2D array. Current shape: {X.shape}")
# 再次检查X的形状和内容
print("Shape of X after conversion:", X.shape)
print("First few elements of X after conversion:", X[:5])
# 确保X是一个二维数组
if len(X.shape) != 2:
raise ValueError(f"X must be a 2D array. Current shape: {X.shape}")
# 确保X不为空
if X.size == 0:
raise ValueError("X is empty. Check data loading process.")
faiss.normalize_L2(X) # 对数据进行L2归一化
y = np.array(y)
logging.info("数据加载和预处理完成。")
# 数据集分割为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=2023)
logging.info("数据集划分为训练集和测试集。")
# 初始化变量,跟踪最佳的k值和相应的准确率
best_k = -1
best_accuracy = 0.0
# 定义测试的k值范围
k_values = range(1, 20)
# 根据提供的库选择K近邻算法实现
KNNClass = FaissKNeighbors if args.library == 'faiss' else KNeighborsClassifier
logging.info(f"使用的库为: {args.library.upper()}")
# 遍历k值,训练并评估模型
for k in tqdm(k_values, desc='寻找最佳k值'):
knn = KNNClass(k=k, res=res) if args.library == 'faiss' else KNNClass(n_neighbors=k)
knn.fit(X_train, y_train)
accuracy = knn.score(X_test, y_test)
# 更新最佳k值和准确率
if accuracy > best_accuracy:
best_k = k
best_accuracy = accuracy
# 打印结果
logging.info(f'最佳k值: {best_k}, 最高准确率: {best_accuracy}')
# 如果是主脚本,则执行main函数
if __name__ == '__main__':
main()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。