代码拉取完成,页面将自动刷新
import torch
import torchvision
import cv2
import numpy as np
from PIL import Image
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torchvision.transforms import ToTensor
from skimage.transform import resize
import os
# 初始化并修改Mask R-CNN模型
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=False)
num_classes = 2
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
hidden_layer = 256
model.roi_heads.mask_predictor = MaskRCNNPredictor(
in_features_mask, hidden_layer, num_classes)
# 加载模型权重
model.load_state_dict(torch.load("model/ok.pth"))
# 设置模型为评估模式
model.eval()
# 定义输入和输出目录
input_dir = "images"
output_dir = "image_jc"
# 如果输出目录不存在,创建它
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# 遍历输入目录中的所有文件
for filename in os.listdir(input_dir):
# 检查文件是否为图像(这里只检查 .jpg 和 .png 文件)
if filename.lower().endswith((".jpg", ".jpeg", ".png")):
# 读取图像并转换为tensor
image_path = os.path.join(input_dir, filename)
image = Image.open(image_path).convert("RGB")
transform = ToTensor()
image_tensor = transform(image).unsqueeze(0)
# 将图像传递给模型
with torch.no_grad():
predictions = model(image_tensor)
# 处理预测结果
pred_boxes = predictions[0]['boxes'].numpy()
pred_scores = predictions[0]['scores'].numpy()
threshold = 0.8
if pred_scores.size > 0: # 检查 pred_scores 是否为空
# count = len([x for x in pred_scores if x > 0.8])
# if count > 1:
# # 这种情况说明,图中有多个小图,暂时就只选一个,后面如果要改动说
# max_score_idx = 1 # 这里的 1 是下标
# else:
# # 找到得分最高的边界框
# max_score_idx = np.argmax(pred_scores)
max_score_idx = np.argmax(pred_scores)
max_score = pred_scores[max_score_idx]
# 如果得分超过阈值,进行剪裁,否则直接保存原图
if max_score > threshold:
box = pred_boxes[max_score_idx]
cropped_image = image.crop((box[0], box[1], box[2], box[3]))
else:
cropped_image = image
else:
cropped_image = image
# 保存剪裁后的图像
output_image_path = os.path.join(output_dir, filename)
cropped_image.save(output_image_path)
def count_files(dir_path):
return len([f for f in os.listdir(dir_path) if os.path.isfile(os.path.join(dir_path, f))])
print('剪裁后的文件数量:', count_files(output_dir))
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。