Skip to content

Commit

Permalink
fix missing-prototypes warnings in torch_cpu (Part 6) (pytorch#101845)
Browse files Browse the repository at this point in the history
This PR fixes more missing-prototypes violations in the torch_cpu source following PRs pytorch#100053, pytorch#100147, pytorch#100245, pytorch#100849 and pytorch#101788

Pull Request resolved: pytorch#101845
Approved by: https://github.com/albanD
  • Loading branch information
cyyever authored and pytorchmergebot committed Jun 15, 2023
1 parent e75f799 commit f290042
Show file tree
Hide file tree
Showing 26 changed files with 132 additions and 100 deletions.
10 changes: 5 additions & 5 deletions aten/src/ATen/core/ivalue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ bool operator==(const ivalue::Tuple& lhs, const ivalue::Tuple& rhs) {
_fastEqualsForContainer);
}

std::ostream& operator<<(std::ostream& out, const ivalue::EnumHolder& v) {
out << v.qualifiedClassName() << "." << v.name();
return out;
}

bool operator==(const ivalue::EnumHolder& lhs, const ivalue::EnumHolder& rhs) {
return lhs.name() == rhs.name() && *rhs.type() == *lhs.type();
}
Expand Down Expand Up @@ -763,11 +768,6 @@ IValueComparator getGreaterThanComparator(const IValue& v) {
};
}

std::ostream& operator<<(std::ostream& out, const ivalue::EnumHolder& v) {
out << v.qualifiedClassName() << "." << v.name();
return out;
}

