代码拉取完成,页面将自动刷新
# Streamlit YOLOv5 Model2X v0.2
# 创建人:曾逸夫
# 创建时间:2022-07-17
# 功能描述:多选,多项模型转换和打包下载
import os
import shutil
import time
import zipfile
import streamlit as st
# 目录操作
def dir_opt(target_dir):
if os.path.exists(target_dir):
shutil.rmtree(target_dir)
os.mkdir(target_dir)
else:
os.mkdir(target_dir)
# 文件下载
def download_file(uploaded_file):
# --------------- 下载 ---------------
with open(f"{uploaded_file}", 'rb') as fmodel:
# 读取转换的模型文件(pt2x)
f_download_model = fmodel.read()
st.download_button(label='下载转换后的模型', data=f_download_model, file_name=f"{uploaded_file}")
fmodel.close()
# 文件压缩
def zipDir(origin_dir, compress_file):
# --------------- 压缩 ---------------
zip = zipfile.ZipFile(f"{compress_file}", "w", zipfile.ZIP_DEFLATED)
for path, dirnames, filenames in os.walk(f"{origin_dir}"):
fpath = path.replace(f"{origin_dir}", '')
for filename in filenames:
zip.write(os.path.join(path, filename), os.path.join(fpath, filename))
zip.close()
# params_include_list = ["torchscript", "onnx", "openvino", "engine", "coreml", "saved_model", "pb", "tflite", "tfjs"]
def cb_opt(device, imgSize, weight_name, btn_model_list, params_include_list, iou_conf, tflite_options, onnx_options,
torchscript_options):
for i in range(len(btn_model_list)):
if btn_model_list[i]:
st.info(f"正在转换{params_include_list[i]}......")
s = time.time()
if i == 0: # torchscript
os.system(
f"python export.py --device {device} --imgsz {imgSize} --weights ./weights/{weight_name} --include {params_include_list[i]} "
+ "".join([f"--{x} " for x in torchscript_options]))
if i == 1: # onnx
os.system(
f"python export.py --device {device} --imgsz {imgSize} --weights ./weights/{weight_name} --include {params_include_list[i]} "
+ "".join([f"--{x} " for x in onnx_options]))
if i == 3:
# TensorRT需要在GPU模式下导出
os.system(
f"python export.py --imgsz {imgSize} --weights ./weights/{weight_name} --include {params_include_list[i]} --device 0"
)
elif i == 8: # tfjs
os.system(
f"python export.py --device {device} --imgsz {imgSize} --weights ./weights/{weight_name} --include {params_include_list[i]} --iou-thres {iou_conf[0]} --conf-thres {iou_conf[1]}"
)
elif i == 7: # tflite
# 参考:https://github.com/zldrobit/yolov5
os.system(
f"python export.py --device {device} --imgsz {imgSize} --weights ./weights/{weight_name} --include {params_include_list[i]} "
+ "".join([f"--{x} " for x in tflite_options]))
else:
os.system(
f"python export.py --device {device} --imgsz {imgSize} --weights ./weights/{weight_name} --include {params_include_list[i]}"
)
e = time.time()
st.success(f"{params_include_list[i]}转换完成,用时{round((e-s), 2)}秒")
zipDir("./weights", "convert_weights.zip") # 打包weights目录,包括原始权重和转换后的权重
download_file("convert_weights.zip") # 下载打包文件
def main():
with st.container():
st.title("Streamlit YOLOv5 Model2X")
st.text("基于Streamlit的YOLOv5模型转换工具")
st.write("-------------------------------------------------------------")
dir_opt("./weights")
uploaded_file = st.file_uploader("选择YOLOv5模型文件(.pt)")
if uploaded_file is not None:
# 读取上传的模型文件(.pt)
weight_name = uploaded_file.name
st.info(f"正在写入{weight_name}......")
bytes_data = uploaded_file.getvalue()
with open(f"./weights/{weight_name}", 'wb') as fb:
fb.write(bytes_data)
fb.close()
st.success(f"{weight_name}写入成功!")
device = st.radio("请选择设备", ('cpu', 'cuda:0'), index=0)
imgSize = st.radio("请选择图片尺寸", (320, 640, 1280), index=1)
st.text("请选择转换的类型:")
cb_torchscript = st.checkbox('TorchScript')
# ------------- torchscript -------------
if cb_torchscript:
torchscript_options = st.multiselect('onnx选项', ['optimize'])
else:
torchscript_options = []
cb_onnx = st.checkbox('ONNX')
# ------------- onnx -------------
if cb_onnx:
onnx_options = st.multiselect('onnx选项', ['dynamic', 'simplify'])
else:
onnx_options = []
cb_openvino = st.checkbox('OpenVINO')
cb_engine = st.checkbox('TensorRT')
cb_coreml = st.checkbox('CoreML')
cb_saved_model = st.checkbox('TensorFlow SavedModel')
cb_pb = st.checkbox('TensorFlow GraphDef')
cb_tflite = st.checkbox('TensorFlow Lite')
# ------------- tflite -------------
if cb_tflite:
tflite_options = st.multiselect('tflite选项', ['int8', 'nms', 'agnostic-nms'])
else:
tflite_options = []
# cb_edgetpu = st.checkbox('TensorFlow Edge TPU')
cb_tfjs = st.checkbox('TensorFlow.js')
# ------------- tfjs -------------
if cb_tfjs:
iou_thres = st.slider(label='NMS IoU', min_value=0.0, max_value=1.0, value=0.45, step=0.05)
conf_thres = st.slider(label='NMS CONF', min_value=0.0, max_value=1.0, value=0.5, step=0.05)
else:
iou_thres, conf_thres = 0.45, 0.5
btn_convert = st.button('转换')
btn_model_list = [
cb_torchscript, cb_onnx, cb_openvino, cb_engine, cb_coreml, cb_saved_model, cb_pb, cb_tflite, cb_tfjs]
params_include_list = [
"torchscript", "onnx", "openvino", "engine", "coreml", "saved_model", "pb", "tflite", "tfjs"]
if btn_convert:
cb_opt(device, imgSize, weight_name, btn_model_list, params_include_list, [iou_thres, conf_thres],
tflite_options, onnx_options, torchscript_options)
st.write("-------------------------------------------------------------")
if __name__ == "__main__":
main()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。