Skip to content

Commit

Permalink
Add support for AVX512 BF16 dot product (halide#5712)
Browse files Browse the repository at this point in the history
* Add support for AVX512 BF16 dot product

* Match on f32*f32

* Remove f32 check
  • Loading branch information
jwlawson authored Feb 9, 2021
1 parent 3e034d6 commit 3fbb12a
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 11 deletions.
33 changes: 24 additions & 9 deletions src/CodeGen_X86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,12 @@ const x86Intrinsic intrinsic_defs[] = {
{"llvm.x86.avx512bf16.cvtneps2bf16.256", BFloat(16, 8), "f32_to_bf16", {Float(32, 8)}, Target::AVX512_SapphireRapids},
// LLVM does not provide an unmasked 128bit cvtneps2bf16 intrinsic, so provide a wrapper around the masked version.
{"vcvtneps2bf16x4", BFloat(16, 4), "f32_to_bf16", {Float(32, 4)}, Target::AVX512_SapphireRapids},

// Dot product vector reduction
// The LLVM intrinsics combine the bf16 pairs into i32, so provide a wrapper to correctly call the intrinsic.
{"dpbf16psx16", Float(32, 16), "dot_product", {Float(32, 16), BFloat(16, 32), BFloat(16, 32)}, Target::AVX512_SapphireRapids},
{"dpbf16psx8", Float(32, 8), "dot_product", {Float(32, 8), BFloat(16, 16), BFloat(16, 16)}, Target::AVX512_SapphireRapids},
{"dpbf16psx4", Float(32, 4), "dot_product", {Float(32, 4), BFloat(16, 8), BFloat(16, 8)}, Target::AVX512_SapphireRapids},
};
// clang-format on

Expand Down Expand Up @@ -481,9 +487,14 @@ void CodeGen_X86::codegen_vector_reduce(const VectorReduce *op, const Expr &init
Expr pattern;
const char *intrin;
Type narrow_type;
uint32_t flags = 0;
enum {
CombineInit = 1 << 0,
};
};
// clang-format off
static const Pattern patterns[] = {
{2, wild_f32x_ * wild_f32x_, "dot_product", BFloat(16), Pattern::CombineInit},
{2, i32(widening_mul(wild_i16x_, wild_i16x_)), "pmaddwd", Int(16)},
{2, i32(widening_mul(wild_i8x_, wild_i8x_)), "pmaddwd", Int(16)},
{2, i32(widening_mul(wild_i8x_, wild_u8x_)), "pmaddwd", Int(16)},
Expand All @@ -505,17 +516,21 @@ void CodeGen_X86::codegen_vector_reduce(const VectorReduce *op, const Expr &init
Expr b = matches[1];
a = lossless_cast(p.narrow_type.with_lanes(a.type().lanes()), a);
b = lossless_cast(p.narrow_type.with_lanes(b.type().lanes()), b);
internal_assert(a.defined());
internal_assert(b.defined());
if (!a.defined() || !b.defined()) { continue; }

value = call_overloaded_intrin(op->type, p.intrin, {a, b});
if (value) {
if (init.defined()) {
Value *x = value;
Value *y = codegen(init);
value = builder->CreateAdd(x, y);
if (p.flags & Pattern::CombineInit) {
value = call_overloaded_intrin(op->type, p.intrin, {init, a, b});
if (value) { return; }
} else {
value = call_overloaded_intrin(op->type, p.intrin, {a, b});
if (value) {
if (init.defined()) {
Value *x = value;
Value *y = codegen(init);
value = builder->CreateAdd(x, y);
}
return;
}
return;
}
}
}
Expand Down
25 changes: 25 additions & 0 deletions src/runtime/x86_avx512.ll
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,28 @@ define weak_odr <4 x i16> @vcvtneps2bf16x4(<4 x float> %arg) nounwind alwaysinl
}

declare <8 x i16> @llvm.x86.avx512bf16.mask.cvtneps2bf16.128(<4 x float>, <8 x i16>, <4 x i1>)

; The bf16 dot product intrinsics combine the bf16 pairs into single i32 elements, so bitcast the inputs to match.
define weak_odr <16 x float> @dpbf16psx16(<16 x float> %init, <32 x i16> %a, <32 x i16> %b) nounwind alwaysinline {
%1 = bitcast <32 x i16> %a to <16 x i32>
%2 = bitcast <32 x i16> %b to <16 x i32>
%3 = tail call <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512(<16 x float> %init, <16 x i32> %1, <16 x i32> %2)
ret <16 x float> %3
}
declare <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512(<16 x float>, <16 x i32>, <16 x i32>)

define weak_odr <8 x float> @dpbf16psx8(<8 x float> %init, <16 x i16> %a, <16 x i16> %b) nounwind alwaysinline {
%1 = bitcast <16 x i16> %a to <8 x i32>
%2 = bitcast <16 x i16> %b to <8 x i32>
%3 = tail call <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256(<8 x float> %init, <8 x i32> %1, <8 x i32> %2)
ret <8 x float> %3
}
declare <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256(<8 x float>, <8 x i32>, <8 x i32>)

define weak_odr <4 x float> @dpbf16psx4(<4 x float> %init, <8 x i16> %a, <8 x i16> %b) nounwind alwaysinline {
%1 = bitcast <8 x i16> %a to <4 x i32>
%2 = bitcast <8 x i16> %b to <4 x i32>
%3 = tail call <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128(<4 x float> %init, <4 x i32> %1, <4 x i32> %2)
ret <4 x float> %3
}
declare <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128(<4 x float>, <4 x i32>, <4 x i32>)
8 changes: 8 additions & 0 deletions test/correctness/simd_op_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,14 @@ class SimdOpCheck : public SimdOpCheckTest {
check("vcvtneps2bf16*ymm", 16, cast(BFloat(16), f32_1));
check("vcvtneps2bf16*xmm", 8, cast(BFloat(16), f32_1));
check("vcvtneps2bf16*xmm", 4, cast(BFloat(16), f32_1));

{
// 16 bit, 2 element dot product
RDom r(0, 2);
check("vdpbf16ps*zmm", 16, sum(f32(in_bf16(2 * x + r)) * in_bf16(2 * x + r + 32)));
check("vdpbf16ps*ymm", 8, sum(f32(in_bf16(2 * x + r)) * in_bf16(2 * x + r + 32)));
check("vdpbf16ps*xmm", 4, sum(f32(in_bf16(2 * x + r)) * in_bf16(2 * x + r + 32)));
}
}
}

Expand Down
5 changes: 3 additions & 2 deletions test/correctness/simd_op_check.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class SimdOpCheckTest {

ImageParam in_f32{Float(32), 1, "in_f32"};
ImageParam in_f64{Float(64), 1, "in_f64"};
ImageParam in_bf16{BFloat(16), 1, "in_bf16"};
ImageParam in_i8{Int(8), 1, "in_i8"};
ImageParam in_u8{UInt(8), 1, "in_u8"};
ImageParam in_i16{Int(16), 1, "in_i16"};
Expand All @@ -38,8 +39,8 @@ class SimdOpCheckTest {
ImageParam in_i64{Int(64), 1, "in_i64"};
ImageParam in_u64{UInt(64), 1, "in_u64"};

const std::vector<ImageParam> image_params{in_f32, in_f64, in_i8, in_u8, in_i16, in_u16, in_i32, in_u32, in_i64, in_u64};
const std::vector<Argument> arg_types{in_f32, in_f64, in_i8, in_u8, in_i16, in_u16, in_i32, in_u32, in_i64, in_u64};
const std::vector<ImageParam> image_params{in_f32, in_f64, in_bf16, in_i8, in_u8, in_i16, in_u16, in_i32, in_u32, in_i64, in_u64};
const std::vector<Argument> arg_types{in_f32, in_f64, in_bf16, in_i8, in_u8, in_i16, in_u16, in_i32, in_u32, in_i64, in_u64};
int W;
int H;

Expand Down

0 comments on commit 3fbb12a

Please sign in to comment.