3 Star 1 Fork 0

Luxian/PointnetKNN

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
modelnet_dataset.py 5.45 KB
一键复制 编辑 原始数据 按行查看 历史
'''
ModelNet dataset. Support ModelNet40, ModelNet10, XYZ and normal channels. Up to 10000 points.
'''
import os
import os.path
import json
import numpy as np
import sys
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = BASE_DIR
sys.path.append(os.path.join(ROOT_DIR, 'utils'))
import provider
def pc_normalize(pc):
l = pc.shape[0]
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
pc = pc / m
return pc
class ModelNetDataset():
def __init__(self, root, batch_size = 32, npoints = 1024, split='train', normalize=True, normal_channel=False, modelnet10=False, cache_size=15000, shuffle=None):
self.root = root
self.batch_size = batch_size
self.npoints = npoints
self.normalize = normalize
if modelnet10:
self.catfile = os.path.join(self.root, 'modelnet10_shape_names.txt')
else:
self.catfile = os.path.join(self.root, 'shape_names.txt')
self.cat = [line.rstrip() for line in open(self.catfile)]
self.classes = dict(zip(self.cat, range(len(self.cat))))
self.normal_channel = normal_channel
shape_ids = {}
if modelnet10:
shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet10_train.txt'))]
shape_ids['test']= [line.rstrip() for line in open(os.path.join(self.root, 'modelnet10_test.txt'))]
else:
shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_train.txt'))]
shape_ids['test']= [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_test.txt'))]
assert(split=='train' or split=='test')
shape_names = ['_'.join(x.split('_')[0:-1]) for x in shape_ids[split]]
# list of (shape_name, shape_txt_file_path) tuple
self.datapath = [(shape_names[i], os.path.join(self.root, shape_names[i], shape_ids[split][i])+'.txt') for i in range(len(shape_ids[split]))]
self.cache_size = cache_size # how many data points to cache in memory
self.cache = {} # from index to (point_set, cls) tuple
if shuffle is None:
if split == 'train': self.shuffle = True
else: self.shuffle = False
else:
self.shuffle = shuffle
self.reset()
def _augment_batch_data(self, batch_data):
if self.normal_channel:
rotated_data = provider.rotate_point_cloud_with_normal(batch_data)
rotated_data = provider.rotate_perturbation_point_cloud_with_normal(rotated_data)
else:
rotated_data = provider.rotate_point_cloud(batch_data)
rotated_data = provider.rotate_perturbation_point_cloud(rotated_data)
jittered_data = provider.random_scale_point_cloud(rotated_data[:,:,0:3])
jittered_data = provider.shift_point_cloud(jittered_data)
jittered_data = provider.jitter_point_cloud(jittered_data)
rotated_data[:,:,0:3] = jittered_data
return provider.shuffle_points(rotated_data)
def _get_item(self, index):
if index in self.cache:
point_set, cls = self.cache[index]
else:
fn = self.datapath[index]
cls = self.classes[self.datapath[index][0]]
cls = np.array([cls]).astype(np.int32)
point_set = np.loadtxt(fn[1],delimiter=',').astype(np.float32)
# Take the first npoints
point_set = point_set[0:self.npoints,:]
if self.normalize:
point_set[:,0:3] = pc_normalize(point_set[:,0:3])
if not self.normal_channel:
point_set = point_set[:,0:3]
if len(self.cache) < self.cache_size:
self.cache[index] = (point_set, cls)
return point_set, cls
def __getitem__(self, index):
return self._get_item(index)
def __len__(self):
return len(self.datapath)
def num_channel(self):
if self.normal_channel:
return 6
else:
return 3
def reset(self):
self.idxs = np.arange(0, len(self.datapath))
if self.shuffle:
np.random.shuffle(self.idxs)
self.num_batches = (len(self.datapath)+self.batch_size-1) // self.batch_size
self.batch_idx = 0
def has_next_batch(self):
return self.batch_idx < self.num_batches
def next_batch(self, augment=False):
''' returned dimension may be smaller than self.batch_size '''
start_idx = self.batch_idx * self.batch_size
end_idx = min((self.batch_idx+1) * self.batch_size, len(self.datapath))
bsize = end_idx - start_idx
batch_data = np.zeros((bsize, self.npoints, self.num_channel()))
batch_label = np.zeros((bsize), dtype=np.int32)
for i in range(bsize):
ps,cls = self._get_item(self.idxs[i+start_idx])
batch_data[i] = ps
batch_label[i] = cls
self.batch_idx += 1
if augment: batch_data = self._augment_batch_data(batch_data)
return batch_data, batch_label
if __name__ == '__main__':
d = ModelNetDataset(root = '../data/modelnet40_normal_resampled', split='test')
print(d.shuffle)
print(len(d))
import time
tic = time.time()
for i in range(10):
ps, cls = d[i]
print(time.time() - tic)
print(ps.shape, type(ps), cls)
print(d.has_next_batch())
ps_batch, cls_batch = d.next_batch(True)
print(ps_batch.shape)
print(cls_batch.shape)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/hanxiaoyang1/pointnetknn.git
git@gitee.com:hanxiaoyang1/pointnetknn.git
hanxiaoyang1
pointnetknn
PointnetKNN
master

搜索帮助