Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add default compiler options for v6e #887

Merged
merged 4 commits into from
Dec 12, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 71 additions & 5 deletions axlearn/common/compiler_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
# This module must not depend on any jax/axlearn modules so that
# importing this module does not result in initializing jax.
import re
from typing import Any, Union
from typing import Any, Dict, Union


def default_xla_options(
*, instance_type: str, num_slices: int, backend: str
) -> dict[str, Union[str, bool]]:
) -> dict[str, Union[str, bool, int]]:
"""Return the default flags for the given instance type and backend.

These options can be passed to `jitted_fn.lower(...).compile(compiler_options=...)`
Expand All @@ -31,7 +31,7 @@ def default_xla_options(
if backend != "tpu":
raise NotImplementedError(backend)
version = infer_tpu_version(infer_tpu_type(instance_type))
options = dict(
options: Dict[str, Union[int, str, bool]] = dict(
xla_tpu_spmd_rng_bit_generator_unsafe=True, # SPMD partition-aware RngBitGenerator.
xla_tpu_enable_latency_hiding_scheduler="true", # Try to schedule ops efficiently.
xla_tpu_perform_spmd_cse_prevention="false",
Expand All @@ -44,6 +44,68 @@ def default_xla_options(
xla_enable_async_all_gather="true", # Allow async all-gather.
xla_enable_async_collective_permute="true", # Allow async collective permute.
)
if version == "v6e":
options.update(
# Change to 16GB. The default is 4GB which is too small for larger models. This
# cause the step time to be double. You should increase this
# further if you see "Allocator failed to allocate". A feature
# to dynamically allocate may come later: b/380514965
megascale_grpc_premap_memory_bytes=17179869184,
# Improved performance for v6e.
xla_tpu_scoped_vmem_limit_kib=98304,
xla_tpu_enable_async_collective_fusion=True,
xla_tpu_enable_async_collective_fusion_fuse_all_gather=True,
Copy link
Contributor

@Ethanlm Ethanlm Dec 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs These need to be "true" instead of True, otherwise it will be converted to 1 by xla_flags_from_options and then fail

I1212 21:41:11.645763 132689403348992 launch.py:112] LIBTPU_INIT_ARGS='--xla_tpu_spmd_rng_bit_generator_unsafe=1 --xla_tpu_enable_latency_hiding_scheduler=true --xla_tpu_perform_spmd_cse_prevention=false --megascale_grpc_premap_memory_bytes=17179869184 --xla_tpu_scoped_vmem_limit_kib=98304 --xla_tpu_enable_async_collective_fusion=1 --xla_tpu_enable_async_collective_fusion_fuse_all_gather=1 --xla_tpu_enable_async_collective_fusion_multiple_steps=1 --xla_tpu_overlap_compute_collective_tc=1 --xla_enable_async_all_gather=1 --xla_tpu_enable_all_experimental_scheduler_features=1 --xla_tpu_enable_scheduler_memory_pressure_tracking=1 --xla_tpu_host_transfer_overlap_limit=24 --xla_tpu_aggressive_opt_barrier_removal=1 --xla_lhs_prioritize_async_depth_over_stall=1 --xla_tpu_enable_ag_backward_pipelining=1 --xla_should_allow_loop_variant_parameter_in_chain=1 --xla_should_add_loop_invariant_op_in_chain=1 --xla_max_concurrent_host_send_recv=100 --xla_tpu_scheduler_percent_shared_memory_limit=90 --xla_latency_hiding_scheduler_rerun=2 --xla_tpu_use_enhanced_launch_barrier=1'
2024-12-12 21:41:14.669321: I external/tsl/tsl/platform/default/grpc_credentials.cc:30] gRPC insecure client credentials are used.
I1212 21:41:14.670917 132689403348992 distributed.py:119] Connecting to JAX distributed service on ethanli-fuji-70b-v2-test1-job-0-0.ethanli-fuji-70b-v2-test1:8476
2024-12-12 21:41:14.999583: I external/xla/xla/pjrt/distributed/client.cc:135] Connected to distributed JAX controller
ERROR: Illegal value '1' specified for flag 'xla_tpu_enable_async_collective_fusion_fuse_all_gather'; expected one of true/enabled, false/disabled or auto
ERROR: Illegal value '1' specified for flag 'xla_enable_async_all_gather'; expected one of true/enabled, false/disabled or auto
ERROR: Illegal value '1' specified for flag 'xla_tpu_enable_scheduler_memory_pressure_tracking'; expected one of true/enabled, false/disabled or auto
ERROR: Illegal value '1' specified for flag 'xla_tpu_aggressive_opt_barrier_removal'; expected one of true/enabled, false/disabled or auto
ERROR: Illegal value '1' specified for flag 'xla_lhs_prioritize_async_depth_over_stall'; expected one of true/enabled, false/disabled or auto
ERROR: Illegal value '1' specified for flag 'xla_should_allow_loop_variant_parameter_in_chain'; expected one of true/enabled, false/disabled or auto
ERROR: Illegal value '1' specified for flag 'xla_should_add_loop_invariant_op_in_chain'; expected one of true/enabled, false/disabled or auto

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's what I had originally, but pytype doesn't like that and will fail. So I see 2 options:

  1. Change behavior of axlearn that converts True boolean to 1 and let it return "true" instead.
  2. Ignore pytype check

xla_tpu_enable_async_collective_fusion_multiple_steps=True,
xla_tpu_overlap_compute_collective_tc=True,
xla_enable_async_all_gather=True,
# Host offloading flags
xla_tpu_enable_all_experimental_scheduler_features=True,
# Flag to enable memory tracking scheduling. The default AUTO only enables
# it in some situations. Not needed if
# xla_tpu_enable_all_experimental_scheduler_features is set to true already.
xla_tpu_enable_scheduler_memory_pressure_tracking=True,
# Flag controlling the maximum number of overlapping host offloadings.
xla_tpu_host_transfer_overlap_limit=24,
# Flag to enable the aggressive removal of opt-barriers.
xla_tpu_aggressive_opt_barrier_removal=True,
# Flag to enable more aggressive scheduling for async ops, such as pushing
# the async start to the beginning of the loop body.
xla_lhs_prioritize_async_depth_over_stall=True,
# Flag to enable pipelining of cross-DCN all-gathers.
xla_tpu_enable_ag_backward_pipelining=True,
xla_should_allow_loop_variant_parameter_in_chain=True,
xla_should_add_loop_invariant_op_in_chain=True,
# Flag controlling the maximum number of overlapping cross-DCN send/recv.
xla_max_concurrent_host_send_recv=100,
# Flag controlling the HBM memory limit as a percentage of the total HBM size.
# Default value is 95. Can tune up or down to give more or less memory for the
# scheduler. The scheduler favors more on less memory usage when it's under
# memory pressure, instead of hiding latency by overlapping more computations
# and communications.
xla_tpu_scheduler_percent_shared_memory_limit=90,
# Flag controlling the number of times the scheduler is run if the scheduled
# peak memory usage exceeds the initial memory limit, by setting memory limit
# to 90% of the previous memory limit each time. Default value is 1. Sometimes
# when the scheduler thinks it goes out memory, it may not actually happen due
# to other factors controlled by other compiler passes, or the initial memory
# limit is already set too low. Cutting the memory limit to 90% of previous one
# though, may make the scheduler weighting too much on the memory usage instead
# of latency side.
xla_latency_hiding_scheduler_rerun=2,
xla_tpu_use_enhanced_launch_barrier=True,
# Sparsecore offloading for all reduce.
# Uncomment below flags to enable it.
# xla_sc_disable_megacore_partitioning=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Ethanlm do you recall why you commented all them off?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We weren't sure if this would perform better for everyone. It's a newer feature as well. So Ethan recommended not enabling by default. For both llama 2 70b and 405b, we saw much better performance with it being enabled though.

# xla_tpu_use_tc_device_shape_on_sc=True,
# tpu_use_continuations=True,
# xla_jf_crs_combiner_threshold_count=10,
# xla_sc_enable_instruction_fusion="false",
# xla_sc_disjoint_spmem="false",
# xla_tpu_enable_sparse_core_collective_offload_all_reduce=True,
)
# This flag can be removed after upgrading to Jax 0.4.38.
# Uncomment for sparsecore offloading.
# options["2a886c8_chip_config_name"] = "megachip_tccontrol"
if num_slices > 1:
# Support multiple TPU slices connected over a data center network.
options.update(
Expand All @@ -59,12 +121,16 @@ def default_xla_options(

# Validate options. Will never fail if this function is implemented correctly.
for k, v in options.items():
assert v in [True, False, "true", "false"], (k, v)
try:
int(v)
continue
except ValueError:
assert v in [True, False, "true", "false", "megachip_tccontrol"], (k, v)

return options


def xla_flags_from_options(xla_options: dict[str, Union[str, bool]]) -> str:
def xla_flags_from_options(xla_options: dict[str, Union[str, bool, int]]) -> str:
"""Convert an XLA options dict suitable for
`jitted_fn.lower(...).compile(compiler_options=xla_options)`
to XLA flags suitable for the `XLA_FLAGS` environment variable.
Expand Down
Loading