Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NVIDIA GPU] Fix mem p2p init in collective permute thunk #20086

Closed
wants to merge 3 commits into from

Conversation

Tixxx
Copy link
Contributor

@Tixxx Tixxx commented Dec 3, 2024

Move pointer initialization to the thunk init stage instead of runtime to get rid of the runtime blocking wait.
Add a device sync point using nccl allreduce before doing memcpy to make sure all gpus arrive at the same stage. Otherwise it's possible to have data corruptions when the receiving rank hasn't arrived at the memcpy.

@reedwm reedwm requested a review from frgossen December 4, 2024 05:45
@Tixxx Tixxx force-pushed the tixxx/memcpy_p2p_fix branch from 9d3a8a4 to 10d3501 Compare December 9, 2024 06:05
memcpy to make sure data consistency across ranks
@Tixxx Tixxx force-pushed the tixxx/memcpy_p2p_fix branch from 10d3501 to 050bc59 Compare December 12, 2024 05:28
@Tixxx Tixxx changed the title Fix mem p2p init in collective permute thunk [NVIDIA GPU] Fix mem p2p init in collective permute thunk Dec 12, 2024
Copy link
Member

@frgossen frgossen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!

copybara-service bot pushed a commit that referenced this pull request Dec 12, 2024
Imported from GitHub PR #20086

Move pointer initialization to the thunk init stage instead of runtime to get rid of the runtime blocking wait.
Add a device sync point using nccl allreduce before doing memcpy to make sure all gpus arrive at the same stage. Otherwise it's possible to have data corruptions when the receiving rank hasn't arrived at the memcpy.
Copybara import of the project:

--
ba4ad04 by TJ Xu <[email protected]>:

Moved pointer init to thunk init stage and add a sync point before doing
memcpy to make sure data consistency across ranks

--
050bc59 by TJ Xu <[email protected]>:

Added e2e test for mem cpy p2p in a loop

Merging this change closes #20086

FUTURE_COPYBARA_INTEGRATE_REVIEW=#20086 from Tixxx:tixxx/memcpy_p2p_fix 050bc59
PiperOrigin-RevId: 705647424
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Dec 12, 2024
Imported from GitHub PR openxla/xla#20086

Move pointer initialization to the thunk init stage instead of runtime to get rid of the runtime blocking wait.
Add a device sync point using nccl allreduce before doing memcpy to make sure all gpus arrive at the same stage. Otherwise it's possible to have data corruptions when the receiving rank hasn't arrived at the memcpy.
Copybara import of the project:

--
ba4ad0445f27d7249b4bcebb4ac573188cf50cb0 by TJ Xu <[email protected]>:

Moved pointer init to thunk init stage and add a sync point before doing
memcpy to make sure data consistency across ranks

--
050bc59c02732da728fe43bd6c4c12702d070c2c by TJ Xu <[email protected]>:

Added e2e test for mem cpy p2p in a loop

Merging this change closes #20086

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#20086 from Tixxx:tixxx/memcpy_p2p_fix 050bc59c02732da728fe43bd6c4c12702d070c2c
PiperOrigin-RevId: 705647424
copybara-service bot pushed a commit that referenced this pull request Dec 13, 2024
Imported from GitHub PR #20086

Move pointer initialization to the thunk init stage instead of runtime to get rid of the runtime blocking wait.
Add a device sync point using nccl allreduce before doing memcpy to make sure all gpus arrive at the same stage. Otherwise it's possible to have data corruptions when the receiving rank hasn't arrived at the memcpy.
Copybara import of the project:

--
ba4ad04 by TJ Xu <[email protected]>:

Moved pointer init to thunk init stage and add a sync point before doing
memcpy to make sure data consistency across ranks

--
050bc59 by TJ Xu <[email protected]>:

Added e2e test for mem cpy p2p in a loop

Merging this change closes #20086

