-
Notifications
You must be signed in to change notification settings - Fork 129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a differentiable sparse matrix vector product on top of our ops #392
Conversation
049cf08
to
14918e1
Compare
8df7ef4
to
146e0bf
Compare
14918e1
to
7d361a3
Compare
146e0bf
to
1d703b1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! happy to know we now have this, it can be used for instance to use iterative solvers (not sure if that was already the plan).
4a78a03
to
be37500
Compare
I want to add at least Conjugate Gradient at some point, but there is only so many hours in the day :) Will probably do eventually. |
be37500
to
1e9dcb9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
…acebookresearch#392) * Add autograd function for sparse matrix vector product. * Add wrapper for sparse_mv in SparseLinearization. * Added autograd function for sparse matrix transpose vector product. * Add wrapper for sparse_mtv in SparseLinearization to make differentiable Atb. * Fix dtype index bug.
Backward pass can be made more efficient in GPU if we write a custom CUDA kernel for it, but this should be reasonable enough for now.