std::ostream& operator<<(std::ostream & out, const IValue & v) {
auto formatter = [&](std::ostream& out, const IValue& v) {
out << v;
Expand Down
13 changes: 9 additions & 4 deletions aten/src/ATen/native/TensorFactories.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,15 +171,15 @@ Tensor arange(
return at::arange_out(result, start, end, step);
}

Tensor& arange_start_out(const Scalar& start, const Scalar& end, Tensor& result) {
static Tensor& arange_start_out(const Scalar& start, const Scalar& end, Tensor& result) {
return at::arange_out(result, start, end, /*step=*/1);
}

Tensor& arange_out(const Scalar& end, Tensor& result) {
return at::arange_out(result, /*start=*/0, end, /*step=*/1);
}

Tensor& arange_out(Tensor& result, const Scalar& start, const Scalar& end) {
static Tensor& arange_out(Tensor& result, const Scalar& start, const Scalar& end) {
return at::arange_out(result, start, end, /*step=*/1);
}

Expand All @@ -189,14 +189,14 @@ Tensor _dim_arange(const Tensor& like, int64_t dim) {

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ complex / polar ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

void complex_check_floating(const Tensor& a, const Tensor& b) {
static void complex_check_floating(const Tensor& a, const Tensor& b) {
TORCH_CHECK((a.scalar_type() == kFloat || a.scalar_type() == kDouble || a.scalar_type() == kHalf) &&
(b.scalar_type() == kFloat || b.scalar_type() == kDouble || b.scalar_type() == kHalf),
"Expected both inputs to be Half, Float or Double tensors but got ",
a.scalar_type(), " and ", b.scalar_type());
}

void complex_check_dtype(
static void complex_check_dtype(
const Tensor& result,
const Tensor& a,
const Tensor& b) {
Expand Down Expand Up @@ -352,7 +352,12 @@ Tensor& empty_out(IntArrayRef size,
return self.to(ScalarType::n, non_blocking); \
}

// Some scalar types in CAST_OP have no declarations, they may be unused in Pytorch.
// But we keep them and ignore the warning here until verified in the future.
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wmissing-prototypes"
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, DEFINE_CAST_OP)
#pragma clang diagnostic pop

#undef DEFINE_CAST_OP

Expand Down
64 changes: 27 additions & 37 deletions aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,7 @@ static void modified_bessel_k1_kernel(TensorIteratorBase& iterator) {

#define IMPLEMENT_FLOAT_KERNEL(op) \
inline namespace CPU_CAPABILITY { \
void op##_kernel(TensorIteratorBase& iter) { \
static void op##_kernel(TensorIteratorBase& iter) { \
TORCH_INTERNAL_ASSERT(iter.ntensors() == 2); \
AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), #op "_vml_cpu", [&]() { \
constexpr int64_t grain_size = 2048; \
Expand All @@ -715,6 +715,19 @@ static void modified_bessel_k1_kernel(TensorIteratorBase& iterator) {
} \
REGISTER_DISPATCH(op##_stub, &CPU_CAPABILITY::op##_kernel)

#define STATIC_IMPLEMENT_COMPLEX_KERNEL(op) \
inline namespace CPU_CAPABILITY { \
static void op##_kernel(TensorIteratorBase& iter) { \
TORCH_INTERNAL_ASSERT(iter.ntensors() == 2); \
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), #op "_vml_cpu", [&]() { \
constexpr int64_t grain_size = 2048; \
iter.for_each(IMPLEMENT_ITERATOR_LAMBDA(op), grain_size); \
}); \
iter.cast_outputs(); \
} \
} \
REGISTER_DISPATCH(op##_stub, &CPU_CAPABILITY::op##_kernel)

} // CPU_CAPABILITY namespace

REGISTER_DISPATCH(rsqrt_stub, &CPU_CAPABILITY::rsqrt_kernel);
Expand Down Expand Up @@ -761,51 +774,28 @@ REGISTER_DISPATCH(special_modified_bessel_i1_stub, &CPU_CAPABILITY::modified_bes
REGISTER_DISPATCH(special_modified_bessel_k0_stub, &CPU_CAPABILITY::modified_bessel_k0_kernel);
REGISTER_DISPATCH(special_modified_bessel_k1_stub, &CPU_CAPABILITY::modified_bessel_k1_kernel);

// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
IMPLEMENT_COMPLEX_KERNEL(acos)
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
IMPLEMENT_COMPLEX_KERNEL(asin)
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
IMPLEMENT_COMPLEX_KERNEL(atan)
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
STATIC_IMPLEMENT_COMPLEX_KERNEL(acos)
STATIC_IMPLEMENT_COMPLEX_KERNEL(asin)
STATIC_IMPLEMENT_COMPLEX_KERNEL(atan)
IMPLEMENT_FLOAT_KERNEL(ceil)
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
IMPLEMENT_COMPLEX_KERNEL(cos)
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
STATIC_IMPLEMENT_COMPLEX_KERNEL(cos)
IMPLEMENT_FLOAT_KERNEL(erf)
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
IMPLEMENT_FLOAT_KERNEL(erfc)
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
IMPLEMENT_FLOAT_KERNEL(erfinv)
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
IMPLEMENT_COMPLEX_KERNEL(exp)
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
IMPLEMENT_COMPLEX_KERNEL(expm1)
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
STATIC_IMPLEMENT_COMPLEX_KERNEL(exp)
STATIC_IMPLEMENT_COMPLEX_KERNEL(expm1)
IMPLEMENT_FLOAT_KERNEL(floor)
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
IMPLEMENT_COMPLEX_KERNEL(log)
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
IMPLEMENT_COMPLEX_KERNEL(log10)
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
IMPLEMENT_COMPLEX_KERNEL(log1p)
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
IMPLEMENT_COMPLEX_KERNEL(log2)
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
STATIC_IMPLEMENT_COMPLEX_KERNEL(log)
STATIC_IMPLEMENT_COMPLEX_KERNEL(log10)
STATIC_IMPLEMENT_COMPLEX_KERNEL(log1p)
STATIC_IMPLEMENT_COMPLEX_KERNEL(log2)
IMPLEMENT_FLOAT_KERNEL(i0)
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
IMPLEMENT_FLOAT_KERNEL(round)
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
IMPLEMENT_COMPLEX_KERNEL(sin)
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
STATIC_IMPLEMENT_COMPLEX_KERNEL(sin)
IMPLEMENT_COMPLEX_KERNEL(sqrt)
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
IMPLEMENT_COMPLEX_KERNEL(tan)
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
IMPLEMENT_COMPLEX_KERNEL(tanh)
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
STATIC_IMPLEMENT_COMPLEX_KERNEL(tan)
STATIC_IMPLEMENT_COMPLEX_KERNEL(tanh)
IMPLEMENT_FLOAT_KERNEL(trunc)
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
IMPLEMENT_FLOAT_KERNEL(lgamma)

} // namespace at::native
14 changes: 7 additions & 7 deletions aten/src/ATen/native/mkldnn/Conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,47 +27,47 @@ Tensor mkldnn_convolution(
TORCH_CHECK(false, "mkldnn_convolution_forward: ATen not compiled with MKLDNN support");
}

Tensor mkldnn_convolution_backward_input(
static Tensor mkldnn_convolution_backward_input(
IntArrayRef input_size, const Tensor& grad_output, const Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) {
TORCH_CHECK(false, "mkldnn_convolution_backward_input: ATen not compiled with MKLDNN support");
}

std::tuple<Tensor, Tensor> mkldnn_convolution_backward_weights(
static std::tuple<Tensor, Tensor> mkldnn_convolution_backward_weights(
IntArrayRef weight_size, const Tensor& grad_output, const Tensor& input,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) {
TORCH_CHECK(false, "mkldnn_convolution_backward_weights: ATen not compiled with MKLDNN support");
}

std::tuple<Tensor, Tensor, Tensor> mkldnn_convolution_backward(
static std::tuple<Tensor, Tensor, Tensor> mkldnn_convolution_backward(
const Tensor& input, const Tensor& grad_output_t, const Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, std::array<bool,3> output_mask) {
TORCH_CHECK(false, "mkldnn_convolution_backward: ATen not compiled with MKLDNN support");
}

REGISTER_NO_CPU_DISPATCH(mkldnn_convolution_backward_stub);

Tensor mkldnn_convolution_transpose(
static Tensor mkldnn_convolution_transpose(
const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups) {
TORCH_CHECK(false, "mkldnn_convolution_transpose: ATen not compiled with MKLDNN support");
}

Tensor mkldnn_convolution_transpose_backward_input(
static Tensor mkldnn_convolution_transpose_backward_input(
IntArrayRef input_size, const Tensor& grad_output, const Tensor& weight,
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation,
int64_t groups, bool bias_defined) {
TORCH_CHECK(false, "mkldnn_convolution_transpose_backward_input: ATen not compiled with MKLDNN support");
}

std::tuple<Tensor, Tensor> mkldnn_convolution_transpose_backward_weights(
static std::tuple<Tensor, Tensor> mkldnn_convolution_transpose_backward_weights(
IntArrayRef weight_size, const Tensor& grad_output, const Tensor& input,
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation,
int64_t groups, bool bias_defined) {
TORCH_CHECK(false, "mkldnn_convolution_transpose_backward_weights: ATen not compiled with MKLDNN support");
}

std::tuple<Tensor, Tensor, Tensor> mkldnn_convolution_transpose_backward(
static std::tuple<Tensor, Tensor, Tensor> mkldnn_convolution_transpose_backward(
const Tensor& input, const Tensor& grad_output_t, const Tensor& weight,
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation,
int64_t groups, std::array<bool,3> output_mask) {
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/native/mkldnn/Normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_native_batch_norm_legit_native.h>
#include <ATen/ops/_to_dense_native.h>
#include <ATen/ops/empty_native.h>
#include <ATen/ops/native_batch_norm_backward_native.h>
Expand Down Expand Up @@ -34,7 +35,7 @@ std::tuple<Tensor, Tensor, Tensor> mkldnn_batch_norm_backward(
TORCH_CHECK(false, "mkldnn_batch_norm_backward: ATen not compiled with MKLDNN support");
}

std::tuple<Tensor, Tensor, Tensor> mkldnn_layer_norm_last_index_weight_bias_f32(
static std::tuple<Tensor, Tensor, Tensor> mkldnn_layer_norm_last_index_weight_bias_f32(
const Tensor& input,
IntArrayRef normalized_shape, const Tensor& weight, const Tensor& bias,
double eps, bool inplace) {
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/mkldnn/TensorShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ namespace at {
namespace native {


Tensor mkldnn_view_symint(const Tensor& self, c10::SymIntArrayRef size) {
static Tensor mkldnn_view_symint(const Tensor& self, c10::SymIntArrayRef size) {
return mkldnn_view(self, C10_AS_INTARRAYREF_SLOW(size));
}

Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/native/prim_native_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/is_nonzero_native.h>
#include <ATen/ops/_foobar_native.h>
#include <ATen/ops/_test_functorch_fallback_native.h>
#endif

Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/native/quantized/cpu/Pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <ATen/ops/quantized_max_pool1d_native.h>
#include <ATen/ops/quantized_max_pool2d.h>
#include <ATen/ops/quantized_max_pool2d_native.h>
#include <ATen/ops/quantized_max_pool3d_native.h>
#endif

#include <algorithm>
Expand Down
15 changes: 2 additions & 13 deletions aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <ATen/ops/_sparse_csr_tensor_unsafe_native.h>
#include <ATen/ops/_sparse_mm_reduce_impl_backward_native.h>
#include <ATen/ops/_sparse_mm_reduce_impl_backward_native.h>
#include <ATen/ops/_sparse_mm_reduce_impl_native.h>
#include <ATen/ops/_unique.h>
#include <ATen/ops/abs.h>
#include <ATen/ops/abs_native.h>
Expand Down Expand Up @@ -464,7 +465,7 @@ CREATE_UNARY_UFUNC(tan);
CREATE_UNARY_UFUNC(tanh);
CREATE_UNARY_UFUNC(trunc);
CREATE_UNARY_UFUNC(conj_physical);
CREATE_UNARY_UFUNC(relu);
static CREATE_UNARY_UFUNC(relu);

// With addition of `round.decimals` overload, using CREATE_UNARY_UFUNC leads
// to unresolved overload.
Expand Down Expand Up @@ -776,18 +777,6 @@ Tensor _sparse_csr_mm(const Tensor& mat1, const Tensor& mat2) {
1.0);
}

Tensor _sparse_csr_addmm(
const Tensor& t,
const SparseCsrTensor& sparse,
const Tensor& dense,
const Scalar& beta,
const Scalar& alpha) {
// _sparse_addmm forward is functionally equivalent to addmm; it's
// just the backward that is different. This technically does an
// unnecessary redispatch, I was too lazy to make it not do that
return at::addmm(t, sparse, dense, beta, alpha);
}

// Functions for element-wise addition.
Tensor add_sparse_csr(
const Tensor& self,
Expand Down
5 changes: 5 additions & 0 deletions aten/src/ATen/native/sparse/SparseUnaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,12 @@ COALESCED_UNARY_UFUNC(sqrt);
COALESCED_UNARY_UFUNC(tan);
COALESCED_UNARY_UFUNC(tanh);
COALESCED_UNARY_UFUNC(trunc);
// relu function has no declaration, it may be unused in Pytorch.
// But we keep it and ignore the warning here until verified in the future.
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wmissing-prototypes"
COALESCED_UNARY_UFUNC(relu);
#pragma clang diagnostic pop

COALESCED_UNARY_UFUNC_NO_INPLACE(signbit);
COALESCED_UNARY_UFUNC_NO_INPLACE(isneginf);
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/native/xnnpack/Activation.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifdef USE_XNNPACK

#include <ATen/native/xnnpack/Common.h>
#include <ATen/native/xnnpack/Engine.h>
#include <ATen/native/utils/Factory.h>

namespace at {
Expand All @@ -18,7 +19,7 @@ bool use_hardswish(
true;
}

Tensor& hardswish_impl(Tensor& input, Tensor& output) {
static Tensor& hardswish_impl(Tensor& input, Tensor& output) {
using namespace internal;

xnn_operator_t hardswish_op{};
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/native/xnnpack/AveragePooling.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#ifdef USE_XNNPACK

#include <ATen/native/xnnpack/Common.h>
#include <ATen/native/utils/Factory.h>
#include <ATen/native/xnnpack/Common.h>
#include <ATen/native/xnnpack/Engine.h>
#include <ATen/native/xnnpack/Pooling.h>

namespace at {
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/native/xnnpack/ChannelShuffle.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifdef USE_XNNPACK

#include <ATen/native/xnnpack/Common.h>
#include <ATen/native/xnnpack/Engine.h>
#include <ATen/native/utils/Factory.h>

namespace at {
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/native/xnnpack/Convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

#include <vector>

#include <ATen/native/xnnpack/Common.h>
#include <ATen/native/ConvUtils.h>
#include <ATen/native/utils/Factory.h>
#include <ATen/native/utils/ParamUtils.h>
#include <ATen/native/xnnpack/Common.h>
#include <ATen/native/xnnpack/Convolution.h>
#include <ATen/native/xnnpack/Engine.h>
#include <c10/util/irange.h>

namespace at {
Expand Down
9 changes: 9 additions & 0 deletions aten/src/ATen/native/xnnpack/Convolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,15 @@ Tensor run(ContextConv2D& context, const Tensor& input);

} // namespace convolution2d
} // namespace internal

Tensor convolution2d(
const Tensor& input,
const Tensor& weight,
const Tensor& bias,
const IntArrayRef padding,
const IntArrayRef stride,
const IntArrayRef dilation,
const int64_t groups);
} // namespace xnnpack
} // namespace native
} // namespace at
Expand Down
10 changes: 10 additions & 0 deletions aten/src/ATen/native/xnnpack/Linear.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,16 @@ ContextLinear create(
Tensor run(const ContextLinear& context, const Tensor& input);
} // namespace linear
} // namespace internal

bool use_linear(
const Tensor& input,
const Tensor& weight,
const Tensor& bias);

Tensor linear(
const Tensor& input,
const Tensor& weight,
const Tensor& bias);
} // namespace xnnpack
} // namespace native
} // namespace at
Expand Down
Loading

0 comments on commit f290042

Please sign in to comment.