2 Star 10 Fork 1

陈泽艇/手写汉字识别

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
MyDataset.py 2.20 KB
一键复制 编辑 原始数据 按行查看 历史
陈泽艇 提交于 2021-05-02 21:51 . 完善文件读取错误
import cv2
from config import args
import numpy as np
import os
from torch.utils.data import Dataset
import struct
import pickle
class MyDataset(Dataset):
def __init__(self, data_path, transforms):
super(MyDataset, self).__init__()
self.images = [] # 用于存储图片
self.labels = [] # 用于存储标签
self.get_data(data_path) # 通过原始数据集获取图片和标签
self.transforms = transforms # 图片需要进行的变换,ToTensor()等等
def __getitem__(self, index):
image = self.images[index]
image = self.transforms(image) # 进行变换
label = self.labels[index]
return image, label
def __len__(self):
return len(self.labels)
def get_data(self, data_path):
f = open(args.root + '/char_dict', 'rb')
char_dict = pickle.load(f) # 获取码表
f.close()
for file_name in os.listdir(data_path):
if file_name.endswith('.gnt'):
file_path = os.path.join(data_path, file_name)
with open(file_path, 'rb') as f:
header_size = 10
while True:
header = np.fromfile(f, dtype='uint8', count=header_size)
if not header.size:
break
sample_size = header[0] + (header[1] << 8) + (header[2] << 16) + (header[3] << 24)
tag_code = header[5] + (header[4] << 8)
tag_code = struct.pack('>H', tag_code).decode('gb2312')
label = char_dict[tag_code]
width = header[6] + (header[7] << 8)
height = header[8] + (header[9] << 8)
if header_size + width * height != sample_size:
break
image = np.fromfile(f, dtype='uint8', count=width * height).reshape((height, width))
_, image = cv2.threshold(image, 200, 255, cv2.THRESH_BINARY_INV)
image = cv2.resize(image, (args.image_size, args.image_size))
self.images.append(image)
self.labels.append(label)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/zeting-chen/wordCNN.git
git@gitee.com:zeting-chen/wordCNN.git
zeting-chen
wordCNN
手写汉字识别
master

搜索帮助