代码拉取完成,页面将自动刷新
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from pc_utils import (rotate_point_cloud, PointcloudScaleAndTranslate)
import rs_cnn.data.data_utils as rscnn_d_utils
from rs_cnn.data.ModelNet40Loader import ModelNet40Cls as rscnn_ModelNet40Cls
# import pointnet2.utils.pointnet2_utils as pointnet2_utils
from pointnet2_tf.modelnet_h5_dataset import ModelNetH5Dataset as pointnet2_ModelNetH5Dataset
from dgcnn.pytorch.data import ModelNet40 as dgcnn_ModelNet40
# from modelnetc_utils import ModelNetC as dgcnn_ModelNetC
from pointcloudc_utils import PointCloudC as dgcnn_ModelNetC
# distilled from the following sources:
# https://github.com/Yochengliu/Relation-Shape-CNN/blob/master/data/ModelNet40Loader.py
# https://github.com/Yochengliu/Relation-Shape-CNN/blob/master/train_cls.py
class ModelNet40Rscnn(Dataset):
def __init__(self, split, data_path, train_data_path,
valid_data_path, test_data_path, num_points):
self.split = split
self.num_points = num_points
_transforms = transforms.Compose([rscnn_d_utils.PointcloudToTensor()])
rscnn_params = {
'num_points': 1024, # although it does not matter
'root': data_path,
'transforms': _transforms,
'train': (split in ["train", "valid"]),
'data_file': {
'train': train_data_path,
'valid': valid_data_path,
'test': test_data_path
}[self.split]
}
self.rscnn_dataset = rscnn_ModelNet40Cls(**rscnn_params)
self.PointcloudScaleAndTranslate = PointcloudScaleAndTranslate()
def __len__(self):
return self.rscnn_dataset.__len__()
def __getitem__(self, idx):
point, label = self.rscnn_dataset.__getitem__(idx)
# for compatibility with the overall code
point = np.array(point)
label = label[0].item()
return {'pc': point, 'label': label}
def batch_proc(self, data_batch, device):
point = data_batch['pc'].to(device)
if self.split == "train":
# (B, npoint)
fps_idx = pointnet2_utils.furthest_point_sample(point, 1200)
fps_idx = fps_idx[:, np.random.choice(1200, self.num_points,
False)]
point = pointnet2_utils.gather_operation(
point.transpose(1, 2).contiguous(),
fps_idx).transpose(1, 2).contiguous() # (B, N, 3)
point.data = self.PointcloudScaleAndTranslate(point.data)
else:
fps_idx = pointnet2_utils.furthest_point_sample(
point, self.num_points) # (B, npoint)
point = pointnet2_utils.gather_operation(
point.transpose(1, 2).contiguous(),
fps_idx).transpose(1, 2).contiguous()
# to maintain compatibility
point = point.cpu()
return {'pc': point, 'label': data_batch['label']}
# distilled from the following sources:
# https://github.com/charlesq34/pointnet2/blob/7961e26e31d0ba5a72020635cee03aac5d0e754a/modelnet_h5_dataset.py
# https://github.com/charlesq34/pointnet2/blob/7961e26e31d0ba5a72020635cee03aac5d0e754a/train.py
class ModelNet40PN2(Dataset):
def __init__(self, split, train_data_path,
valid_data_path, test_data_path, num_points):
self.split = split
self.dataset_name = 'modelnet40_pn2'
data_path = {
"train": train_data_path,
"valid": valid_data_path,
"test": test_data_path
}[self.split]
pointnet2_params = {
'list_filename': data_path,
# this has nothing to do with actual dataloader batch size
'batch_size': 32,
'npoints': num_points,
'shuffle': False
}
# loading all the pointnet2data
self._dataset = pointnet2_ModelNetH5Dataset(**pointnet2_params)
all_pc = []
all_label = []
while self._dataset.has_next_batch():
# augmentation here has nothing to do with actual data_augmentation
pc, label = self._dataset.next_batch(augment=False)
all_pc.append(pc)
all_label.append(label)
self.all_pc = np.concatenate(all_pc)
self.all_label = np.concatenate(all_label)
def __len__(self):
return self.all_pc.shape[0]
def __getitem__(self, idx):
return {'pc': self.all_pc[idx], 'label': np.int64(self.all_label[idx])}
def batch_proc(self, data_batch, device):
if self.split == "train":
point = np.array(data_batch['pc'])
point = self._dataset._augment_batch_data(point)
# converted to tensor to maintain compatibility with the other code
data_batch['pc'] = torch.tensor(point)
else:
pass
return data_batch
class ModelNet40Dgcnn(Dataset):
def __init__(self, split, train_data_path,
valid_data_path, test_data_path, num_points):
self.split = split
self.data_path = {
"train": train_data_path,
"valid": valid_data_path,
"test": test_data_path
}[self.split]
dgcnn_params = {
'partition': 'train' if split in ['train', 'valid'] else 'test',
'num_points': num_points,
"data_path": self.data_path
}
self.dataset = dgcnn_ModelNet40(**dgcnn_params)
def __len__(self):
return self.dataset.__len__()
def __getitem__(self, idx):
pc, label = self.dataset.__getitem__(idx)
return {'pc': pc, 'label': label.item()}
class ModelNetC(Dataset):
def __init__(self, split):
dgcnn_params = {
"split": split
}
self.dataset = dgcnn_ModelNetC(**dgcnn_params)
def __len__(self):
return self.dataset.__len__()
def __getitem__(self, idx):
pc, label = self.dataset.__getitem__(idx)
return {'pc': pc, 'label': label.item()}
def create_dataloader(split, cfg):
num_workers = cfg.DATALOADER.num_workers
batch_size = cfg.DATALOADER.batch_size
dataset_args = {
"split": split
}
# if cfg.EXP.DATASET == "modelnet40_rscnn":
# dataset_args.update(dict(**cfg.DATALOADER.MODELNET40_RSCNN))
# # augmentation directly done in the code so that
# # it is as similar to the vanilla code as possible
# dataset = ModelNet40Rscnn(**dataset_args)
# elif cfg.EXP.DATASET == "modelnet40_pn2":
# dataset_args.update(dict(**cfg.DATALOADER.MODELNET40_PN2))
# dataset = ModelNet40PN2(**dataset_args)
# elif cfg.EXP.DATASET == "modelnet40_dgcnn":
# dataset_args.update(dict(**cfg.DATALOADER.MODELNET40_DGCNN))
# dataset = ModelNet40Dgcnn(**dataset_args)
# elif cfg.EXP.DATASET == "modelnet_c":
# dataset_args.update(dict(**cfg.DATALOADER.MODELNET_C))
dataset = ModelNetC(**dataset_args)
# else:
# assert False
if "batch_proc" not in dir(dataset):
dataset.batch_proc = None
return DataLoader(
dataset,
batch_size,
num_workers=num_workers,
shuffle=(split == "train"),
drop_last=(split == "train"),
pin_memory=(torch.cuda.is_available()) and (not num_workers)
)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。