-
Notifications
You must be signed in to change notification settings - Fork 280
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add default compiler options for v6e #887
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,12 +4,12 @@ | |
# This module must not depend on any jax/axlearn modules so that | ||
# importing this module does not result in initializing jax. | ||
import re | ||
from typing import Any, Union | ||
from typing import Any, Dict, Union | ||
|
||
|
||
def default_xla_options( | ||
*, instance_type: str, num_slices: int, backend: str | ||
) -> dict[str, Union[str, bool]]: | ||
) -> dict[str, Union[str, bool, int]]: | ||
"""Return the default flags for the given instance type and backend. | ||
|
||
These options can be passed to `jitted_fn.lower(...).compile(compiler_options=...)` | ||
|
@@ -31,7 +31,7 @@ def default_xla_options( | |
if backend != "tpu": | ||
raise NotImplementedError(backend) | ||
version = infer_tpu_version(infer_tpu_type(instance_type)) | ||
options = dict( | ||
options: Dict[str, Union[int, str, bool]] = dict( | ||
xla_tpu_spmd_rng_bit_generator_unsafe=True, # SPMD partition-aware RngBitGenerator. | ||
xla_tpu_enable_latency_hiding_scheduler="true", # Try to schedule ops efficiently. | ||
xla_tpu_perform_spmd_cse_prevention="false", | ||
|
@@ -44,6 +44,68 @@ def default_xla_options( | |
xla_enable_async_all_gather="true", # Allow async all-gather. | ||
xla_enable_async_collective_permute="true", # Allow async collective permute. | ||
) | ||
if version == "v6e": | ||
options.update( | ||
# Change to 16GB. The default is 4GB which is too small for larger models. This | ||
# cause the step time to be double. You should increase this | ||
# further if you see "Allocator failed to allocate". A feature | ||
# to dynamically allocate may come later: b/380514965 | ||
megascale_grpc_premap_memory_bytes=17179869184, | ||
# Improved performance for v6e. | ||
xla_tpu_scoped_vmem_limit_kib=98304, | ||
xla_tpu_enable_async_collective_fusion=True, | ||
xla_tpu_enable_async_collective_fusion_fuse_all_gather=True, | ||
xla_tpu_enable_async_collective_fusion_multiple_steps=True, | ||
xla_tpu_overlap_compute_collective_tc=True, | ||
xla_enable_async_all_gather=True, | ||
# Host offloading flags | ||
xla_tpu_enable_all_experimental_scheduler_features=True, | ||
# Flag to enable memory tracking scheduling. The default AUTO only enables | ||
# it in some situations. Not needed if | ||
# xla_tpu_enable_all_experimental_scheduler_features is set to true already. | ||
xla_tpu_enable_scheduler_memory_pressure_tracking=True, | ||
# Flag controlling the maximum number of overlapping host offloadings. | ||
xla_tpu_host_transfer_overlap_limit=24, | ||
# Flag to enable the aggressive removal of opt-barriers. | ||
xla_tpu_aggressive_opt_barrier_removal=True, | ||
# Flag to enable more aggressive scheduling for async ops, such as pushing | ||
# the async start to the beginning of the loop body. | ||
xla_lhs_prioritize_async_depth_over_stall=True, | ||
# Flag to enable pipelining of cross-DCN all-gathers. | ||
xla_tpu_enable_ag_backward_pipelining=True, | ||
xla_should_allow_loop_variant_parameter_in_chain=True, | ||
xla_should_add_loop_invariant_op_in_chain=True, | ||
# Flag controlling the maximum number of overlapping cross-DCN send/recv. | ||
xla_max_concurrent_host_send_recv=100, | ||
# Flag controlling the HBM memory limit as a percentage of the total HBM size. | ||
# Default value is 95. Can tune up or down to give more or less memory for the | ||
# scheduler. The scheduler favors more on less memory usage when it's under | ||
# memory pressure, instead of hiding latency by overlapping more computations | ||
# and communications. | ||
xla_tpu_scheduler_percent_shared_memory_limit=90, | ||
# Flag controlling the number of times the scheduler is run if the scheduled | ||
# peak memory usage exceeds the initial memory limit, by setting memory limit | ||
# to 90% of the previous memory limit each time. Default value is 1. Sometimes | ||
# when the scheduler thinks it goes out memory, it may not actually happen due | ||
# to other factors controlled by other compiler passes, or the initial memory | ||
# limit is already set too low. Cutting the memory limit to 90% of previous one | ||
# though, may make the scheduler weighting too much on the memory usage instead | ||
# of latency side. | ||
xla_latency_hiding_scheduler_rerun=2, | ||
xla_tpu_use_enhanced_launch_barrier=True, | ||
# Sparsecore offloading for all reduce. | ||
# Uncomment below flags to enable it. | ||
# xla_sc_disable_megacore_partitioning=True, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Ethanlm do you recall why you commented all them off? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We weren't sure if this would perform better for everyone. It's a newer feature as well. So Ethan recommended not enabling by default. For both llama 2 70b and 405b, we saw much better performance with it being enabled though. |
||
# xla_tpu_use_tc_device_shape_on_sc=True, | ||
# tpu_use_continuations=True, | ||
# xla_jf_crs_combiner_threshold_count=10, | ||
# xla_sc_enable_instruction_fusion="false", | ||
# xla_sc_disjoint_spmem="false", | ||
# xla_tpu_enable_sparse_core_collective_offload_all_reduce=True, | ||
) | ||
# This flag can be removed after upgrading to Jax 0.4.38. | ||
# Uncomment for sparsecore offloading. | ||
# options["2a886c8_chip_config_name"] = "megachip_tccontrol" | ||
if num_slices > 1: | ||
# Support multiple TPU slices connected over a data center network. | ||
options.update( | ||
|
@@ -59,12 +121,16 @@ def default_xla_options( | |
|
||
# Validate options. Will never fail if this function is implemented correctly. | ||
for k, v in options.items(): | ||
assert v in [True, False, "true", "false"], (k, v) | ||
try: | ||
int(v) | ||
continue | ||
except ValueError: | ||
assert v in [True, False, "true", "false", "megachip_tccontrol"], (k, v) | ||
|
||
return options | ||
|
||
|
||
def xla_flags_from_options(xla_options: dict[str, Union[str, bool]]) -> str: | ||
def xla_flags_from_options(xla_options: dict[str, Union[str, bool, int]]) -> str: | ||
"""Convert an XLA options dict suitable for | ||
`jitted_fn.lower(...).compile(compiler_options=xla_options)` | ||
to XLA flags suitable for the `XLA_FLAGS` environment variable. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs These need to be "true" instead of
True
, otherwise it will be converted to 1 byxla_flags_from_options
and then failThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's what I had originally, but pytype doesn't like that and will fail. So I see 2 options: