1 Star 1 Fork 0

Haixu He/波段分布变化检测

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
test.py 4.75 KB
一键复制 编辑 原始数据 按行查看 历史
Haixu He 提交于 2022-05-09 19:17 . update file
"""
@Description :分割影像
@Author :hhx
@Date :2022/5/9 14:46
"""
import os
import time
import cv2
import numpy as np
from skimage import segmentation, io
import torch
import torch.nn as nn
from PIL import Image
class Args(object):
input_image_path = r'CJCC/11_2016-12.jpg' # image/coral.jpg image/tiger.jpg
train_epoch = 2**8
mod_dim1 = 64 #
mod_dim2 = 32
gpu_id = 0
min_label_num = 6 # if the label number small than it, break loop
max_label_num = 256 # if the label number small than it, start to show result image.
class MyNet(nn.Module):
def __init__(self, inp_dim, mod_dim1, mod_dim2):
super(MyNet, self).__init__()
self.seq = nn.Sequential(
nn.Conv2d(inp_dim, mod_dim1, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(mod_dim1),
nn.ReLU(inplace=True),
nn.Conv2d(mod_dim1, mod_dim2, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(mod_dim2),
nn.ReLU(inplace=True),
nn.Conv2d(mod_dim2, mod_dim1, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(mod_dim1),
nn.ReLU(inplace=True),
nn.Conv2d(mod_dim1, mod_dim2, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(mod_dim2),
)
def forward(self, x):
return self.seq(x)
def run(filapath):
start_time0 = time.time()
args = Args()
torch.cuda.manual_seed_all(1943)
np.random.seed(1943)
# os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) # choose GPU:0
os.environ['CUDA_VISIBLE_DEVICES'] = 'cpu' # choose GPU:0
image = cv2.imread(filapath)
# print(image)
print('原始图像的形状为:', image.shape)
'''segmentation ML'''
# seg_map = segmentation.felzenszwalb(image, scale=32, sigma=0.5, min_size=64) # felz方法
seg_map = segmentation.slic(image, compactness=16, n_segments=800, max_iter=8) # slic方法
seg_map = seg_map.flatten()
seg_lab = [np.where(seg_map == u_label)[0] for u_label in np.unique(seg_map)]
# print(seg_lab)
'''train init'''
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
# print(device)
tensor = image.transpose((2, 0, 1))
tensor = tensor.astype(np.float32) / 255.0
tensor = tensor[np.newaxis, :, :, :]
# print(tensor.shape)
tensor = torch.from_numpy(tensor).to(device)
model = MyNet(inp_dim=3, mod_dim1=args.mod_dim1, mod_dim2=args.mod_dim2).to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=5e-2, momentum=0.9)
# optimizer = torch.optim.RMSprop(model.parameters(), lr=1e-1, momentum=0.0)
image_flatten = image.reshape((-1, 3))
color_avg = np.random.randint(255, size=(args.max_label_num, 3))
show = image
'''train loop'''
start_time1 = time.time()
model.train()
for batch_idx in range(args.train_epoch):
'''forward'''
optimizer.zero_grad()
output = model(tensor)[0]
output = output.permute(1, 2, 0).view(-1, args.mod_dim2)
target = torch.argmax(output, 1)
im_target = target.data.cpu().numpy()
'''refine'''
for inds in seg_lab:
u_labels, hist = np.unique(im_target[inds], return_counts=True)
# print(u_labels,hist)
im_target[inds] = u_labels[np.argmax(hist)]
# print(u_labels,np.argmax(hist))
'''backward'''
target = torch.from_numpy(im_target)
target = target.to(device)
loss = criterion(output, target)
loss.backward()
optimizer.step()
'''show image'''
un_label, lab_inverse = np.unique(
im_target,
return_inverse=True,
)
if un_label.shape[0] < args.max_label_num: # update show
img_flatten = image_flatten.copy()
if len(color_avg) != un_label.shape[0]:
color_avg = [np.mean(img_flatten[im_target == label], axis=0, dtype=np.int) for label in un_label]
for lab_id, color in enumerate(color_avg):
img_flatten[lab_inverse == lab_id] = color
show = img_flatten.reshape(image.shape)
cv2.imshow("seg_pt", show)
cv2.waitKey(1)
print('Loss:', batch_idx, loss.item(), len(un_label))
if len(un_label) < args.min_label_num:
break
'''save'''
time0 = time.time() - start_time0
time1 = time.time() - start_time1
print('PyTorchInit: %.2f\nTimeUsed: %.2f' % (time0, time1))
name = filapath.split('\\')[1].split('.')[0]
np.save('result\class\{}.npy'.format(name), show)
cv2.imwrite("result\image\%s.jpg" % (name), show)
if __name__ == '__main__':
path = 'CJCC'
for file in os.listdir(path):
filapath = os.path.join(path, file)
run(filapath)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/HaixuHe/Band-distribution-change-detection.git
git@gitee.com:HaixuHe/Band-distribution-change-detection.git
HaixuHe
Band-distribution-change-detection
波段分布变化检测
master

搜索帮助