1 Star 0 Fork 1

ZYung/TensorFlow2.0_ResNet

forked from FIRC/TensorFlow2.0_ResNet 
加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
prepare_data.py 2.30 KB
一键复制 编辑 原始数据 按行查看 历史
372046933 提交于 2020-01-08 15:27 . Update prepare_data.py
import tensorflow as tf
import config
import pathlib
from config import image_height, image_width, channels
def load_and_preprocess_image(img_path):
# read pictures
img_raw = tf.io.read_file(img_path)
# decode pictures
img_tensor = tf.image.decode_jpeg(img_raw, channels=channels)
# resize
img_tensor = tf.image.resize(img_tensor, [image_height, image_width])
img_tensor = tf.cast(img_tensor, tf.float32)
# normalization
img = img_tensor / 255.0
return img
def get_images_and_labels(data_root_dir):
# get all images' paths (format: string)
data_root = pathlib.Path(data_root_dir)
all_image_path = [str(path) for path in list(data_root.glob('*/*'))]
# get labels' names
label_names = sorted(item.name for item in data_root.glob('*/'))
# dict: {label : index}
label_to_index = dict((label, index) for index, label in enumerate(label_names))
# get all images' labels
all_image_label = [label_to_index[pathlib.Path(single_image_path).parent.name] for single_image_path in all_image_path]
return all_image_path, all_image_label
def get_dataset(dataset_root_dir):
all_image_path, all_image_label = get_images_and_labels(data_root_dir=dataset_root_dir)
# print("image_path: {}".format(all_image_path[:]))
# print("image_label: {}".format(all_image_label[:]))
# load the dataset and preprocess images
image_dataset = tf.data.Dataset.from_tensor_slices(all_image_path).map(load_and_preprocess_image)
label_dataset = tf.data.Dataset.from_tensor_slices(all_image_label)
dataset = tf.data.Dataset.zip((image_dataset, label_dataset))
image_count = len(all_image_path)
return dataset, image_count
def generate_datasets():
train_dataset, train_count = get_dataset(dataset_root_dir=config.train_dir)
valid_dataset, valid_count = get_dataset(dataset_root_dir=config.valid_dir)
test_dataset, test_count = get_dataset(dataset_root_dir=config.test_dir)
# read the original_dataset in the form of batch
train_dataset = train_dataset.shuffle(buffer_size=train_count).batch(batch_size=config.BATCH_SIZE)
valid_dataset = valid_dataset.batch(batch_size=config.BATCH_SIZE)
test_dataset = test_dataset.batch(batch_size=config.BATCH_SIZE)
return train_dataset, valid_dataset, test_dataset, train_count, valid_count, test_count
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/XIAOTUOYABG/TensorFlow2.0_ResNet.git
git@gitee.com:XIAOTUOYABG/TensorFlow2.0_ResNet.git
XIAOTUOYABG
TensorFlow2.0_ResNet
TensorFlow2.0_ResNet
master

搜索帮助