Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Commit

Permalink
added support for boxnet.
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesq34 committed Sep 11, 2019
1 parent 257b8d8 commit 7c19af3
Show file tree
Hide file tree
Showing 6 changed files with 274 additions and 28 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,15 @@ To train a new VoteNet model on SUN RGB-D data (depth images):

CUDA_VISIBLE_DEVICES=0 python train.py --dataset sunrgbd --log_dir log_sunrgbd

You can use `CUDA_VISIBLE_DEVICES=0,1,2` to specify which GPU(s) to use. Without specifying CUDA devices, the training will use all the available GPUs and train with data parallel (Note that due to I/O load, training speedup is not linear to the nubmer of GPUs used). Run `python train.py -h` to see more training options.
You can use `CUDA_VISIBLE_DEVICES=0,1,2` to specify which GPU(s) to use. Without specifying CUDA devices, the training will use all the available GPUs and train with data parallel (Note that due to I/O load, training speedup is not linear to the nubmer of GPUs used). Run `python train.py -h` to see more training options (e.g. you can also set `--model boxnet` to train with the baseline BoxNet model).
While training you can check the `log_sunrgbd/log_train.txt` file on its progress, or use the TensorBoard to see loss curves.

To test the trained model with its checkpoint:

python eval.py --dataset sunrgbd --checkpoint_path log_sunrgbd/checkpoint.tar --dump_dir eval_sunrgbd --cluster_sampling seed_fps --use_3d_nms --use_cls_nms --per_class_proposal

Example results will be dumped in the `eval_sunrgbd` folder (or any other folder you specify). You can run `python eval.py -h` to see the full options for evaluation. After the evaluation, you can use MeshLab to visualize the predicted votes and 3D bounding boxes (select wireframe mode to view the boxes).
Final evaluation results will be printed on screen and also written in the `log_eval.txt` file under the dump directory. In default we evaluate with both [email protected] and [email protected] with 3D IoU on oriented boxes.
Final evaluation results will be printed on screen and also written in the `log_eval.txt` file under the dump directory. In default we evaluate with both [email protected] and [email protected] with 3D IoU on oriented boxes. A properly trained VoteNet should have around 57 [email protected] and 32 [email protected].

### Train and test on ScanNet

Expand All @@ -88,7 +88,7 @@ To test the trained model with its checkpoint:

python eval.py --dataset scannet --checkpoint_path log_scannet/checkpoint.tar --dump_dir eval_scannet --num_point 40000 --cluster_sampling seed_fps --use_3d_nms --use_cls_nms --per_class_proposal

Example results will be dumped in the `eval_scannet` folder (or any other folder you specify). In default we evaluate with both [email protected] and [email protected] with 3D IoU on axis aligned boxes.
Example results will be dumped in the `eval_scannet` folder (or any other folder you specify). In default we evaluate with both [email protected] and [email protected] with 3D IoU on axis aligned boxes. A properly trained VoteNet should have around 58 [email protected] and 35 [email protected].

