1 Star 1 Fork 0

zhugeliang1/Attention-Gated-Networks

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
visualise_attention.py 7.71 KB
一键复制 编辑 原始数据 按行查看 历史
Jo Schlemper 提交于 2018-04-13 15:06 . added ultrasound
from torch.utils.data import DataLoader
from dataio.loader import get_dataset, get_dataset_path
from dataio.transformation import get_dataset_transformation
from utils.util import json_file_to_pyobj
from utils.visualiser import Visualiser
from models import get_model
import os, time
# import matplotlib
# matplotlib.use('Agg')
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import math, numpy
import numpy as np
from scipy.misc import imresize
from skimage.transform import resize
def plotNNFilter(units, figure_id, interp='bilinear', colormap=cm.jet, colormap_lim=None, title=''):
plt.ion()
filters = units.shape[2]
n_columns = round(math.sqrt(filters))
n_rows = math.ceil(filters / n_columns) + 1
fig = plt.figure(figure_id, figsize=(n_rows*3,n_columns*3))
fig.clf()
for i in range(filters):
ax1 = plt.subplot(n_rows, n_columns, i+1)
plt.imshow(units[:,:,i].T, interpolation=interp, cmap=colormap)
plt.axis('on')
ax1.set_xticklabels([])
ax1.set_yticklabels([])
plt.colorbar()
if colormap_lim:
plt.clim(colormap_lim[0],colormap_lim[1])
plt.subplots_adjust(wspace=0, hspace=0)
plt.tight_layout()
plt.suptitle(title)
def plotNNFilterOverlay(input_im, units, figure_id, interp='bilinear',
colormap=cm.jet, colormap_lim=None, title='', alpha=0.8):
plt.ion()
filters = units.shape[2]
fig = plt.figure(figure_id, figsize=(5,5))
fig.clf()
for i in range(filters):
plt.imshow(input_im[:,:,0], interpolation=interp, cmap='gray')
plt.imshow(units[:,:,i], interpolation=interp, cmap=colormap, alpha=alpha)
plt.axis('off')
plt.colorbar()
plt.title(title, fontsize='small')
if colormap_lim:
plt.clim(colormap_lim[0],colormap_lim[1])
plt.subplots_adjust(wspace=0, hspace=0)
plt.tight_layout()
# plt.savefig('{}/{}.png'.format(dir_name,time.time()))
## Load options
PAUSE = .01
#config_name = 'config_sononet_attention_fs8_v6.json'
#config_name = 'config_sononet_attention_fs8_v8.json'
#config_name = 'config_sononet_attention_fs8_v9.json'
#config_name = 'config_sononet_attention_fs8_v10.json'
#config_name = 'config_sononet_attention_fs8_v11.json'
#config_name = 'config_sononet_attention_fs8_v13.json'
#config_name = 'config_sononet_attention_fs8_v14.json'
#config_name = 'config_sononet_attention_fs8_v15.json'
#config_name = 'config_sononet_attention_fs8_v16.json'
#config_name = 'config_sononet_grid_attention_fs8_v1.json'
config_name = 'config_sononet_grid_attention_fs8_deepsup_v1.json'
config_name = 'config_sononet_grid_attention_fs8_deepsup_v2.json'
config_name = 'config_sononet_grid_attention_fs8_deepsup_v3.json'
config_name = 'config_sononet_grid_attention_fs8_deepsup_v4.json'
# config_name = 'config_sononet_grid_att_fs8_avg.json'
config_name = 'config_sononet_grid_att_fs8_avg_v2.json'
# config_name = 'config_sononet_grid_att_fs8_avg_v3.json'
#config_name = 'config_sononet_grid_att_fs8_avg_v4.json'
#config_name = 'config_sononet_grid_att_fs8_avg_v5.json'
#config_name = 'config_sononet_grid_att_fs8_avg_v5.json'
#config_name = 'config_sononet_grid_att_fs8_avg_v6.json'
#config_name = 'config_sononet_grid_att_fs8_avg_v7.json'
#config_name = 'config_sononet_grid_att_fs8_avg_v8.json'
#config_name = 'config_sononet_grid_att_fs8_avg_v9.json'
#config_name = 'config_sononet_grid_att_fs8_avg_v10.json'
#config_name = 'config_sononet_grid_att_fs8_avg_v11.json'
#config_name = 'config_sononet_grid_att_fs8_avg_v12.json'
config_name = 'config_sononet_grid_att_fs8_avg_v12_scratch.json'
config_name = 'config_sononet_grid_att_fs4_avg_v12.json'
#config_name = 'config_sononet_grid_attention_fs8_v3.json'
json_opts = json_file_to_pyobj('/vol/bitbucket/js3611/projects/transfer_learning/ultrasound/configs_2/{}'.format(config_name))
train_opts = json_opts.training
dir_name = os.path.join('visualisation_debug', config_name)
if not os.path.isdir(dir_name):
os.makedirs(dir_name)
os.makedirs(os.path.join(dir_name,'pos'))
os.makedirs(os.path.join(dir_name,'neg'))
# Setup the NN Model
model = get_model(json_opts.model)
if hasattr(model.net, 'classification_mode'):
model.net.classification_mode = 'attention'
if hasattr(model.net, 'deep_supervised'):
model.net.deep_supervised = False
# Setup Dataset and Augmentation
dataset_class = get_dataset(train_opts.arch_type)
dataset_path = get_dataset_path(train_opts.arch_type, json_opts.data_path)
dataset_transform = get_dataset_transformation(train_opts.arch_type, opts=json_opts.augmentation)
# Setup Data Loader
dataset = dataset_class(dataset_path, split='train', transform=dataset_transform['valid'])
data_loader = DataLoader(dataset=dataset, num_workers=1, batch_size=1, shuffle=True)
# test
for iteration, data in enumerate(data_loader, 1):
model.set_input(data[0], data[1])
cls = dataset.label_names[int(data[1])]
model.validate()
pred_class = model.pred[1]
pred_cls = dataset.label_names[int(pred_class)]
#########################################################
# Display the input image and Down_sample the input image
input_img = model.input[0,0].cpu().numpy()
#input_img = numpy.expand_dims(imresize(input_img, (fmap_size[0], fmap_size[1]), interp='bilinear'), axis=2)
input_img = numpy.expand_dims(input_img, axis=2)
# plotNNFilter(input_img, figure_id=0, colormap="gray")
plotNNFilterOverlay(input_img, numpy.zeros_like(input_img), figure_id=0, interp='bilinear',
colormap=cm.jet, title='[GT:{}|P:{}]'.format(cls, pred_cls),alpha=0)
chance = np.random.random() < 0.01 if cls == "BACKGROUND" else 1
if cls != pred_cls:
plt.savefig('{}/neg/{:03d}.png'.format(dir_name,iteration))
elif cls == pred_cls and chance:
plt.savefig('{}/pos/{:03d}.png'.format(dir_name,iteration))
#########################################################
# Compatibility Scores overlay with input
attentions = []
for i in [1,2]:
fmap = model.get_feature_maps('compatibility_score%d'%i, upscale=False)
if not fmap:
continue
# Output of the attention block
fmap_0 = fmap[0].squeeze().permute(1,2,0).cpu().numpy()
fmap_size = fmap_0.shape
# Attention coefficient (b x c x w x h x s)
attention = fmap[1].squeeze().cpu().numpy()
attention = attention[:, :]
#attention = numpy.expand_dims(resize(attention, (fmap_size[0], fmap_size[1]), mode='constant', preserve_range=True), axis=2)
attention = numpy.expand_dims(resize(attention, (input_img.shape[0], input_img.shape[1]), mode='constant', preserve_range=True), axis=2)
# this one is useless
#plotNNFilter(fmap_0, figure_id=i+3, interp='bilinear', colormap=cm.jet, title='compat. feature %d' %i)
plotNNFilterOverlay(input_img, attention, figure_id=i, interp='bilinear', colormap=cm.jet, title='[GT:{}|P:{}] compat. {}'.format(cls,pred_cls,i), alpha=0.5)
attentions.append(attention)
#plotNNFilterOverlay(input_img, attentions[0], figure_id=4, interp='bilinear', colormap=cm.jet, title='[GT:{}|P:{}] compat. (all)'.format(cls, pred_cls), alpha=0.5)
plotNNFilterOverlay(input_img, numpy.mean(attentions,0), figure_id=4, interp='bilinear', colormap=cm.jet, title='[GT:{}|P:{}] compat. (all)'.format(cls, pred_cls), alpha=0.5)
if cls != pred_cls:
plt.savefig('{}/neg/{:03d}_hm.png'.format(dir_name,iteration))
elif cls == pred_cls and chance:
plt.savefig('{}/pos/{:03d}_hm.png'.format(dir_name,iteration))
# Linear embedding g(x)
# (b, c, h, w)
#gx = fmap[2].squeeze().permute(1,2,0).cpu().numpy()
#plotNNFilter(gx, figure_id=3, interp='nearest', colormap=cm.jet)
plt.show()
plt.pause(PAUSE)
model.destructor()
#if iteration == 1: break
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/wsq_chd/Attention-Gated-Networks.git
git@gitee.com:wsq_chd/Attention-Gated-Networks.git
wsq_chd
Attention-Gated-Networks
Attention-Gated-Networks
master

搜索帮助