Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorization refactor #205

Merged
merged 27 commits into from
Jun 9, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
424c714
Created a wrapper cost function class that combines the aux vars for …
luisenp Jun 1, 2022
ee9e235
Disabled support for optimization variables in cost weights.
luisenp Jun 1, 2022
ea74465
Changed Objective to iterate over CFWrapper if available, and Theseus…
luisenp Jun 1, 2022
1b3af0b
Added a Vectorizer class and moved CFWrappers there.
luisenp Jun 1, 2022
2d3f9d2
Renamed vectorizer as Vectorize, added logic to replace Objective ite…
luisenp Jun 1, 2022
6a146cb
Added a CostFunctionSchema -> List[CostFunction] to use for vectoriza…
luisenp Jun 2, 2022
6c6a887
_CostFunctionWrapper is now meant to just store a cached value coming…
luisenp Jun 2, 2022
77ac280
Added code to automatically compute shared vars in Vectorize.
luisenp Jun 2, 2022
31237da
Changed vectorized costs construction to ensure that their weight is …
luisenp Jun 2, 2022
d30e1af
Implemented main cost function vectorization logic.
luisenp Jun 6, 2022
36e89c7
Updated bug that was causing detached gradients.
luisenp Jun 6, 2022
376e8ef
Fixed invalid check in theseus end-to-end unit tests.
luisenp Jun 6, 2022
ae6db18
Added unit test for schema and shared var computation.
luisenp Jun 6, 2022
0a2ee0a
Added a test to check that computed vectorized errors are correct.
luisenp Jun 6, 2022
58cee83
Moved vectorization update call to base linearization class.
luisenp Jun 7, 2022
7e60f87
Changed code to allow batch_size > 1 in shared variables.
luisenp Jun 7, 2022
399bb90
Fixed unit test and added call to Objective.update() in update_vector…
luisenp Jun 7, 2022
10cbf1c
Added new private iterator for vectorized costs.
luisenp Jun 7, 2022
10b208a
Replaced _register_vars_in_list with TheseusFunction.register_vars.
luisenp Jun 9, 2022
db5f366
Renamed vectorize_cost_fns kwarg as vectorize.
luisenp Jun 9, 2022
bb83db3
Added license headers.
luisenp Jun 9, 2022
1d0cd20
Small refactor.
luisenp Jun 9, 2022
e902924
Fixed bug that was preventing vectorized costs to work with to(). End…
luisenp Jun 9, 2022
0ec439f
Renamed the private Objective cost function iterator to _get_iterator().
luisenp Jun 9, 2022
aab9ead
Renamed kwarg in register_vars.
luisenp Jun 9, 2022
e57f310
Set vectorize=True for inverse kinematics and backward tests.
luisenp Jun 9, 2022
d6a434f
Remove lingering comments.
luisenp Jun 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fixed unit test and added call to Objective.update() in update_vector…
…ization() if batch_size is None.
  • Loading branch information
luisenp committed Jun 7, 2022
commit 399bb902a31561fcbd0623f38a5866c5b8626d84
2 changes: 2 additions & 0 deletions theseus/core/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,8 @@ def _get_batch_size(batch_sizes: Sequence[int]) -> int:
def update_vectorization(self):
if self._vectorization_run is None:
return
if self._batch_size is None:
self.update()
self._vectorization_run()

# iterates over cost functions
Expand Down
2 changes: 1 addition & 1 deletion theseus/core/tests/test_vectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def test_vectorized_error():
objective.add(th.Difference(s, ws, s_target))

vectorization = th.Vectorize(objective)
objective.update()
objective.update_vectorization()

assert objective._cost_functions_iterable is vectorization._cost_fn_wrappers
for w in vectorization._cost_fn_wrappers:
Expand Down