Skip to content

Commit

Permalink
Add support for AVX512 f32x32 to bf16x32 conversion (halide#5711)
Browse files Browse the repository at this point in the history
The vcvtne2ps2bf16 instruction combines two f32x16 vectors and converts
them to one bf16x32 vector. We can use this to support converting a
f32x32 vector to bf16x32 vector by splitting the input vector into two.
  • Loading branch information
jwlawson authored Feb 5, 2021
1 parent 8ee7f4c commit 1b22dfe
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/CodeGen_X86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ const x86Intrinsic intrinsic_defs[] = {
{"llvm.x86.sse2.pmadd.wd", Int(32, 4), "pmaddwd", {Int(16, 8), Int(16, 8)}},

// Convert FP32 to BF16
{"vcvtne2ps2bf16x32", BFloat(16, 32), "f32_to_bf16", {Float(32, 32)}, Target::AVX512_SapphireRapids},
{"llvm.x86.avx512bf16.cvtneps2bf16.512", BFloat(16, 16), "f32_to_bf16", {Float(32, 16)}, Target::AVX512_SapphireRapids},
{"llvm.x86.avx512bf16.cvtneps2bf16.256", BFloat(16, 8), "f32_to_bf16", {Float(32, 8)}, Target::AVX512_SapphireRapids},
// LLVM does not provide an unmasked 128bit cvtneps2bf16 intrinsic, so provide a wrapper around the masked version.
Expand Down
10 changes: 10 additions & 0 deletions src/runtime/x86_avx512.ll
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@

; Split a 32 element f32 vector into two 16 element vectors to use the cvtne2ps2bf16 intrinsic.
define weak_odr <32 x i16> @vcvtne2ps2bf16x32(<32 x float> %arg) nounwind alwaysinline {
%1 = shufflevector <32 x float> %arg, <32 x float> undef, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
%2 = shufflevector <32 x float> %arg, <32 x float> undef, <16 x i32> <i32 16, i32 17, i32 18, i32 19, i32 20, i32 21, i32 22, i32 23, i32 24, i32 25, i32 26, i32 27, i32 28, i32 29, i32 30, i32 31>
%3 = tail call <32 x i16> @llvm.x86.avx512bf16.cvtne2ps2bf16.512(<16 x float> %2, <16 x float> %1)
ret <32 x i16> %3
}

declare <32 x i16> @llvm.x86.avx512bf16.cvtne2ps2bf16.512(<16 x float>, <16 x float>)

; LLVM does not have an unmasked version of cvtneps2bf16.128, so provide a wrapper around the masked version.
define weak_odr <4 x i16> @vcvtneps2bf16x4(<4 x float> %arg) nounwind alwaysinline {
%1 = tail call <8 x i16> @llvm.x86.avx512bf16.mask.cvtneps2bf16.128(<4 x float> %arg, <8 x i16> undef, <4 x i1> <i1 true, i1 true, i1 true, i1 true>)
Expand Down
1 change: 1 addition & 0 deletions test/correctness/simd_op_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,7 @@ class SimdOpCheck : public SimdOpCheckTest {
check("vpminsq", 8, min(i64_1, i64_2));
}
if (use_avx512 && target.has_feature(Target::AVX512_SapphireRapids)) {
check("vcvtne2ps2bf16*zmm", 32, cast(BFloat(16), f32_1));
check("vcvtneps2bf16*ymm", 16, cast(BFloat(16), f32_1));
check("vcvtneps2bf16*xmm", 8, cast(BFloat(16), f32_1));
check("vcvtneps2bf16*xmm", 4, cast(BFloat(16), f32_1));
Expand Down

0 comments on commit 1b22dfe

Please sign in to comment.