forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbuild_quantization_configs.py
64 lines (51 loc) · 1.88 KB
/
build_quantization_configs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
"""
This script will generate default values of quantization configs.
These are for use in the documentation.
"""
import os.path
import torch
from torch.ao.quantization.backend_config import get_native_backend_config_dict
from torch.ao.quantization.backend_config.utils import (
entry_to_pretty_str,
remove_boolean_dispatch_from_name,
)
# Create a directory for the images, if it doesn't exist
QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH = os.path.join(
os.path.realpath(os.path.join(__file__, "..")), "quantization_backend_configs"
)
if not os.path.exists(QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH):
os.mkdir(QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH)
output_path = os.path.join(
QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH, "default_backend_config.txt"
)
with open(output_path, "w") as f:
native_backend_config_dict = get_native_backend_config_dict()
configs = native_backend_config_dict["configs"]
def _sort_key_func(entry):
pattern = entry["pattern"]
while isinstance(pattern, tuple):
pattern = pattern[-1]
pattern = remove_boolean_dispatch_from_name(pattern)
if not isinstance(pattern, str):
# methods are already strings
pattern = torch.typename(pattern)
# we want
#
# torch.nn.modules.pooling.AdaptiveAvgPool1d
#
# and
#
# torch._VariableFunctionsClass.adaptive_avg_pool1d
#
# to be next to each other, so convert to all lower case
# and remove the underscores, and compare the last part
# of the string
pattern_str_normalized = pattern.lower().replace("_", "")
key = pattern_str_normalized.split(".")[-1]
return key
configs.sort(key=_sort_key_func)
entries = []
for entry in configs:
entries.append(entry_to_pretty_str(entry))
entries = ",\n".join(entries)
f.write(entries)