1 Star 1 Fork 1

娄维尧/3D-Lung-nodules-detection

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
data_detector.py 18.90 KB
一键复制 编辑 原始数据 按行查看 历史
royce.mao 提交于 2018-12-29 15:45 . CT影像分析
# -*- coding: utf-8 -*-
import numpy as np
import torch
from torch.utils.data import Dataset
import os
import time
import collections
import random
from layers import iou
from scipy.ndimage import zoom
import warnings
from scipy.ndimage.interpolation import rotate
from scipy.ndimage.morphology import binary_dilation,generate_binary_structure
class DataBowl3Detector(Dataset):
def __init__(self, split, config, phase = 'train',split_comber=None):
assert(phase == 'train' or phase == 'val' or phase == 'test')
self.phase = phase
self.max_stride = config['max_stride']
self.stride = config['stride']
sizelim = config['sizelim']/config['reso']
sizelim2 = config['sizelim2']/config['reso']
sizelim3 = config['sizelim3']/config['reso']
self.blacklist = config['blacklist']
self.isScale = config['aug_scale']
self.r_rand = config['r_rand_crop']
self.augtype = config['augtype']
data_dir = config['datadir']
self.pad_value = config['pad_value']
self.split_comber = split_comber
idcs = split
if phase!='test':
idcs = [f for f in idcs if f not in self.blacklist]
self.channel = config['chanel']
if self.channel==2:
self.filenames = [os.path.join(data_dir, '%s_merge.npy' % idx) for idx in idcs]
elif self.channel ==1:
if 'cleanimg' in config and config['cleanimg']:
self.filenames = [os.path.join(data_dir, '%s_clean.npy' % idx) for idx in idcs]
else:
self.filenames = [os.path.join(data_dir, '%s_img.npy' % idx) for idx in idcs]
self.kagglenames = [f for f in self.filenames if len(f.split('/')[-1].split('_')[0])>20]
self.lunanames = [f for f in self.filenames if len(f.split('/')[-1].split('_')[0])<20]
labels = []
for idx in idcs:
if config['luna_raw'] ==True:
try:
l = np.load(os.path.join(data_dir, '%s_label_raw.npy' % idx))
except:
l = np.load(os.path.join(data_dir, '%s_label.npy' %idx))
else:
l = np.load(os.path.join(data_dir, '%s_label.npy' %idx))
labels.append(l)
self.sample_bboxes = labels
if self.phase!='test':
self.bboxes = []
for i, l in enumerate(labels):
if len(l) > 0 :
for t in l:
if t[3]>sizelim:
self.bboxes.append([np.concatenate([[i],t])])
if t[3]>sizelim2:
self.bboxes+=[[np.concatenate([[i],t])]]*2
if t[3]>sizelim3:
self.bboxes+=[[np.concatenate([[i],t])]]*4
self.bboxes = np.concatenate(self.bboxes,axis = 0)
# patch数据的裁剪
self.crop = Crop(config)
# online hard negative mining(平衡正负样本接近7:3)
self.label_mapping = LabelMapping(config, self.phase)
def __getitem__(self, idx,split=None):
# 训练阶段,在__getitem__()方法中,返回结果要对数据进行crop操作和augment操作,而对应的label也要随着图像的变换做相应的变换
# 测试阶段,直接做切割
t = time.time()
np.random.seed(int(str(t%1)[2:7]))#seed according to time
isRandomImg = False
if self.phase !='test':
if idx>=len(self.bboxes):
isRandom = True
idx = idx%len(self.bboxes)
isRandomImg = np.random.randint(2)
else:
isRandom = False
else:
isRandom = False
if self.phase != 'test':
if not isRandomImg:
bbox = self.bboxes[idx]
filename = self.filenames[int(bbox[0])]
imgs = np.load(filename)[0:self.channel]
bboxes = self.sample_bboxes[int(bbox[0])]
isScale = self.augtype['scale'] and (self.phase=='train')
sample, target, bboxes, coord = self.crop(imgs, bbox[1:], bboxes,isScale,isRandom)
if self.phase=='train' and not isRandom:
sample, target, bboxes, coord = augment(sample, target, bboxes, coord,
ifflip = self.augtype['flip'], ifrotate=self.augtype['rotate'], ifswap = self.augtype['swap'])
else:
randimid = np.random.randint(len(self.kagglenames))
filename = self.kagglenames[randimid]
imgs = np.load(filename)[0:self.channel]
bboxes = self.sample_bboxes[randimid]
isScale = self.augtype['scale'] and (self.phase=='train')
sample, target, bboxes, coord = self.crop(imgs, [], bboxes,isScale=False,isRand=True)
label = self.label_mapping(sample.shape[1:], target, bboxes)
sample = sample.astype(np.float32)
#if filename in self.kagglenames:
# label[label==-1]=0
sample = (sample.astype(np.float32)-128)/128
return torch.from_numpy(sample), torch.from_numpy(label), coord
else:
imgs = np.load(self.filenames[idx])
bboxes = self.sample_bboxes[idx]
nz, nh, nw = imgs.shape[1:]
pz = int(np.ceil(float(nz) / self.stride)) * self.stride
ph = int(np.ceil(float(nh) / self.stride)) * self.stride
pw = int(np.ceil(float(nw) / self.stride)) * self.stride
imgs = np.pad(imgs, [[0,0],[0, pz - nz], [0, ph - nh], [0, pw - nw]], 'constant',constant_values = self.pad_value)
xx,yy,zz = np.meshgrid(np.linspace(-0.5,0.5,imgs.shape[1]/self.stride),
np.linspace(-0.5,0.5,imgs.shape[2]/self.stride),
np.linspace(-0.5,0.5,imgs.shape[3]/self.stride),indexing ='ij')
coord = np.concatenate([xx[np.newaxis,...], yy[np.newaxis,...],zz[np.newaxis,:]],0).astype('float32')
# 测试阶段,需要对原数据做切割(split操作对数据进行padding,以及z、x、y轴上的处理)
imgs, nzhw = self.split_comber.split(imgs)
coord2, nzhw2 = self.split_comber.split(coord,
side_len = self.split_comber.side_len/self.stride,
max_stride = self.split_comber.max_stride/self.stride,
margin = self.split_comber.margin/self.stride)
assert np.all(nzhw==nzhw2)
imgs = (imgs.astype(np.float32)-128)/128
return torch.from_numpy(imgs.astype(np.float32)), bboxes, torch.from_numpy(coord2.astype(np.float32)), np.array(nzhw)
def __len__(self):
if self.phase == 'train':
return len(self.bboxes)/(1-self.r_rand) # 训练阶段整个数据集的大小就是len(self.bboxes)/(1-self.r_rand)
elif self.phase =='val':
return len(self.bboxes)
else:
return len(self.filenames)
def augment(sample, target, bboxes, coord, ifflip = True, ifrotate=True, ifswap = True):
# angle1 = np.random.rand()*180
if ifrotate:
validrot = False
counter = 0
while not validrot:
newtarget = np.copy(target)
angle1 = (np.random.rand()-0.5)*20
size = np.array(sample.shape[2:4]).astype('float')
rotmat = np.array([[np.cos(angle1/180*np.pi),-np.sin(angle1/180*np.pi)],[np.sin(angle1/180*np.pi),np.cos(angle1/180*np.pi)]])
newtarget[1:3] = np.dot(rotmat,target[1:3]-size/2)+size/2
if np.all(newtarget[:3]>target[3]) and np.all(newtarget[:3]< np.array(sample.shape[1:4])-newtarget[3]):
validrot = True
target = newtarget
sample = rotate(sample,angle1,axes=(2,3),reshape=False)
coord = rotate(coord,angle1,axes=(2,3),reshape=False)
for box in bboxes:
box[1:3] = np.dot(rotmat,box[1:3]-size/2)+size/2
else:
counter += 1
if counter ==3:
break
if ifswap:
if sample.shape[1]==sample.shape[2] and sample.shape[1]==sample.shape[3]:
axisorder = np.random.permutation(3)
sample = np.transpose(sample,np.concatenate([[0],axisorder+1]))
coord = np.transpose(coord,np.concatenate([[0],axisorder+1]))
target[:3] = target[:3][axisorder]
bboxes[:,:3] = bboxes[:,:3][:,axisorder]
if ifflip:
# flipid = np.array([np.random.randint(2),np.random.randint(2),np.random.randint(2)])*2-1
flipid = np.array([1,np.random.randint(2),np.random.randint(2)])*2-1
sample = np.ascontiguousarray(sample[:,::flipid[0],::flipid[1],::flipid[2]])
coord = np.ascontiguousarray(coord[:,::flipid[0],::flipid[1],::flipid[2]])
for ax in range(3):
if flipid[ax]==-1:
target[ax] = np.array(sample.shape[ax+1])-target[ax]
bboxes[:,ax]= np.array(sample.shape[ax+1])-bboxes[:,ax]
return sample, target, bboxes, coord
class Crop(object):
def __init__(self, config):
self.crop_size = config['crop_size']
self.bound_size = config['bound_size']
self.stride = config['stride']
self.pad_value = config['pad_value']
def __call__(self, imgs, target, bboxes,isScale=False,isRand=False):
if isScale:
radiusLim = [8.,100.]
scaleLim = [0.75,1.25] # 图像随机缩放的一个区间,相当于增强数据
scaleRange = [np.min([np.max([(radiusLim[0]/target[3]),scaleLim[0]]),1])
,np.max([np.min([(radiusLim[1]/target[3]),scaleLim[1]]),1])]
scale = np.random.rand()*(scaleRange[1]-scaleRange[0])+scaleRange[0]
crop_size = (np.array(self.crop_size).astype('float')/scale).astype('int')
else:
crop_size=self.crop_size
bound_size = self.bound_size
# crop的[1, 128, 128, 128]的图像
## target是当前结节索引的位置信息
## bboxes是当前图像的位置信息
target = np.copy(target)
bboxes = np.copy(bboxes)
start = []
for i in range(3):
if not isRand:
r = target[3] / 2
s = np.floor(target[i] - r)+ 1 - bound_size
e = np.ceil (target[i] + r)+ 1 + bound_size - crop_size[i]
else:
s = np.max([imgs.shape[i+1]-crop_size[i]/2,imgs.shape[i+1]/2+bound_size])
e = np.min([crop_size[i]/2, imgs.shape[i+1]/2-bound_size])
target = np.array([np.nan,np.nan,np.nan,np.nan])
if s>e:
start.append(np.random.randint(e,s))#!
else:
start.append(int(target[i])-crop_size[i]/2+np.random.randint(-bound_size/2,bound_size/2))
normstart = np.array(start).astype('float32')/np.array(imgs.shape[1:])-0.5
normsize = np.array(crop_size).astype('float32')/np.array(imgs.shape[1:])
xx,yy,zz = np.meshgrid(np.linspace(normstart[0],normstart[0]+normsize[0],self.crop_size[0]/self.stride),
np.linspace(normstart[1],normstart[1]+normsize[1],self.crop_size[1]/self.stride),
np.linspace(normstart[2],normstart[2]+normsize[2],self.crop_size[2]/self.stride),indexing ='ij')
coord = np.concatenate([xx[np.newaxis,...], yy[np.newaxis,...],zz[np.newaxis,:]],0).astype('float32')
pad = []
pad.append([0,0])
for i in range(3):
leftpad = max(0,-start[i])
rightpad = max(0,start[i]+crop_size[i]-imgs.shape[i+1])
pad.append([leftpad,rightpad])
crop = imgs[:,
max(start[0],0):min(start[0] + crop_size[0],imgs.shape[1]),
max(start[1],0):min(start[1] + crop_size[1],imgs.shape[2]),
max(start[2],0):min(start[2] + crop_size[2],imgs.shape[3])]
crop = np.pad(crop,pad,'constant',constant_values =self.pad_value)
for i in range(3):
target[i] = target[i] - start[i]
for i in range(len(bboxes)):
for j in range(3):
bboxes[i][j] = bboxes[i][j] - start[j]
if isScale:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
crop = zoom(crop,[1,scale,scale,scale],order=1)
newpad = self.crop_size[0]-crop.shape[1:][0]
if newpad<0:
crop = crop[:,:-newpad,:-newpad,:-newpad]
elif newpad>0:
pad2 = [[0,0],[0,newpad],[0,newpad],[0,newpad]]
crop = np.pad(crop,pad2,'constant',constant_values =self.pad_value)
for i in range(4):
target[i] = target[i]*scale
for i in range(len(bboxes)):
for j in range(4):
bboxes[i][j] = bboxes[i][j]*scale
return crop, target, bboxes, coord
class LabelMapping(object):
def __init__(self, config, phase):
self.stride = np.array(config['stride'])
self.num_neg = int(config['num_neg'])
self.th_neg = config['th_neg']
self.anchors = np.asarray(config['anchors'])
self.phase = phase
if phase == 'train':
self.th_pos = config['th_pos_train']
elif phase == 'val':
self.th_pos = config['th_pos_val']
def __call__(self, input_size, target, bboxes):
stride = self.stride
num_neg = self.num_neg
th_neg = self.th_neg
anchors = self.anchors
th_pos = self.th_pos
struct = generate_binary_structure(3,1)
output_size = []
for i in range(3):
assert(input_size[i] % stride == 0)
output_size.append(input_size[i] / stride)
label = np.zeros(output_size + [len(anchors), 5], np.float32)
offset = ((stride.astype('float')) - 1) / 2
oz = np.arange(offset, offset + stride * (output_size[0] - 1) + 1, stride)
oh = np.arange(offset, offset + stride * (output_size[1] - 1) + 1, stride)
ow = np.arange(offset, offset + stride * (output_size[2] - 1) + 1, stride)
for bbox in bboxes:
for i, anchor in enumerate(anchors):
iz, ih, iw = select_samples(bbox, anchor, th_neg, oz, oh, ow)
label[iz, ih, iw, i, 0] = 1
label[:,:,:, i, 0] = binary_dilation(label[:,:,:, i, 0].astype('bool'),structure=struct,iterations=1).astype('float32')
label = label-1
if self.phase == 'train' and self.num_neg > 0:
neg_z, neg_h, neg_w, neg_a = np.where(label[:, :, :, :, 0] == -1)
neg_idcs = random.sample(range(len(neg_z)), min(num_neg, len(neg_z)))
neg_z, neg_h, neg_w, neg_a = neg_z[neg_idcs], neg_h[neg_idcs], neg_w[neg_idcs], neg_a[neg_idcs]
label[:, :, :, :, 0] = 0
label[neg_z, neg_h, neg_w, neg_a, 0] = -1
if np.isnan(target[0]):
return label
iz, ih, iw, ia = [], [], [], []
for i, anchor in enumerate(anchors):
iiz, iih, iiw = select_samples(target, anchor, th_pos, oz, oh, ow)
iz.append(iiz)
ih.append(iih)
iw.append(iiw)
ia.append(i * np.ones((len(iiz),), np.int64))
iz = np.concatenate(iz, 0)
ih = np.concatenate(ih, 0)
iw = np.concatenate(iw, 0)
ia = np.concatenate(ia, 0)
flag = True
if len(iz) == 0:
pos = []
for i in range(3):
pos.append(max(0, int(np.round((target[i] - offset) / stride))))
idx = np.argmin(np.abs(np.log(target[3] / anchors)))
pos.append(idx)
flag = False
else:
idx = random.sample(range(len(iz)), 1)[0]
pos = [iz[idx], ih[idx], iw[idx], ia[idx]]
dz = (target[0] - oz[pos[0]]) / anchors[pos[3]]
dh = (target[1] - oh[pos[1]]) / anchors[pos[3]]
dw = (target[2] - ow[pos[2]]) / anchors[pos[3]]
dd = np.log(target[3] / anchors[pos[3]])
label[pos[0], pos[1], pos[2], pos[3], :] = [1, dz, dh, dw, dd]
return label
def select_samples(bbox, anchor, th, oz, oh, ow):
z, h, w, d = bbox
max_overlap = min(d, anchor)
min_overlap = np.power(max(d, anchor), 3) * th / max_overlap / max_overlap
if min_overlap > max_overlap:
return np.zeros((0,), np.int64), np.zeros((0,), np.int64), np.zeros((0,), np.int64)
else:
s = z - 0.5 * np.abs(d - anchor) - (max_overlap - min_overlap)
e = z + 0.5 * np.abs(d - anchor) + (max_overlap - min_overlap)
mz = np.logical_and(oz >= s, oz <= e)
iz = np.where(mz)[0]
s = h - 0.5 * np.abs(d - anchor) - (max_overlap - min_overlap)
e = h + 0.5 * np.abs(d - anchor) + (max_overlap - min_overlap)
mh = np.logical_and(oh >= s, oh <= e)
ih = np.where(mh)[0]
s = w - 0.5 * np.abs(d - anchor) - (max_overlap - min_overlap)
e = w + 0.5 * np.abs(d - anchor) + (max_overlap - min_overlap)
mw = np.logical_and(ow >= s, ow <= e)
iw = np.where(mw)[0]
if len(iz) == 0 or len(ih) == 0 or len(iw) == 0:
return np.zeros((0,), np.int64), np.zeros((0,), np.int64), np.zeros((0,), np.int64)
lz, lh, lw = len(iz), len(ih), len(iw)
iz = iz.reshape((-1, 1, 1))
ih = ih.reshape((1, -1, 1))
iw = iw.reshape((1, 1, -1))
iz = np.tile(iz, (1, lh, lw)).reshape((-1))
ih = np.tile(ih, (lz, 1, lw)).reshape((-1))
iw = np.tile(iw, (lz, lh, 1)).reshape((-1))
centers = np.concatenate([
oz[iz].reshape((-1, 1)),
oh[ih].reshape((-1, 1)),
ow[iw].reshape((-1, 1))], axis = 1)
r0 = anchor / 2
s0 = centers - r0
e0 = centers + r0
r1 = d / 2
s1 = bbox[:3] - r1
s1 = s1.reshape((1, -1))
e1 = bbox[:3] + r1
e1 = e1.reshape((1, -1))
overlap = np.maximum(0, np.minimum(e0, e1) - np.maximum(s0, s1))
intersection = overlap[:, 0] * overlap[:, 1] * overlap[:, 2]
union = anchor * anchor * anchor + d * d * d - intersection
iou = intersection / union
mask = iou >= th
#if th > 0.4:
# if np.sum(mask) == 0:
# print(['iou not large', iou.max()])
# else:
# print(['iou large', iou[mask]])
iz = iz[mask]
ih = ih[mask]
iw = iw[mask]
return iz, ih, iw
def collate(batch):
if torch.is_tensor(batch[0]):
return [b.unsqueeze(0) for b in batch]
elif isinstance(batch[0], np.ndarray):
return batch
elif isinstance(batch[0], int):
return torch.LongTensor(batch)
elif isinstance(batch[0], collections.Iterable):
transposed = zip(*batch)
return [collate(samples) for samples in transposed]
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/lou_wei_yao/Lung-nodules-detection.git
git@gitee.com:lou_wei_yao/Lung-nodules-detection.git
lou_wei_yao
Lung-nodules-detection
3D-Lung-nodules-detection
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385