Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA:Collective] Support normalizing all-reduce #19819

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions xla/hlo/testlib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,11 @@ cc_library(
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/hlo/ir:hlo_module_group",
"//xla/hlo/parser:hlo_parser",
"//xla/hlo/pass:hlo_pass",
"//xla/hlo/utils:hlo_query",
"//xla/service:computation_layout",
"//xla/service:computation_placer_hdr",
"//xla/service:hlo_module_config",
"//xla/service:hlo_verifier",
"@com_google_absl//absl/algorithm:container",
Expand Down
49 changes: 41 additions & 8 deletions xla/hlo/testlib/hlo_hardware_independent_test_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,15 @@ limitations under the License.
#include "absl/types/span.h"
#include "xla/debug_options_flags.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/ir/hlo_module_group.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/parser/hlo_parser.h"
#include "xla/hlo/pass/hlo_pass_interface.h"
#include "xla/hlo/testlib/filecheck.h"
#include "xla/hlo/testlib/verified_hlo_module.h"
#include "xla/hlo/utils/hlo_query.h"
#include "xla/service/computation_placer.h"
#include "xla/service/hlo_module_config.h"
#include "xla/service/hlo_verifier.h"
#include "xla/shape.h"
Expand Down Expand Up @@ -77,12 +80,23 @@ HloHardwareIndependentTestBase::CreateNewVerifiedModule(
instruction_can_change_layout_func_);
}

DeviceAssignment HloHardwareIndependentTestBase::GetDefaultDeviceAssignment(
int64_t replica_count, int64_t num_partitions) const {
DeviceAssignment device_assignment(replica_count, num_partitions);
device_assignment.FillIota(0);
return device_assignment;
}

absl::StatusOr<std::unique_ptr<VerifiedHloModule>>
HloHardwareIndependentTestBase::ParseAndReturnVerifiedModule(
absl::string_view hlo_text, int64_t replica_count,
int64_t num_partitions) const {
return ParseAndReturnVerifiedModule(
hlo_text, GetModuleConfigForTest(replica_count, num_partitions));
absl::string_view hlo_text, int64_t replica_count, int64_t num_partitions,
std::optional<DeviceAssignment> device_assignment) const {
HloModuleConfig config =
GetModuleConfigForTest(replica_count, num_partitions);
if (device_assignment.has_value()) {
config.set_static_device_assignment(device_assignment.value());
}
return ParseAndReturnVerifiedModule(hlo_text, config);
}

absl::Status HloHardwareIndependentTestBase::
Expand Down Expand Up @@ -110,12 +124,31 @@ absl::Status HloHardwareIndependentTestBase::

absl::StatusOr<std::unique_ptr<VerifiedHloModule>>
HloHardwareIndependentTestBase::ParseAndReturnVerifiedModule(
absl::string_view hlo_text, const HloModuleConfig& config) const {
absl::string_view hlo_text, const HloModuleConfig& config,
const HloParserOptions& parser_options) const {
return ParseAndReturnVerifiedModule(hlo_text, config, parser_options,
ShapeUtil::ByteSizeOfElements);
}

absl::StatusOr<std::unique_ptr<VerifiedHloModule>>
HloHardwareIndependentTestBase::ParseAndReturnVerifiedModule(
absl::string_view hlo_text, const HloModuleConfig& config,
const HloParserOptions& parser_options,
std::function<int64_t(const xla::Shape&)> shape_size_fn) const {
HloModuleConfig config_with_device_assignment = config;
if (!config.has_static_device_assignment()) {
default_device_assignment_ =
std::make_unique<DeviceAssignment>(GetDefaultDeviceAssignment(
config.replica_count(), config.num_partitions()));
config_with_device_assignment.set_static_device_assignment(
*default_device_assignment_);
}
auto module = std::make_unique<VerifiedHloModule>(
TestName(), config, verifier_layout_sensitive_,
allow_mixed_precision_in_hlo_verifier_, ShapeUtil::ByteSizeOfElements,
TestName(), config_with_device_assignment, verifier_layout_sensitive_,
allow_mixed_precision_in_hlo_verifier_, shape_size_fn,
instruction_can_change_layout_func_);
TF_RETURN_IF_ERROR(module->ParseHloStringAndVerifyModule(hlo_text));
TF_RETURN_IF_ERROR(
module->ParseHloStringAndVerifyModule(hlo_text, parser_options));
return module;
}

