Skip to content

Commit

Permalink
re-enable int8 for api change (#579)
Browse files Browse the repository at this point in the history
zhuhaozhe authored May 20, 2022
1 parent 0b286b0 commit 0bded92
Showing 4 changed files with 11,187 additions and 5,308 deletions.
38 changes: 35 additions & 3 deletions models/recommendation/pytorch/dlrm/product/dlrm_s_pytorch.py
Original file line number Diff line number Diff line change
@@ -97,6 +97,10 @@
# intel
import intel_extension_for_pytorch as ipex
from torch.utils import ThroughputBenchmark

# int8
from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, QConfig
from intel_extension_for_pytorch.quantization import prepare, convert
# For distributed run
import extend_distributed as ext_dist

@@ -403,12 +407,18 @@ def trace_model(args, dlrm, test_ld):
dlrm.emb_l.bfloat16()
dlrm = ipex.optimize(dlrm, dtype=torch.bfloat16, inplace=True)
elif args.int8:
conf = ipex.quantization.QuantConf(args.int8_configure)
dlrm = ipex.quantization.convert(dlrm, conf, (X, lS_o, lS_i))
qconfig = QConfig(activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_symmetric, dtype=torch.qint8),
weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric))
prepared_dlrm = prepare(dlrm, qconfig, example_inputs=(X, lS_o, lS_i), inplace=True)
prepared_dlrm.load_qconf_summary(qconf_summary = args.int8_configure)
dlrm = convert(prepared_dlrm)
else:
dlrm = ipex.optimize(dlrm, dtype=torch.float, inplace=True)
if args.int8:
dlrm = freeze(dlrm)
dlrm = torch.jit.trace(dlrm, [X, lS_o, lS_i])
dlrm = torch.jit.freeze(dlrm)
dlrm(X, lS_o, lS_i)
dlrm(X, lS_o, lS_i)
else:
with torch.cpu.amp.autocast(enabled=args.bf16):
dlrm = torch.jit.trace(dlrm, (X, lS_o, lS_i), check_trace=True)
@@ -666,6 +676,7 @@ def run():
parser.add_argument("--ipex-merged-emb", action="store_true", default=False)
parser.add_argument("--num-warmup-iters", type=int, default=1000)
parser.add_argument("--int8", action="store_true", default=False)
parser.add_argument("--calibration", action="store_true", default=False)
parser.add_argument("--int8-configure", type=str, default="./int8_configure.json")
parser.add_argument("--dist-backend", type=str, default="ccl")

@@ -743,6 +754,7 @@ def run():
sigmoid_top=ln_top.size - 2,
loss_threshold=args.loss_threshold,
)

if args.ipex_merged_emb:
dlrm.emb_l = ipex.nn.modules.MergedEmbeddingBagWithSGD.from_embeddingbag_list(dlrm.emb_l, lr=args.learning_rate)
dlrm.need_linearize_indices_and_offsets = torch.BoolTensor([False])
@@ -805,6 +817,26 @@ def run():
print("Testing state: accuracy = {:3.3f} %".format(ld_acc_test * 100))

ext_dist.barrier()

if args.calibration:
assert args.load_model != "", "need load weight to do calibration"
dlrm.eval()
qconfig = QConfig(activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_symmetric, dtype=torch.qint8),
weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric))
for j, inputBatch in enumerate(train_ld):
X, lS_o, lS_i, T, W, CBPP = unpack_batch(inputBatch)
example_inputs = (X, lS_o, lS_i)
prepared_dlrm = prepare(dlrm, qconfig, example_inputs=example_inputs, inplace=True)
break

for j, inputBatch in enumerate(train_ld):
prepared_dlrm(X, lS_o, lS_i)
if j == 2:
break
prepared_dlrm.save_qconf_summary(qconf_summary = args.int8_configure)
print("calibration done, save config file to ", args.int8_configure)
exit()

print("time/loss/accuracy (if enabled):")

if args.bf16 and not args.inference_only:
Loading

0 comments on commit 0bded92

Please sign in to comment.