1 Star 0 Fork 0

houkai2020/temp

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
inference.py 12.57 KB
一键复制 编辑 原始数据 按行查看 历史
houkai 提交于 2020-04-03 19:26 . 1
# Copyright 2020 Google Research. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Inference related utilities."""
from __future__ import absolute_import
from __future__ import division
# gtype import
from __future__ import print_function
import copy
import os
from absl import logging
import numpy as np
from PIL import Image
import tensorflow.compat.v1 as tf
from typing import Text, Dict, Any, List
import anchors
import dataloader
import det_model_fn
import hparams_config
import utils
from visualize import vis_utils
# coco_id_mapping = {
# 1: 'person', 2: 'bicycle', 3: 'car', 4: 'motorcycle', 5: 'airplane',
# 6: 'bus', 7: 'train', 8: 'truck', 9: 'boat', 10: 'traffic light',
# 11: 'fire hydrant', 13: 'stop sign', 14: 'parking meter', 15: 'bench',
# 16: 'bird', 17: 'cat', 18: 'dog', 19: 'horse', 20: 'sheep', 21: 'cow',
# 22: 'elephant', 23: 'bear', 24: 'zebra', 25: 'giraffe', 27: 'backpack',
# 28: 'umbrella', 31: 'handbag', 32: 'tie', 33: 'suitcase', 34: 'frisbee',
# 35: 'skis', 36: 'snowboard', 37: 'sports ball', 38: 'kite',
# 39: 'baseball bat', 40: 'baseball glove', 41: 'skateboard', 42: 'surfboard',
# 43: 'tennis racket', 44: 'bottle', 46: 'wine glass', 47: 'cup', 48: 'fork',
# 49: 'knife', 50: 'spoon', 51: 'bowl', 52: 'banana', 53: 'apple',
# 54: 'sandwich', 55: 'orange', 56: 'broccoli', 57: 'carrot', 58: 'hot dog',
# 59: 'pizza', 60: 'donut', 61: 'cake', 62: 'chair', 63: 'couch',
# 64: 'potted plant', 65: 'bed', 67: 'dining table', 70: 'toilet', 72: 'tv',
# 73: 'laptop', 74: 'mouse', 75: 'remote', 76: 'keyboard', 77: 'cell phone',
# 78: 'microwave', 79: 'oven', 80: 'toaster', 81: 'sink', 82: 'refrigerator',
# 84: 'book', 85: 'clock', 86: 'vase', 87: 'scissors', 88: 'teddy bear',
# 89: 'hair drier', 90: 'toothbrush',
# } # pyformat: disable
coco_id_mapping = {
1: 'zhujia', 2: 'fujia'
} # pyformat: disable
def image_preprocess(image, image_size: int):
"""Preprocess image for inference.
Args:
image: input image, can be a tensor or a numpy arary.
image_size: integer of image size.
Returns:
(image, scale): a tuple of processed image and its scale.
"""
input_processor = dataloader.DetectionInputProcessor(image, image_size)
input_processor.normalize_image()
input_processor.set_scale_factors_to_output_size()
image = input_processor.resize_and_crop_image()
image_scale = input_processor.image_scale_to_original
return image, image_scale
def build_inputs(image_path_pattern: Text, image_size: int):
"""Read and preprocess input images.
Args:
image_path_pattern: a path to indicate a single or multiple files.
image_size: a single integer for image width and height.
Returns:
(raw_images, images, scales): raw images, processed images, and scales.
"""
raw_images, images, scales = [], [], []
for f in tf.io.gfile.glob(image_path_pattern):
image = Image.open(f)
raw_images.append(image)
image, scale = image_preprocess(image, image_size)
images.append(image)
scales.append(scale)
return raw_images, tf.stack(images), tf.stack(scales)
def build_model(model_name: Text, inputs: tf.Tensor, **kwargs):
"""Build model for a given model name.
Args:
model_name: the name of the model.
inputs: an image tensor or a numpy array.
**kwargs: extra parameters for model builder.
Returns:
(class_outputs, box_outputs): the outputs for class and box predictions.
Each is a dictionary with key as feature level and value as predictions.
"""
model_arch = det_model_fn.get_model_arch(model_name)
class_outputs, box_outputs = model_arch(inputs, model_name, **kwargs)
return class_outputs, box_outputs
def restore_ckpt(sess, ckpt_path, enable_ema=True, export_ckpt=None):
"""Restore variables from a given checkpoint.
Args:
sess: a tf session for restoring or exporting models.
ckpt_path: the path of the checkpoint. Can be a file path or a folder path.
enable_ema: whether reload ema values or not.
export_ckpt: whether to export the restored model.
"""
sess.run(tf.global_variables_initializer())
if tf.io.gfile.isdir(ckpt_path):
ckpt_path = tf.train.latest_checkpoint(ckpt_path)
if enable_ema:
ema = tf.train.ExponentialMovingAverage(decay=0.0)
ema_vars = utils.get_ema_vars()
var_dict = ema.variables_to_restore(ema_vars)
ema_assign_op = ema.apply(ema_vars)
else:
var_dict = utils.get_ema_vars()
ema_assign_op = None
tf.train.get_or_create_global_step()
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(var_dict, max_to_keep=1)
saver.restore(sess, ckpt_path)
if export_ckpt:
print('export model to {}'.format(export_ckpt))
if ema_assign_op is not None:
sess.run(ema_assign_op)
saver = tf.train.Saver(max_to_keep=1, save_relative_paths=True)
saver.save(sess, export_ckpt)
def det_post_process(params: Dict[Any, Any], cls_outputs: Dict[int, tf.Tensor],
box_outputs: Dict[int, tf.Tensor], scales: List[float]):
"""Post preprocessing the box/class predictions.
Args:
params: a parameter dictionary that includes `min_level`, `max_level`,
`batch_size`, and `num_classes`.
cls_outputs: an OrderDict with keys representing levels and values
representing logits in [batch_size, height, width, num_anchors].
box_outputs: an OrderDict with keys representing levels and values
representing box regression targets in
[batch_size, height, width, num_anchors * 4].
scales: a list of float values indicating image scale.
Returns:
detections_batch: a batch of detection results. Each detection is a tensor
with each row representing [image_id, x, y, width, height, score, class].
"""
# TODO(tanmingxing): refactor the code to make it more explicity.
outputs = {'cls_outputs_all': [None], 'box_outputs_all': [None],
'indices_all': [None], 'classes_all': [None]}
det_model_fn.add_metric_fn_inputs(
params, cls_outputs, box_outputs, outputs)
# Create anchor_label for picking top-k predictions.
eval_anchors = anchors.Anchors(params['min_level'],
params['max_level'],
params['num_scales'],
params['aspect_ratios'],
params['anchor_scale'],
params['image_size'])
anchor_labeler = anchors.AnchorLabeler(eval_anchors, params['num_classes'])
# Add all detections for each input image.
detections_batch = []
for index in range(params['batch_size']):
cls_outputs_per_sample = outputs['cls_outputs_all'][index]
box_outputs_per_sample = outputs['box_outputs_all'][index]
indices_per_sample = outputs['indices_all'][index]
classes_per_sample = outputs['classes_all'][index]
detections = anchor_labeler.generate_detections(
cls_outputs_per_sample, box_outputs_per_sample, indices_per_sample,
classes_per_sample, image_id=[index], image_scale=[scales[index]])
detections_batch.append(detections)
return detections_batch
def visualize_image(image,
boxes,
classes,
scores,
id_mapping,
min_score_thresh=0.2,
max_boxes_to_draw=50,
line_thickness=4,
**kwargs):
"""Visualizes a given image.
Args:
image: a image with shape [H, W, C].
boxes: a box prediction with shape [N, 4] ordered [ymin, xmin, ymax, xmax].
classes: a class prediction with shape [N].
scores: A list of float value with shape [N].
id_mapping: a dictionary from class id to name.
min_score_thresh: minimal score for showing. If claass probability is below
this threshold, then the object will not show up.
max_boxes_to_draw: maximum bounding box to draw.
line_thickness: how thick is the bounding box line.
**kwargs: extra parameters.
Returns:
output_image: an output image with annotated boxes and classes.
"""
category_index = {k: {'id': k, 'name': id_mapping[k]} for k in id_mapping}
img = np.array(image)
vis_utils.visualize_boxes_and_labels_on_image_array(
img,
boxes,
classes,
scores,
category_index,
min_score_thresh=min_score_thresh,
max_boxes_to_draw=max_boxes_to_draw,
line_thickness=line_thickness,
**kwargs)
return img
class InferenceDriver(object):
"""A driver for doing inference."""
def __init__(self, model_name: Text, ckpt_path: Text, image_size: int = None, num_classes: int = None,
label_id_mapping: Dict[int, Text] = None):
"""Initialize the inference driver.
Args:
model_name: target model name, such as efficientdet-d0.
ckpt_path: checkpoint path, such as /tmp/efficientdet-d0/.
image_size: user specified image size. If None, use the default image size
defined by model_name.
label_id_mapping: a dictionary from id to name. If None, use the default
coco_id_mapping (with 90 classes).
"""
self.model_name = model_name
self.ckpt_path = ckpt_path
self.label_id_mapping = label_id_mapping or coco_id_mapping
self.params = hparams_config.get_detection_config(self.model_name).as_dict()
self.params.update(dict(is_training_bn=False, use_bfloat16=False))
if image_size:
self.params.update(dict(image_size=image_size))
if num_classes:
self.params.update(dict(num_classes=num_classes))
def inference(self,
image_path_pattern: Text,
output_dir: Text,
**kwargs):
"""Read and preprocess input images.
Args:
image_path_pattern: Image file pattern such as /tmp/img*.jpg
output_dir: the directory for output images. Output images will be named
as 0.jpg, 1.jpg, ....
**kwargs: extra parameters for for vistualization, such as
min_score_thresh, max_boxes_to_draw, and line_thickness.
Returns:
Annotated image.
"""
params = copy.deepcopy(self.params)
with tf.Session() as sess:
# Buid inputs and preprocessing.
raw_images, images, scales = build_inputs(
image_path_pattern, params['image_size'])
# Build model.
class_outputs, box_outputs = build_model(
self.model_name, images, **self.params)
restore_ckpt(sess, self.ckpt_path, enable_ema=True, export_ckpt=None)
params.update(dict(batch_size=len(raw_images))) # for postprocessing.
# Build postprocessing.
detections_batch = det_post_process(
params, class_outputs, box_outputs, scales)
outputs_np = sess.run(detections_batch)
# Visualize results.
for i, output_np in enumerate(outputs_np):
# output_np has format [image_id, x, y, width, height, score, class]
boxes = output_np[:, 1:5]
classes = output_np[:, 6].astype(int)
scores = output_np[:, 5]
# convert [x, y, width, height] to [ymin, xmin, ymax, xmax]
# TODO(tanmingxing): make this convertion more efficient.
boxes[:, [0, 1, 2, 3]] = boxes[:, [1, 0, 3, 2]]
boxes[:, 2:4] += boxes[:, 0:2]
img = visualize_image(raw_images[i], boxes, classes, scores,
self.label_id_mapping, **kwargs)
output_image_path = os.path.join(output_dir, str(i) + '.jpg')
Image.fromarray(img).save(output_image_path)
logging.info('writing file to %s', output_image_path)
return outputs_np
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/houkai2020/temp.git
git@gitee.com:houkai2020/temp.git
houkai2020
temp
temp
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385