Expand Down
38 changes: 30 additions & 8 deletions xla/hlo/testlib/hlo_hardware_independent_test_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@ limitations under the License.
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/ir/hlo_module_group.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/parser/hlo_parser.h"
#include "xla/hlo/pass/hlo_pass_interface.h"
#include "xla/hlo/testlib/verified_hlo_module.h"
#include "xla/layout.h"
#include "xla/service/computation_layout.h"
#include "xla/service/computation_placer.h"
#include "xla/service/hlo_module_config.h"
#include "xla/service/hlo_verifier.h"
#include "xla/shape_layout.h"
Expand Down Expand Up @@ -95,14 +97,24 @@ class HloHardwareIndependentTestBase : public ::testing::Test {
std::unique_ptr<VerifiedHloModule> CreateNewVerifiedModule(
const std::string& name = TestName(), int64_t replica_count = 1) const;

//
DeviceAssignment GetDefaultDeviceAssignment(int64_t replica_count,
int64_t num_partitions) const;
// Parses the given string and returns module as a VerifiedHloModule.
virtual absl::StatusOr<std::unique_ptr<VerifiedHloModule>>
ParseAndReturnVerifiedModule(
absl::string_view hlo_text, int64_t replica_count = 1,
int64_t num_partitions = 1,
std::optional<DeviceAssignment> device_assignment = std::nullopt) const;
virtual absl::StatusOr<std::unique_ptr<VerifiedHloModule>>
ParseAndReturnVerifiedModule(
absl::string_view hlo_text, const HloModuleConfig& config,
const HloParserOptions& parser_options = HloParserOptions()) const;
absl::StatusOr<std::unique_ptr<VerifiedHloModule>>
ParseAndReturnVerifiedModule(absl::string_view hlo_text,
int64_t replica_count = 1,
int64_t num_partitions = 1) const;
absl::StatusOr<std::unique_ptr<VerifiedHloModule>>
ParseAndReturnVerifiedModule(absl::string_view hlo_text,
const HloModuleConfig& config) const;
ParseAndReturnVerifiedModule(
absl::string_view hlo_text, const HloModuleConfig& config,
const HloParserOptions& parser_options,
std::function<int64_t(const xla::Shape&)> shape_size_fn) const;

// Runs the hlo_pass with the provided module and returns the result. This
// function also verifies that the module remains unchanged when hlo_pass
Expand Down Expand Up @@ -181,13 +193,22 @@ class HloHardwareIndependentTestBase : public ::testing::Test {
// options (e.g. disabling additional passes).
virtual DebugOptions GetDebugOptionsForTest() const;

void TearDown() override { default_device_assignment_.reset(); }
// Gets an HloModuleConfig with options appropriate for tests.
HloModuleConfig GetModuleConfigForTest(int64_t replica_count = 1,
int64_t num_partitions = 1) const {
HloModuleConfig GetModuleConfigForTest(
int64_t replica_count = 1, int64_t num_partitions = 1,
std::optional<DeviceAssignment> device_assignment = std::nullopt) const {
HloModuleConfig config;
config.set_debug_options(GetDebugOptionsForTest());
config.set_replica_count(replica_count);
config.set_num_partitions(num_partitions);
if (device_assignment.has_value()) {
config.set_static_device_assignment(std::move(device_assignment.value()));
} else {
default_device_assignment_ = std::make_unique<DeviceAssignment>(
GetDefaultDeviceAssignment(replica_count, num_partitions));
config.set_static_device_assignment(*default_device_assignment_);
}
return config;
}

Expand Down Expand Up @@ -269,6 +290,7 @@ class HloHardwareIndependentTestBase : public ::testing::Test {
bool allow_mixed_precision_in_hlo_verifier_;
HloPredicate instruction_can_change_layout_func_;
std::unique_ptr<HloVerifier> hlo_verifier_;
mutable std::unique_ptr<DeviceAssignment> default_device_assignment_;
};

} // namespace xla
Expand Down
38 changes: 38 additions & 0 deletions xla/hlo/transforms/collectives/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,44 @@ xla_cc_test(
],
)

cc_library(
name = "all_reduce_normalizer",
srcs = ["all_reduce_normalizer.cc"],
hdrs = ["all_reduce_normalizer.h"],
deps = [
"//xla:literal",
"//xla:shape_util",
"//xla:util",
"//xla/hlo/ir:hlo",
"//xla/hlo/pass:hlo_pass",
"//xla/hlo/utils:hlo_query",
"//xla/hlo/utils:hlo_sharding_util",
"//xla/service:collective_ops_utils",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:statusor",
],
)

xla_cc_test(
name = "all_reduce_normalizer_test",
srcs = ["all_reduce_normalizer_test.cc"],
deps = [
":all_reduce_normalizer",
"//xla/hlo/ir:hlo",
"//xla/hlo/testlib:hlo_hardware_independent_test_base",
"//xla/hlo/utils:hlo_matchers",
"@com_google_absl//absl/log",
"@com_google_absl//absl/strings:string_view",
"@com_google_googletest//:gtest_main",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/platform:test_main",
],
)

cc_library(
name = "all_reduce_contiguous",
srcs = ["all_reduce_contiguous.cc"],
Expand Down
Loading
Loading