代码拉取完成,页面将自动刷新
同步操作将从 FIRC/TensorFlow2.0_ResNet 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
import tensorflow as tf
import config
import pathlib
from config import image_height, image_width, channels
def load_and_preprocess_image(img_path):
# read pictures
img_raw = tf.io.read_file(img_path)
# decode pictures
img_tensor = tf.image.decode_jpeg(img_raw, channels=channels)
# resize
img_tensor = tf.image.resize(img_tensor, [image_height, image_width])
img_tensor = tf.cast(img_tensor, tf.float32)
# normalization
img = img_tensor / 255.0
return img
def get_images_and_labels(data_root_dir):
# get all images' paths (format: string)
data_root = pathlib.Path(data_root_dir)
all_image_path = [str(path) for path in list(data_root.glob('*/*'))]
# get labels' names
label_names = sorted(item.name for item in data_root.glob('*/'))
# dict: {label : index}
label_to_index = dict((label, index) for index, label in enumerate(label_names))
# get all images' labels
all_image_label = [label_to_index[pathlib.Path(single_image_path).parent.name] for single_image_path in all_image_path]
return all_image_path, all_image_label
def get_dataset(dataset_root_dir):
all_image_path, all_image_label = get_images_and_labels(data_root_dir=dataset_root_dir)
# print("image_path: {}".format(all_image_path[:]))
# print("image_label: {}".format(all_image_label[:]))
# load the dataset and preprocess images
image_dataset = tf.data.Dataset.from_tensor_slices(all_image_path).map(load_and_preprocess_image)
label_dataset = tf.data.Dataset.from_tensor_slices(all_image_label)
dataset = tf.data.Dataset.zip((image_dataset, label_dataset))
image_count = len(all_image_path)
return dataset, image_count
def generate_datasets():
train_dataset, train_count = get_dataset(dataset_root_dir=config.train_dir)
valid_dataset, valid_count = get_dataset(dataset_root_dir=config.valid_dir)
test_dataset, test_count = get_dataset(dataset_root_dir=config.test_dir)
# read the original_dataset in the form of batch
train_dataset = train_dataset.shuffle(buffer_size=train_count).batch(batch_size=config.BATCH_SIZE)
valid_dataset = valid_dataset.batch(batch_size=config.BATCH_SIZE)
test_dataset = test_dataset.batch(batch_size=config.BATCH_SIZE)
return train_dataset, valid_dataset, test_dataset, train_count, valid_count, test_count
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。