代码拉取完成,页面将自动刷新
import cv2
from config import args
import numpy as np
import os
from torch.utils.data import Dataset
import struct
import pickle
class MyDataset(Dataset):
def __init__(self, data_path, transforms):
super(MyDataset, self).__init__()
self.images = [] # 用于存储图片
self.labels = [] # 用于存储标签
self.get_data(data_path) # 通过原始数据集获取图片和标签
self.transforms = transforms # 图片需要进行的变换,ToTensor()等等
def __getitem__(self, index):
image = self.images[index]
image = self.transforms(image) # 进行变换
label = self.labels[index]
return image, label
def __len__(self):
return len(self.labels)
def get_data(self, data_path):
f = open(args.root + '/char_dict', 'rb')
char_dict = pickle.load(f) # 获取码表
f.close()
for file_name in os.listdir(data_path):
if file_name.endswith('.gnt'):
file_path = os.path.join(data_path, file_name)
with open(file_path, 'rb') as f:
header_size = 10
while True:
header = np.fromfile(f, dtype='uint8', count=header_size)
if not header.size:
break
sample_size = header[0] + (header[1] << 8) + (header[2] << 16) + (header[3] << 24)
tag_code = header[5] + (header[4] << 8)
tag_code = struct.pack('>H', tag_code).decode('gb2312')
label = char_dict[tag_code]
width = header[6] + (header[7] << 8)
height = header[8] + (header[9] << 8)
if header_size + width * height != sample_size:
break
image = np.fromfile(f, dtype='uint8', count=width * height).reshape((height, width))
_, image = cv2.threshold(image, 200, 255, cv2.THRESH_BINARY_INV)
image = cv2.resize(image, (args.image_size, args.image_size))
self.images.append(image)
self.labels.append(label)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。