代码拉取完成,页面将自动刷新
import paddle
import cv2
import numpy as np
from dataset import TestDataset
import config
class StepTest(paddle.callbacks.Callback):
def __init__(self):
self.testdataset = TestDataset('test')
def on_train_batch_end(self, step, logs=None):
outs = self.model.predict(self.testdataset)
outs = outs[0]
for index, out in enumerate(outs):
path = self.testdataset.indexs[index]
img = cv2.imread(path)
# # 输出热力图
# res = np.reshape(out, (config.LABLE_SIZE, config.LABLE_SIZE, config.CLASS_NUMBER))
# lab = (np.argmax(res, axis=-1) / config.CLASS_NUMBER * 255).astype(np.uint8)
# color = cv2.applyColorMap(lab, cv2.COLORMAP_JET)
# cv2.imwrite('result/' + str(index) + 'result.jpg', color)
# 输出每张图
res = np.reshape(out, (config.LABLE_SIZE, config.LABLE_SIZE, config.CLASS_NUMBER))
res = res * 254
res = res.astype(np.uint8)
# for ch in range(config.CLASS_NUMBER):
# r = res[:, :, ch]
# lab = config.ID2LABEL[ch]
# cv2.imwrite('step_log/' + str(step) + 'result' + lab + '.jpg', r)
ch = 1
r = res[:, :, ch]
lab = config.ID2LABEL[ch]
cv2.imwrite('step_log/' + str(step) + 'result' + lab + '.jpg', r)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。