Skip to content

Commit

Permalink
[ONEDNN][BC-breaking] update onednn from v2.7.3 to v3.1.1 (pytorch#97957
Browse files Browse the repository at this point in the history
)

**Summary**
Update onednn from v2.7.3 to v3.1.1.
It is bc-breaking as some APIs are changed on oneDNN side. Changes include:
- PyTorch code where oneDNN is directly called
- Submodule `third_party/ideep` to adapt to oneDNN's new API.
- CMAKE files to fix build issues.

**Test plan**
Building issues and correctness are covered by CI checks.
For performance, we have run TorchBench models to ensure there is no regression. Below is the comparison before and after oneDNN update.
![image](https://github.com/pytorch/pytorch/assets/12522207/415a4ff0-7566-40c6-aed0-24997a475b0e)

Note:
- Base commit of PyTorch: da322ea
- CPU: Intel(R) Xeon(R) Platinum 8380 CPU @ 2.30GHz (Ice Lake)

Pull Request resolved: pytorch#97957
Approved by: https://github.com/jgong5, https://github.com/jerryzh168
  • Loading branch information
Xia-Weiwen authored and pytorchmergebot committed Aug 25, 2023
1 parent ff37f60 commit 97a291f
Show file tree
Hide file tree
Showing 19 changed files with 164 additions and 189 deletions.
3 changes: 2 additions & 1 deletion aten/src/ATen/native/RNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ bool use_mkldnn(const Tensor& input, TensorList params, TensorList hx) {
};
return input.options().backend() == at::Backend::CPU &&
is_cpu_backend(params) && is_cpu_backend(hx) &&
(input.scalar_type() == kFloat || input.scalar_type() == kBFloat16);
(input.scalar_type() == kFloat || input.scalar_type() == kBFloat16) &&
input.numel() != 0;
#endif
return false;
}
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/mkldnn/Conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ Tensor mkldnn_convolution_pointwise_binary(
ideep::post_ops po;
po.append_binary(it_binary->second, other_desc);
if (unary_attr_value != "none") {
po.append_eltwise(1.0, unary_alg, 0.f, 0.f);
po.append_eltwise(unary_alg, 0.f, 0.f);
}
op_attr.set_post_ops(po);

Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/mkldnn/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ AttrFunction attr_func_hardsigmoid =
ideep::attr_t attr;
ideep::post_ops po;
po.append_eltwise(
1.0f, ideep::algorithm::eltwise_hardsigmoid, 1.0f / 6.0f, 0.5f);
ideep::algorithm::eltwise_hardsigmoid, 1.0f / 6.0f, 0.5f);
attr.set_post_ops(po);
return attr;
};
Expand Down
41 changes: 2 additions & 39 deletions aten/src/ATen/native/quantized/cpu/OnednnUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,20 +55,7 @@ struct LinearPrimitiveCache : PrimitiveCache {
this->param = param;
}

LinearPrimitiveCache(
const PrimitiveCacheKey& key,
const LinearParams& param,
const ideep::tensor& bias) {
this->key = key;
this->param = param;
if (!bias.is_empty()) {
expected_bias =
bias.reorder_if_differ_in(param.pd.bias_desc(), param.bias_attr);
}
}

LinearParams param;
ideep::tensor expected_bias;

// For dynamic qlinear, scale and zero point
// are set at execution time. So we only need to compare
Expand All @@ -84,64 +71,40 @@ struct LinearPrimitiveCache : PrimitiveCache {
LinearParams& get_param() {
return param;
}

ideep::tensor& get_expected_bias() {
return expected_bias;
}
};

struct ConvPrimitiveCache : PrimitiveCache {
ConvPrimitiveCache() {}

ConvPrimitiveCache(
const PrimitiveCacheKey& key,
const ConvParams& params,
const ideep::tensor& bias) {
const ConvParams& params) {
this->key = key;
this->params = params;
if (!bias.is_empty()) {
this->expected_bias =
bias.reorder_if_differ_in(params.pd.bias_desc(), params.bias_attr);
}
}

ideep::tensor expected_bias;
ConvParams params;

ConvParams& get_params() {
return params;
}

ideep::tensor& get_bias() {
return expected_bias;
}
};

struct DeconvPrimitiveCache : PrimitiveCache {
DeconvPrimitiveCache() {}

DeconvPrimitiveCache(
const PrimitiveCacheKey& key,
const DeconvParams& params,
const ideep::tensor& bias) {
const DeconvParams& params) {
this->key = key;
this->params = params;
if (!bias.is_empty()) {
this->expected_bias =
bias.reorder_if_differ_in(params.pd.bias_desc(), params.bias_attr);
}
}

DeconvParams params;
ideep::tensor expected_bias;

DeconvParams& get_params() {
return params;
}

ideep::tensor& get_bias() {
return expected_bias;
}
};

