Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Commit

Permalink
fixed three_interpolate grad ops and other minor updates.
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesq34 committed Oct 21, 2019
1 parent 3992eea commit da10b32
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 7 deletions.
13 changes: 8 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ If you find our work useful in your research, please consider citing:

## Installation

Install [Pytorch](https://pytorch.org/get-started/locally/) and [Tensorflow](https://github.com/tensorflow/tensorflow) (for TensorBoard). It is required that you have access to GPUs. Matlab is required to prepare data for SUN RGB-D. The code is tested with Ubuntu 18.04, Pytorch v1.1, TensorFlow v1.14, CUDA 10.0 and cuDNN v7.4.
Install [Pytorch](https://pytorch.org/get-started/locally/) and [Tensorflow](https://github.com/tensorflow/tensorflow) (for TensorBoard). It is required that you have access to GPUs. Matlab is required to prepare data for SUN RGB-D. The code is tested with Ubuntu 18.04, Pytorch v1.1, TensorFlow v1.14, CUDA 10.0 and cuDNN v7.4. Note: there is some incompatibility with newer version of Pytorch (e.g. v1.3), which is to be fixed.

Compile the CUDA layers for [PointNet++](http://arxiv.org/abs/1706.02413), which we used in the backbone network:

Expand All @@ -35,9 +35,9 @@ To see if the compilation is successful, try to run `python models/votenet.py` t
Install the following Python dependencies (with `pip install`):

matplotlib
cv2
opencv-python
plyfile
trimesh>=2.35.39,<2.35.40
'trimesh>=2.35.39,<2.35.40'

## Run demo

Expand All @@ -46,13 +46,13 @@ Unzip the file under the project root path (`/path/to/project/demo_files`) and t

python demo.py

The demo uses a pre-trained model (on SUN RGB-D) to detect objects in a point cloud from an indoor room of a table and a few chairs (from SUN RGB-D val set). You can use 3D visualization software such as the [MeshLab](http://www.meshlab.net/) to open the dumped file under `demo_files/sunrgbd` to see the 3D detection output. Specifically, open `***_pc.ply` and `***_pred_confident_nms_bbox.ply` to see the input point cloud and predicted 3D bounding boxes.
The demo uses a pre-trained model (on SUN RGB-D) to detect objects in a point cloud from an indoor room of a table and a few chairs (from SUN RGB-D val set). You can use 3D visualization software such as the [MeshLab](http://www.meshlab.net/) to open the dumped file under `demo_files/sunrgbd_results` to see the 3D detection output. Specifically, open `***_pc.ply` and `***_pred_confident_nms_bbox.ply` to see the input point cloud and predicted 3D bounding boxes.

You can also run the following command to use another pretrained model on a ScanNet:

python demo.py --dataset scannet --num_point 40000

Detection results will be dumped to `demo_files/scannet`.
Detection results will be dumped to `demo_files/scannet_results`.

## Training and evaluating

Expand Down Expand Up @@ -99,3 +99,6 @@ We want to thank Erik Wijmans for his PointNet++ implementation in Pytorch ([ori

## License
votenet is relased under the MIT License. See the [LICENSE file](https://arxiv.org/pdf/1904.09664.pdf) for more details.

## Change log
10/20/2019: Fixed a bug of the 3D interpolation customized ops (corrected gradient computation). Re-training the model after the fix slightly improves mAP (less than 1 point).
1 change: 1 addition & 0 deletions models/proposal_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def forward(self, xyz, features, end_points):
elif self.sampling == 'random':
# Random sampling from the votes
num_seed = end_points['seed_xyz'].shape[1]
batch_size = end_points['seed_xyz'].shape[0]
sample_inds = torch.randint(0, num_seed, (batch_size, self.num_proposal), dtype=torch.int).cuda()
xyz, features, _ = self.vote_aggregation(xyz, features, sample_inds)
else:
Expand Down
2 changes: 1 addition & 1 deletion pointnet2/_ext_src/src/interpolate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx,
at::device(grad_out.device()).dtype(at::ScalarType::Float));

if (grad_out.type().is_cuda()) {
three_interpolate_kernel_wrapper(
three_interpolate_grad_kernel_wrapper(
grad_out.size(0), grad_out.size(1), grad_out.size(2), m,
grad_out.data<float>(), idx.data<int>(), weight.data<float>(),
output.data<float>());
Expand Down
33 changes: 33 additions & 0 deletions pointnet2/pointnet2_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

''' Testing customized ops. '''

import torch
from torch.autograd import gradcheck
import numpy as np

import os
import sys
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(BASE_DIR)
import pointnet2_utils

def test_interpolation_grad():
batch_size = 1
feat_dim = 2
m = 4
feats = torch.randn(batch_size, feat_dim, m, requires_grad=True).float().cuda()

def interpolate_func(inputs):
idx = torch.from_numpy(np.array([[[0,1,2],[1,2,3]]])).int().cuda()
weight = torch.from_numpy(np.array([[[1,1,1],[2,2,2]]])).float().cuda()
interpolated_feats = pointnet2_utils.three_interpolate(inputs, idx, weight)
return interpolated_feats

assert (gradcheck(interpolate_func, feats, atol=1e-1, rtol=1e-1))

if __name__=='__main__':
test_interpolation_grad()
2 changes: 1 addition & 1 deletion utils/pc_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ def write_lines_as_cylinders(pcl, filename, rad=0.005, res=64):
height = np.sqrt(np.dot(vec, vec))
scene.add_geometry(trimesh.creation.cylinder(radius=rad, height=height, sections=res, transform=M))
mesh_list = trimesh.util.concatenate(scene.dump())
trimesh.io.export.export_mesh(mesh_list, f'{filename}.ply', file_type='ply')
trimesh.io.export.export_mesh(mesh_list, '%s.ply'%(filename), file_type='ply')

# ----------------------------------------
# Testing
Expand Down

0 comments on commit da10b32

Please sign in to comment.