Skip to content

Commit

Permalink
Power VJP fix for 0 (ml-explore#505)
Browse files Browse the repository at this point in the history
  • Loading branch information
awni authored Jan 20, 2024
1 parent 6bf779e commit b207c2c
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
7 changes: 6 additions & 1 deletion mlx/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1852,7 +1852,12 @@ std::vector<array> Power::vjp(
for (auto arg : argnums) {
if (arg == 0) {
vjps.push_back(multiply(
outputs[0], divide(primals[1], primals[0], stream()), stream()));
power(
primals[0],
subtract(primals[1], array(1, primals[0].dtype()), stream()),
stream()),
primals[1],
stream()));
} else {
vjps.push_back(multiply(log(primals[0], stream()), outputs[0], stream()));
}
Expand Down
13 changes: 13 additions & 0 deletions python/tests/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,19 @@ def fun(x, y):
out = mx.grad(fun)(mx.array(1.0, t), mx.array(1.0, t))
self.assertEqual(out.dtype, t)

def test_power_grad(self):
x = mx.array(0.0)
g = mx.grad(lambda x: x**2)(x)
self.assertEqual(g.item(), 0.0)

x = mx.array(0.0)
g = mx.grad(lambda x: x**1.5)(x)
self.assertEqual(g.item(), 0.0)

x = mx.array(2.0)
g = mx.grad(lambda x: x**2)(x)
self.assertAlmostEqual(g.item(), 4.0)


if __name__ == "__main__":
unittest.main()

0 comments on commit b207c2c

Please sign in to comment.