1 Star 1 Fork 1

娄维尧/3D-Lung-nodules-detection

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
split_combine.py 3.59 KB
一键复制 编辑 原始数据 按行查看 历史
royce.mao 提交于 2018-12-29 15:45 . CT影像分析
# -*- coding: utf-8 -*-
"""
Created on 2018/12/19 14:00
# SplitComb是一个类,类的主要参数有:side_len=144, max_stride=16, stride=4, margin=32, pad_value=170
# SplitComb.split对数据进行padding,以及z、x、y轴上的处理操作,SplitComb.combine对patch数据进行合并操作
"""
import torch
import numpy as np
class SplitComb():
def __init__(self,side_len,max_stride,stride,margin,pad_value):
self.side_len = side_len
self.max_stride = max_stride
self.stride = stride
self.margin = margin
self.pad_value = pad_value
def split(self, data, side_len = None, max_stride = None, margin = None):
if side_len==None:
side_len = self.side_len # 144
if max_stride == None:
max_stride = self.max_stride # 16 margin=32
if margin == None:
margin = self.margin
assert(side_len > margin)
assert(side_len % max_stride == 0)
assert(margin % max_stride == 0)
splits = []
_, z, h, w = data.shape
nz = int(np.ceil(float(z) / side_len))
nh = int(np.ceil(float(h) / side_len))
nw = int(np.ceil(float(w) / side_len))
nzhw = [nz,nh,nw]
self.nzhw = nzhw
pad = [ [0, 0],
[margin, nz * side_len - z + margin],
[margin, nh * side_len - h + margin],
[margin, nw * side_len - w + margin]]
data = np.pad(data, pad, 'edge') # 图像边缘值的像素填充
for iz in range(nz):
for ih in range(nh):
for iw in range(nw):
sz = iz * side_len
ez = (iz + 1) * side_len + 2 * margin
sh = ih * side_len
eh = (ih + 1) * side_len + 2 * margin
sw = iw * side_len
ew = (iw + 1) * side_len + 2 * margin
split = data[np.newaxis, :, sz:ez, sh:eh, sw:ew]
splits.append(split)
splits = np.concatenate(splits, 0)
return splits,nzhw
def combine(self, output, nzhw = None, side_len=None, stride=None, margin=None):
if side_len==None:
side_len = self.side_len
if stride == None:
stride = self.stride
if margin == None:
margin = self.margin
if nzhw is None:
nz = self.nz
nh = self.nh
nw = self.nw
else:
nz,nh,nw = nzhw
assert(side_len % stride == 0)
assert(margin % stride == 0)
side_len /= stride # 36
margin /= stride # 8
splits = []
for i in range(len(output)):
splits.append(output[i])
output = -1000000 * np.ones((
nz * side_len,
nh * side_len,
nw * side_len,
splits[0].shape[3],
splits[0].shape[4]), np.float32)
idx = 0
for iz in range(nz):
for ih in range(nh):
for iw in range(nw):
sz = iz * side_len
ez = (iz + 1) * side_len
sh = ih * side_len
eh = (ih + 1) * side_len
sw = iw * side_len
ew = (iw + 1) * side_len
# print(splits[0].shape) # 切分后的维度(52, 52, 52, 3, 5)
## margin=8,side_len=36
split = splits[idx][margin:margin + side_len, margin:margin + side_len, margin:margin + side_len]
output[sz:ez, sh:eh, sw:ew] = split
idx += 1
return output
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/lou_wei_yao/Lung-nodules-detection.git
git@gitee.com:lou_wei_yao/Lung-nodules-detection.git
lou_wei_yao
Lung-nodules-detection
3D-Lung-nodules-detection
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385