1 Star 1 Fork 0

Dr-Zhuang/image-captioning-seqgan

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
datasets.py 3.77 KB
一键复制 编辑 原始数据 按行查看 历史
import json
import torch
from torch.utils.data import Dataset
import numpy as np
from utils import pil_loader
class ImageCaptionDataset(Dataset):
def __init__(self, dataset,
split_type,
use_img_feats,
transform,
img_src_path,
processed_data_path,
cnn_architecture):
super(ImageCaptionDataset, self).__init__()
self.split_type = split_type
self.use_img_feats = use_img_feats
self.transform = transform
self.img_src_path = img_src_path
self.processed_data_path = processed_data_path
self.dataset = dataset
self.cnn_architecture = cnn_architecture
with open(processed_data_path + '/' + dataset + '/' + split_type + '/img_names.json') as f:
self.img_names = json.load(f)
with open(processed_data_path + '/' + dataset + '/' + split_type + '/captions.json') as f:
self.caps = json.load(f)
with open(processed_data_path + '/' + dataset + '/' + split_type + '/captions_len.json') as f:
self.cap_lens = json.load(f)
if split_type == 'val':
with open(processed_data_path + '/' + dataset + '/' + split_type + '/caps_per_img.json') as f:
self.caps_per_img = json.load(f)
def __getitem__(self, index):
img_name = self.img_names[index]
if self.use_img_feats:
img_feats = np.load(self.processed_data_path + '/' +
self.dataset + '/' + self.split_type +
'/image_features/' + self.cnn_architecture + '/' +
img_name.split('/')[-1] + '.npy')
img_feats = torch.from_numpy(img_feats)
cap = torch.LongTensor(self.caps[index])
cap_len = torch.LongTensor([self.cap_lens[index]])
if self.split_type == 'train':
return img_feats, cap, cap_len
else:
# matching_idxs = [idx for idx, path in enumerate(self.img_names) if path == img_name]
# all_captions = [self.caps[idx] for idx in matching_idxs]
all_caps = self.caps_per_img[img_name]
return img_feats, cap, cap_len, torch.LongTensor(all_caps)
else:
img = pil_loader(self.img_src_path + '/' + self.dataset + '/' + img_name)
img = self.transform(img)
img = torch.FloatTensor(img)
cap = torch.LongTensor(self.caps[index])
cap_len = torch.LongTensor([self.cap_lens[index]])
if self.split_type == 'train':
return img, cap, cap_len
else:
all_caps = self.caps_per_img[img_name]
return img, cap, cap_len, torch.LongTensor(all_caps)
def __len__(self):
return len(self.caps)
class ImageDataset(Dataset):
def __init__(self, split_type, dataset, transform, img_src_path, processed_data_path):
super(ImageDataset, self).__init__()
self.split_type = split_type
self.transform = transform
self.processed_data_path = processed_data_path
self.img_src_path = img_src_path
self.dataset = dataset
with open(processed_data_path + '/' + dataset + '/' + split_type + '/img_names.json') as f:
self.img_names = json.load(f)
self.img_names = sorted(set(self.img_names))
def __getitem__(self, index):
img_name = self.img_names[index]
img = pil_loader(self.img_src_path + '/' + self.dataset + '/' + img_name)
img = self.transform(img)
img = torch.FloatTensor(img)
return img, img_name
def __len__(self):
return len(self.img_names)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/zhuang_shuo/image-captioning-seqgan.git
git@gitee.com:zhuang_shuo/image-captioning-seqgan.git
zhuang_shuo
image-captioning-seqgan
image-captioning-seqgan
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385