1 Star 1 Fork 1

付昌陇/MMBSSL

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
dataloader.py 4.28 KB
一键复制 编辑 原始数据 按行查看 历史
fu-changlong 提交于 2023-07-03 18:30 . 提交
import torch.utils.data
from torch.utils.data.dataloader import default_collate
from batch import BatchSubstructContext, BatchMasking, BatchAE, BatchMaskingND
class DataLoaderSubstructContext(torch.utils.data.DataLoader):
r"""Data loader which merges data objects from a
:class:`torch_geometric.data.dataset` to a mini-batch.
Args:
dataset (Dataset): The dataset from which to load the data.
batch_size (int, optional): How may samples per batch to load.
(default: :obj:`1`)
shuffle (bool, optional): If set to :obj:`True`, the data will be
reshuffled at every epoch (default: :obj:`True`)
"""
def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs):
super(DataLoaderSubstructContext, self).__init__(
dataset,
batch_size,
shuffle,
collate_fn=lambda data_list: BatchSubstructContext.from_data_list(data_list),
**kwargs)
class DataLoaderMasking1(torch.utils.data.DataLoader):
r"""Data loader which merges data objects from a
:class:`torch_geometric.data.dataset` to a mini-batch.
Args:
dataset (Dataset): The dataset from which to load the data.
batch_size (int, optional): How may samples per batch to load.
(default: :obj:`1`)
shuffle (bool, optional): If set to :obj:`True`, the data will be
reshuffled at every epoch (default: :obj:`True`)
"""
def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs):
super(DataLoaderMasking1, self).__init__(
dataset,
batch_size,
shuffle,
collate_fn=lambda data_list: BatchMasking.from_data_list(data_list),
**kwargs)
class DataLoaderMasking2(torch.utils.data.DataLoader):
r"""Data loader which merges data objects from a
:class:`torch_geometric.data.dataset` to a mini-batch.
Args:
dataset (Dataset): The dataset from which to load the data.
batch_size (int, optional): How may samples per batch to load.
(default: :obj:`1`)
shuffle (bool, optional): If set to :obj:`True`, the data will be
reshuffled at every epoch (default: :obj:`True`)
"""
def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs):
super(DataLoaderMasking2, self).__init__(
dataset,
batch_size,
shuffle,
collate_fn=lambda data_list: BatchMaskingND.from_data_list(data_list),
**kwargs)
from util import MaskAtom
class DataLoaderMaskingPred(torch.utils.data.DataLoader):
r"""Data loader which merges data objects from a
:class:`torch_geometric.data.dataset` to a mini-batch.
Args:
dataset (Dataset): The dataset from which to load the data.
batch_size (int, optional): How may samples per batch to load.
(default: :obj:`1`)
shuffle (bool, optional): If set to :obj:`True`, the data will be
reshuffled at every epoch (default: :obj:`True`)
"""
def __init__(self, dataset, batch_size=1, shuffle=True, mask_rate=0.0, mask_edge=0.0, **kwargs):
self._transform = MaskAtom(num_atom_type = 119, num_edge_type = 5, mask_rate = mask_rate, mask_edge=mask_edge)
super(DataLoaderMaskingPred, self).__init__(
dataset,
batch_size,
shuffle,
collate_fn=self.collate_fn,
**kwargs)
def collate_fn(self, batches):
batchs = [self._transform(x) for x in batches]
return BatchMasking.from_data_list(batchs)
class DataLoaderAE(torch.utils.data.DataLoader):
r"""Data loader which merges data objects from a
:class:`torch_geometric.data.dataset` to a mini-batch.
Args:
dataset (Dataset): The dataset from which to load the data.
batch_size (int, optional): How may samples per batch to load.
(default: :obj:`1`)
shuffle (bool, optional): If set to :obj:`True`, the data will be
reshuffled at every epoch (default: :obj:`True`)
"""
def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs):
super(DataLoaderAE, self).__init__(
dataset,
batch_size,
shuffle,
collate_fn=lambda data_list: BatchAE.from_data_list(data_list),
**kwargs)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/fu-changlong/mmbssl.git
git@gitee.com:fu-changlong/mmbssl.git
fu-changlong
mmbssl
MMBSSL
master

搜索帮助