enum PostOps {
Expand Down
21 changes: 8 additions & 13 deletions aten/src/ATen/native/quantized/cpu/qconv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1148,8 +1148,8 @@ at::Tensor PackedConvWeightsOnednn<kSpatialDim>::apply_impl(

// has_accum: extra input besides the conv to do conv add fusion.
bool has_accum = accum.has_value() ? true : false;
auto& ctx = at::globalContext();
if (has_accum) {
auto& ctx = at::globalContext();
func_name += "_add";
TORCH_CHECK(
!transpose(),
Expand All @@ -1172,8 +1172,7 @@ at::Tensor PackedConvWeightsOnednn<kSpatialDim>::apply_impl(
auto src_data_type = dnnl::memory::data_type::u8;
auto src_desc = ideep::tensor::desc(src_dims, src_data_type,
kSpatialDim == 2 ? ideep::format_tag::nhwc : ideep::format_tag::ndhwc);
ideep::tensor src;
src.init(src_desc, act_contig.data_ptr());
ideep::tensor src(src_desc, act_contig.data_ptr());
// weights & bias
ideep::tensor& weights = *(weight_.get());
bool with_bias = bias_.has_value();
Expand Down Expand Up @@ -1262,11 +1261,9 @@ at::Tensor PackedConvWeightsOnednn<kSpatialDim>::apply_impl(
// The true scale and zero point is stored in ideep::scale_t(scale_size, inv_output_scale) and dst_zero_points.
dst.set_scale(accum_scale);
dst.set_zero_point(accum_zero_points);
} else {
op_attr = kReluFused ? ideep::attr_t::fuse_relu() : ideep::attr_t();
} else if (kReluFused) {
op_attr = ideep::attr_t::fuse_relu();
}
// Since src zero point is unknown, set runtime value here
op_attr.set_zero_points(DNNL_ARG_SRC, ideep::utils::tensor_zp_mask(1), src_zero_points);

// Bias might be modified outside (e.g. by quantization bias correction).
// If so, update the prepacked bias as well.
Expand All @@ -1290,15 +1287,14 @@ at::Tensor PackedConvWeightsOnednn<kSpatialDim>::apply_impl(
dnnl::algorithm::deconvolution_direct,
dnnl::prop_kind::forward_inference,
ideep::u8s8, ideep::engine::cpu_engine());
get_deconv_cache() = DeconvPrimitiveCache(cache_key, params, b);
get_deconv_cache() = DeconvPrimitiveCache(cache_key, params);
auto expected_weight_desc = ideep::tensor::desc(params.pd.weights_desc(), groups());
weights = weights.reorder_if_differ_in(expected_weight_desc);
});
if (get_deconv_cache().hit(cache_key)) {
DeconvParams& params = get_deconv_cache().get_params();
auto& expected_bias = get_deconv_cache().get_bias();
ideep::convolution_transpose_forward::compute<false, false>(
params, src, weights, expected_bias, dst);
params, src, weights, b, dst);
} else {
ideep::convolution_transpose_forward::compute(
src, weights, b, dst_dims, dst,
Expand All @@ -1323,15 +1319,14 @@ at::Tensor PackedConvWeightsOnednn<kSpatialDim>::apply_impl(
op_attr, dnnl::algorithm::convolution_direct,
dnnl::prop_kind::forward_inference,
ideep::u8s8, ideep::engine::cpu_engine());
get_conv_cache() = ConvPrimitiveCache(cache_key, params, b);
get_conv_cache() = ConvPrimitiveCache(cache_key, params);
auto expected_weight_desc = ideep::tensor::desc(params.pd.weights_desc(), groups());
weights = weights.reorder_if_differ_in(expected_weight_desc);
});
// If hit, use cached data. If miss, fall back to normal path.
if (get_conv_cache().hit(cache_key)) {
auto& params = get_conv_cache().get_params();
auto& expected_bias = get_conv_cache().get_bias();
ideep::convolution_forward::compute<false, false>(params, src, weights, expected_bias, dst);
ideep::convolution_forward::compute<false, false>(params, src, weights, b, dst);
} else {
ideep::convolution_forward::compute(
src, weights, b, dst_dims, dst,
Expand Down
5 changes: 1 addition & 4 deletions aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -398,10 +398,7 @@ c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> PackedConvWeightsOnednn<
}

// Set runtime src zero point
auto src_zero_point = {DNNL_RUNTIME_S32_VAL};
op_attr.set_zero_points(DNNL_ARG_SRC,
ideep::utils::tensor_zp_mask(src_zero_point.size()),
src_zero_point);
op_attr.set_zero_points_mask(DNNL_ARG_SRC, /* zero_points_mask= */0);
at::Tensor weight_copy;
ideep::tensor::desc w_desc;
ideep::dims dims_iohw, dims_giohw;
Expand Down
5 changes: 2 additions & 3 deletions aten/src/ATen/native/quantized/cpu/qlinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -847,13 +847,12 @@ at::Tensor PackedLinearWeightsOnednn::apply_impl(
params, x, w, b, y,
src_scales, weights_scales, dst_scales,
src_zero_point, dst_zero_point, 1.0f, 1.0f, op_attr);
get_cache() = LinearPrimitiveCache(cache_key, params, b);
get_cache() = LinearPrimitiveCache(cache_key, params);
w = w.reorder_if_differ_in(params.pd.weights_desc());
});
if (get_cache().hit(cache_key)) {
LinearParams& params = get_cache().get_param();
auto& expected_bias = get_cache().get_expected_bias();
ideep::matmul_forward::compute<false, false>(params, x, w, expected_bias, y);
ideep::matmul_forward::compute<false, false>(params, x, w, b, y);
} else {
ideep::matmul_forward::compute(x, w, b, y, src_scales, weights_scales,
dst_scales, src_zero_point, dst_zero_point,
Expand Down
7 changes: 5 additions & 2 deletions aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,11 @@ c10::intrusive_ptr<LinearPackedParamsBase> PackedLinearWeightsOnednn::prepack(
auto weight_copy = weight.clone();
ideep::tensor wgt = ideep::tensor({dims, dnnl::memory::data_type::s8}, weight_copy.data_ptr());
wgt.transpose_(0, 1); // ONEDNN requires transposed weight
auto w_desc = ideep::matmul_forward::expected_weights_desc(wgt.get_dims(), dnnl::memory::data_type::s8,
dnnl::memory::data_type::u8);
auto src_dims = ideep::dims(); // Unknown when prepacking
ideep::attr_t op_attr;
op_attr.set_zero_points_mask(DNNL_ARG_SRC, 0);
auto w_desc = ideep::matmul_forward::expected_weights_desc(wgt.get_dims(), src_dims, dnnl::memory::data_type::s8,
dnnl::memory::data_type::u8, op_attr);
ideep::tensor exp_wgt(w_desc);
exp_wgt.feed_from(wgt);
ideep::tensor * packed_weight_p = new ideep::tensor(std::move(exp_wgt));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class IDEEPInt8FullyConnectedOp final : public IDEEPOperator {
Y_.init({{X.get_dim(0), filter.get_dim(0)}, idtype::f32});
}

X_in = X_in.to_public();
if (InputSize() > BIAS) {
ideep::inner_product_forward::compute(
X_in, filter_, bias_, Y_);
Expand Down
27 changes: 6 additions & 21 deletions cmake/Modules/FindMKLDNN.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,24 @@ IF(NOT MKLDNN_FOUND)
SET(MKLDNN_INCLUDE_DIR)

SET(IDEEP_ROOT "${PROJECT_SOURCE_DIR}/third_party/ideep")
SET(MKLDNN_ROOT "${PROJECT_SOURCE_DIR}/third_party/ideep/mkl-dnn/third_party/oneDNN")
SET(MKLDNN_ROOT "${PROJECT_SOURCE_DIR}/third_party/ideep/mkl-dnn")
IF(NOT APPLE AND NOT WIN32 AND NOT BUILD_LITE_INTERPRETER)
MESSAGE("-- Will build oneDNN Graph")
SET(LLGA_ROOT "${PROJECT_SOURCE_DIR}/third_party/ideep/mkl-dnn")
SET(BUILD_ONEDNN_GRAPH ON)
SET(ONEDNN_BUILD_GRAPH ON CACHE BOOL "" FORCE)
ENDIF(NOT APPLE AND NOT WIN32 AND NOT BUILD_LITE_INTERPRETER)

FIND_PACKAGE(BLAS)
FIND_PATH(IDEEP_INCLUDE_DIR ideep.hpp PATHS ${IDEEP_ROOT} PATH_SUFFIXES include)
FIND_PATH(MKLDNN_INCLUDE_DIR dnnl.hpp dnnl.h PATHS ${MKLDNN_ROOT} PATH_SUFFIXES include)
FIND_PATH(MKLDNN_INCLUDE_DIR dnnl.hpp dnnl.h PATHS ${MKLDNN_ROOT} PATH_SUFFIXES include/oneapi/dnnl)
IF(NOT MKLDNN_INCLUDE_DIR)
MESSAGE("MKLDNN_INCLUDE_DIR not found")
EXECUTE_PROCESS(COMMAND git${CMAKE_EXECUTABLE_SUFFIX} submodule update --init mkl-dnn WORKING_DIRECTORY ${IDEEP_ROOT})
FIND_PATH(MKLDNN_INCLUDE_DIR dnnl.hpp dnnl.h PATHS ${MKLDNN_ROOT} PATH_SUFFIXES include)
ENDIF(NOT MKLDNN_INCLUDE_DIR)
IF(BUILD_ONEDNN_GRAPH)
FIND_PATH(LLGA_INCLUDE_DIR oneapi/dnnl/dnnl_graph.hpp PATHS ${LLGA_ROOT} PATH_SUFFIXES include)
FIND_PATH(LLGA_INCLUDE_DIR dnnl_graph.hpp PATHS ${LLGA_ROOT} PATH_SUFFIXES include/oneapi/dnnl)
ENDIF(BUILD_ONEDNN_GRAPH)

IF(NOT IDEEP_INCLUDE_DIR OR NOT MKLDNN_INCLUDE_DIR)
Expand Down Expand Up @@ -93,21 +95,7 @@ IF(NOT MKLDNN_FOUND)
ENDIF()
ENDIF()

IF(BUILD_ONEDNN_GRAPH)
ADD_SUBDIRECTORY(${LLGA_ROOT})
IF(NOT TARGET dnnl_graph)
MESSAGE("Failed to include LLGA target")
RETURN()
ENDIF(NOT TARGET dnnl_graph)

IF(CMAKE_COMPILER_IS_GNUCC)
TARGET_COMPILE_OPTIONS(dnnl_graph PRIVATE -Wno-maybe-uninitialized)
TARGET_COMPILE_OPTIONS(dnnl_graph PRIVATE -Wno-strict-overflow)
TARGET_COMPILE_OPTIONS(dnnl_graph PRIVATE -Wno-error=strict-overflow)
ENDIF(CMAKE_COMPILER_IS_GNUCC)
ELSE(BUILD_ONEDNN_GRAPH)
ADD_SUBDIRECTORY(${MKLDNN_ROOT})
ENDIF(BUILD_ONEDNN_GRAPH)
ADD_SUBDIRECTORY(${MKLDNN_ROOT})

IF(NOT TARGET dnnl)
MESSAGE("Failed to include MKL-DNN target")
Expand All @@ -120,9 +108,6 @@ IF(NOT MKLDNN_FOUND)
TARGET_COMPILE_OPTIONS(dnnl PRIVATE -Wno-error=strict-overflow)
ENDIF(NOT APPLE AND CMAKE_COMPILER_IS_GNUCC)
LIST(APPEND MKLDNN_LIBRARIES ${MKL_OPENMP_LIBRARY})
IF(BUILD_ONEDNN_GRAPH)
LIST(APPEND MKLDNN_LIBRARIES "$<TARGET_FILE:dnnl_graph>")
ENDIF(BUILD_ONEDNN_GRAPH)
LIST(APPEND MKLDNN_LIBRARIES dnnl)

SET(MKLDNN_FOUND TRUE)
Expand Down
12 changes: 0 additions & 12 deletions cmake/public/mkldnn.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,3 @@ set_property(
set_property(
TARGET caffe2::mkldnn PROPERTY INTERFACE_LINK_LIBRARIES
${MKLDNN_LIBRARIES})
if(BUILD_ONEDNN_GRAPH)
if(NOT TARGET caffe2::dnnl_graph)
add_library(caffe2::dnnl_graph INTERFACE IMPORTED)
endif()

set_property(
TARGET caffe2::dnnl_graph PROPERTY INTERFACE_INCLUDE_DIRECTORIES
${MKLDNN_INCLUDE_DIR})
set_property(
TARGET caffe2::dnnl_graph PROPERTY INTERFACE_LINK_LIBRARIES
${MKLDNN_LIBRARIES})
endif()
1 change: 1 addition & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3295,6 +3295,7 @@ def perm_fn(x):
self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
np.testing.assert_allclose(result, ref_output, atol=1e-5)

@set_default_dtype(torch.double)
def test_transformerdecoderlayer_gelu(self):
# this is a deterministic test for TransformerDecoderLayer with gelu activation
d_model = 4
Expand Down
Loading

0 comments on commit 97a291f

Please sign in to comment.