3 Star 9 Fork 4

兜兜丨有糖丶/AutoPic

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
0201剪切图片.py 2.99 KB
一键复制 编辑 原始数据 按行查看 历史
兜兜丨有糖丶 提交于 2023-04-02 15:13 . init
import torch
import torchvision
import cv2
import numpy as np
from PIL import Image
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torchvision.transforms import ToTensor
from skimage.transform import resize
import os
# 初始化并修改Mask R-CNN模型
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=False)
num_classes = 2
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
hidden_layer = 256
model.roi_heads.mask_predictor = MaskRCNNPredictor(
in_features_mask, hidden_layer, num_classes)
# 加载模型权重
model.load_state_dict(torch.load("model/ok.pth"))
# 设置模型为评估模式
model.eval()
# 定义输入和输出目录
input_dir = "images"
output_dir = "image_jc"
# 如果输出目录不存在,创建它
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# 遍历输入目录中的所有文件
for filename in os.listdir(input_dir):
# 检查文件是否为图像(这里只检查 .jpg 和 .png 文件)
if filename.lower().endswith((".jpg", ".jpeg", ".png")):
# 读取图像并转换为tensor
image_path = os.path.join(input_dir, filename)
image = Image.open(image_path).convert("RGB")
transform = ToTensor()
image_tensor = transform(image).unsqueeze(0)
# 将图像传递给模型
with torch.no_grad():
predictions = model(image_tensor)
# 处理预测结果
pred_boxes = predictions[0]['boxes'].numpy()
pred_scores = predictions[0]['scores'].numpy()
threshold = 0.8
if pred_scores.size > 0: # 检查 pred_scores 是否为空
# count = len([x for x in pred_scores if x > 0.8])
# if count > 1:
# # 这种情况说明,图中有多个小图,暂时就只选一个,后面如果要改动说
# max_score_idx = 1 # 这里的 1 是下标
# else:
# # 找到得分最高的边界框
# max_score_idx = np.argmax(pred_scores)
max_score_idx = np.argmax(pred_scores)
max_score = pred_scores[max_score_idx]
# 如果得分超过阈值,进行剪裁,否则直接保存原图
if max_score > threshold:
box = pred_boxes[max_score_idx]
cropped_image = image.crop((box[0], box[1], box[2], box[3]))
else:
cropped_image = image
else:
cropped_image = image
# 保存剪裁后的图像
output_image_path = os.path.join(output_dir, filename)
cropped_image.save(output_image_path)
def count_files(dir_path):
return len([f for f in os.listdir(dir_path) if os.path.isfile(os.path.join(dir_path, f))])
print('剪裁后的文件数量:', count_files(output_dir))
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/freyStudio/auto-pic.git
git@gitee.com:freyStudio/auto-pic.git
freyStudio
auto-pic
AutoPic
master

搜索帮助