-
Notifications
You must be signed in to change notification settings - Fork 129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multi-batch unit tests for Lie groups #522
Conversation
…ned to single-batch-dim output.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Only minor comments.
all_input_types, atol = get_test_cfg(op_name, dtype, dim, data_shape, module=module) | ||
for input_types in all_input_types: | ||
inputs = sample_inputs(input_types, batch_size, dtype, rng) | ||
funcs = ( | ||
tuple(left_project_func(module, x) for x in inputs) | ||
tuple(left_project_func(module, x, bs) for x in inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wondering if we need to check bs=None
or dim_out=None
, which assumes no broadcasting and is most common user-case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought the case tuple()
covered this for the no batch size (see line 10). Or do you mean something else?
The tests for different dim_out
are in #527.
matrix = matrix[sels, ..., sels, :, :] | ||
g = group.reshape(d, *group.shape[-2:]) | ||
# Compute projected gradient matrix | ||
ret = module._left_project_autograd_fn(g, matrix) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we assume the same dimension of inputs and outputs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think so. There should be no broadcasting in the test where this function is used.
Missing unit test for gradient ops