Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Commit

Permalink
fix Expected isFloatingType error for pytorch version 1.2+
Browse files Browse the repository at this point in the history
  • Loading branch information
Saining Xie committed Feb 7, 2020
1 parent da10b32 commit f610e09
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions pointnet2/pointnet2_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def forward(ctx, xyz, npoint):
torch.Tensor
(B, npoint) tensor containing the set
"""
return _ext.furthest_point_sampling(xyz, npoint)
fps_inds = _ext.furthest_point_sampling(xyz, npoint)
ctx.mark_non_differentiable(fps_inds)
return fps_inds

@staticmethod
def backward(xyz, a=None):
Expand Down Expand Up @@ -277,7 +279,9 @@ def forward(ctx, radius, nsample, xyz, new_xyz):
torch.Tensor
(B, npoint, nsample) tensor with the indicies of the features that form the query balls
"""
return _ext.ball_query(new_xyz, xyz, radius, nsample)
inds = _ext.ball_query(new_xyz, xyz, radius, nsample)
ctx.mark_non_differentiable(inds)
return inds

@staticmethod
def backward(ctx, a=None):
Expand Down

0 comments on commit f610e09

Please sign in to comment.