代码拉取完成,页面将自动刷新
# -*- coding: utf-8 -*-
"""
Created on 2018/12/19 14:00
# SplitComb是一个类,类的主要参数有:side_len=144, max_stride=16, stride=4, margin=32, pad_value=170
# SplitComb.split对数据进行padding,以及z、x、y轴上的处理操作,SplitComb.combine对patch数据进行合并操作
"""
import torch
import numpy as np
class SplitComb():
def __init__(self,side_len,max_stride,stride,margin,pad_value):
self.side_len = side_len
self.max_stride = max_stride
self.stride = stride
self.margin = margin
self.pad_value = pad_value
def split(self, data, side_len = None, max_stride = None, margin = None):
if side_len==None:
side_len = self.side_len # 144
if max_stride == None:
max_stride = self.max_stride # 16 margin=32
if margin == None:
margin = self.margin
assert(side_len > margin)
assert(side_len % max_stride == 0)
assert(margin % max_stride == 0)
splits = []
_, z, h, w = data.shape
nz = int(np.ceil(float(z) / side_len))
nh = int(np.ceil(float(h) / side_len))
nw = int(np.ceil(float(w) / side_len))
nzhw = [nz,nh,nw]
self.nzhw = nzhw
pad = [ [0, 0],
[margin, nz * side_len - z + margin],
[margin, nh * side_len - h + margin],
[margin, nw * side_len - w + margin]]
data = np.pad(data, pad, 'edge') # 图像边缘值的像素填充
for iz in range(nz):
for ih in range(nh):
for iw in range(nw):
sz = iz * side_len
ez = (iz + 1) * side_len + 2 * margin
sh = ih * side_len
eh = (ih + 1) * side_len + 2 * margin
sw = iw * side_len
ew = (iw + 1) * side_len + 2 * margin
split = data[np.newaxis, :, sz:ez, sh:eh, sw:ew]
splits.append(split)
splits = np.concatenate(splits, 0)
return splits,nzhw
def combine(self, output, nzhw = None, side_len=None, stride=None, margin=None):
if side_len==None:
side_len = self.side_len
if stride == None:
stride = self.stride
if margin == None:
margin = self.margin
if nzhw is None:
nz = self.nz
nh = self.nh
nw = self.nw
else:
nz,nh,nw = nzhw
assert(side_len % stride == 0)
assert(margin % stride == 0)
side_len /= stride # 36
margin /= stride # 8
splits = []
for i in range(len(output)):
splits.append(output[i])
output = -1000000 * np.ones((
nz * side_len,
nh * side_len,
nw * side_len,
splits[0].shape[3],
splits[0].shape[4]), np.float32)
idx = 0
for iz in range(nz):
for ih in range(nh):
for iw in range(nw):
sz = iz * side_len
ez = (iz + 1) * side_len
sh = ih * side_len
eh = (ih + 1) * side_len
sw = iw * side_len
ew = (iw + 1) * side_len
# print(splits[0].shape) # 切分后的维度(52, 52, 52, 3, 5)
## margin=8,side_len=36
split = splits[idx][margin:margin + side_len, margin:margin + side_len, margin:margin + side_len]
output[sz:ez, sh:eh, sw:ew] = split
idx += 1
return output
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。