代码拉取完成,页面将自动刷新
同步操作将从 cold/基于深度学习的车牌识别系统 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
import numpy as np
import os
import cv2
from tensorflow.keras import layers, losses, models
'''
def unet_train():
height = 512
width = 512
path = ''
input_name = os.listdir(path + 'train_image')
n = len(input_name)
print(n)
X_train, y_train = [], []
for i in range(n):
print("正在读取第%d张图片" % i)
img = cv2.imread(path + 'train_image/%d.png' % i)
label = cv2.imread(path + 'train_label/%d.png' % i)
X_train.append(img)
y_train.append(label)
X_train = np.array(X_train)
y_train = np.array(y_train)
def Conv2d_BN(x, nb_filter, kernel_size, strides=(1, 1), padding='same'):
x = layers.Conv2D(nb_filter, kernel_size, strides=strides, padding=padding)(x)
x = layers.BatchNormalization(axis=3)(x)
x = layers.LeakyReLU(alpha=0.1)(x)
return x
def Conv2dT_BN(x, filters, kernel_size, strides=(2, 2), padding='same'):
x = layers.Conv2DTranspose(filters, kernel_size, strides=strides, padding=padding)(x)
x = layers.BatchNormalization(axis=3)(x)
x = layers.LeakyReLU(alpha=0.1)(x)
return x
inpt = layers.Input(shape=(height, width, 3))
conv1 = Conv2d_BN(inpt, 8, (3, 3))
conv1 = Conv2d_BN(conv1, 8, (3, 3))
pool1 = layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(conv1)
conv2 = Conv2d_BN(pool1, 16, (3, 3))
conv2 = Conv2d_BN(conv2, 16, (3, 3))
pool2 = layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(conv2)
conv3 = Conv2d_BN(pool2, 32, (3, 3))
conv3 = Conv2d_BN(conv3, 32, (3, 3))
pool3 = layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(conv3)
conv4 = Conv2d_BN(pool3, 64, (3, 3))
conv4 = Conv2d_BN(conv4, 64, (3, 3))
pool4 = layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(conv4)
conv5 = Conv2d_BN(pool4, 128, (3, 3))
conv5 = layers.Dropout(0.5)(conv5)
conv5 = Conv2d_BN(conv5, 128, (3, 3))
conv5 = layers.Dropout(0.5)(conv5)
convt1 = Conv2dT_BN(conv5, 64, (3, 3))
concat1 = layers.concatenate([conv4, convt1], axis=3)
concat1 = layers.Dropout(0.5)(concat1)
conv6 = Conv2d_BN(concat1, 64, (3, 3))
conv6 = Conv2d_BN(conv6, 64, (3, 3))
convt2 = Conv2dT_BN(conv6, 32, (3, 3))
concat2 = layers.concatenate([conv3, convt2], axis=3)
concat2 = layers.Dropout(0.5)(concat2)
conv7 = Conv2d_BN(concat2, 32, (3, 3))
conv7 = Conv2d_BN(conv7, 32, (3, 3))
convt3 = Conv2dT_BN(conv7, 16, (3, 3))
concat3 = layers.concatenate([conv2, convt3], axis=3)
concat3 = layers.Dropout(0.5)(concat3)
conv8 = Conv2d_BN(concat3, 16, (3, 3))
conv8 = Conv2d_BN(conv8, 16, (3, 3))
convt4 = Conv2dT_BN(conv8, 8, (3, 3))
concat4 = layers.concatenate([conv1, convt4], axis=3)
concat4 = layers.Dropout(0.5)(concat4)
conv9 = Conv2d_BN(concat4, 8, (3, 3))
conv9 = Conv2d_BN(conv9, 8, (3, 3))
conv9 = layers.Dropout(0.5)(conv9)
outpt = layers.Conv2D(filters=3, kernel_size=(1, 1), strides=(1, 1), padding='same', activation='relu')(conv9)
model = models.Model(inpt, outpt)
model.compile(optimizer='adam',
loss='mean_squared_error',
metrics=['accuracy'])
model.summary()
print("开始训练u-net")
model.fit(X_train, y_train, epochs=100, batch_size=15)
model.save('unet.h5')
print('unet.h5保存成功!!!')
'''
def unet_predict(unet, img_src_path):
img_src = cv2.imdecode(np.fromfile(img_src_path, dtype=np.uint8), -1)
# img_src=cv2.imread(img_src_path)
if img_src.shape != (512, 512, 3):
img_src = cv2.resize(img_src, dsize=(512, 512), interpolation=cv2.INTER_AREA)[:, :, :3] # dsize=(宽度,高度),[:,:,:3]是防止图片为4通道图片,后续无法reshape
img_src = img_src.reshape(1, 512, 512, 3)
img_mask = unet.predict(img_src) # 归一化除以255后进行预测
img_src = img_src.reshape(512, 512, 3) # 将原图reshape为3维
img_mask = img_mask.reshape(512, 512, 3) # 将预测后图片reshape为3维
img_mask = img_mask / np.max(img_mask) * 255 # 归一化后乘以255
img_mask[:, :, 2] = img_mask[:, :, 1] = img_mask[:, :, 0] # 三个通道保持相同
img_mask = img_mask.astype(np.uint8) # 将img_mask类型转为int型
return img_src, img_mask
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。