1 Star 0 Fork 0

婷婷的小双子/pytorch-segmentation

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
transforms.py 2.75 KB
一键复制 编辑 原始数据 按行查看 历史
dusty-nv 提交于 2019-08-13 12:22 . added Resize() transform
import numpy as np
from PIL import Image
import random
import torch
from torchvision import transforms as T
from torchvision.transforms import functional as F
def pad_if_smaller(img, size, fill=0):
min_size = min(img.size)
if min_size < size:
ow, oh = img.size
padh = size - oh if oh < size else 0
padw = size - ow if ow < size else 0
img = F.pad(img, (0, 0, padw, padh), fill=fill)
return img
class Compose(object):
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, image, target):
for t in self.transforms:
image, target = t(image, target)
return image, target
class Resize(object):
def __init__(self, size):
self.size = size
def __call__(self, image, target):
image = F.resize(image, self.size)
target = F.resize(target, self.size, interpolation=Image.NEAREST)
return image, target
class RandomResize(object):
def __init__(self, min_size, max_size=None):
self.min_size = min_size
if max_size is None:
max_size = min_size
self.max_size = max_size
def __call__(self, image, target):
size = random.randint(self.min_size, self.max_size)
image = F.resize(image, size)
target = F.resize(target, size, interpolation=Image.NEAREST)
return image, target
class RandomHorizontalFlip(object):
def __init__(self, flip_prob):
self.flip_prob = flip_prob
def __call__(self, image, target):
if random.random() < self.flip_prob:
image = F.hflip(image)
target = F.hflip(target)
return image, target
class RandomCrop(object):
def __init__(self, size):
self.size = size
def __call__(self, image, target):
image = pad_if_smaller(image, self.size)
target = pad_if_smaller(target, self.size, fill=255)
crop_params = T.RandomCrop.get_params(image, (self.size, self.size))
image = F.crop(image, *crop_params)
target = F.crop(target, *crop_params)
return image, target
class CenterCrop(object):
def __init__(self, size):
self.size = size
def __call__(self, image, target):
image = F.center_crop(image, self.size)
target = F.center_crop(target, self.size)
return image, target
class ToTensor(object):
def __call__(self, image, target):
image = F.to_tensor(image)
target = torch.as_tensor(np.asarray(target), dtype=torch.int64)
return image, target
class Normalize(object):
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, image, target):
image = F.normalize(image, mean=self.mean, std=self.std)
return image, target
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/lift-zpz/pytorch-segmentation.git
git@gitee.com:lift-zpz/pytorch-segmentation.git
lift-zpz
pytorch-segmentation
pytorch-segmentation
master

搜索帮助