1 Star 0 Fork 0

CardinalSystem/Dain-App

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
balancedsampler.py 1.69 KB
一键复制 编辑 原始数据 按行查看 历史
User 提交于 2020-12-01 13:55 . - First release
from torch.utils.data.sampler import Sampler
import torch
class RandomBalancedSampler(Sampler):
"""Samples elements randomly, with an arbitrary size, independant from dataset length.
this is a balanced sampling that will sample the whole dataset with a random permutation.
Arguments:
data_source (Dataset): dataset to sample from
"""
def __init__(self, data_source, epoch_size):
self.data_size = len(data_source)
self.epoch_size = epoch_size
self.index = 0
def __next__(self):
if self.index == 0:
#re-shuffle the sampler
self.indices = torch.randperm(self.data_size)
self.index = (self.index+1)%self.data_size
return self.indices[self.index]
def next(self):
return self.__next__()
def __iter__(self):
return self
def __len__(self):
return min(self.data_size,self.epoch_size) if self.epoch_size>0 else self.data_size
class SequentialBalancedSampler(Sampler):
"""Samples elements dequentially, with an arbitrary size, independant from dataset length.
this is a balanced sampling that will sample the whole dataset before resetting it.
Arguments:
data_source (Dataset): dataset to sample from
"""
def __init__(self, data_source, epoch_size):
self.data_size = len(data_source)
self.epoch_size = epoch_size
self.index = 0
def __next__(self):
self.index = (self.index+1)%self.data_size
return self.index
def next(self):
return self.__next__()
def __iter__(self):
return self
def __len__(self):
return min(self.data_size,self.epoch_size) if self.epoch_size>0 else self.data_size
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/cardinalsystem/Dain-App.git
git@gitee.com:cardinalsystem/Dain-App.git
cardinalsystem
Dain-App
Dain-App
master

搜索帮助