utils.py 8.42 KB
一键复制 编辑 原始数据 按行查看 历史
jxinwei 提交于 2024-01-01 20:28
import copy
import numpy as np
import torch
from PIL import Image
from medpy import metric
from scipy.ndimage import zoom
import torch.nn as nn
import SimpleITK as sitk
class DiceLoss(nn.Module):
def __init__(self, n_classes):
super(DiceLoss, self).__init__()
self.n_classes = n_classes
def _one_hot_encoder(self, input_tensor):
tensor_list = []
for i in range(self.n_classes):
temp_prob = input_tensor == i # * torch.ones_like(input_tensor)
output_tensor = torch.cat(tensor_list, dim=1)
return output_tensor.float()
def _dice_loss(self, score, target):
target = target.float()
smooth = 1e-5
intersect = torch.sum(score * target)
y_sum = torch.sum(target * target)
z_sum = torch.sum(score * score)
loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
loss = 1 - loss
return loss
def forward(self, inputs, target, weight=None, softmax=False):
if softmax:
inputs = torch.softmax(inputs, dim=1)
target = self._one_hot_encoder(target)
if weight is None:
weight = [1] * self.n_classes
assert inputs.size() == target.size(), 'predict {} & target {} shape do not match'.format(inputs.size(), target.size())
class_wise_dice = []
loss = 0.0
for i in range(0, self.n_classes):
dice = self._dice_loss(inputs[:, i], target[:, i])
class_wise_dice.append(1.0 - dice.item())
loss += dice * weight[i]
return loss / self.n_classes
def calculate_metric_percase(pred, gt):
# print(type(gt))
# print(gt.shape())
gt = gt.copy()
pred[pred > 0] = 1
gt[gt > 0] = 1
if pred.sum() > 0 and gt.sum()>0:
dice = metric.binary.dc(pred, gt)
hd95 = metric.binary.hd95(pred, gt)
return dice, hd95
elif pred.sum() > 0 and gt.sum()==0:
return 1, 0
return 0, 0
# def test_single_volume(image, label, net, classes, patch_size=[256, 256], test_save_path=None, case=None, z_spacing=1):
# image, label = image.squeeze(0).cpu().detach().numpy(), label.squeeze(0).cpu().detach().numpy()
# if len(image.shape) == 3:
# prediction = np.zeros_like(label)
# for ind in range(image.shape[0]):
# slice = image[ind, :, :]
# x, y = slice.shape[0], slice.shape[1]
# if x != patch_size[0] or y != patch_size[1]:
# slice = zoom(slice, (patch_size[0] / x, patch_size[1] / y), order=3) # previous using 0
# input = torch.from_numpy(slice).unsqueeze(0).unsqueeze(0).float().cuda()
# net.eval()
# with torch.no_grad():
# outputs = net(input)
# out = torch.argmax(torch.softmax(outputs, dim=1), dim=1).squeeze(0)
# out = out.cpu().detach().numpy()
# if x != patch_size[0] or y != patch_size[1]:
# pred = zoom(out, (x / patch_size[0], y / patch_size[1]), order=0)
# else:
# pred = out
# prediction[ind] = pred
# else:
# input = torch.from_numpy(image).unsqueeze(
# 0).unsqueeze(0).float().cuda()
# net.eval()
# with torch.no_grad():
# out = torch.argmax(torch.softmax(net(input), dim=1), dim=1).squeeze(0)
# prediction = out.cpu().detach().numpy()
# metric_list = []
# for i in range(1, classes):
# metric_list.append(calculate_metric_percase(prediction == i, label == i))
# if test_save_path is not None:
# img_itk = sitk.GetImageFromArray(image.astype(np.float32))
# prd_itk = sitk.GetImageFromArray(prediction.astype(np.float32))
# lab_itk = sitk.GetImageFromArray(label.astype(np.float32))
# img_itk.SetSpacing((1, 1, z_spacing))
# prd_itk.SetSpacing((1, 1, z_spacing))
# lab_itk.SetSpacing((1, 1, z_spacing))
# sitk.WriteImage(prd_itk, test_save_path + '/'+case + "_pred.nii.gz")
# sitk.WriteImage(img_itk, test_save_path + '/'+ case + "_img.nii.gz")
# sitk.WriteImage(lab_itk, test_save_path + '/'+ case + "_gt.nii.gz")
# return metric_list
def test_single_volume(image, label, net, classes, patch_size=[256, 256], test_save_path=None, case=None, z_spacing=1):
#如果张量 image 的形状为 [1, C, H, W],则经过挤压操作后,形状变为 [C, H, W]
image, label = image.squeeze(0).cpu().detach().numpy(), label.squeeze(0).cpu().detach().numpy()
# print("图像数据")
# print(image)
# print(label)
# print("标签数据")
_,x, y = image.shape
if x != patch_size[0] or y != patch_size[1]:
image = zoom(image, (1,patch_size[0] / x, patch_size[1] / y), order=3)
input = torch.from_numpy(image).unsqueeze(0).float().cuda()
with torch.no_grad():
out = torch.argmax(torch.softmax(net(input), dim=1), dim=1).squeeze(0)
#将预测结果的索引张量转移到 CPU 上,并将其转换为 NumPy 数组形式
out = out.cpu().detach().numpy()
if x != patch_size[0] or y != patch_size[1]:
prediction = zoom(out, (x / patch_size[0], y / patch_size[1]), order=0)
prediction = out
metric_list = []
for i in range(1, classes):
metric_list.append(calculate_metric_percase(prediction == i, label == i))
if test_save_path is not None:
a1 = copy.deepcopy(prediction)
a2 = copy.deepcopy(prediction)
a3 = copy.deepcopy(prediction)
a1[a1 == 0] = 0
a1[a1 == 1] = 1
a1[a1 == 3] = 255
a1[a1 == 4] = 20
a2[a2 == 0] = 0
a2[a2 == 1] = 1
a2[a2 == 3] = 0
a2[a2 == 4] = 10
a3[a3 == 0] = 0
a3[a3 == 1] = 1
a3[a3 == 3] = 0
a3[a3 == 4] = 120
a1 = Image.fromarray(np.uint8(a1)).convert('L')
a2 = Image.fromarray(np.uint8(a2)).convert('L')
a3 = Image.fromarray(np.uint8(a3)).convert('L')
# prediction = Image.merge('RGB', [a1, a2, a3])
prediction = a1
return metric_list
def test_three_channel_volume(image, label, net, classes, patch_size=[256, 256], test_save_path=None, vis_save_path=None, case=None, z_spacing=1):
#如果张量 image 的形状为 [1, C, H, W],则经过挤压操作后,形状变为 [C, H, W]
image, label = image.squeeze(0).cpu().detach().numpy(), label.squeeze(0).cpu().detach().numpy()
# print("图像数据")
# print(image)
# print(label)
# print("标签数据")
_,x, y = image.shape
if x != patch_size[0] or y != patch_size[1]:
image = zoom(image, (1,patch_size[0] / x, patch_size[1] / y), order=3)
input = torch.from_numpy(image).unsqueeze(0).float().cuda()
with torch.no_grad():
out = torch.argmax(torch.softmax(net(input), dim=1), dim=1).squeeze(0)
#将预测结果的索引张量转移到 CPU 上,并将其转换为 NumPy 数组形式
out = out.cpu().detach().numpy()
if x != patch_size[0] or y != patch_size[1]:
prediction = zoom(out, (x / patch_size[0], y / patch_size[1]), order=0)
prediction = out
metric_list = []
for i in range(1, classes):
metric_list.append(calculate_metric_percase(prediction == i, label == i))
if test_save_path is not None:
a1 = copy.deepcopy(prediction)
a2 = copy.deepcopy(prediction)
a3 = copy.deepcopy(prediction)
a1[a1 == 0] = 0
a1[a1 == 1] = 255
a1[a1 == 3] = 255
a1[a1 == 4] = 20
a2[a2 == 0] = 0
a2[a2 == 1] = 255
a2[a2 == 3] = 0
a2[a2 == 4] = 10
a3[a3 == 0] = 0
a3[a3 == 1] = 255
a3[a3 == 3] = 0
a3[a3 == 4] = 120
a1 = Image.fromarray(np.uint8(a1)).convert('L')
a2 = Image.fromarray(np.uint8(a2)).convert('L')
a3 = Image.fromarray(np.uint8(a3)).convert('L')
prediction = Image.merge('RGB', [a1, a2, a3])
# return metric_list
