Skip to content

Commit

Permalink
Add a test in hlo_evaluator_test to demonstrate how to obtain diagona…
Browse files Browse the repository at this point in the history
…l from a matrix.

PiperOrigin-RevId: 698575912
  • Loading branch information
ZixuanJiang authored and Google-ML-Automation committed Nov 25, 2024
1 parent 074a691 commit aa32b07
Showing 1 changed file with 25 additions and 0 deletions.
25 changes: 25 additions & 0 deletions xla/hlo/evaluator/hlo_evaluator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3708,6 +3708,31 @@ ENTRY main {
EXPECT_TRUE(LiteralTestUtil::Equal(expected_result, result));
}

TEST_F(HloEvaluatorTest, EvaluateGather_GetDiagonal) {
const std::string hlo_text = R"(
HloModule module
ENTRY %module {
%operand = f32[4,4] parameter(0)
%indices = s32[4,1] iota(), iota_dimension=0
ROOT %gather = f32[4,1] gather(%operand, %indices), offset_dims={},
collapsed_slice_dims={1}, start_index_map={1}, operand_batching_dims={0},
start_indices_batching_dims={0}, index_vector_dim=2, slice_sizes={1,1}
}
)";
TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));

Literal operand = LiteralUtil::CreateR2<float>({{0.0, 0.1, 0.2, 0.3},
{1.0, 1.1, 1.2, 1.3},
{2.0, 2.1, 2.2, 2.3},
{3.0, 3.1, 3.2, 3.3}});
Literal expected_result =
LiteralUtil::CreateR2<float>({{0.0}, {1.1}, {2.2}, {3.3}});

TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand}));
EXPECT_TRUE(LiteralTestUtil::Equal(expected_result, result));
}

TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV1_Update) {
const char* hlo_text = R"(
HloModule TensorFlowScatterV1
Expand Down

0 comments on commit aa32b07

Please sign in to comment.