Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MatmulTest.TestF32ConstantWeights failure on arm64 linux #19416

Closed
snadampal opened this issue Nov 16, 2024 · 1 comment
Closed

MatmulTest.TestF32ConstantWeights failure on arm64 linux #19416

snadampal opened this issue Nov 16, 2024 · 1 comment

Comments

@snadampal
Copy link
Contributor

snadampal commented Nov 16, 2024

MatmulTest.TestF32ConstantWeights test is failing on arm64 linux platform with openxla mainline code. From the HLO passes and matmul contraction rewriter logic analysis, it looks like the behavior is as expected and the test app need to be fixed. Following is my analysis:

The failure is due to the mismatch between the expected and the resulted custom-call signature; the weights argument (arg1) is expected to come as a constant literal but it was written as broadcast. The broadcast arg seem to be correct in this case.
expected signature is:

custom-call(%{{[a-z,A-Z,0-9,\.]*}}, %constant{{[a-z,A-Z,0-9,\.]*}}), custom_call_target="__onednn$matmul",

and the rewritten and then simplified custom-call signature is:

custom-call(%{{[a-z,A-Z,0-9,\.]*}}, %broadcast{{[a-z,A-Z,0-9,\.]*}}), custom_call_target="__onednn$matmul",

The Analysis shows that, though onednn contraction rewriter adds constant literal for weights after they are pre-packed, the additional HLO passes, in this case, the algsimp, detected that all the values in the literal are the same, so the following HLO pass replaces const with broadcast.

file: https://github.com/openxla/xla/blob/main/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc#L2059

  // If a literal is all the same element replace it with a scalar broadcast.
  if (ShapeUtil::ElementsIn(constant->shape()) > 1 &&
      constant->literal().IsAllFirst()) {
          std::cout << "AlgebraicSimplifierVisitor::HandleConstant3" << std::endl;
    Literal unique_scalar(
        LiteralUtil::GetFirstScalarLiteral(constant->literal()));
    HloInstruction* scalar = constant->AddInstruction(
        simplifier_->CreateConstantWithLayoutUpdated(std::move(unique_scalar)));
   return ReplaceWithNewInstruction(
        constant,
        HloInstruction::CreateBroadcast(constant->shape(), scalar, {}));
  }

Here is the test case:
https://github.com/openxla/xla/blob/main/xla/service/cpu/tests/onednn_matmul_test.cc#L997

TEST_F(MatmulTest, TestF32ConstantWeights) {
  const char* matmul_module_str = R"(
  HloModule matmul.test.f32

  ENTRY matmul.test.f32 {
    arg.0 = f32[64,256,16] parameter(0), parameter_replication={false}
    constant = f32[] constant(1)
    arg.1 = f32[16,32] broadcast(constant), dimensions={}
    ROOT onednn.matmul.0 = f32[64,256,32] dot(arg.0, arg.1), lhs_contracting_dims={2}, rhs_contracting_dims={0}
  })";

  EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4}));
  MatchOptimizedHlo(matmul_module_str,
                    R"(
  ; CHECK:     %matmul.test.f32
  ; CHECK-NOT: custom_call_target="__onednn$matmul_reorder",
  ; CHECK:     custom-call(%{{[a-z,A-Z,0-9,\.]*}}, %constant{{[a-z,A-Z,0-9,\.]*}}), custom_call_target="__onednn$matmul",
  )");
}
@snadampal
Copy link
Contributor Author

It's been fixed on the mainline.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant