1 Star 1 Fork 0

小荷才露尖尖角/cardiac-segmentation

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
fcn_model.py 7.46 KB
一键复制 编辑 原始数据 按行查看 历史
小荷才露尖尖角 提交于 2019-12-22 12:05 +08:00 . # update for FCN and U-net
#!/usr/bin/env python2.7
from keras import optimizers
from keras.models import Model
from keras.layers import Dropout, Lambda
from keras.layers import Input, average
from keras.layers import Conv2D, MaxPooling2D, Conv2DTranspose
from keras.layers import ZeroPadding2D, Cropping2D
from keras import backend as K
def mvn(tensor):
'''Performs per-channel spatial mean-variance normalization.'''
epsilon = 1e-6
mean = K.mean(tensor, axis=(1,2), keepdims=True)
std = K.std(tensor, axis=(1,2), keepdims=True)
mvn = (tensor - mean) / (std + epsilon)
return mvn
def crop(tensors):
'''
List of 2 tensors, the second tensor having larger spatial dimensions.
'''
h_dims, w_dims = [], []
for t in tensors:
b, h, w, d = K.get_variable_shape(t)
h_dims.append(h)
w_dims.append(w)
crop_h, crop_w = (h_dims[1] - h_dims[0]), (w_dims[1] - w_dims[0])
rem_h = crop_h % 2
rem_w = crop_w % 2
crop_h_dims = (crop_h // 2, crop_h // 2 + rem_h)
crop_w_dims = (crop_w // 2, crop_w // 2 + rem_w)
cropped = Cropping2D(cropping=(crop_h_dims, crop_w_dims))(tensors[1])
return cropped
def dice_coef(y_true, y_pred, smooth=1.0):
'''Average dice coefficient per batch.'''
axes = (1,2,3)
intersection = K.sum(y_true * y_pred, axis=axes)
summation = K.sum(y_true, axis=axes) + K.sum(y_pred, axis=axes)
# print('intersection:',intersection,'\n summation: ',summation)
return K.mean((2.0 * intersection + smooth) / (summation + smooth), axis=0)
def dice_coef_loss(y_true, y_pred):
return 1.0 - dice_coef(y_true, y_pred, smooth=1.0)
def jaccard_coef(y_true, y_pred, smooth=1.0):
'''Average jaccard coefficient per batch.'''
axes = (1,2,3)
intersection = K.sum(y_true * y_pred, axis=axes)
union = K.sum(y_true, axis=axes) + K.sum(y_pred, axis=axes) - intersection
# print('intersection:',intersection,'\n union: ',union)
return K.mean( (intersection + smooth) / (union + smooth), axis=0)
def fcn_model(input_shape, num_classes, weights=None):
''' "Skip" FCN architecture similar to Long et al., 2015
https://arxiv.org/abs/1411.4038
'''
if num_classes == 2:
num_classes = 1
loss = dice_coef_loss
activation = 'sigmoid'
else:
loss = 'categorical_crossentropy'
activation = 'softmax'
kwargs = dict(
kernel_size=3,
strides=1,
activation='relu',
padding='same',
use_bias=True,
kernel_initializer='glorot_uniform',
bias_initializer='zeros',
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
trainable=True,
)
data = Input(shape=input_shape, dtype='float', name='data')
mvn0 = Lambda(mvn, name='mvn0')(data)
pad = ZeroPadding2D(padding=5, name='pad')(mvn0)
conv1 = Conv2D(filters=64, name='conv1', **kwargs)(pad)
mvn1 = Lambda(mvn, name='mvn1')(conv1)
conv2 = Conv2D(filters=64, name='conv2', **kwargs)(mvn1)
mvn2 = Lambda(mvn, name='mvn2')(conv2)
conv3 = Conv2D(filters=64, name='conv3', **kwargs)(mvn2)
mvn3 = Lambda(mvn, name='mvn3')(conv3)
pool1 = MaxPooling2D(pool_size=3, strides=2,
padding='valid', name='pool1')(mvn3)
conv4 = Conv2D(filters=128, name='conv4', **kwargs)(pool1)
mvn4 = Lambda(mvn, name='mvn4')(conv4)
conv5 = Conv2D(filters=128, name='conv5', **kwargs)(mvn4)
mvn5 = Lambda(mvn, name='mvn5')(conv5)
conv6 = Conv2D(filters=128, name='conv6', **kwargs)(mvn5)
mvn6 = Lambda(mvn, name='mvn6')(conv6)
conv7 = Conv2D(filters=128, name='conv7', **kwargs)(mvn6)
mvn7 = Lambda(mvn, name='mvn7')(conv7)
pool2 = MaxPooling2D(pool_size=3, strides=2,
padding='valid', name='pool2')(mvn7)
conv8 = Conv2D(filters=256, name='conv8', **kwargs)(pool2)
mvn8 = Lambda(mvn, name='mvn8')(conv8)
conv9 = Conv2D(filters=256, name='conv9', **kwargs)(mvn8)
mvn9 = Lambda(mvn, name='mvn9')(conv9)
conv10 = Conv2D(filters=256, name='conv10', **kwargs)(mvn9)
mvn10 = Lambda(mvn, name='mvn10')(conv10)
conv11 = Conv2D(filters=256, name='conv11', **kwargs)(mvn10)
mvn11 = Lambda(mvn, name='mvn11')(conv11)
pool3 = MaxPooling2D(pool_size=3, strides=2,
padding='valid', name='pool3')(mvn11)
drop1 = Dropout(rate=0.5, name='drop1')(pool3)
conv12 = Conv2D(filters=512, name='conv12', **kwargs)(drop1)
mvn12 = Lambda(mvn, name='mvn12')(conv12)
conv13 = Conv2D(filters=512, name='conv13', **kwargs)(mvn12)
mvn13 = Lambda(mvn, name='mvn13')(conv13)
conv14 = Conv2D(filters=512, name='conv14', **kwargs)(mvn13)
mvn14 = Lambda(mvn, name='mvn14')(conv14)
conv15 = Conv2D(filters=512, name='conv15', **kwargs)(mvn14)
mvn15 = Lambda(mvn, name='mvn15')(conv15)
drop2 = Dropout(rate=0.5, name='drop2')(mvn15)
score_conv15 = Conv2D(filters=num_classes, kernel_size=1,
strides=1, activation=None, padding='valid',
kernel_initializer='glorot_uniform', use_bias=True,
name='score_conv15')(drop2)
upsample1 = Conv2DTranspose(filters=num_classes, kernel_size=3,
strides=2, activation=None, padding='valid',
kernel_initializer='glorot_uniform', use_bias=False,
name='upsample1')(score_conv15)
score_conv11 = Conv2D(filters=num_classes, kernel_size=1,
strides=1, activation=None, padding='valid',
kernel_initializer='glorot_uniform', use_bias=True,
name='score_conv11')(mvn11)
crop1 = Lambda(crop, name='crop1')([upsample1, score_conv11])
fuse_scores1 = average([crop1, upsample1], name='fuse_scores1')
upsample2 = Conv2DTranspose(filters=num_classes, kernel_size=3,
strides=2, activation=None, padding='valid',
kernel_initializer='glorot_uniform', use_bias=False,
name='upsample2')(fuse_scores1)
score_conv7 = Conv2D(filters=num_classes, kernel_size=1,
strides=1, activation=None, padding='valid',
kernel_initializer='glorot_uniform', use_bias=True,
name='score_conv7')(mvn7)
crop2 = Lambda(crop, name='crop2')([upsample2, score_conv7])
fuse_scores2 = average([crop2, upsample2], name='fuse_scores2')
upsample3 = Conv2DTranspose(filters=num_classes, kernel_size=3,
strides=2, activation=None, padding='valid',
kernel_initializer='glorot_uniform', use_bias=False,
name='upsample3')(fuse_scores2)
crop3 = Lambda(crop, name='crop3')([data, upsample3])
predictions = Conv2D(filters=num_classes, kernel_size=1,
strides=1, activation=activation, padding='valid',
kernel_initializer='glorot_uniform', use_bias=True,
name='predictions')(crop3)
model = Model(inputs=data, outputs=predictions)
if weights is not None:
model.load_weights(weights)
model.compile(optimizer=optimizers.Adam(), loss=loss,
metrics=['accuracy', dice_coef, jaccard_coef])
# sgd = optimizers.SGD(lr=0.01, momentum=0.9, nesterov=True)
# model.compile(optimizer=sgd, loss=loss,
# metrics=['accuracy', dice_coef, jaccard_coef])
return model
if __name__ == '__main__':
model = fcn_model((100, 100, 1), 2, weights=None)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/Heconnor/cardiac-segmentation.git
git@gitee.com:Heconnor/cardiac-segmentation.git
Heconnor
cardiac-segmentation
cardiac-segmentation
master

搜索帮助