Skip to content

Commit

Permalink
[XLA:Collective] Support normalizing all-reduce
Browse files Browse the repository at this point in the history
1. Add a normalizer to normalier unsupported All-reduce into All-To-All + Reduce + All-Gather

PiperOrigin-RevId: 698916344
  • Loading branch information
Tongfei-Guo authored and Google-ML-Automation committed Jan 10, 2025
1 parent 0947e8e commit eb9c5b2
Show file tree
Hide file tree
Showing 13 changed files with 511 additions and 30 deletions.
1 change: 1 addition & 0 deletions xla/hlo/testlib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ cc_library(
"//xla/hlo/pass:hlo_pass",
"//xla/hlo/utils:hlo_query",
"//xla/service:computation_layout",
"//xla/service:computation_placer_hdr",
"//xla/service:hlo_module_config",
"//xla/service:hlo_verifier",
"@com_google_absl//absl/algorithm:container",
Expand Down
42 changes: 36 additions & 6 deletions xla/hlo/testlib/hlo_hardware_independent_test_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,14 @@ limitations under the License.
#include "absl/types/span.h"
#include "xla/debug_options_flags.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/ir/hlo_module_group.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/pass/hlo_pass_interface.h"
#include "xla/hlo/testlib/filecheck.h"
#include "xla/hlo/testlib/verified_hlo_module.h"
#include "xla/hlo/utils/hlo_query.h"
#include "xla/service/computation_placer.h"
#include "xla/service/hlo_module_config.h"
#include "xla/service/hlo_verifier.h"
#include "xla/shape.h"
Expand Down Expand Up @@ -81,12 +83,23 @@ HloHardwareIndependentTestBase::CreateNewVerifiedModule(
instruction_can_change_layout_func_);
}

DeviceAssignment HloHardwareIndependentTestBase::GetDefaultDeviceAssignment(
int64_t replica_count, int64_t num_partitions) const {
DeviceAssignment device_assignment(replica_count, num_partitions);
device_assignment.FillIota(0);
return device_assignment;
}

absl::StatusOr<std::unique_ptr<VerifiedHloModule>>
HloHardwareIndependentTestBase::ParseAndReturnVerifiedModule(
absl::string_view hlo_text, int64_t replica_count,
int64_t num_partitions) const {
return ParseAndReturnVerifiedModule(
hlo_text, GetModuleConfigForTest(replica_count, num_partitions));
absl::string_view hlo_text, int64_t replica_count, int64_t num_partitions,
std::optional<DeviceAssignment> device_assignment) const {
HloModuleConfig config =
GetModuleConfigForTest(replica_count, num_partitions);
if (device_assignment.has_value()) {
config.set_static_device_assignment(device_assignment.value());
}
return ParseAndReturnVerifiedModule(hlo_text, config);
}

absl::Status HloHardwareIndependentTestBase::
Expand Down Expand Up @@ -115,9 +128,26 @@ absl::Status HloHardwareIndependentTestBase::
absl::StatusOr<std::unique_ptr<VerifiedHloModule>>
HloHardwareIndependentTestBase::ParseAndReturnVerifiedModule(
absl::string_view hlo_text, const HloModuleConfig& config) const {
return ParseAndReturnVerifiedModule(hlo_text, config,
ShapeUtil::ByteSizeOfElements);
}

