1 Star 0 Fork 0

jxinwei/TransUnet

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
utils.py 8.42 KB
一键复制 编辑 原始数据 按行查看 历史
jxinwei 提交于 2024-01-01 20:28 . 初始化提交
import copy
import numpy as np
import torch
from PIL import Image
from medpy import metric
from scipy.ndimage import zoom
import torch.nn as nn
import SimpleITK as sitk
class DiceLoss(nn.Module):
def __init__(self, n_classes):
super(DiceLoss, self).__init__()
self.n_classes = n_classes
def _one_hot_encoder(self, input_tensor):
tensor_list = []
for i in range(self.n_classes):
temp_prob = input_tensor == i # * torch.ones_like(input_tensor)
tensor_list.append(temp_prob.unsqueeze(1))
output_tensor = torch.cat(tensor_list, dim=1)
return output_tensor.float()
def _dice_loss(self, score, target):
target = target.float()
smooth = 1e-5
intersect = torch.sum(score * target)
y_sum = torch.sum(target * target)
z_sum = torch.sum(score * score)
loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
loss = 1 - loss
return loss
def forward(self, inputs, target, weight=None, softmax=False):
if softmax:
inputs = torch.softmax(inputs, dim=1)
target = self._one_hot_encoder(target)
if weight is None:
weight = [1] * self.n_classes
assert inputs.size() == target.size(), 'predict {} & target {} shape do not match'.format(inputs.size(), target.size())
class_wise_dice = []
loss = 0.0
for i in range(0, self.n_classes):
dice = self._dice_loss(inputs[:, i], target[:, i])
class_wise_dice.append(1.0 - dice.item())
loss += dice * weight[i]
return loss / self.n_classes
def calculate_metric_percase(pred, gt):
# print(type(gt))
# print(gt.shape())
gt = gt.copy()
pred[pred > 0] = 1
print(gt)
gt[gt > 0] = 1
if pred.sum() > 0 and gt.sum()>0:
dice = metric.binary.dc(pred, gt)
hd95 = metric.binary.hd95(pred, gt)
return dice, hd95
elif pred.sum() > 0 and gt.sum()==0:
return 1, 0
else:
return 0, 0
# def test_single_volume(image, label, net, classes, patch_size=[256, 256], test_save_path=None, case=None, z_spacing=1):
# image, label = image.squeeze(0).cpu().detach().numpy(), label.squeeze(0).cpu().detach().numpy()
# if len(image.shape) == 3:
# prediction = np.zeros_like(label)
# for ind in range(image.shape[0]):
# slice = image[ind, :, :]
# x, y = slice.shape[0], slice.shape[1]
# if x != patch_size[0] or y != patch_size[1]:
# slice = zoom(slice, (patch_size[0] / x, patch_size[1] / y), order=3) # previous using 0
# input = torch.from_numpy(slice).unsqueeze(0).unsqueeze(0).float().cuda()
# net.eval()
# with torch.no_grad():
# outputs = net(input)
# out = torch.argmax(torch.softmax(outputs, dim=1), dim=1).squeeze(0)
# out = out.cpu().detach().numpy()
# if x != patch_size[0] or y != patch_size[1]:
# pred = zoom(out, (x / patch_size[0], y / patch_size[1]), order=0)
# else:
# pred = out
# prediction[ind] = pred
# else:
# input = torch.from_numpy(image).unsqueeze(
# 0).unsqueeze(0).float().cuda()
# net.eval()
# with torch.no_grad():
# out = torch.argmax(torch.softmax(net(input), dim=1), dim=1).squeeze(0)
# prediction = out.cpu().detach().numpy()
# metric_list = []
# for i in range(1, classes):
# metric_list.append(calculate_metric_percase(prediction == i, label == i))
#
# if test_save_path is not None:
# img_itk = sitk.GetImageFromArray(image.astype(np.float32))
# prd_itk = sitk.GetImageFromArray(prediction.astype(np.float32))
# lab_itk = sitk.GetImageFromArray(label.astype(np.float32))
# img_itk.SetSpacing((1, 1, z_spacing))
# prd_itk.SetSpacing((1, 1, z_spacing))
# lab_itk.SetSpacing((1, 1, z_spacing))
# sitk.WriteImage(prd_itk, test_save_path + '/'+case + "_pred.nii.gz")
# sitk.WriteImage(img_itk, test_save_path + '/'+ case + "_img.nii.gz")
# sitk.WriteImage(lab_itk, test_save_path + '/'+ case + "_gt.nii.gz")
# return metric_list
#输出单通道计算指标图
def test_single_volume(image, label, net, classes, patch_size=[256, 256], test_save_path=None, case=None, z_spacing=1):
#如果张量 image 的形状为 [1, C, H, W],则经过挤压操作后,形状变为 [C, H, W]
image, label = image.squeeze(0).cpu().detach().numpy(), label.squeeze(0).cpu().detach().numpy()
# print("图像数据")
# print(image)
# print(label)
# print("标签数据")
_,x, y = image.shape
if x != patch_size[0] or y != patch_size[1]:
#缩放图像符合网络输入
image = zoom(image, (1,patch_size[0] / x, patch_size[1] / y), order=3)
input = torch.from_numpy(image).unsqueeze(0).float().cuda()
net.eval()
with torch.no_grad():
#得到预测结果的索引张量
out = torch.argmax(torch.softmax(net(input), dim=1), dim=1).squeeze(0)
#将预测结果的索引张量转移到 CPU 上,并将其转换为 NumPy 数组形式
out = out.cpu().detach().numpy()
if x != patch_size[0] or y != patch_size[1]:
#缩放图像至原始大小
prediction = zoom(out, (x / patch_size[0], y / patch_size[1]), order=0)
else:
prediction = out
metric_list = []
for i in range(1, classes):
metric_list.append(calculate_metric_percase(prediction == i, label == i))
if test_save_path is not None:
a1 = copy.deepcopy(prediction)
a2 = copy.deepcopy(prediction)
a3 = copy.deepcopy(prediction)
a1[a1 == 0] = 0
a1[a1 == 1] = 1
a1[a1 == 3] = 255
a1[a1 == 4] = 20
a2[a2 == 0] = 0
a2[a2 == 1] = 1
a2[a2 == 3] = 0
a2[a2 == 4] = 10
a3[a3 == 0] = 0
a3[a3 == 1] = 1
a3[a3 == 3] = 0
a3[a3 == 4] = 120
a1 = Image.fromarray(np.uint8(a1)).convert('L')
a2 = Image.fromarray(np.uint8(a2)).convert('L')
a3 = Image.fromarray(np.uint8(a3)).convert('L')
# prediction = Image.merge('RGB', [a1, a2, a3])
prediction = a1
prediction.save(test_save_path+'/'+case+'.PNG')
return metric_list
#输出三通道对比图
def test_three_channel_volume(image, label, net, classes, patch_size=[256, 256], test_save_path=None, vis_save_path=None, case=None, z_spacing=1):
#如果张量 image 的形状为 [1, C, H, W],则经过挤压操作后,形状变为 [C, H, W]
image, label = image.squeeze(0).cpu().detach().numpy(), label.squeeze(0).cpu().detach().numpy()
# print("图像数据")
# print(image)
# print(label)
# print("标签数据")
_,x, y = image.shape
if x != patch_size[0] or y != patch_size[1]:
#缩放图像符合网络输入
image = zoom(image, (1,patch_size[0] / x, patch_size[1] / y), order=3)
input = torch.from_numpy(image).unsqueeze(0).float().cuda()
net.eval()
with torch.no_grad():
#得到预测结果的索引张量
out = torch.argmax(torch.softmax(net(input), dim=1), dim=1).squeeze(0)
#将预测结果的索引张量转移到 CPU 上,并将其转换为 NumPy 数组形式
out = out.cpu().detach().numpy()
if x != patch_size[0] or y != patch_size[1]:
#缩放图像至原始大小
prediction = zoom(out, (x / patch_size[0], y / patch_size[1]), order=0)
else:
prediction = out
metric_list = []
for i in range(1, classes):
metric_list.append(calculate_metric_percase(prediction == i, label == i))
if test_save_path is not None:
a1 = copy.deepcopy(prediction)
a2 = copy.deepcopy(prediction)
a3 = copy.deepcopy(prediction)
a1[a1 == 0] = 0
a1[a1 == 1] = 255
a1[a1 == 3] = 255
a1[a1 == 4] = 20
a2[a2 == 0] = 0
a2[a2 == 1] = 255
a2[a2 == 3] = 0
a2[a2 == 4] = 10
a3[a3 == 0] = 0
a3[a3 == 1] = 255
a3[a3 == 3] = 0
a3[a3 == 4] = 120
a1 = Image.fromarray(np.uint8(a1)).convert('L')
a2 = Image.fromarray(np.uint8(a2)).convert('L')
a3 = Image.fromarray(np.uint8(a3)).convert('L')
prediction = Image.merge('RGB', [a1, a2, a3])
prediction.save(test_save_path+'/'+case+'.PNG')
# return metric_list
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/jxinwei/trans-unet.git
git@gitee.com:jxinwei/trans-unet.git
jxinwei
trans-unet
TransUnet
master

搜索帮助