## Acknowledgements
We want to thank Erik Wijmans for his PointNet++ implementation in Pytorch ([original codebase](https://github.com/erikwijmans/Pointnet2_PyTorch)).
Expand Down
21 changes: 13 additions & 8 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,19 @@ def my_worker_init_fn(worker_id):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
num_input_channel = int(FLAGS.use_color)*3 + int(not FLAGS.no_height)*1

net = MODEL.VoteNet(num_class=DATASET_CONFIG.num_class,
num_heading_bin=DATASET_CONFIG.num_heading_bin,
num_size_cluster=DATASET_CONFIG.num_size_cluster,
mean_size_arr=DATASET_CONFIG.mean_size_arr,
num_proposal=FLAGS.num_target,
input_feature_dim=num_input_channel,
vote_factor=FLAGS.vote_factor,
sampling=FLAGS.cluster_sampling)
if FLAGS.model == 'boxnet':
Detector = MODEL.BoxNet
else:
Detector = MODEL.VoteNet

net = Detector(num_class=DATASET_CONFIG.num_class,
num_heading_bin=DATASET_CONFIG.num_heading_bin,
num_size_cluster=DATASET_CONFIG.num_size_cluster,
mean_size_arr=DATASET_CONFIG.mean_size_arr,
num_proposal=FLAGS.num_target,
input_feature_dim=num_input_channel,
vote_factor=FLAGS.vote_factor,
sampling=FLAGS.cluster_sampling)
net.to(device)
criterion = MODEL.get_loss

Expand Down
115 changes: 115 additions & 0 deletions models/boxnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn
import numpy as np
import sys
import os
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = os.path.dirname(BASE_DIR)
sys.path.append(BASE_DIR)
from backbone_module import Pointnet2Backbone
from proposal_module import ProposalModule
from dump_helper import dump_results
from loss_helper_boxnet import get_loss


class BoxNet(nn.Module):
r"""
A deep neural network for 3D object detection with end-to-end optimizable hough voting.
Parameters
----------
num_class: int
Number of semantics classes to predict over -- size of softmax classifier
num_heading_bin: int
num_size_cluster: int
input_feature_dim: (default: 0)
Input dim in the feature descriptor for each point. If the point cloud is Nx9, this
value should be 6 as in an Nx9 point cloud, 3 of the channels are xyz, and 6 are feature descriptors
num_proposal: int (default: 128)
Number of proposals/detections generated from the network. Each proposal is a 3D OBB with a semantic class.
vote_factor: (default: 1)
Number of votes generated from each seed point.
"""

def __init__(self, num_class, num_heading_bin, num_size_cluster, mean_size_arr,
input_feature_dim=0, num_proposal=128, vote_factor=1, sampling='vote_fps'):
super().__init__()

self.num_class = num_class
self.num_heading_bin = num_heading_bin
self.num_size_cluster = num_size_cluster
self.mean_size_arr = mean_size_arr
assert(mean_size_arr.shape[0] == self.num_size_cluster)
self.input_feature_dim = input_feature_dim
self.num_proposal = num_proposal
self.vote_factor = vote_factor
self.sampling=sampling

# Backbone point feature learning
self.backbone_net = Pointnet2Backbone(input_feature_dim=self.input_feature_dim)

# Box proposal, aggregation and detection
self.pnet = ProposalModule(num_class, num_heading_bin, num_size_cluster,
mean_size_arr, num_proposal, sampling)

def forward(self, inputs):
""" Forward pass of the network
Args:
inputs: dict
{point_clouds}
point_clouds: Variable(torch.cuda.FloatTensor)
(B, N, 3 + input_channels) tensor
Point cloud to run predicts on
Each point in the point-cloud MUST
be formated as (x, y, z, features...)
Returns:
end_points: dict
"""
end_points = {}
batch_size = inputs['point_clouds'].shape[0]

end_points = self.backbone_net(inputs['point_clouds'], end_points)
xyz = end_points['fp2_xyz']
features = end_points['fp2_features']
end_points['seed_inds'] = end_points['fp2_inds']
end_points['seed_xyz'] = xyz
end_points['seed_features'] = features

# Directly predict bounding boxes (skips voting)
end_points = self.pnet(xyz, features, end_points)

return end_points


if __name__=='__main__':
sys.path.append(os.path.join(ROOT_DIR, 'sunrgbd'))
from sunrgbd_detection_dataset import SunrgbdDetectionVotesDataset, DC

# Define dataset
TRAIN_DATASET = SunrgbdDetectionVotesDataset('train', num_points=20000, use_v1=True)

# Define model
model = BoxNet(10,12,10,np.random.random((10,3))).cuda()

# Model forward pass
sample = TRAIN_DATASET[5]
inputs = {'point_clouds': torch.from_numpy(sample['point_clouds']).unsqueeze(0).cuda()}
end_points = model(inputs)
for key in end_points:
print(key, end_points[key])

# Compute loss
for key in sample:
end_points[key] = torch.from_numpy(sample[key]).unsqueeze(0).cuda()
loss, end_points = get_loss(end_points, DC)
print('loss', loss)
end_points['point_clouds'] = inputs['point_clouds']
end_points['pred_mask'] = np.ones((1,128))
dump_results(end_points, 'tmp', DC)
17 changes: 8 additions & 9 deletions models/dump_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,11 @@ def dump_results(end_points, dump_dir, config, inference_switch=False):

# NETWORK OUTPUTS
seed_xyz = end_points['seed_xyz'].detach().cpu().numpy() # (B,num_seed,3)
aggregated_vote_xyz = end_points['aggregated_vote_xyz'].detach().cpu().numpy()
#seed_labels = end_points['seed_labels'].detach().cpu().numpy() # (B,num_seed,)
vote_xyz = end_points['vote_xyz'].detach().cpu().numpy() # (B,num_seed,3)
aggregated_vote_xyz = end_points['aggregated_vote_xyz'].detach().cpu().numpy()
if 'vote_xyz' in end_points:
aggregated_vote_xyz = end_points['aggregated_vote_xyz'].detach().cpu().numpy()
vote_xyz = end_points['vote_xyz'].detach().cpu().numpy() # (B,num_seed,3)
aggregated_vote_xyz = end_points['aggregated_vote_xyz'].detach().cpu().numpy()
objectness_scores = end_points['objectness_scores'].detach().cpu().numpy() # (B,K,2)
#object_seed_scores = end_points['seed_logits'].detach().cpu().numpy() # (B,num_seed,2)
pred_center = end_points['center'].detach().cpu().numpy() # (B,K,3)
pred_heading_class = torch.argmax(end_points['heading_scores'], -1) # B,num_proposal
pred_heading_residual = torch.gather(end_points['heading_residuals'], 2, pred_heading_class.unsqueeze(-1)) # B,num_proposal,1
Expand All @@ -62,14 +61,14 @@ def dump_results(end_points, dump_dir, config, inference_switch=False):
for i in range(batch_size):
pc = point_clouds[i,:,:]
objectness_prob = softmax(objectness_scores[i,:,:])[:,1] # (K,)
#object_seed_prob = softmax(object_seed_scores[i,:,:])[:,1] # (num_seed,)

# Dump various point clouds
pc_util.write_ply(pc, os.path.join(dump_dir, '%06d_pc.ply'%(idx_beg+i)))
pc_util.write_ply(seed_xyz[i,:,:], os.path.join(dump_dir, '%06d_seed_pc.ply'%(idx_beg+i)))
pc_util.write_ply(end_points['vote_xyz'][i,:,:], os.path.join(dump_dir, '%06d_vgen_pc.ply'%(idx_beg+i)))
pc_util.write_ply(aggregated_vote_xyz[i,:,:], os.path.join(dump_dir, '%06d_aggregated_vote_pc.ply'%(idx_beg+i)))
pc_util.write_ply(aggregated_vote_xyz[i,:,:], os.path.join(dump_dir, '%06d_aggregated_vote_pc.ply'%(idx_beg+i)))
if 'vote_xyz' in end_points:
pc_util.write_ply(end_points['vote_xyz'][i,:,:], os.path.join(dump_dir, '%06d_vgen_pc.ply'%(idx_beg+i)))
pc_util.write_ply(aggregated_vote_xyz[i,:,:], os.path.join(dump_dir, '%06d_aggregated_vote_pc.ply'%(idx_beg+i)))
pc_util.write_ply(aggregated_vote_xyz[i,:,:], os.path.join(dump_dir, '%06d_aggregated_vote_pc.ply'%(idx_beg+i)))
pc_util.write_ply(pred_center[i,:,0:3], os.path.join(dump_dir, '%06d_proposal_pc.ply'%(idx_beg+i)))
if np.sum(objectness_prob>DUMP_CONF_THRESH)>0:
pc_util.write_ply(pred_center[i,objectness_prob>DUMP_CONF_THRESH,0:3], os.path.join(dump_dir, '%06d_confident_proposal_pc.ply'%(idx_beg+i)))
Expand Down
122 changes: 122 additions & 0 deletions models/loss_helper_boxnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn
import numpy as np
import sys
import os
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = os.path.dirname(BASE_DIR)
sys.path.append(os.path.join(ROOT_DIR, 'utils'))
from nn_distance import nn_distance, huber_loss
sys.path.append(BASE_DIR)
from loss_helper import compute_box_and_sem_cls_loss

OBJECTNESS_CLS_WEIGHTS = [0.2,0.8] # put larger weights on positive objectness

def compute_objectness_loss(end_points):
""" Compute objectness loss for the proposals.
Args:
end_points: dict (read-only)
Returns:
objectness_loss: scalar Tensor
objectness_label: (batch_size, num_seed) Tensor with value 0 or 1
objectness_mask: (batch_size, num_seed) Tensor with value 0 or 1
object_assignment: (batch_size, num_seed) Tensor with long int
within [0,num_gt_object-1]
"""
# Associate proposal and GT objects by point-to-point distances
aggregated_vote_xyz = end_points['aggregated_vote_xyz']
gt_center = end_points['center_label'][:,:,0:3]
B = gt_center.shape[0]
K = aggregated_vote_xyz.shape[1]
K2 = gt_center.shape[1]
dist1, ind1, dist2, _ = nn_distance(aggregated_vote_xyz, gt_center) # dist1: BxK, dist2: BxK2

# Generate objectness label and mask
# NOTE: Different from VoteNet, here we use seed label as objectness label.
seed_inds = end_points['seed_inds'].long() # B,num_seed in [0,num_points-1]
seed_gt_votes_mask = torch.gather(end_points['vote_label_mask'], 1, seed_inds)
end_points['seed_labels'] = seed_gt_votes_mask
aggregated_vote_inds = end_points['aggregated_vote_inds']
objectness_label = torch.gather(end_points['seed_labels'], 1, aggregated_vote_inds.long()) # select (B,K) from (B,1024)
objectness_mask = torch.ones((objectness_label.shape[0], objectness_label.shape[1])).cuda() # no ignore zone anymore

# Compute objectness loss
objectness_scores = end_points['objectness_scores']
criterion = nn.CrossEntropyLoss(torch.Tensor(OBJECTNESS_CLS_WEIGHTS).cuda(), reduction='none')
objectness_loss = criterion(objectness_scores.transpose(2,1), objectness_label)
objectness_loss = torch.sum(objectness_loss * objectness_mask)/(torch.sum(objectness_mask)+1e-6)

# Set assignment
object_assignment = ind1 # (B,K) with values in 0,1,...,K2-1

return objectness_loss, objectness_label, objectness_mask, object_assignment


def get_loss(end_points, config):
""" Loss functions
Args:
end_points: dict
{
seed_xyz, seed_inds,
center,
heading_scores, heading_residuals_normalized,
size_scores, size_residuals_normalized,
sem_cls_scores, #seed_logits,#
center_label,
heading_class_label, heading_residual_label,
size_class_label, size_residual_label,
sem_cls_label,
box_label_mask,
vote_label, vote_label_mask
}
config: dataset config instance
Returns:
loss: pytorch scalar tensor
end_points: dict
"""

# Obj loss
objectness_loss, objectness_label, objectness_mask, object_assignment = \
compute_objectness_loss(end_points)
end_points['objectness_loss'] = objectness_loss
end_points['objectness_label'] = objectness_label
end_points['objectness_mask'] = objectness_mask
end_points['object_assignment'] = object_assignment
total_num_proposal = objectness_label.shape[0]*objectness_label.shape[1]
end_points['pos_ratio'] = \
torch.sum(objectness_label.float().cuda())/float(total_num_proposal)
end_points['neg_ratio'] = \
torch.sum(objectness_mask.float())/float(total_num_proposal) - end_points['pos_ratio']

# Box loss and sem cls loss
center_loss, heading_cls_loss, heading_reg_loss, size_cls_loss, size_reg_loss, sem_cls_loss = \
compute_box_and_sem_cls_loss(end_points, config)
end_points['center_loss'] = center_loss
end_points['heading_cls_loss'] = heading_cls_loss
end_points['heading_reg_loss'] = heading_reg_loss
end_points['size_cls_loss'] = size_cls_loss
end_points['size_reg_loss'] = size_reg_loss
end_points['sem_cls_loss'] = sem_cls_loss
box_loss = center_loss + 0.1*heading_cls_loss + heading_reg_loss + 0.1*size_cls_loss + size_reg_loss
end_points['box_loss'] = box_loss

# Final loss function
loss = 0.5*objectness_loss + box_loss + 0.1*sem_cls_loss
loss *= 10
end_points['loss'] = loss

# --------------------------------------------
# Some other statistics
obj_pred_val = torch.argmax(end_points['objectness_scores'], 2) # B,K
obj_acc = torch.sum((obj_pred_val==objectness_label.long()).float()*objectness_mask)/(torch.sum(objectness_mask)+1e-6)
end_points['obj_acc'] = obj_acc

return loss, end_points
21 changes: 13 additions & 8 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,14 +149,19 @@ def my_worker_init_fn(worker_id):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
num_input_channel = int(FLAGS.use_color)*3 + int(not FLAGS.no_height)*1

net = MODEL.VoteNet(num_class=DATASET_CONFIG.num_class,
num_heading_bin=DATASET_CONFIG.num_heading_bin,
num_size_cluster=DATASET_CONFIG.num_size_cluster,
mean_size_arr=DATASET_CONFIG.mean_size_arr,
num_proposal=FLAGS.num_target,
input_feature_dim=num_input_channel,
vote_factor=FLAGS.vote_factor,
sampling=FLAGS.cluster_sampling)
if FLAGS.model == 'boxnet':
Detector = MODEL.BoxNet
else:
Detector = MODEL.VoteNet

net = Detector(num_class=DATASET_CONFIG.num_class,
num_heading_bin=DATASET_CONFIG.num_heading_bin,
num_size_cluster=DATASET_CONFIG.num_size_cluster,
mean_size_arr=DATASET_CONFIG.mean_size_arr,
num_proposal=FLAGS.num_target,
input_feature_dim=num_input_channel,
vote_factor=FLAGS.vote_factor,
sampling=FLAGS.cluster_sampling)

if torch.cuda.device_count() > 1:
log_string("Let's use %d GPUs!" % (torch.cuda.device_count()))
Expand Down

0 comments on commit 7c19af3

Please sign in to comment.