absl::StatusOr<std::unique_ptr<VerifiedHloModule>>
HloHardwareIndependentTestBase::ParseAndReturnVerifiedModule(
absl::string_view hlo_text, const HloModuleConfig& config,
std::function<int64_t(const xla::Shape&)> shape_size_fn) const {
LOG(INFO) << "my_yy: " << config.has_static_device_assignment();
HloModuleConfig config_with_device_assignment = config;
if (!config.has_static_device_assignment()) {
default_device_assignment_ =
std::make_unique<DeviceAssignment>(GetDefaultDeviceAssignment(
config.replica_count(), config.num_partitions()));
config_with_device_assignment.set_static_device_assignment(
*default_device_assignment_);
}
auto module = std::make_unique<VerifiedHloModule>(
TestName(), config, verifier_layout_sensitive_,
allow_mixed_precision_in_hlo_verifier_, ShapeUtil::ByteSizeOfElements,
TestName(), config_with_device_assignment, verifier_layout_sensitive_,
allow_mixed_precision_in_hlo_verifier_, shape_size_fn,
instruction_can_change_layout_func_);
TF_RETURN_IF_ERROR(module->ParseHloStringAndVerifyModule(hlo_text));
return module;
Expand Down
24 changes: 19 additions & 5 deletions xla/hlo/testlib/hlo_hardware_independent_test_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ limitations under the License.
#include "xla/hlo/testlib/verified_hlo_module.h"
#include "xla/layout.h"
#include "xla/service/computation_layout.h"
#include "xla/service/computation_placer.h"
#include "xla/service/hlo_module_config.h"
#include "xla/service/hlo_verifier.h"
#include "xla/shape_layout.h"
Expand Down Expand Up @@ -95,14 +96,22 @@ class HloHardwareIndependentTestBase : public ::testing::Test {
std::unique_ptr<VerifiedHloModule> CreateNewVerifiedModule(
const std::string& name = TestName(), int64_t replica_count = 1) const;

//
DeviceAssignment GetDefaultDeviceAssignment(int64_t replica_count,
int64_t num_partitions) const;
// Parses the given string and returns module as a VerifiedHloModule.
absl::StatusOr<std::unique_ptr<VerifiedHloModule>>
ParseAndReturnVerifiedModule(absl::string_view hlo_text,
int64_t replica_count = 1,
int64_t num_partitions = 1) const;
absl::StatusOr<std::unique_ptr<VerifiedHloModule>>
virtual absl::StatusOr<std::unique_ptr<VerifiedHloModule>>
ParseAndReturnVerifiedModule(
absl::string_view hlo_text, int64_t replica_count = 1,
int64_t num_partitions = 1,
std::optional<DeviceAssignment> device_assignment = std::nullopt) const;
virtual absl::StatusOr<std::unique_ptr<VerifiedHloModule>>
ParseAndReturnVerifiedModule(absl::string_view hlo_text,
const HloModuleConfig& config) const;
absl::StatusOr<std::unique_ptr<VerifiedHloModule>>
ParseAndReturnVerifiedModule(
absl::string_view hlo_text, const HloModuleConfig& config,
std::function<int64_t(const xla::Shape&)> shape_size_fn) const;

// Runs the hlo_pass with the provided module and returns the result. This
// function also verifies that the module remains unchanged when hlo_pass
Expand Down Expand Up @@ -171,13 +180,17 @@ class HloHardwareIndependentTestBase : public ::testing::Test {
// options (e.g. disabling additional passes).
virtual DebugOptions GetDebugOptionsForTest() const;

void TearDown() override { default_device_assignment_.reset(); }
// Gets an HloModuleConfig with options appropriate for tests.
HloModuleConfig GetModuleConfigForTest(int64_t replica_count = 1,
int64_t num_partitions = 1) const {
HloModuleConfig config;
config.set_debug_options(GetDebugOptionsForTest());
config.set_replica_count(replica_count);
config.set_num_partitions(num_partitions);
default_device_assignment_ = std::make_unique<DeviceAssignment>(
GetDefaultDeviceAssignment(replica_count, num_partitions));
config.set_static_device_assignment(*default_device_assignment_);
return config;
}

Expand Down Expand Up @@ -259,6 +272,7 @@ class HloHardwareIndependentTestBase : public ::testing::Test {
bool allow_mixed_precision_in_hlo_verifier_;
HloPredicate instruction_can_change_layout_func_;
std::unique_ptr<HloVerifier> hlo_verifier_;
mutable std::unique_ptr<DeviceAssignment> default_device_assignment_;
};

} // namespace xla
Expand Down
38 changes: 38 additions & 0 deletions xla/hlo/transforms/collectives/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,44 @@ xla_cc_test(
],
)

cc_library(
name = "all_reduce_normalizer",
srcs = ["all_reduce_normalizer.cc"],
hdrs = ["all_reduce_normalizer.h"],
deps = [
"//xla:literal",
"//xla:shape_util",
"//xla:util",
"//xla/hlo/ir:hlo",
"//xla/hlo/pass:hlo_pass",
"//xla/hlo/utils:hlo_query",
"//xla/hlo/utils:hlo_sharding_util",
"//xla/service:collective_ops_utils",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:statusor",
],
)

xla_cc_test(
name = "all_reduce_normalizer_test",
srcs = ["all_reduce_normalizer_test.cc"],
deps = [
":all_reduce_normalizer",
"//xla/hlo/ir:hlo",
"//xla/hlo/testlib:hlo_hardware_independent_test_base",
"//xla/hlo/utils:hlo_matchers",
"@com_google_absl//absl/log",
"@com_google_absl//absl/strings:string_view",
"@com_google_googletest//:gtest_main",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/platform:test_main",
],
)

cc_library(
name = "all_reduce_contiguous",
srcs = ["all_reduce_contiguous.cc"],
Expand Down
Loading

0 comments on commit eb9c5b2

Please sign in to comment.