FUTURE_COPYBARA_INTEGRATE_REVIEW=#20086 from Tixxx:tixxx/memcpy_p2p_fix 050bc59
PiperOrigin-RevId: 705647424
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Dec 13, 2024
Imported from GitHub PR openxla/xla#20086

Move pointer initialization to the thunk init stage instead of runtime to get rid of the runtime blocking wait.
Add a device sync point using nccl allreduce before doing memcpy to make sure all gpus arrive at the same stage. Otherwise it's possible to have data corruptions when the receiving rank hasn't arrived at the memcpy.
Copybara import of the project:

--
ba4ad0445f27d7249b4bcebb4ac573188cf50cb0 by TJ Xu <[email protected]>:

Moved pointer init to thunk init stage and add a sync point before doing
memcpy to make sure data consistency across ranks

--
050bc59c02732da728fe43bd6c4c12702d070c2c by TJ Xu <[email protected]>:

Added e2e test for mem cpy p2p in a loop

Merging this change closes #20086

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#20086 from Tixxx:tixxx/memcpy_p2p_fix 050bc59c02732da728fe43bd6c4c12702d070c2c
PiperOrigin-RevId: 705647424
copybara-service bot pushed a commit that referenced this pull request Dec 13, 2024
Imported from GitHub PR #20086

Move pointer initialization to the thunk init stage instead of runtime to get rid of the runtime blocking wait.
Add a device sync point using nccl allreduce before doing memcpy to make sure all gpus arrive at the same stage. Otherwise it's possible to have data corruptions when the receiving rank hasn't arrived at the memcpy.
Copybara import of the project:

--
ba4ad04 by TJ Xu <[email protected]>:

Moved pointer init to thunk init stage and add a sync point before doing
memcpy to make sure data consistency across ranks

--
050bc59 by TJ Xu <[email protected]>:

Added e2e test for mem cpy p2p in a loop

Merging this change closes #20086

FUTURE_COPYBARA_INTEGRATE_REVIEW=#20086 from Tixxx:tixxx/memcpy_p2p_fix 050bc59
PiperOrigin-RevId: 705647424
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Dec 13, 2024
Imported from GitHub PR openxla/xla#20086

Move pointer initialization to the thunk init stage instead of runtime to get rid of the runtime blocking wait.
Add a device sync point using nccl allreduce before doing memcpy to make sure all gpus arrive at the same stage. Otherwise it's possible to have data corruptions when the receiving rank hasn't arrived at the memcpy.
Copybara import of the project:

--
ba4ad0445f27d7249b4bcebb4ac573188cf50cb0 by TJ Xu <[email protected]>:

Moved pointer init to thunk init stage and add a sync point before doing
memcpy to make sure data consistency across ranks

--
050bc59c02732da728fe43bd6c4c12702d070c2c by TJ Xu <[email protected]>:

Added e2e test for mem cpy p2p in a loop

Merging this change closes #20086

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#20086 from Tixxx:tixxx/memcpy_p2p_fix 050bc59c02732da728fe43bd6c4c12702d070c2c
PiperOrigin-RevId: 705647424
copybara-service bot pushed a commit that referenced this pull request Dec 13, 2024
Imported from GitHub PR #20086

Move pointer initialization to the thunk init stage instead of runtime to get rid of the runtime blocking wait.
Add a device sync point using nccl allreduce before doing memcpy to make sure all gpus arrive at the same stage. Otherwise it's possible to have data corruptions when the receiving rank hasn't arrived at the memcpy.
Copybara import of the project:

--
ba4ad04 by TJ Xu <[email protected]>:

Moved pointer init to thunk init stage and add a sync point before doing
memcpy to make sure data consistency across ranks

--
050bc59 by TJ Xu <[email protected]>:

Added e2e test for mem cpy p2p in a loop

Merging this change closes #20086

FUTURE_COPYBARA_INTEGRATE_REVIEW=#20086 from Tixxx:tixxx/memcpy_p2p_fix 050bc59
PiperOrigin-RevId: 705647424
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Dec 13, 2024
Imported from GitHub PR openxla/xla#20086

