1 Star 0 Fork 245

发故宫/faiss_dog_cat_question

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
train.py 5.77 KB
一键复制 编辑 原始数据 按行查看 历史
发故宫 提交于 2024-09-28 15:52 . 1
import numpy as np
import faiss
from util import createXY
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
import argparse
import logging
from tqdm import tqdm
from FaissKNeighbors import FaissKNeighbors
import os
import sys
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def get_args():
parser = argparse.ArgumentParser(description='使用CPU或GPU训练模型。')
parser.add_argument('-m', '--mode', type=str, required=True, choices=['cpu', 'gpu'], help='选择训练模式:CPU或GPU。')
parser.add_argument('-f', '--feature', type=str, required=True, choices=['flat', 'vgg'], help='选择特征提取方法:flat或vgg。')
parser.add_argument('-l', '--library', type=str, required=True, choices=['sklearn', 'faiss'], help='选择使用的库:sklearn或faiss。')
args = parser.parse_args()
return args
def check_cuda_environment():
logging.info("检查 CUDA 环境...")
if 'CUDA_VISIBLE_DEVICES' in os.environ:
logging.info(f"CUDA_VISIBLE_DEVICES: {os.environ['CUDA_VISIBLE_DEVICES']}")
else:
logging.warning("未设置 CUDA_VISIBLE_DEVICES 环境变量")
try:
import torch
logging.info(f"PyTorch 版本: {torch.__version__}")
logging.info(f"CUDA 是否可用: {torch.cuda.is_available()}")
if torch.cuda.is_available():
logging.info(f"CUDA 版本: {torch.version.cuda}")
logging.info(f"可用 GPU 数量: {torch.cuda.device_count()}")
except ImportError:
logging.warning("未安装 PyTorch,无法检查 CUDA 可用性")
def check_gpu_support():
check_cuda_environment()
try:
logging.info(f"FAISS 库路径: {faiss.__file__}")
logging.info(f"FAISS 编译信息: {faiss.get_compile_options()}")
gpu_count = faiss.get_num_gpus()
logging.info(f"FAISS 报告的 GPU 数量: {gpu_count}")
if gpu_count > 0:
for i in range(gpu_count):
logging.info(f"测试 GPU {i}")
res = faiss.StandardGpuResources()
logging.info(f"GPU {i} 测试成功")
logging.info(f"FAISS 检测到 {gpu_count} 个 GPU 设备")
return True
else:
logging.warning("FAISS 未检测到 GPU 设备,将使用 CPU 模式")
return False
except AttributeError as e:
logging.error(f"FAISS GPU 支持错误: {str(e)}")
logging.error("FAISS 不支持 GPU。请确保安装了支持 GPU 的 FAISS 版本")
return False
except Exception as e:
logging.error(f"检查 GPU 支持时发生未知错误: {str(e)}")
return False
def main():
args = get_args()
logging.info(f"Python 版本: {sys.version}")
logging.info(f"FAISS 版本: {faiss.__version__}")
# 检查 GPU 支持
if args.mode == 'gpu':
if not check_gpu_support():
logging.error("由于 GPU 不可用,程序将退出。请使用 -m cpu 选项重新运行程序")
return
try:
res = faiss.StandardGpuResources() if args.mode == 'gpu' else None
logging.info("成功创建 GPU 资源" if args.mode == 'gpu' else "使用 CPU 模式")
except AttributeError:
logging.error("创建 GPU 资源失败。可能的原因:")
logging.error("1. 安装的 FAISS 版本不支持 GPU")
logging.error("2. CUDA 环境未正确配置")
logging.error("请检查 FAISS 安装和 CUDA 配置,或使用 CPU 模式运行")
return
except Exception as e:
logging.error(f"创建 GPU 资源时发生未知错误: {str(e)}")
return
logging.info(f"选择模式是 {args.mode.upper()}")
logging.info(f"选择特征提取方法是 {args.feature.upper()}")
logging.info(f"选择使用的库是 {args.library.upper()}")
try:
X, y = createXY(train_folder="../data/train", dest_folder=".",method=args.feature)
X = np.array(X).astype('float32')
faiss.normalize_L2(X)
y = np.array(y)
logging.info("数据加载和预处理完成。")
logging.info(f"X 形状: {X.shape}, y 形状: {y.shape}")
except Exception as e:
logging.error(f"数据加载或预处理失败:{str(e)}")
logging.error("请检查数据文件路径是否正确,以及数据格式是否符合要求")
return
try:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=2023)
logging.info("数据集划分为训练集和测试集。")
logging.info(f"训练集形状: X_train {X_train.shape}, y_train {y_train.shape}")
logging.info(f"测试集形状: X_test {X_test.shape}, y_test {y_test.shape}")
except Exception as e:
logging.error(f"数据集划分失败:{str(e)}")
return
best_k = -1
best_accuracy = 0.0
k_values = range(1, 6)
KNNClass = FaissKNeighbors if args.library == 'faiss' else KNeighborsClassifier
logging.info(f"使用的库为: {args.library.upper()}")
try:
for k in tqdm(k_values, desc='寻找最佳k值'):
knn = KNNClass(k=k, res=res) if args.library == 'faiss' else KNNClass(n_neighbors=k)
knn.fit(X_train, y_train)
accuracy = knn.score(X_test, y_test)
if accuracy > best_accuracy:
best_k = k
best_accuracy = accuracy
except Exception as e:
logging.error(f"模型训练或评估过程中出错:{str(e)}")
if args.library == 'faiss' and args.mode == 'gpu':
logging.error("如果使用 FAISS 和 GPU 模式,请确保 FAISS 正确安装并支持 GPU")
return
logging.info(f'最佳k值: {best_k}, 最高准确率: {best_accuracy}')
if __name__ == '__main__':
main()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/fa-forbidden-city/faiss_dog_cat_question.git
git@gitee.com:fa-forbidden-city/faiss_dog_cat_question.git
fa-forbidden-city
faiss_dog_cat_question
faiss_dog_cat_question
main

搜索帮助