Skip to content

Commit ff119d0

Browse files
authoredJan 30, 2024
fix qat tests (#61211) (#61284)
1 parent ac1702b commit ff119d0

File tree

2 files changed

+63
-9
lines changed

2 files changed

+63
-9
lines changed
 

‎test/quantization/test_post_training_quantization_mobilenetv1.py

+62-8
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
import paddle
2727
from paddle.dataset.common import download
28+
from paddle.io import Dataset
2829
from paddle.static.log_helper import get_logger
2930
from paddle.static.quantization import PostTrainingQuantization
3031

@@ -116,6 +117,33 @@ def val(data_dir=DATA_DIR):
116117
return _reader_creator(file_list, 'val', shuffle=False, data_dir=data_dir)
117118

118119

120+
class ImageNetDataset(Dataset):
121+
def __init__(self, data_dir=DATA_DIR, shuffle=False, need_label=False):
122+
super().__init__()
123+
self.need_label = need_label
124+
self.data_dir = data_dir
125+
val_file_list = os.path.join(data_dir, 'val_list.txt')
126+
with open(val_file_list) as flist:
127+
lines = [line.strip() for line in flist]
128+
if shuffle:
129+
np.random.shuffle(lines)
130+
self.data = [line.split() for line in lines]
131+
132+
def __getitem__(self, index):
133+
sample = self.data[index]
134+
data_path = os.path.join(self.data_dir, sample[0])
135+
data, label = process_image(
136+
[data_path, sample[1]], mode='val', color_jitter=False, rotate=False
137+
)
138+
if self.need_label:
139+
return data, np.array([label]).astype('int64')
140+
else:
141+
return data
142+
143+
def __len__(self):
144+
return len(self.data)
145+
146+
119147
class TestPostTrainingQuantization(unittest.TestCase):
120148
def setUp(self):
121149
self.int8_download = 'int8/download'
@@ -267,7 +295,7 @@ def run_program(
267295
throughput = cnt / np.sum(periods)
268296
latency = np.average(periods)
269297
acc1 = np.sum(test_info) / cnt
270-
return (throughput, latency, acc1)
298+
return (throughput, latency, acc1, feed_dict)
271299

272300
def generate_quantized_model(
273301
self,
@@ -284,6 +312,7 @@ def generate_quantized_model(
284312
batch_nums=1,
285313
onnx_format=False,
286314
deploy_backend=None,
315+
feed_name="inputs",
287316
):
288317
try:
289318
os.system("mkdir " + self.int8_model)
@@ -293,11 +322,30 @@ def generate_quantized_model(
293322

294323
place = paddle.CPUPlace()
295324
exe = paddle.static.Executor(place)
296-
val_reader = val()
325+
image = paddle.static.data(
326+
name=feed_name[0], shape=[None, 3, 224, 224], dtype='float32'
327+
)
328+
feed_list = [image]
329+
if len(feed_name) == 2:
330+
label = paddle.static.data(
331+
name='label', shape=[None, 1], dtype='int64'
332+
)
333+
feed_list.append(label)
334+
335+
val_dataset = ImageNetDataset(need_label=len(feed_list) == 2)
336+
data_loader = paddle.io.DataLoader(
337+
val_dataset,
338+
places=place,
339+
feed_list=feed_list,
340+
drop_last=False,
341+
return_list=False,
342+
batch_size=2,
343+
shuffle=False,
344+
)
297345

298346
ptq = PostTrainingQuantization(
299347
executor=exe,
300-
sample_generator=val_reader,
348+
data_loader=data_loader,
301349
model_dir=model_path,
302350
model_filename=model_filename,
303351
params_filename=params_filename,
@@ -348,7 +396,12 @@ def run_test(
348396
model, infer_iterations * batch_size
349397
)
350398
)
351-
(fp32_throughput, fp32_latency, fp32_acc1) = self.run_program(
399+
(
400+
fp32_throughput,
401+
fp32_latency,
402+
fp32_acc1,
403+
feed_name,
404+
) = self.run_program(
352405
model_path,
353406
model_filename,
354407
params_filename,
@@ -370,14 +423,15 @@ def run_test(
370423
batch_nums,
371424
onnx_format,
372425
deploy_backend,
426+
feed_name,
373427
)
374428

375429
_logger.info(
376430
"Start INT8 inference for {} on {} images ...".format(
377431
model, infer_iterations * batch_size
378432
)
379433
)
380-
(int8_throughput, int8_latency, int8_acc1) = self.run_program(
434+
(int8_throughput, int8_latency, int8_acc1, _) = self.run_program(
381435
self.int8_model,
382436
model_filename,
383437
params_filename,
@@ -421,7 +475,7 @@ def test_post_training_kl_mobilenetv1(self):
421475
is_use_cache_file = False
422476
is_optimize_model = True
423477
diff_threshold = 0.025
424-
batch_nums = 1
478+
batch_nums = 2
425479
self.run_test(
426480
model,
427481
'inference.pdmodel',
@@ -607,7 +661,7 @@ def test_post_training_onnx_format_mobilenetv1_tensorrt(self):
607661
is_optimize_model = False
608662
onnx_format = True
609663
diff_threshold = 0.05
610-
batch_nums = 2
664+
batch_nums = 12
611665
deploy_backend = "tensorrt"
612666
self.run_test(
613667
model,
@@ -650,7 +704,7 @@ def test_post_training_onnx_format_mobilenetv1_mkldnn(self):
650704
is_optimize_model = False
651705
onnx_format = True
652706
diff_threshold = 0.05
653-
batch_nums = 1
707+
batch_nums = 12
654708
deploy_backend = "mkldnn"
655709
self.run_test(
656710
model,

‎test/quantization/test_post_training_quantization_resnet50.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def run_program(
113113
throughput = cnt / np.sum(periods)
114114
latency = np.average(periods)
115115
acc1 = np.sum(test_info) / cnt
116-
return (throughput, latency, acc1)
116+
return (throughput, latency, acc1, feed_dict)
117117

118118

119119
class TestPostTrainingForResnet50ONNXFormat(TestPostTrainingForResnet50):

0 commit comments

Comments
 (0)
Please sign in to comment.