5 Star 2 Fork 0

chenyuming/CVTeamTools

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
split_train_test.py 1.72 KB
一键复制 编辑 原始数据 按行查看 历史
chenyuming 提交于 2021-09-19 15:47 . refactor
"""
data format:yolo style。
split train set and test set
数据保存
└── $arg.root_dir
├── images # 所有图像
└── labels # 对应的标注文件,.txt
return
└── $arg.root_dir
├── Main/train.txt
└── Main/test.txt
"""
import os
import cv2
import json
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import argparse
import shutil
import glob
parser = argparse.ArgumentParser()
parser.add_argument('--root_dir', default='./data', type=str,
help="root path of images and labels, include ./images and ./labels and classes.txt")
arg = parser.parse_args()
def self_train_test_split(img_paths, ratio_train=0.9, ratio_test=0.1):
# 这里可以修改数据集划分的比例。
assert int(ratio_train + ratio_test) == 1
train_img, test_img = train_test_split(img_paths, test_size=1 - ratio_train, random_state=233)
print("NUMS of train:test = {}:{}".format(len(train_img), len(test_img)))
return train_img, test_img
def save_train_test_list(root):
train_path = os.path.join(root, 'Main', 'train.txt')
test_path = os.path.join(root, 'Main', 'test.txt')
img_paths = glob.glob(os.path.join(root, 'images/**.jpg'))
for i in range(len(img_paths)):
img_paths[i] += '\n'
train_img, test_img = self_train_test_split(img_paths)
if len(train_img) > 0:
with open(train_path, 'w') as f:
f.writelines(train_img)
if len(test_img) > 0:
with open(test_path, 'w') as f:
f.writelines(test_img)
if __name__ == "__main__":
root_path = arg.root_dir
# root_path = '/home/cym/CYM/dataset/Engineering_vehicle/'
assert os.path.exists(root_path)
save_train_test_list(root_path)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/orzchenyuming/CVTeamTools.git
git@gitee.com:orzchenyuming/CVTeamTools.git
orzchenyuming
CVTeamTools
CVTeamTools
master

搜索帮助