1 Star 1 Fork 0

小荷才露尖尖角/cardiac-segmentation

Create your Gitee Account
Explore and code with more than 12 million developers,Free private repositories !:)
Sign up
文件
Clone or Download
fcn_model.py 7.46 KB
Copy Edit Raw Blame History
#!/usr/bin/env python2.7
from keras import optimizers
from keras.models import Model
from keras.layers import Dropout, Lambda
from keras.layers import Input, average
from keras.layers import Conv2D, MaxPooling2D, Conv2DTranspose
from keras.layers import ZeroPadding2D, Cropping2D
from keras import backend as K
def mvn(tensor):
'''Performs per-channel spatial mean-variance normalization.'''
epsilon = 1e-6
mean = K.mean(tensor, axis=(1,2), keepdims=True)
std = K.std(tensor, axis=(1,2), keepdims=True)
mvn = (tensor - mean) / (std + epsilon)
return mvn
def crop(tensors):
'''
List of 2 tensors, the second tensor having larger spatial dimensions.
'''
h_dims, w_dims = [], []
for t in tensors:
b, h, w, d = K.get_variable_shape(t)
h_dims.append(h)
w_dims.append(w)
crop_h, crop_w = (h_dims[1] - h_dims[0]), (w_dims[1] - w_dims[0])
rem_h = crop_h % 2
rem_w = crop_w % 2
crop_h_dims = (crop_h // 2, crop_h // 2 + rem_h)
crop_w_dims = (crop_w // 2, crop_w // 2 + rem_w)
cropped = Cropping2D(cropping=(crop_h_dims, crop_w_dims))(tensors[1])
return cropped
def dice_coef(y_true, y_pred, smooth=1.0):
'''Average dice coefficient per batch.'''
axes = (1,2,3)
intersection = K.sum(y_true * y_pred, axis=axes)
summation = K.sum(y_true, axis=axes) + K.sum(y_pred, axis=axes)
# print('intersection:',intersection,'\n summation: ',summation)
return K.mean((2.0 * intersection + smooth) / (summation + smooth), axis=0)
def dice_coef_loss(y_true, y_pred):
return 1.0 - dice_coef(y_true, y_pred, smooth=1.0)
def jaccard_coef(y_true, y_pred, smooth=1.0):
'''Average jaccard coefficient per batch.'''
axes = (1,2,3)
intersection = K.sum(y_true * y_pred, axis=axes)
union = K.sum(y_true, axis=axes) + K.sum(y_pred, axis=axes) - intersection
# print('intersection:',intersection,'\n union: ',union)
return K.mean( (intersection + smooth) / (union + smooth), axis=0)
def fcn_model(input_shape, num_classes, weights=None):
''' "Skip" FCN architecture similar to Long et al., 2015
https://arxiv.org/abs/1411.4038
'''
if num_classes == 2:
num_classes = 1
loss = dice_coef_loss
activation = 'sigmoid'
else:
loss = 'categorical_crossentropy'
activation = 'softmax'
kwargs = dict(
kernel_size=3,
strides=1,
activation='relu',
padding='same',
use_bias=True,
kernel_initializer='glorot_uniform',
bias_initializer='zeros',
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
trainable=True,
)
data = Input(shape=input_shape, dtype='float', name='data')
mvn0 = Lambda(mvn, name='mvn0')(data)
pad = ZeroPadding2D(padding=5, name='pad')(mvn0)
conv1 = Conv2D(filters=64, name='conv1', **kwargs)(pad)
mvn1 = Lambda(mvn, name='mvn1')(conv1)
conv2 = Conv2D(filters=64, name='conv2', **kwargs)(mvn1)
mvn2 = Lambda(mvn, name='mvn2')(conv2)
conv3 = Conv2D(filters=64, name='conv3', **kwargs)(mvn2)
mvn3 = Lambda(mvn, name='mvn3')(conv3)
pool1 = MaxPooling2D(pool_size=3, strides=2,
padding='valid', name='pool1')(mvn3)
conv4 = Conv2D(filters=128, name='conv4', **kwargs)(pool1)
mvn4 = Lambda(mvn, name='mvn4')(conv4)
conv5 = Conv2D(filters=128, name='conv5', **kwargs)(mvn4)
mvn5 = Lambda(mvn, name='mvn5')(conv5)
conv6 = Conv2D(filters=128, name='conv6', **kwargs)(mvn5)
mvn6 = Lambda(mvn, name='mvn6')(conv6)
conv7 = Conv2D(filters=128, name='conv7', **kwargs)(mvn6)
mvn7 = Lambda(mvn, name='mvn7')(conv7)
pool2 = MaxPooling2D(pool_size=3, strides=2,
padding='valid', name='pool2')(mvn7)
conv8 = Conv2D(filters=256, name='conv8', **kwargs)(pool2)
mvn8 = Lambda(mvn, name='mvn8')(conv8)
conv9 = Conv2D(filters=256, name='conv9', **kwargs)(mvn8)
mvn9 = Lambda(mvn, name='mvn9')(conv9)
conv10 = Conv2D(filters=256, name='conv10', **kwargs)(mvn9)
mvn10 = Lambda(mvn, name='mvn10')(conv10)
conv11 = Conv2D(filters=256, name='conv11', **kwargs)(mvn10)
mvn11 = Lambda(mvn, name='mvn11')(conv11)
pool3 = MaxPooling2D(pool_size=3, strides=2,
padding='valid', name='pool3')(mvn11)
drop1 = Dropout(rate=0.5, name='drop1')(pool3)
conv12 = Conv2D(filters=512, name='conv12', **kwargs)(drop1)
mvn12 = Lambda(mvn, name='mvn12')(conv12)
conv13 = Conv2D(filters=512, name='conv13', **kwargs)(mvn12)
mvn13 = Lambda(mvn, name='mvn13')(conv13)
conv14 = Conv2D(filters=512, name='conv14', **kwargs)(mvn13)
mvn14 = Lambda(mvn, name='mvn14')(conv14)
conv15 = Conv2D(filters=512, name='conv15', **kwargs)(mvn14)
mvn15 = Lambda(mvn, name='mvn15')(conv15)
drop2 = Dropout(rate=0.5, name='drop2')(mvn15)
score_conv15 = Conv2D(filters=num_classes, kernel_size=1,
strides=1, activation=None, padding='valid',
kernel_initializer='glorot_uniform', use_bias=True,
name='score_conv15')(drop2)
upsample1 = Conv2DTranspose(filters=num_classes, kernel_size=3,
strides=2, activation=None, padding='valid',
kernel_initializer='glorot_uniform', use_bias=False,
name='upsample1')(score_conv15)
score_conv11 = Conv2D(filters=num_classes, kernel_size=1,
strides=1, activation=None, padding='valid',
kernel_initializer='glorot_uniform', use_bias=True,
name='score_conv11')(mvn11)
crop1 = Lambda(crop, name='crop1')([upsample1, score_conv11])
fuse_scores1 = average([crop1, upsample1], name='fuse_scores1')
upsample2 = Conv2DTranspose(filters=num_classes, kernel_size=3,
strides=2, activation=None, padding='valid',
kernel_initializer='glorot_uniform', use_bias=False,
name='upsample2')(fuse_scores1)
score_conv7 = Conv2D(filters=num_classes, kernel_size=1,
strides=1, activation=None, padding='valid',
kernel_initializer='glorot_uniform', use_bias=True,
name='score_conv7')(mvn7)
crop2 = Lambda(crop, name='crop2')([upsample2, score_conv7])
fuse_scores2 = average([crop2, upsample2], name='fuse_scores2')
upsample3 = Conv2DTranspose(filters=num_classes, kernel_size=3,
strides=2, activation=None, padding='valid',
kernel_initializer='glorot_uniform', use_bias=False,
name='upsample3')(fuse_scores2)
crop3 = Lambda(crop, name='crop3')([data, upsample3])
predictions = Conv2D(filters=num_classes, kernel_size=1,
strides=1, activation=activation, padding='valid',
kernel_initializer='glorot_uniform', use_bias=True,
name='predictions')(crop3)
model = Model(inputs=data, outputs=predictions)
if weights is not None:
model.load_weights(weights)
model.compile(optimizer=optimizers.Adam(), loss=loss,
metrics=['accuracy', dice_coef, jaccard_coef])
# sgd = optimizers.SGD(lr=0.01, momentum=0.9, nesterov=True)
# model.compile(optimizer=sgd, loss=loss,
# metrics=['accuracy', dice_coef, jaccard_coef])
return model
if __name__ == '__main__':
model = fcn_model((100, 100, 1), 2, weights=None)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/Heconnor/cardiac-segmentation.git
git@gitee.com:Heconnor/cardiac-segmentation.git
Heconnor
cardiac-segmentation
cardiac-segmentation
master

Search