2 Star 0 Fork 0

rjyrjy/super-resolution

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
data.py 6.56 KB
一键复制 编辑 原始数据 按行查看 历史
Martin Krasser 提交于 2019-08-22 14:34 . Rewrite for Tensorflow 2.0
import os
import tensorflow as tf
from tensorflow.python.data.experimental import AUTOTUNE
class DIV2K:
def __init__(self,
scale=2,
subset='train',
downgrade='bicubic',
images_dir='.div2k/images',
caches_dir='.div2k/caches'):
self._ntire_2018 = True
_scales = [2, 3, 4, 8]
if scale in _scales:
self.scale = scale
else:
raise ValueError(f'scale must be in ${_scales}')
if subset == 'train':
self.image_ids = range(1, 801)
elif subset == 'valid':
self.image_ids = range(801, 901)
else:
raise ValueError("subset must be 'train' or 'valid'")
_downgrades_a = ['bicubic', 'unknown']
_downgrades_b = ['mild', 'difficult']
if scale == 8 and downgrade != 'bicubic':
raise ValueError(f'scale 8 only allowed for bicubic downgrade')
if downgrade in _downgrades_b and scale != 4:
raise ValueError(f'{downgrade} downgrade requires scale 4')
if downgrade == 'bicubic' and scale == 8:
self.downgrade = 'x8'
elif downgrade in _downgrades_b:
self.downgrade = downgrade
else:
self.downgrade = downgrade
self._ntire_2018 = False
self.subset = subset
self.images_dir = images_dir
self.caches_dir = caches_dir
os.makedirs(images_dir, exist_ok=True)
os.makedirs(caches_dir, exist_ok=True)
def __len__(self):
return len(self.image_ids)
def dataset(self, batch_size=16, repeat_count=None, random_transform=True):
ds = tf.data.Dataset.zip((self.lr_dataset(), self.hr_dataset()))
if random_transform:
ds = ds.map(lambda lr, hr: random_crop(lr, hr, scale=self.scale), num_parallel_calls=AUTOTUNE)
ds = ds.map(random_rotate, num_parallel_calls=AUTOTUNE)
ds = ds.map(random_flip, num_parallel_calls=AUTOTUNE)
ds = ds.batch(batch_size)
ds = ds.repeat(repeat_count)
ds = ds.prefetch(buffer_size=AUTOTUNE)
return ds
def hr_dataset(self):
if not os.path.exists(self._hr_images_dir()):
download_archive(self._hr_images_archive(), self.images_dir, extract=True)
ds = self._images_dataset(self._hr_image_files()).cache(self._hr_cache_file())
if not os.path.exists(self._hr_cache_index()):
self._populate_cache(ds, self._hr_cache_file())
return ds
def lr_dataset(self):
if not os.path.exists(self._lr_images_dir()):
download_archive(self._lr_images_archive(), self.images_dir, extract=True)
ds = self._images_dataset(self._lr_image_files()).cache(self._lr_cache_file())
if not os.path.exists(self._lr_cache_index()):
self._populate_cache(ds, self._lr_cache_file())
return ds
def _hr_cache_file(self):
return os.path.join(self.caches_dir, f'DIV2K_{self.subset}_HR.cache')
def _lr_cache_file(self):
return os.path.join(self.caches_dir, f'DIV2K_{self.subset}_LR_{self.downgrade}_X{self.scale}.cache')
def _hr_cache_index(self):
return f'{self._hr_cache_file()}.index'
def _lr_cache_index(self):
return f'{self._lr_cache_file()}.index'
def _hr_image_files(self):
images_dir = self._hr_images_dir()
return [os.path.join(images_dir, f'{image_id:04}.png') for image_id in self.image_ids]
def _lr_image_files(self):
images_dir = self._lr_images_dir()
return [os.path.join(images_dir, self._lr_image_file(image_id)) for image_id in self.image_ids]
def _lr_image_file(self, image_id):
if not self._ntire_2018 or self.scale == 8:
return f'{image_id:04}x{self.scale}.png'
else:
return f'{image_id:04}x{self.scale}{self.downgrade[0]}.png'
def _hr_images_dir(self):
return os.path.join(self.images_dir, f'DIV2K_{self.subset}_HR')
def _lr_images_dir(self):
if self._ntire_2018:
return os.path.join(self.images_dir, f'DIV2K_{self.subset}_LR_{self.downgrade}')
else:
return os.path.join(self.images_dir, f'DIV2K_{self.subset}_LR_{self.downgrade}', f'X{self.scale}')
def _hr_images_archive(self):
return f'DIV2K_{self.subset}_HR.zip'
def _lr_images_archive(self):
if self._ntire_2018:
return f'DIV2K_{self.subset}_LR_{self.downgrade}.zip'
else:
return f'DIV2K_{self.subset}_LR_{self.downgrade}_X{self.scale}.zip'
@staticmethod
def _images_dataset(image_files):
ds = tf.data.Dataset.from_tensor_slices(image_files)
ds = ds.map(tf.io.read_file)
ds = ds.map(lambda x: tf.image.decode_png(x, channels=3), num_parallel_calls=AUTOTUNE)
return ds
@staticmethod
def _populate_cache(ds, cache_file):
print(f'Caching decoded images in {cache_file} ...')
for _ in ds: pass
print(f'Cached decoded images in {cache_file}.')
# -----------------------------------------------------------
# Transformations
# -----------------------------------------------------------
def random_crop(lr_img, hr_img, hr_crop_size=96, scale=2):
lr_crop_size = hr_crop_size // scale
lr_img_shape = tf.shape(lr_img)[:2]
lr_w = tf.random.uniform(shape=(), maxval=lr_img_shape[1] - lr_crop_size + 1, dtype=tf.int32)
lr_h = tf.random.uniform(shape=(), maxval=lr_img_shape[0] - lr_crop_size + 1, dtype=tf.int32)
hr_w = lr_w * scale
hr_h = lr_h * scale
lr_img_cropped = lr_img[lr_h:lr_h + lr_crop_size, lr_w:lr_w + lr_crop_size]
hr_img_cropped = hr_img[hr_h:hr_h + hr_crop_size, hr_w:hr_w + hr_crop_size]
return lr_img_cropped, hr_img_cropped
def random_flip(lr_img, hr_img):
rn = tf.random.uniform(shape=(), maxval=1)
return tf.cond(rn < 0.5,
lambda: (lr_img, hr_img),
lambda: (tf.image.flip_left_right(lr_img),
tf.image.flip_left_right(hr_img)))
def random_rotate(lr_img, hr_img):
rn = tf.random.uniform(shape=(), maxval=4, dtype=tf.int32)
return tf.image.rot90(lr_img, rn), tf.image.rot90(hr_img, rn)
# -----------------------------------------------------------
# IO
# -----------------------------------------------------------
def download_archive(file, target_dir, extract=True):
source_url = f'http://data.vision.ee.ethz.ch/cvl/DIV2K/{file}'
target_dir = os.path.abspath(target_dir)
tf.keras.utils.get_file(file, source_url, cache_subdir=target_dir, extract=extract)
os.remove(os.path.join(target_dir, file))
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/Ruanjiyang/super-resolution.git
git@gitee.com:Ruanjiyang/super-resolution.git
Ruanjiyang
super-resolution
super-resolution
master

搜索帮助