代码拉取完成,页面将自动刷新
import copy
import numpy as np
import torch
from PIL import Image
from medpy import metric
from scipy.ndimage import zoom
import torch.nn as nn
import SimpleITK as sitk
class DiceLoss(nn.Module):
def __init__(self, n_classes):
super(DiceLoss, self).__init__()
self.n_classes = n_classes
def _one_hot_encoder(self, input_tensor):
tensor_list = []
for i in range(self.n_classes):
temp_prob = input_tensor == i # * torch.ones_like(input_tensor)
tensor_list.append(temp_prob.unsqueeze(1))
output_tensor = torch.cat(tensor_list, dim=1)
return output_tensor.float()
def _dice_loss(self, score, target):
target = target.float()
smooth = 1e-5
intersect = torch.sum(score * target)
y_sum = torch.sum(target * target)
z_sum = torch.sum(score * score)
loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
loss = 1 - loss
return loss
def forward(self, inputs, target, weight=None, softmax=False):
if softmax:
inputs = torch.softmax(inputs, dim=1)
target = self._one_hot_encoder(target)
if weight is None:
weight = [1] * self.n_classes
assert inputs.size() == target.size(), 'predict {} & target {} shape do not match'.format(inputs.size(), target.size())
class_wise_dice = []
loss = 0.0
for i in range(0, self.n_classes):
dice = self._dice_loss(inputs[:, i], target[:, i])
class_wise_dice.append(1.0 - dice.item())
loss += dice * weight[i]
return loss / self.n_classes
def calculate_metric_percase(pred, gt):
# print(type(gt))
# print(gt.shape())
gt = gt.copy()
pred[pred > 0] = 1
print(gt)
gt[gt > 0] = 1
if pred.sum() > 0 and gt.sum()>0:
dice = metric.binary.dc(pred, gt)
hd95 = metric.binary.hd95(pred, gt)
return dice, hd95
elif pred.sum() > 0 and gt.sum()==0:
return 1, 0
else:
return 0, 0
# def test_single_volume(image, label, net, classes, patch_size=[256, 256], test_save_path=None, case=None, z_spacing=1):
# image, label = image.squeeze(0).cpu().detach().numpy(), label.squeeze(0).cpu().detach().numpy()
# if len(image.shape) == 3:
# prediction = np.zeros_like(label)
# for ind in range(image.shape[0]):
# slice = image[ind, :, :]
# x, y = slice.shape[0], slice.shape[1]
# if x != patch_size[0] or y != patch_size[1]:
# slice = zoom(slice, (patch_size[0] / x, patch_size[1] / y), order=3) # previous using 0
# input = torch.from_numpy(slice).unsqueeze(0).unsqueeze(0).float().cuda()
# net.eval()
# with torch.no_grad():
# outputs = net(input)
# out = torch.argmax(torch.softmax(outputs, dim=1), dim=1).squeeze(0)
# out = out.cpu().detach().numpy()
# if x != patch_size[0] or y != patch_size[1]:
# pred = zoom(out, (x / patch_size[0], y / patch_size[1]), order=0)
# else:
# pred = out
# prediction[ind] = pred
# else:
# input = torch.from_numpy(image).unsqueeze(
# 0).unsqueeze(0).float().cuda()
# net.eval()
# with torch.no_grad():
# out = torch.argmax(torch.softmax(net(input), dim=1), dim=1).squeeze(0)
# prediction = out.cpu().detach().numpy()
# metric_list = []
# for i in range(1, classes):
# metric_list.append(calculate_metric_percase(prediction == i, label == i))
#
# if test_save_path is not None:
# img_itk = sitk.GetImageFromArray(image.astype(np.float32))
# prd_itk = sitk.GetImageFromArray(prediction.astype(np.float32))
# lab_itk = sitk.GetImageFromArray(label.astype(np.float32))
# img_itk.SetSpacing((1, 1, z_spacing))
# prd_itk.SetSpacing((1, 1, z_spacing))
# lab_itk.SetSpacing((1, 1, z_spacing))
# sitk.WriteImage(prd_itk, test_save_path + '/'+case + "_pred.nii.gz")
# sitk.WriteImage(img_itk, test_save_path + '/'+ case + "_img.nii.gz")
# sitk.WriteImage(lab_itk, test_save_path + '/'+ case + "_gt.nii.gz")
# return metric_list
#输出单通道计算指标图
def test_single_volume(image, label, net, classes, patch_size=[256, 256], test_save_path=None, case=None, z_spacing=1):
#如果张量 image 的形状为 [1, C, H, W],则经过挤压操作后,形状变为 [C, H, W]
image, label = image.squeeze(0).cpu().detach().numpy(), label.squeeze(0).cpu().detach().numpy()
# print("图像数据")
# print(image)
# print(label)
# print("标签数据")
_,x, y = image.shape
if x != patch_size[0] or y != patch_size[1]:
#缩放图像符合网络输入
image = zoom(image, (1,patch_size[0] / x, patch_size[1] / y), order=3)
input = torch.from_numpy(image).unsqueeze(0).float().cuda()
net.eval()
with torch.no_grad():
#得到预测结果的索引张量
out = torch.argmax(torch.softmax(net(input), dim=1), dim=1).squeeze(0)
#将预测结果的索引张量转移到 CPU 上,并将其转换为 NumPy 数组形式
out = out.cpu().detach().numpy()
if x != patch_size[0] or y != patch_size[1]:
#缩放图像至原始大小
prediction = zoom(out, (x / patch_size[0], y / patch_size[1]), order=0)
else:
prediction = out
metric_list = []
for i in range(1, classes):
metric_list.append(calculate_metric_percase(prediction == i, label == i))
if test_save_path is not None:
a1 = copy.deepcopy(prediction)
a2 = copy.deepcopy(prediction)
a3 = copy.deepcopy(prediction)
a1[a1 == 0] = 0
a1[a1 == 1] = 1
a1[a1 == 3] = 255
a1[a1 == 4] = 20
a2[a2 == 0] = 0
a2[a2 == 1] = 1
a2[a2 == 3] = 0
a2[a2 == 4] = 10
a3[a3 == 0] = 0
a3[a3 == 1] = 1
a3[a3 == 3] = 0
a3[a3 == 4] = 120
a1 = Image.fromarray(np.uint8(a1)).convert('L')
a2 = Image.fromarray(np.uint8(a2)).convert('L')
a3 = Image.fromarray(np.uint8(a3)).convert('L')
# prediction = Image.merge('RGB', [a1, a2, a3])
prediction = a1
prediction.save(test_save_path+'/'+case+'.PNG')
return metric_list
#输出三通道对比图
def test_three_channel_volume(image, label, net, classes, patch_size=[256, 256], test_save_path=None, vis_save_path=None, case=None, z_spacing=1):
#如果张量 image 的形状为 [1, C, H, W],则经过挤压操作后,形状变为 [C, H, W]
image, label = image.squeeze(0).cpu().detach().numpy(), label.squeeze(0).cpu().detach().numpy()
# print("图像数据")
# print(image)
# print(label)
# print("标签数据")
_,x, y = image.shape
if x != patch_size[0] or y != patch_size[1]:
#缩放图像符合网络输入
image = zoom(image, (1,patch_size[0] / x, patch_size[1] / y), order=3)
input = torch.from_numpy(image).unsqueeze(0).float().cuda()
net.eval()
with torch.no_grad():
#得到预测结果的索引张量
out = torch.argmax(torch.softmax(net(input), dim=1), dim=1).squeeze(0)
#将预测结果的索引张量转移到 CPU 上,并将其转换为 NumPy 数组形式
out = out.cpu().detach().numpy()
if x != patch_size[0] or y != patch_size[1]:
#缩放图像至原始大小
prediction = zoom(out, (x / patch_size[0], y / patch_size[1]), order=0)
else:
prediction = out
metric_list = []
for i in range(1, classes):
metric_list.append(calculate_metric_percase(prediction == i, label == i))
if test_save_path is not None:
a1 = copy.deepcopy(prediction)
a2 = copy.deepcopy(prediction)
a3 = copy.deepcopy(prediction)
a1[a1 == 0] = 0
a1[a1 == 1] = 255
a1[a1 == 3] = 255
a1[a1 == 4] = 20
a2[a2 == 0] = 0
a2[a2 == 1] = 255
a2[a2 == 3] = 0
a2[a2 == 4] = 10
a3[a3 == 0] = 0
a3[a3 == 1] = 255
a3[a3 == 3] = 0
a3[a3 == 4] = 120
a1 = Image.fromarray(np.uint8(a1)).convert('L')
a2 = Image.fromarray(np.uint8(a2)).convert('L')
a3 = Image.fromarray(np.uint8(a3)).convert('L')
prediction = Image.merge('RGB', [a1, a2, a3])
prediction.save(test_save_path+'/'+case+'.PNG')
# return metric_list
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。