forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathIntReprQuant.cu
38 lines (34 loc) · 958 Bytes
/
IntReprQuant.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/Dispatch.h>
#include <ATen/TensorIterator.h>
#include <ATen/native/cuda/Loops.cuh>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/empty.h>
#include <ATen/ops/int_repr_native.h>
#endif
namespace at {
namespace native {
Tensor int_repr_quantized_cuda(const Tensor& self) {
Tensor dst;
AT_DISPATCH_QINT_TYPES(self.scalar_type(), "int_repr_quantized_cuda", [&]() {
dst = at::empty(
self.sizes(),
self.options().dtype(UNDERLYING_TYPE),
self.suggest_memory_format());
auto iter = TensorIteratorConfig()
.check_all_same_dtype(false)
.add_output(dst)
.add_input(self)
.build();
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t value) -> underlying_t {
return value.val_;
});
});
return dst;
}
} // namespace native
} // namespace at