代码拉取完成,页面将自动刷新
同步操作将从 mynameisi/faiss_dog_cat_question 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
import cv2
import numpy as np
import os
from os.path import exists
from imutils import paths
import pickle
from tqdm import tqdm
from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input
from tensorflow.keras.preprocessing import image
import logging
# 配置日志记录的基本设置
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def get_size(file_path):
"""
获取指定文件的大小(以MB为单位)
参数:
file_path (str): 文件的路径
返回:
float: 文件大小(MB)
"""
return os.path.getsize(file_path) / (1024 * 1024) # 文件大小(MB)
def load_existing_data(x_file_path, y_file_path):
"""
加载已存在的X和y数据
参数:
x_file_path (str): X数据文件路径
y_file_path (str): y数据文件路径
返回:
tuple: (X, y)
"""
logging.info("X和y已经存在,直接读取")
logging.info(f"X文件大小:{get_size(x_file_path):.2f}MB")
logging.info(f"y文件大小:{get_size(y_file_path):.2f}MB")
with open(x_file_path, 'rb') as f:
X = pickle.load(f)
with open(y_file_path, 'rb') as f:
y = pickle.load(f)
return X, y
def create_vgg_features(image_paths, model, batch_size):
"""
使用VGG16模型创建图像特征
参数:
image_paths (list): 图像文件路径列表
model (VGG16): VGG16模型
batch_size (int): 批处理大小
返回:
list: 图像特征列表
list: 标签列表
"""
X = []
y = []
num_batches = len(image_paths) // batch_size + (1 if len(image_paths) % batch_size else 0)
for idx in tqdm(range(num_batches), desc="读取图像"):
batch_images = []
batch_labels = []
start = idx * batch_size
end = min((idx + 1) * batch_size, len(image_paths))
for image_path in image_paths[start:end]:
img = image.load_img(image_path, target_size=(224, 224))
img = image.img_to_array(img)
batch_images.append(img)
label = os.path.basename(image_path).split('_')[0]
label = 1 if label == 'dog' else 0
batch_labels.append(label)
batch_images = np.array(batch_images)
batch_images = preprocess_input(batch_images)
batch_features = model.predict(batch_images, verbose=0)
X.extend(batch_features)
y.extend(batch_labels)
return X, y
def create_flat_features(image_paths, batch_size):
"""
创建平面图像特征
参数:
image_paths (list): 图像文件路径列表
batch_size (int): 批处理大小
返回:
list: 图像特征列表
list: 标签列表
"""
X = []
y = []
num_batches = len(image_paths) // batch_size + (1 if len(image_paths) % batch_size else 0)
for idx in tqdm(range(num_batches), desc="读取图像"):
batch_images = []
batch_labels = []
start = idx * batch_size
end = min((idx + 1) * batch_size, len(image_paths))
for image_path in image_paths[start:end]:
img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
img = cv2.resize(img, (32, 32))
img = img.flatten()
batch_images.append(img)
label = os.path.basename(image_path).split('_')[0]
label = 1 if label == 'dog' else 0
batch_labels.append(label)
batch_images = np.array(batch_images)
X.extend(batch_images)
y.extend(batch_labels)
return X, y
def save_data(X, y, x_file_path, y_file_path):
"""
保存X和y数据
参数:
X (list): 图像特征列表
y (list): 标签列表
x_file_path (str): X数据文件路径
y_file_path (str): y数据文件路径
"""
with open(x_file_path, 'wb') as f:
pickle.dump(X, f)
with open(y_file_path, 'wb') as f:
pickle.dump(y, f)
def createXY(train_folder, dest_folder, method='vgg', batch_size=64):
x_file_path = os.path.join(dest_folder, "X.pkl")
y_file_path = os.path.join(dest_folder, "y.pkl")
if exists(x_file_path) and exists(y_file_path):
return load_existing_data(x_file_path, y_file_path)
logging.info("读取所有图像,生成X和y")
image_paths = list(paths.list_images(train_folder))
if method == 'vgg':
model = VGG16(weights='imagenet', include_top=False, pooling="max")
logging.info("完成构建 VGG16 模型")
X, y = create_vgg_features(image_paths, model, batch_size)
elif method == 'flat':
X, y = create_flat_features(image_paths, batch_size)
logging.info(f"X.shape: {np.shape(X)}")
logging.info(f"y.shape: {np.shape(y)}")
save_data(X, y, x_file_path, y_file_path)
return X, y
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。