forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathAffineQuantizer.cu
271 lines (242 loc) · 9.22 KB
/
AffineQuantizer.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/TensorIterator.h>
#include <ATen/native/quantized/AffineQuantizer.h>
#include <cmath>
#include <ATen/native/cuda/Loops.cuh>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_unsafe_view_native.h>
#include <ATen/ops/any.h>
#include <ATen/ops/gt.h>
#include <ATen/ops/lt.h>
#endif
namespace at {
namespace native {
namespace {
template <typename T>
void check_zero_points_cuda(
const std::string& fn_name,
const Tensor& zero_points) {
constexpr int64_t qmin = std::numeric_limits<T>::min();
constexpr int64_t qmax = std::numeric_limits<T>::max();
auto zp_within_upper = at::any(at::gt(zero_points, qmax)).item().equal(false);
auto zp_within_lower = at::any(at::lt(zero_points, qmin)).item().equal(false);
TORCH_CHECK(
zp_within_lower,
fn_name,
"zero_point is below lower bound.");
TORCH_CHECK(
zp_within_upper,
fn_name,
"zero_point is above upper bound.");
}
void quantize_tensor_per_tensor_affine_cuda(
const Tensor& rtensor,
Tensor& qtensor,
double scale,
int64_t zero_point) {
AT_DISPATCH_QINT_TYPES(
qtensor.scalar_type(), "quantize_tensor_per_tensor_affine_cuda", [&]() {
constexpr int64_t qmin = std::numeric_limits<underlying_t>::min();
constexpr int64_t qmax = std::numeric_limits<underlying_t>::max();
auto iter = TensorIteratorConfig()
.check_all_same_dtype(false)
.add_output(qtensor)
.add_input(rtensor)
.add_input(qtensor)
.build();
gpu_kernel(
iter,
[=] GPU_LAMBDA(float raw_val, scalar_t quantized_val) -> scalar_t {
int64_t qvalue =
static_cast<int64_t>(std::nearbyint(raw_val / scale) + zero_point);
qvalue = std::max<int64_t>(qvalue, qmin);
qvalue = std::min<int64_t>(qvalue, qmax);
quantized_val.val_ = qvalue;
return quantized_val;
});
});
}
void dequantize_tensor_per_tensor_affine_cuda(
const Tensor& qtensor,
Tensor& rtensor,
double scale,
int64_t zero_point) {
AT_DISPATCH_QINT_TYPES(
qtensor.scalar_type(), "dequantize_tensor_per_tensor_affine_cuda", [&]() {
auto iter = TensorIteratorConfig()
.check_all_same_dtype(false)
.add_output(rtensor)
.add_input(qtensor)
.build();
gpu_kernel(iter, [=] GPU_LAMBDA(scalar_t value) -> float {
return (static_cast<float>(value.val_) - zero_point) * scale;
});
});
}
void quantize_tensor_per_channel_affine_cuda(
const Tensor& rtensor,
Tensor& qtensor,
const Tensor& scales,
const Tensor& zero_points,
int64_t axis) {
static constexpr auto fn_name = "quantize_tensor_per_channel_affine_cuda";
std::vector<int64_t> expected_shape(rtensor.dim(), 1);
expected_shape[axis] = rtensor.size(axis);
auto shaped_scales = native::_unsafe_view(scales, expected_shape);
auto shaped_zero_points = native::_unsafe_view(zero_points, expected_shape);
auto iter = TensorIteratorConfig()
.check_all_same_dtype(false)
.add_output(qtensor)
.add_input(rtensor)
.add_input(qtensor)
.add_input(shaped_scales)
.add_input(shaped_zero_points)
.build();
AT_DISPATCH_QINT_TYPES(
qtensor.scalar_type(), fn_name, [&]() {
check_zero_points_cuda<underlying_t>(fn_name, zero_points);
constexpr int64_t qmin = std::numeric_limits<underlying_t>::min();
constexpr int64_t qmax = std::numeric_limits<underlying_t>::max();
// trying to match _quantize_per_channel_ref_nd in test_quantized_tensor.py
gpu_kernel(
iter,
[=] GPU_LAMBDA(float raw_val, scalar_t quantized_val, double scale, int64_t zero_point) -> scalar_t {
int64_t qvalue =
static_cast<int64_t>(std::nearbyint(raw_val/scale) + zero_point);
qvalue = std::max<int64_t>(qvalue, qmin);
qvalue = std::min<int64_t>(qvalue, qmax);
quantized_val.val_ = qvalue;
return quantized_val;
});
});
}
void dequantize_tensor_per_channel_affine_cuda(
const Tensor& qtensor,
Tensor& rtensor,
const Tensor& scales,
const Tensor& zero_points,
int64_t axis) {
static constexpr auto fn_name = "dequantize_tensor_per_channel_affine_cuda";
std::vector<int64_t> expected_shape(rtensor.dim(), 1);
expected_shape[axis] = rtensor.size(axis);
auto shaped_scales = native::_unsafe_view(scales, expected_shape);
auto shaped_zero_points = native::_unsafe_view(zero_points, expected_shape);
AT_DISPATCH_QINT_TYPES(
qtensor.scalar_type(),
fn_name,
[&]() {
check_zero_points_cuda<underlying_t>(fn_name, zero_points);
auto iter = TensorIteratorConfig()
.check_all_same_dtype(false)
.add_output(rtensor)
.add_input(qtensor)
.add_input(shaped_scales)
.add_input(shaped_zero_points)
.build();
gpu_kernel(
iter,
[=] GPU_LAMBDA(
scalar_t value, double scale, int64_t zero_point) -> float {
return static_cast<float>(value.val_ - zero_point) * scale;
});
});
}
void quantize_tensor_per_channel_float_qparams_cuda(
const Tensor& rtensor,
Tensor& qtensor,
const Tensor& scales,
const Tensor& zero_points,
int64_t axis) {
static constexpr auto fn_name = "quantize_tensor_per_channel_float_qparams_cuda";
std::vector<int64_t> expected_shape(rtensor.dim(), 1);
expected_shape[axis] = rtensor.size(axis);
auto shaped_scales = native::_unsafe_view(scales, expected_shape);
auto shaped_zero_points = native::_unsafe_view(zero_points, expected_shape);
auto iter = TensorIteratorConfig()
.check_all_same_dtype(false)
.add_output(qtensor)
.add_input(rtensor)
.add_input(qtensor)
.add_input(shaped_scales)
.add_input(shaped_zero_points)
.build();
AT_DISPATCH_QINT_TYPES(
qtensor.scalar_type(),
fn_name,
[&]() {
check_zero_points_cuda<underlying_t>(fn_name, zero_points);
constexpr int64_t qmin = std::numeric_limits<underlying_t>::min();
constexpr int64_t qmax = std::numeric_limits<underlying_t>::max();
// trying to match _quantize_per_channel_ref_nd in
gpu_kernel(
iter,
[=] GPU_LAMBDA(
float raw_val,
scalar_t quantized_val,
float scale,
float zero_point) -> scalar_t {
float inv_scale = 1.0f / scale;
int64_t qvalue = lrintf(raw_val * inv_scale + zero_point);
qvalue = std::max<int64_t>(qvalue, qmin);
qvalue = std::min<int64_t>(qvalue, qmax);
quantized_val.val_ = qvalue;
return quantized_val;
});
});
}
void dequantize_tensor_per_channel_float_qparams_cuda(
const Tensor& qtensor,
Tensor& rtensor,
const Tensor& scales,
const Tensor& zero_points,
int64_t axis) {
static constexpr auto fn_name = "dequantize_tensor_per_channel_float_qparams_cuda";
std::vector<int64_t> expected_shape(rtensor.dim(), 1);
expected_shape[axis] = rtensor.size(axis);
auto shaped_scales = native::_unsafe_view(scales, expected_shape);
auto shaped_zero_points = native::_unsafe_view(zero_points, expected_shape);
AT_DISPATCH_QINT_TYPES(
qtensor.scalar_type(),
fn_name,
[&]() {
check_zero_points_cuda<underlying_t>(fn_name, zero_points);
auto iter = TensorIteratorConfig()
.check_all_same_dtype(false)
.add_output(rtensor)
.add_input(qtensor)
.add_input(shaped_scales)
.add_input(shaped_zero_points)
.build();
gpu_kernel(
iter,
[=] GPU_LAMBDA(
scalar_t value, float scale, float zero_point) -> float {
return (static_cast<float>(value.val_) - zero_point) * scale;
});
});
}
} // anonymous namespace
REGISTER_DISPATCH(
quantize_tensor_per_tensor_affine_stub,
&quantize_tensor_per_tensor_affine_cuda);
REGISTER_DISPATCH(
dequantize_tensor_per_tensor_affine_stub,
&dequantize_tensor_per_tensor_affine_cuda);
REGISTER_DISPATCH(
quantize_tensor_per_channel_affine_stub,
&quantize_tensor_per_channel_affine_cuda);
REGISTER_DISPATCH(
dequantize_tensor_per_channel_affine_stub,
&dequantize_tensor_per_channel_affine_cuda);
REGISTER_DISPATCH(
quantize_tensor_per_channel_float_qparams_stub,
&quantize_tensor_per_channel_float_qparams_cuda);
REGISTER_DISPATCH(
dequantize_tensor_per_channel_float_qparams_stub,
&dequantize_tensor_per_channel_float_qparams_cuda);
} // namespace native
} // namespace at