-
Notifications
You must be signed in to change notification settings - Fork 37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for SD3 testing along with a refactor of the suite. #266
base: main
Are you sure you want to change the base?
Changes from all commits
bd6c512
1f8fa6c
afbf4b7
ef3087a
b235dcb
10f6db3
71947a7
133ec67
db550fd
e255915
306d2a0
c34665d
2e4c194
0d52c6f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
{ | ||
"config_name": "cpu_llvm_task", | ||
"iree_compile_flags" : [ | ||
"--iree-hal-target-backends=llvm-cpu", | ||
"--iree-llvmcpu-target-cpu-features=host", | ||
"--iree-llvmcpu-fail-on-out-of-bounds-stack-allocation=false", | ||
"--iree-llvmcpu-distribution-size=32", | ||
"--iree-opt-const-eval=false", | ||
"--iree-llvmcpu-enable-ukernels=all", | ||
"--iree-global-opt-enable-quantized-matmul-reassociation" | ||
], | ||
"iree_run_module_flags": [ | ||
"--device=local-task", | ||
"--parameters=model=real_weights.irpa" | ||
], | ||
"skip_compile_tests": [], | ||
"skip_run_tests": [], | ||
"expected_compile_failures": [], | ||
"expected_run_failures": [] | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
{ | ||
"config_name": "cpu_llvm_task", | ||
"iree_compile_flags" : [ | ||
"--iree-hal-target-backends=llvm-cpu", | ||
"--iree-llvmcpu-target-cpu-features=host", | ||
"--iree-llvmcpu-fail-on-out-of-bounds-stack-allocation=false", | ||
"--iree-llvmcpu-distribution-size=32", | ||
"--iree-opt-const-eval=false", | ||
"--iree-llvmcpu-enable-ukernels=all", | ||
"--iree-global-opt-enable-quantized-matmul-reassociation" | ||
], | ||
"iree_run_module_flags": [ | ||
"--device=local-task", | ||
"--parameters=model=real_weights.irpa" | ||
], | ||
"skip_compile_tests": [], | ||
"skip_run_tests": [], | ||
"expected_compile_failures": [ | ||
"sdxl-scheduled-unet-3-tank", | ||
"sd3-mmdit" | ||
], | ||
"expected_run_failures": [] | ||
} |
This file was deleted.
This file was deleted.
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -319,7 +319,12 @@ def __init__(self, spec, **kwargs): | |||||||||||||||||||||||||
|
||||||||||||||||||||||||||
self.run_args = ["iree-run-module", f"--module={vmfb_name}"] | ||||||||||||||||||||||||||
self.run_args.extend(self.spec.iree_run_module_flags) | ||||||||||||||||||||||||||
self.run_args.append(f"--flagfile={self.spec.data_flagfile_name}") | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# expand data flag file, so beter for logging and can use environment variables | ||||||||||||||||||||||||||
flag_file_path = f"{self.test_cwd}/{self.spec.data_flagfile_name}" | ||||||||||||||||||||||||||
file = open(flag_file_path) | ||||||||||||||||||||||||||
for line in file: | ||||||||||||||||||||||||||
self.run_args.append(line.rstrip()) | ||||||||||||||||||||||||||
Comment on lines
+323
to
+327
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Comment style: start with an uppercase character, end with a period. Also fix typo and adjust wording.
Suggested change
Note that flagfiles are a requirement for some commands and environments. For example, certain terminals have character length limits around 512 or so characters for commands and putting flags in files works around that. |
||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
def runtest(self): | ||||||||||||||||||||||||||
# TODO(scotttodd): log files needed by the test (remote files / git LFS) | ||||||||||||||||||||||||||
|
@@ -385,6 +390,8 @@ def test_compile(self): | |||||||||||||||||||||||||
compile_env["IREE_TEST_PATH_EXTENSION"] = os.getenv( | ||||||||||||||||||||||||||
"IREE_TEST_PATH_EXTENSION", default=str(self.test_cwd) | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# expand environment variable for logging | ||||||||||||||||||||||||||
path_extension = compile_env["IREE_TEST_PATH_EXTENSION"] | ||||||||||||||||||||||||||
cmd = subprocess.list2cmdline(self.compile_args) | ||||||||||||||||||||||||||
cmd = cmd.replace("${IREE_TEST_PATH_EXTENSION}", f"{path_extension}") | ||||||||||||||||||||||||||
|
@@ -401,8 +408,15 @@ def test_compile(self): | |||||||||||||||||||||||||
|
||||||||||||||||||||||||||
def test_run(self): | ||||||||||||||||||||||||||
run_env = os.environ.copy() | ||||||||||||||||||||||||||
cmd = subprocess.list2cmdline(self.run_args) | ||||||||||||||||||||||||||
run_env["IREE_TEST_BACKEND"] = os.getenv( | ||||||||||||||||||||||||||
"IREE_TEST_BACKEND", default="none" | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# expand environment variable for logging | ||||||||||||||||||||||||||
backend = run_env["IREE_TEST_BACKEND"] | ||||||||||||||||||||||||||
cmd = subprocess.list2cmdline(self.run_args) | ||||||||||||||||||||||||||
cmd = cmd.replace("${IREE_TEST_BACKEND}", f"{backend}") | ||||||||||||||||||||||||||
Comment on lines
+411
to
+418
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is too sketchy IMO. Flagfiles should be runnable as-is and this is adding an extra indirection that will be too difficult to reproduce outside of a CI environment. Any tests needing this behavior should be using a mechanism other than this conftest.py. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let me elaborate a bit... This iree_tests subproject is home to:
For large test suites following a standardized style (ONNX unit tests, ONNX models, StableHLO models, JAX programs, etc.), For SDXL, SD3, llama, and other models that we're giving special attention, we should be testing both the out of the box import -> compile -> run path that fits that mold and a curated path like https://github.com/iree-org/iree/blob/main/experimental/regression_suite/tests/pregenerated/test_llama2.py or https://github.com/nod-ai/sdxl-scripts .
At the point where a model needs a carve-out in a config.json file, an environment variable, a nonstandard file (spec file), or changes to conftest.py, it is too complex/different and should be given its own separate test.
We can share test code between the standardized path and custom model tests where it makes sense to do so. In particular, the "compile a program" and "run a program" parts could be fixtures (as they are in https://github.com/iree-org/iree/blob/main/experimental/regression_suite/ireers/fixtures.py). We should have a common way for those stages to run, with the same logging format and the same error messages when tests are unexpectedly passing, newly failing, etc. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah this makes sense. I will take a look at an alternate path for the custom models There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can still use the same downloading utilities though to fetch all the sources right (download_remote_files.py). We'll just have a different test_cases.json for the script to parse in the custom path? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can keep the current downloading, or we could follow the IDK, and I can't take a deep context switch right now to think through the design. I'd start with that test_llama2.py and adapt it to the separate repo model (test sources, inputs, outputs in one repo, test configurations in another) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, no worries, I'd rather stick to unified downloading process (making use of the same tools) to a cache in the test-suite and unified compile/runtime through a fixture as you suggested so we are using the same tools across the whole repo where we could. So, I will proceed with that for now |
||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# TODO(scotttodd): expand flagfile(s) | ||||||||||||||||||||||||||
logging.getLogger().info( | ||||||||||||||||||||||||||
f"Launching run command:\n" # | ||||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
--input=1x77x2xi64=@inference_input.0.bin | ||
--input=1x77x2xi64=@inference_input.1.bin | ||
--input=1x77x2xi64=@inference_input.2.bin | ||
--input=1x77x2xi64=@inference_input.3.bin | ||
--input=1x77x2xi64=@inference_input.4.bin | ||
--input=1x77x2xi64=@inference_input.5.bin | ||
--expected_output=2x154x4096xf32=@inference_output.0.bin | ||
--expected_output=2x2048xf32=@inference_output.1.bin | ||
--expected_f32_threshold=0.15f |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
--input="1x77x2xi64" | ||
--input="1x77x2xi64" | ||
--input="1x77x2xi64" | ||
--input="1x77x2xi64" | ||
--input="1x77x2xi64" | ||
--input="1x77x2xi64" | ||
--parameters=splats.irpa |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
{ | ||
"file_format": "test_cases_v0", | ||
"test_cases": [ | ||
{ | ||
"name": "splats", | ||
"runtime_flagfile": "splat_data_flags.txt", | ||
"remote_files": [] | ||
}, | ||
{ | ||
"name": "real_weights", | ||
"runtime_flagfile": "real_weights_data_flags.txt", | ||
"remote_files": [ | ||
"https://sharkpublic.blob.core.windows.net/sharkpublic/sai/sd3-prompt-encoder/inference_input.0.bin", | ||
"https://sharkpublic.blob.core.windows.net/sharkpublic/sai/sd3-prompt-encoder/inference_input.1.bin", | ||
"https://sharkpublic.blob.core.windows.net/sharkpublic/sai/sd3-prompt-encoder/inference_input.2.bin", | ||
"https://sharkpublic.blob.core.windows.net/sharkpublic/sai/sd3-prompt-encoder/inference_input.3.bin", | ||
"https://sharkpublic.blob.core.windows.net/sharkpublic/sai/sd3-prompt-encoder/inference_input.4.bin", | ||
"https://sharkpublic.blob.core.windows.net/sharkpublic/sai/sd3-prompt-encoder/inference_input.5.bin", | ||
"https://sharkpublic.blob.core.windows.net/sharkpublic/sai/sd3-prompt-encoder/inference_output.0.bin", | ||
"https://sharkpublic.blob.core.windows.net/sharkpublic/sai/sd3-prompt-encoder/inference_output.1.bin", | ||
"https://sharkpublic.blob.core.windows.net/sharkpublic/sai/sd3-prompt-encoder/real_weights.irpa" | ||
] | ||
} | ||
] | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
--input=1x64xi64=@inference_input.0.bin | ||
--input=1x64xi64=@inference_input.1.bin | ||
--input=1x64xi64=@inference_input.2.bin | ||
--input=1x64xi64=@inference_input.3.bin | ||
--expected_output=2x64x2048xf16=@inference_output.0.bin | ||
--expected_output=2x1280xf16=@inference_output.1.bin | ||
--expected_f16_threshold=1.0f |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
--parameters=model=real_weights.irpa | ||
--input=2x16x128x128xf16=@inference_input.0.bin | ||
--input=2x154x4096xf16=@inference_input.1.bin | ||
--input=2x2048xf16=@inference_input.2.bin | ||
--input=1xf16=@inference_input.3.bin | ||
--expected_output=2x16x128x128xf32=@inference_output.0.bin | ||
--expected_f16_threshold=1.0f |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
--input="2x16x128x128xf16" | ||
--input="2x154x4096xf16" | ||
--input="2x2048xf16" | ||
--input="1xf16" | ||
--parameters=splats.irpa |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is going to exclude all these special models they should be moved to a different subdirectory.
The
config_*_models
configs were intended to test a set of models that were imported into the test suite in a uniform way and each of these models is special in some way.