Move pointer initialization to the thunk init stage instead of runtime to get rid of the runtime blocking wait.
Add a device sync point using nccl allreduce before doing memcpy to make sure all gpus arrive at the same stage. Otherwise it's possible to have data corruptions when the receiving rank hasn't arrived at the memcpy.
Copybara import of the project:

--
ba4ad0445f27d7249b4bcebb4ac573188cf50cb0 by TJ Xu <[email protected]>:

Moved pointer init to thunk init stage and add a sync point before doing
memcpy to make sure data consistency across ranks

--
050bc59c02732da728fe43bd6c4c12702d070c2c by TJ Xu <[email protected]>:

Added e2e test for mem cpy p2p in a loop

Merging this change closes #20086

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#20086 from Tixxx:tixxx/memcpy_p2p_fix 050bc59c02732da728fe43bd6c4c12702d070c2c
PiperOrigin-RevId: 705647424
if (!params.executor->HostMemoryUnregister(&barrier_flags_[current_id])) {
LOG(ERROR) << "Unregistering barrier flag failed.";
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Several TensorFlow tests are failing with:
error: non-void function does not return a value in all control paths [-Werror,-Wreturn-type]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a returned status for the cleanup function

copybara-service bot pushed a commit that referenced this pull request Dec 16, 2024
Imported from GitHub PR #20086

Move pointer initialization to the thunk init stage instead of runtime to get rid of the runtime blocking wait.
Add a device sync point using nccl allreduce before doing memcpy to make sure all gpus arrive at the same stage. Otherwise it's possible to have data corruptions when the receiving rank hasn't arrived at the memcpy.
Copybara import of the project:

--
ba4ad04 by TJ Xu <[email protected]>:

Moved pointer init to thunk init stage and add a sync point before doing
memcpy to make sure data consistency across ranks

--
050bc59 by TJ Xu <[email protected]>:

Added e2e test for mem cpy p2p in a loop

Merging this change closes #20086

FUTURE_COPYBARA_INTEGRATE_REVIEW=#20086 from Tixxx:tixxx/memcpy_p2p_fix 050bc59
PiperOrigin-RevId: 705647424
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Dec 16, 2024
Imported from GitHub PR openxla/xla#20086

Move pointer initialization to the thunk init stage instead of runtime to get rid of the runtime blocking wait.
Add a device sync point using nccl allreduce before doing memcpy to make sure all gpus arrive at the same stage. Otherwise it's possible to have data corruptions when the receiving rank hasn't arrived at the memcpy.
Copybara import of the project:

--
ba4ad0445f27d7249b4bcebb4ac573188cf50cb0 by TJ Xu <[email protected]>:

Moved pointer init to thunk init stage and add a sync point before doing
memcpy to make sure data consistency across ranks

--
050bc59c02732da728fe43bd6c4c12702d070c2c by TJ Xu <[email protected]>:

Added e2e test for mem cpy p2p in a loop

Merging this change closes #20086

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#20086 from Tixxx:tixxx/memcpy_p2p_fix 050bc59c02732da728fe43bd6c4c12702d070c2c
PiperOrigin-RevId: 705647424
copybara-service bot pushed a commit that referenced this pull request Dec 16, 2024
Imported from GitHub PR #20086

Move pointer initialization to the thunk init stage instead of runtime to get rid of the runtime blocking wait.
Add a device sync point using nccl allreduce before doing memcpy to make sure all gpus arrive at the same stage. Otherwise it's possible to have data corruptions when the receiving rank hasn't arrived at the memcpy.
Copybara import of the project:

--
ba4ad04 by TJ Xu <[email protected]>:

Moved pointer init to thunk init stage and add a sync point before doing
memcpy to make sure data consistency across ranks

--
050bc59 by TJ Xu <[email protected]>:

Added e2e test for mem cpy p2p in a loop

Merging this change closes #20086

FUTURE_COPYBARA_INTEGRATE_REVIEW=#20086 from Tixxx:tixxx/memcpy_p2p_fix 050bc59
PiperOrigin-RevId: 705647424
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Dec 16, 2024
Imported from GitHub PR openxla/xla#20086

Move pointer initialization to the thunk init stage instead of runtime to get rid of the runtime blocking wait.
Add a device sync point using nccl allreduce before doing memcpy to make sure all gpus arrive at the same stage. Otherwise it's possible to have data corruptions when the receiving rank hasn't arrived at the memcpy.
Copybara import of the project:

--
ba4ad0445f27d7249b4bcebb4ac573188cf50cb0 by TJ Xu <[email protected]>:

Moved pointer init to thunk init stage and add a sync point before doing
memcpy to make sure data consistency across ranks

--
050bc59c02732da728fe43bd6c4c12702d070c2c by TJ Xu <[email protected]>:

Added e2e test for mem cpy p2p in a loop

Merging this change closes #20086

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#20086 from Tixxx:tixxx/memcpy_p2p_fix 050bc59c02732da728fe43bd6c4c12702d070c2c
PiperOrigin-RevId: 705647424
copybara-service bot pushed a commit that referenced this pull request Dec 16, 2024
Imported from GitHub PR #20086

Move pointer initialization to the thunk init stage instead of runtime to get rid of the runtime blocking wait.
Add a device sync point using nccl allreduce before doing memcpy to make sure all gpus arrive at the same stage. Otherwise it's possible to have data corruptions when the receiving rank hasn't arrived at the memcpy.
Copybara import of the project:

--
ba4ad04 by TJ Xu <[email protected]>:

Moved pointer init to thunk init stage and add a sync point before doing
memcpy to make sure data consistency across ranks

--
050bc59 by TJ Xu <[email protected]>:

Added e2e test for mem cpy p2p in a loop

Merging this change closes #20086

FUTURE_COPYBARA_INTEGRATE_REVIEW=#20086 from Tixxx:tixxx/memcpy_p2p_fix 050bc59
PiperOrigin-RevId: 705647424
copybara-service bot pushed a commit that referenced this pull request Dec 16, 2024
Imported from GitHub PR #20086

Move pointer initialization to the thunk init stage instead of runtime to get rid of the runtime blocking wait.
Add a device sync point using nccl allreduce before doing memcpy to make sure all gpus arrive at the same stage. Otherwise it's possible to have data corruptions when the receiving rank hasn't arrived at the memcpy.
Copybara import of the project:

--
ba4ad04 by TJ Xu <[email protected]>:

Moved pointer init to thunk init stage and add a sync point before doing
memcpy to make sure data consistency across ranks

--
050bc59 by TJ Xu <[email protected]>:

Added e2e test for mem cpy p2p in a loop

Merging this change closes #20086

FUTURE_COPYBARA_INTEGRATE_REVIEW=#20086 from Tixxx:tixxx/memcpy_p2p_fix 050bc59
PiperOrigin-RevId: 705647424
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Dec 16, 2024
Imported from GitHub PR openxla/xla#20086

Move pointer initialization to the thunk init stage instead of runtime to get rid of the runtime blocking wait.
Add a device sync point using nccl allreduce before doing memcpy to make sure all gpus arrive at the same stage. Otherwise it's possible to have data corruptions when the receiving rank hasn't arrived at the memcpy.
Copybara import of the project:

--
ba4ad0445f27d7249b4bcebb4ac573188cf50cb0 by TJ Xu <[email protected]>:

Moved pointer init to thunk init stage and add a sync point before doing
memcpy to make sure data consistency across ranks

--
050bc59c02732da728fe43bd6c4c12702d070c2c by TJ Xu <[email protected]>:

Added e2e test for mem cpy p2p in a loop

Merging this change closes #20086

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#20086 from Tixxx:tixxx/memcpy_p2p_fix 050bc59c02732da728fe43bd6c4c12702d070c2c
PiperOrigin-RevId: 705647424
@Tixxx Tixxx requested a review from thomasjoerg December 16, 2024 19:26
copybara-service bot pushed a commit that referenced this pull request Dec 17, 2024
Imported from GitHub PR #20086

Move pointer initialization to the thunk init stage instead of runtime to get rid of the runtime blocking wait.
Add a device sync point using nccl allreduce before doing memcpy to make sure all gpus arrive at the same stage. Otherwise it's possible to have data corruptions when the receiving rank hasn't arrived at the memcpy.
Copybara import of the project:

--
ba4ad04 by TJ Xu <[email protected]>:

Moved pointer init to thunk init stage and add a sync point before doing
memcpy to make sure data consistency across ranks

--
050bc59 by TJ Xu <[email protected]>:

Added e2e test for mem cpy p2p in a loop

--
1f75328 by TJ Xu <[email protected]>:

Added return status for cleanup functions

Merging this change closes #20086

FUTURE_COPYBARA_INTEGRATE_REVIEW=#20086 from Tixxx:tixxx/memcpy_p2p_fix 1f75328
PiperOrigin-RevId: 707074350
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Dec 17, 2024
Imported from GitHub PR openxla/xla#20086

Move pointer initialization to the thunk init stage instead of runtime to get rid of the runtime blocking wait.
Add a device sync point using nccl allreduce before doing memcpy to make sure all gpus arrive at the same stage. Otherwise it's possible to have data corruptions when the receiving rank hasn't arrived at the memcpy.
Copybara import of the project:

--
ba4ad0445f27d7249b4bcebb4ac573188cf50cb0 by TJ Xu <[email protected]>:

Moved pointer init to thunk init stage and add a sync point before doing
memcpy to make sure data consistency across ranks

--
050bc59c02732da728fe43bd6c4c12702d070c2c by TJ Xu <[email protected]>:

Added e2e test for mem cpy p2p in a loop

--
1f7532815dfdbb6d047339d7189c1287dc72e6a3 by TJ Xu <[email protected]>:

Added return status for cleanup functions

Merging this change closes #20086

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#20086 from Tixxx:tixxx/memcpy_p2p_fix 1f7532815dfdbb6d047339d7189c1287dc72e6a3
PiperOrigin-RevId: 707074350
copybara-service bot pushed a commit that referenced this pull request Dec 17, 2024
Imported from GitHub PR #20086

Move pointer initialization to the thunk init stage instead of runtime to get rid of the runtime blocking wait.
Add a device sync point using nccl allreduce before doing memcpy to make sure all gpus arrive at the same stage. Otherwise it's possible to have data corruptions when the receiving rank hasn't arrived at the memcpy.
Copybara import of the project:

--
ba4ad04 by TJ Xu <[email protected]>:

Moved pointer init to thunk init stage and add a sync point before doing
memcpy to make sure data consistency across ranks

--
050bc59 by TJ Xu <[email protected]>:

Added e2e test for mem cpy p2p in a loop

Merging this change closes #20086

FUTURE_COPYBARA_INTEGRATE_REVIEW=#20086 from Tixxx:tixxx/memcpy_p2p_fix 050bc59
PiperOrigin-RevId: 705647424
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Dec 17, 2024
Imported from GitHub PR openxla/xla#20086

Move pointer initialization to the thunk init stage instead of runtime to get rid of the runtime blocking wait.
Add a device sync point using nccl allreduce before doing memcpy to make sure all gpus arrive at the same stage. Otherwise it's possible to have data corruptions when the receiving rank hasn't arrived at the memcpy.
Copybara import of the project:

--
ba4ad0445f27d7249b4bcebb4ac573188cf50cb0 by TJ Xu <[email protected]>:

Moved pointer init to thunk init stage and add a sync point before doing
memcpy to make sure data consistency across ranks

--
050bc59c02732da728fe43bd6c4c12702d070c2c by TJ Xu <[email protected]>:

Added e2e test for mem cpy p2p in a loop

Merging this change closes #20086

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#20086 from Tixxx:tixxx/memcpy_p2p_fix 050bc59c02732da728fe43bd6c4c12702d070c2c
PiperOrigin-RevId: 705647424
copybara-service bot pushed a commit that referenced this pull request Dec 17, 2024
Imported from GitHub PR #20086

Move pointer initialization to the thunk init stage instead of runtime to get rid of the runtime blocking wait.
Add a device sync point using nccl allreduce before doing memcpy to make sure all gpus arrive at the same stage. Otherwise it's possible to have data corruptions when the receiving rank hasn't arrived at the memcpy.
Copybara import of the project:

--
ba4ad04 by TJ Xu <[email protected]>:

Moved pointer init to thunk init stage and add a sync point before doing
memcpy to make sure data consistency across ranks

--
050bc59 by TJ Xu <[email protected]>:

Added e2e test for mem cpy p2p in a loop

--
1f75328 by TJ Xu <[email protected]>:

Added return status for cleanup functions

Merging this change closes #20086

FUTURE_COPYBARA_INTEGRATE_REVIEW=#20086 from Tixxx:tixxx/memcpy_p2p_fix 1f75328
PiperOrigin-RevId: 707074350
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Dec 17, 2024
Imported from GitHub PR openxla/xla#20086

Move pointer initialization to the thunk init stage instead of runtime to get rid of the runtime blocking wait.
Add a device sync point using nccl allreduce before doing memcpy to make sure all gpus arrive at the same stage. Otherwise it's possible to have data corruptions when the receiving rank hasn't arrived at the memcpy.
Copybara import of the project:

--
ba4ad0445f27d7249b4bcebb4ac573188cf50cb0 by TJ Xu <[email protected]>:

Moved pointer init to thunk init stage and add a sync point before doing
memcpy to make sure data consistency across ranks

--
050bc59c02732da728fe43bd6c4c12702d070c2c by TJ Xu <[email protected]>:

Added e2e test for mem cpy p2p in a loop

--
1f7532815dfdbb6d047339d7189c1287dc72e6a3 by TJ Xu <[email protected]>:

Added return status for cleanup functions

Merging this change closes #20086

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#20086 from Tixxx:tixxx/memcpy_p2p_fix 1f7532815dfdbb6d047339d7189c1287dc72e6a3
PiperOrigin-RevId: 707074350
copybara-service bot pushed a commit that referenced this pull request Dec 17, 2024
Imported from GitHub PR #20086

Move pointer initialization to the thunk init stage instead of runtime to get rid of the runtime blocking wait.
Add a device sync point using nccl allreduce before doing memcpy to make sure all gpus arrive at the same stage. Otherwise it's possible to have data corruptions when the receiving rank hasn't arrived at the memcpy.
Copybara import of the project:

--
ba4ad04 by TJ Xu <[email protected]>:

Moved pointer init to thunk init stage and add a sync point before doing
memcpy to make sure data consistency across ranks

--
050bc59 by TJ Xu <[email protected]>:

Added e2e test for mem cpy p2p in a loop

Merging this change closes #20086

FUTURE_COPYBARA_INTEGRATE_REVIEW=#20086 from Tixxx:tixxx/memcpy_p2p_fix 050bc59
PiperOrigin-RevId: 705647424
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Dec 17, 2024
Imported from GitHub PR openxla/xla#20086

Move pointer initialization to the thunk init stage instead of runtime to get rid of the runtime blocking wait.
Add a device sync point using nccl allreduce before doing memcpy to make sure all gpus arrive at the same stage. Otherwise it's possible to have data corruptions when the receiving rank hasn't arrived at the memcpy.
Copybara import of the project:

--
ba4ad0445f27d7249b4bcebb4ac573188cf50cb0 by TJ Xu <[email protected]>:

Moved pointer init to thunk init stage and add a sync point before doing
memcpy to make sure data consistency across ranks

--
050bc59c02732da728fe43bd6c4c12702d070c2c by TJ Xu <[email protected]>:

Added e2e test for mem cpy p2p in a loop

Merging this change closes #20086

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#20086 from Tixxx:tixxx/memcpy_p2p_fix 050bc59c02732da728fe43bd6c4c12702d070c2c
PiperOrigin-RevId: 705647424
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Dec 17, 2024
Imported from GitHub PR openxla/xla#20086

Move pointer initialization to the thunk init stage instead of runtime to get rid of the runtime blocking wait.
Add a device sync point using nccl allreduce before doing memcpy to make sure all gpus arrive at the same stage. Otherwise it's possible to have data corruptions when the receiving rank hasn't arrived at the memcpy.
Copybara import of the project:

--
ba4ad0445f27d7249b4bcebb4ac573188cf50cb0 by TJ Xu <[email protected]>:

Moved pointer init to thunk init stage and add a sync point before doing
memcpy to make sure data consistency across ranks

--
050bc59c02732da728fe43bd6c4c12702d070c2c by TJ Xu <[email protected]>:

Added e2e test for mem cpy p2p in a loop

--
1f7532815dfdbb6d047339d7189c1287dc72e6a3 by TJ Xu <[email protected]>:

Added return status for cleanup functions

Merging this change closes #20086

PiperOrigin-RevId: 707145351
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Dec 17, 2024
…alls.

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#20086 from Tixxx:tixxx/memcpy_p2p_fix 1f7532815dfdbb6d047339d7189c1287dc72e6a3
PiperOrigin-RevId: 690686233
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Dec 17, 2024
Imported from GitHub PR openxla/xla#20086

Move pointer initialization to the thunk init stage instead of runtime to get rid of the runtime blocking wait.
Add a device sync point using nccl allreduce before doing memcpy to make sure all gpus arrive at the same stage. Otherwise it's possible to have data corruptions when the receiving rank hasn't arrived at the memcpy.
Copybara import of the project:

--
ba4ad0445f27d7249b4bcebb4ac573188cf50cb0 by TJ Xu <[email protected]>:

Moved pointer init to thunk init stage and add a sync point before doing
memcpy to make sure data consistency across ranks

--
050bc59c02732da728fe43bd6c4c12702d070c2c by TJ Xu <[email protected]>:

Added e2e test for mem cpy p2p in a loop

Merging this change closes #20086

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#20086 from Tixxx:tixxx/memcpy_p2p_fix 050bc59c02732da728fe43bd6c4c12702d070c2c
PiperOrigin-RevId: 705647424
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Dec 18, 2024
Imported from GitHub PR openxla/xla#20086

Move pointer initialization to the thunk init stage instead of runtime to get rid of the runtime blocking wait.
Add a device sync point using nccl allreduce before doing memcpy to make sure all gpus arrive at the same stage. Otherwise it's possible to have data corruptions when the receiving rank hasn't arrived at the memcpy.
Copybara import of the project:

--
ba4ad0445f27d7249b4bcebb4ac573188cf50cb0 by TJ Xu <[email protected]>:

Moved pointer init to thunk init stage and add a sync point before doing
memcpy to make sure data consistency across ranks

--
050bc59c02732da728fe43bd6c4c12702d070c2c by TJ Xu <[email protected]>:

Added e2e test for mem cpy p2p in a loop

Merging this change closes #20086

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#20086 from Tixxx:tixxx/memcpy_p2p_fix 050bc59c02732da728fe43bd6c4c12702d070c2c
PiperOrigin-RevId: 705647424
@reedwm
Copy link
Member

reedwm commented Dec 21, 2024

I'm seeing this regress the Maxtext LLama 7B model with 4-way FSDP and 2-way TP when collective matmul is enabled, using the script you gave me @Tixxx. I see Tokens/s/device: 4125.938 before this PR and Tokens/s/device: 3827.136 after it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants