代码拉取完成,页面将自动刷新
"""
Utilities for importing the CIFAR10 dataset.
Each image in the dataset is a numpy array of shape (32, 32, 3), with the values
being unsigned integers (i.e., in the range 0,1,...,255).
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import pickle
import random
import sys
import tensorflow as tf
version = sys.version_info
import numpy as np
class CIFAR10Data(object):
"""
Unpickles the CIFAR10 dataset from a specified folder containing a pickled
version following the format of Krizhevsky which can be found
[here](https://www.cs.toronto.edu/~kriz/cifar.html).
Inputs to constructor
=====================
- path: path to the pickled dataset. The training data must be pickled
into five files named data_batch_i for i = 1, ..., 5, containing 10,000
examples each, the test data
must be pickled into a single file called test_batch containing 10,000
examples, and the 10 class names must be
pickled into a file called batches.meta. The pickled examples should
be stored as a tuple of two objects: an array of 10,000 32x32x3-shaped
arrays, and an array of their 10,000 true labels.
"""
def __init__(self, path):
train_filenames = ['data_batch_{}'.format(ii + 1) for ii in range(5)]
eval_filename = 'test_batch'
metadata_filename = 'batches.meta'
train_images = np.zeros((50000, 32, 32, 3), dtype='uint8')
train_labels = np.zeros(50000, dtype='int32')
for ii, fname in enumerate(train_filenames):
cur_images, cur_labels = self._load_datafile(
os.path.join(path, fname))
train_images[ii * 10000 : (ii+1) * 10000, ...] = cur_images
train_labels[ii * 10000 : (ii+1) * 10000, ...] = cur_labels
eval_images, eval_labels = self._load_datafile(
os.path.join(path, eval_filename))
with open(os.path.join(path, metadata_filename), 'rb') as fo:
if version.major == 3:
data_dict = pickle.load(fo, encoding='bytes')
else:
data_dict = pickle.load(fo)
self.label_names = data_dict[b'label_names']
for ii in range(len(self.label_names)):
self.label_names[ii] = self.label_names[ii].decode('utf-8')
self.train_data = Dataset(train_images, train_labels)
self.eval_data = Dataset(eval_images, eval_labels)
@staticmethod
def _load_datafile(filename):
with open(filename, 'rb') as fo:
if version.major == 3:
data_dict = pickle.load(fo, encoding='bytes')
else:
data_dict = pickle.load(fo)
assert data_dict[b'data'].dtype == np.uint8
image_data = data_dict[b'data']
image_data = image_data.reshape((10000, 3, 32, 32)).transpose(0,2,3,1)
return image_data, np.array(data_dict[b'labels'])
class AugmentedCIFAR10Data(object):
"""
Data augmentation wrapper over a loaded dataset.
Inputs to constructor
=====================
- raw_cifar10data: the loaded CIFAR10 dataset, via the CIFAR10Data class
- sess: current tensorflow session
"""
def __init__(self, raw_cifar10data, sess):
assert isinstance(raw_cifar10data, CIFAR10Data)
self.image_size = 32
# create augmentation computational graph
self.x_input_placeholder = tf.placeholder(tf.float32, shape=[None, 32, 32, 3])
padded = tf.map_fn(lambda img: tf.image.resize_image_with_crop_or_pad(
img, self.image_size + 4, self.image_size + 4),
self.x_input_placeholder)
cropped = tf.map_fn(lambda img: tf.random_crop(img, [self.image_size,
self.image_size,
3]), padded)
flipped = tf.map_fn(lambda img: tf.image.random_flip_left_right(img), cropped)
self.augmented = flipped
self.train_data = AugmentedDataset(raw_cifar10data.train_data, sess,
self.x_input_placeholder,
self.augmented)
self.eval_data = AugmentedDataset(raw_cifar10data.eval_data, sess,
self.x_input_placeholder,
self.augmented)
self.label_names = raw_cifar10data.label_names
class Dataset(object):
"""
Dataset object implementing a simple batching procedure.
"""
def __init__(self, xs, ys):
self.xs = xs
self.n = xs.shape[0]
self.ys = ys
self.batch_start = 0
self.cur_order = np.random.permutation(self.n)
def get_next_batch(self, batch_size, multiple_passes=False,
reshuffle_after_pass=True):
if self.n < batch_size:
raise ValueError('Batch size can be at most the dataset size')
if not multiple_passes:
actual_batch_size = min(batch_size, self.n - self.batch_start)
if actual_batch_size <= 0:
raise ValueError('Pass through the dataset is complete.')
batch_end = self.batch_start + actual_batch_size
batch_xs = self.xs[self.cur_order[self.batch_start : batch_end],...]
batch_ys = self.ys[self.cur_order[self.batch_start : batch_end],...]
self.batch_start += actual_batch_size
return batch_xs, batch_ys
actual_batch_size = min(batch_size, self.n - self.batch_start)
if actual_batch_size < batch_size:
if reshuffle_after_pass:
self.cur_order = np.random.permutation(self.n)
self.batch_start = 0
batch_end = self.batch_start + batch_size
batch_xs = self.xs[self.cur_order[self.batch_start : batch_end], ...]
batch_ys = self.ys[self.cur_order[self.batch_start : batch_end], ...]
self.batch_start += actual_batch_size
return batch_xs, batch_ys
class AugmentedDataset(object):
"""
Dataset object with built-in data augmentation. When performing
adversarial attacks, we cannot include data augmentation as part of the
model. If we do the adversary will try to backprop through it.
"""
def __init__(self, raw_datasubset, sess, x_input_placeholder,
augmented):
self.sess = sess
self.raw_datasubset = raw_datasubset
self.x_input_placeholder = x_input_placeholder
self.augmented = augmented
def get_next_batch(self, batch_size, multiple_passes=False,
reshuffle_after_pass=True):
raw_batch = self.raw_datasubset.get_next_batch(batch_size,
multiple_passes,
reshuffle_after_pass)
images = raw_batch[0].astype(np.float32)
# return both the raw and augmented input
# for adversarial training with rotation/translations, we start
# from the raw input to avoid compounding augmentations
return (raw_batch[0],
self.sess.run(
self.augmented,
feed_dict={self.x_input_placeholder: raw_batch[0]}),
raw_batch[1])
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。