Skip to content

Commit 3e2c6dd

Browse files
update jax to 0.4.37 (#948)
update BlockSpec usage in tpu_attention use TYPE_CHECKING for BuildDatasetFn in input_fake add todo for BuildDatasetFn
1 parent b125f00 commit 3e2c6dd

File tree

7 files changed

+66
-56
lines changed

7 files changed

+66
-56
lines changed

CHANGELOG.md

+5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# Change Log
22

3+
## 0.1.5
4+
5+
* Changes
6+
* Upgrade Jax from 0.4.33 to 0.4.37.
7+
38
## 0.1.4
49

510
* Changes

axlearn/common/flash_attention/tpu_attention.py

+25-25
Original file line numberDiff line numberDiff line change
@@ -690,7 +690,7 @@ def lm_index_map(batch_index, head_index, q_seq_index, _):
690690
)
691691
out_shape = jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype)
692692
out_shape = [out_shape]
693-
out_specs = [pl.BlockSpec(o_index_map, (block_b, 1, block_q, head_dim))]
693+
out_specs = [pl.BlockSpec((block_b, 1, block_q, head_dim), o_index_map)]
694694

695695
if block_k != kv_seq_len:
696696
m_scratch = pltpu.VMEM((block_b, 1, block_q, MIN_BLOCK_SIZE), jnp.float32)
@@ -703,8 +703,8 @@ def lm_index_map(batch_index, head_index, q_seq_index, _):
703703
if save_residuals:
704704
out_specs = [
705705
*out_specs,
706-
pl.BlockSpec(lm_index_map, (block_b, 1, block_q, MIN_BLOCK_SIZE)),
707-
pl.BlockSpec(lm_index_map, (block_b, 1, block_q, MIN_BLOCK_SIZE)),
706+
pl.BlockSpec((block_b, 1, block_q, MIN_BLOCK_SIZE), lm_index_map),
707+
pl.BlockSpec((block_b, 1, block_q, MIN_BLOCK_SIZE), lm_index_map),
708708
]
709709
l = jax.ShapeDtypeStruct(
710710
(batch_size, num_heads, q_seq_len, MIN_BLOCK_SIZE), dtype=jnp.float32
@@ -718,7 +718,7 @@ def lm_index_map(batch_index, head_index, q_seq_index, _):
718718
out_shape = (*out_shape, None, None)
719719

720720
ab_block_spec = (
721-
pl.BlockSpec(ab_index_map, (block_b, 1, block_q, block_k_major)) if ab is not None else None
721+
pl.BlockSpec((block_b, 1, block_q, block_k_major), ab_index_map) if ab is not None else None
722722
)
723723

724724
q_segment_ids_spec = kv_segment_ids_spec = None
@@ -741,9 +741,9 @@ def kv_segment_ids_index_map(batch_index, head_index, q_seq_index, kv_seq_index)
741741
next_kv_index = kv_seq_index
742742
return (batch_index, 0, next_kv_index)
743743

744-
q_segment_ids_spec = pl.BlockSpec(q_segment_ids_index_map, (block_b, block_q, NUM_LANES))
744+
q_segment_ids_spec = pl.BlockSpec((block_b, block_q, NUM_LANES), q_segment_ids_index_map)
745745
kv_segment_ids_spec = pl.BlockSpec(
746-
kv_segment_ids_index_map, (block_b, NUM_SUBLANES, block_k_major)
746+
(block_b, NUM_SUBLANES, block_k_major), kv_segment_ids_index_map
747747
)
748748

749749
q_segment_ids = jax.lax.broadcast_in_dim(
@@ -764,9 +764,9 @@ def kv_segment_ids_index_map(batch_index, head_index, q_seq_index, kv_seq_index)
764764
)
765765

766766
in_specs = [
767-
pl.BlockSpec(q_index_map, (block_b, 1, block_q, head_dim)),
768-
pl.BlockSpec(kv_index_map, (block_b, 1, block_k_major, head_dim)),
769-
pl.BlockSpec(kv_index_map, (block_b, 1, block_k_major, head_dim)),
767+
pl.BlockSpec((block_b, 1, block_q, head_dim), q_index_map),
768+
pl.BlockSpec((block_b, 1, block_k_major, head_dim), kv_index_map),
769+
pl.BlockSpec((block_b, 1, block_k_major, head_dim), kv_index_map),
770770
ab_block_spec,
771771
q_segment_ids_spec,
772772
kv_segment_ids_spec,
@@ -861,7 +861,7 @@ def qo_index_map(batch_index, head_index, kv_seq_index, q_seq_index):
861861

862862
return (batch_index, head_index, next_q_index, 0)
863863

864-
qo_spec = pl.BlockSpec(qo_index_map, (1, 1, block_q_major, head_dim))
864+
qo_spec = pl.BlockSpec((1, 1, block_q_major, head_dim), qo_index_map)
865865
assert qo_spec.block_shape is not None
866866
assert q.ndim == len(qo_spec.block_shape)
867867
do_spec = qo_spec
@@ -870,20 +870,20 @@ def qo_index_map(batch_index, head_index, kv_seq_index, q_seq_index):
870870
def kv_index_map(batch_index, head_index, kv_seq_index, _):
871871
return (batch_index, head_index, kv_seq_index, 0)
872872

873-
kv_spec = pl.BlockSpec(kv_index_map, (1, 1, block_k_major, head_dim))
873+
kv_spec = pl.BlockSpec((1, 1, block_k_major, head_dim), kv_index_map)
874874
assert kv_spec.block_shape is not None
875875
assert k.ndim == len(kv_spec.block_shape)
876876
assert v.ndim == len(kv_spec.block_shape)
877877

878878
def lm_index_map(batch_index, head_index, _, q_seq_index):
879879
return (batch_index, head_index, q_seq_index, 0)
880880

881-
lm_spec = pl.BlockSpec(lm_index_map, (1, 1, block_q_major, MIN_BLOCK_SIZE))
881+
lm_spec = pl.BlockSpec((1, 1, block_q_major, MIN_BLOCK_SIZE), lm_index_map)
882882
assert lm_spec.block_shape is not None
883883
assert l.ndim == len(lm_spec.block_shape)
884884
assert m.ndim == len(lm_spec.block_shape)
885885

886-
di_spec = pl.BlockSpec(qo_index_map, (1, 1, block_q_major, MIN_BLOCK_SIZE))
886+
di_spec = pl.BlockSpec((1, 1, block_q_major, MIN_BLOCK_SIZE), qo_index_map)
887887
assert di_spec.block_shape is not None
888888
assert di.ndim == len(di_spec.block_shape)
889889

@@ -896,7 +896,7 @@ def ab_index_map(batch_index, head_index, kv_seq_index, q_seq_index):
896896
)
897897

898898
dab_spec = (
899-
pl.BlockSpec(ab_index_map, (1, 1, block_q_major, block_k_major)) if ab is not None else None
899+
pl.BlockSpec((1, 1, block_q_major, block_k_major), ab_index_map) if ab is not None else None
900900
)
901901

902902
q_segment_ids_spec = kv_segment_ids_spec = None
@@ -919,9 +919,9 @@ def kv_segment_ids_index_map(batch_index, head_index, kv_seq_index, _):
919919
del head_index
920920
return (batch_index, 0, kv_seq_index)
921921

922-
q_segment_ids_spec = pl.BlockSpec(q_segment_ids_index_map, (1, block_q_major, NUM_LANES))
922+
q_segment_ids_spec = pl.BlockSpec((1, block_q_major, NUM_LANES), q_segment_ids_index_map)
923923
kv_segment_ids_spec = pl.BlockSpec(
924-
kv_segment_ids_index_map, (1, NUM_SUBLANES, block_k_major)
924+
(1, NUM_SUBLANES, block_k_major), kv_segment_ids_index_map
925925
)
926926

927927
q_segment_ids = jax.lax.broadcast_in_dim(
@@ -962,7 +962,7 @@ def kv_segment_ids_index_map(batch_index, head_index, kv_seq_index, _):
962962
def dkv_index_map(batch_index, head_index, kv_seq_index, _):
963963
return (batch_index, head_index, kv_seq_index, 0)
964964

965-
dkv_spec = pl.BlockSpec(dkv_index_map, (1, 1, block_k_major, head_dim))
965+
dkv_spec = pl.BlockSpec((1, 1, block_k_major, head_dim), dkv_index_map)
966966
out_specs = [dkv_spec, dkv_spec]
967967
scratch_shapes = [
968968
pltpu.VMEM((block_k_major, head_dim), jnp.float32), # type: ignore
@@ -1050,7 +1050,7 @@ def _flash_attention_bwd_dq(
10501050
def qo_index_map(batch_index, head_index, q_seq_index, _):
10511051
return (batch_index, head_index, q_seq_index, 0)
10521052

1053-
qo_spec = pl.BlockSpec(qo_index_map, (1, 1, block_q_major, head_dim))
1053+
qo_spec = pl.BlockSpec((1, 1, block_q_major, head_dim), qo_index_map)
10541054
do_spec = qo_spec
10551055

10561056
def kv_index_map(batch_index, head_index, q_seq_index, kv_seq_index):
@@ -1066,20 +1066,20 @@ def kv_index_map(batch_index, head_index, q_seq_index, kv_seq_index):
10661066
next_kv_index = kv_seq_index
10671067
return (batch_index, head_index, next_kv_index, 0)
10681068

1069-
kv_spec = pl.BlockSpec(kv_index_map, (1, 1, block_k_major, head_dim))
1069+
kv_spec = pl.BlockSpec((1, 1, block_k_major, head_dim), kv_index_map)
10701070
assert kv_spec.block_shape is not None
10711071
assert k.ndim == len(kv_spec.block_shape)
10721072
assert v.ndim == len(kv_spec.block_shape)
10731073

10741074
def lm_index_map(batch_index, head_index, q_seq_index, _):
10751075
return (batch_index, head_index, q_seq_index, 0)
10761076

1077-
lm_spec = pl.BlockSpec(lm_index_map, (1, 1, block_q_major, MIN_BLOCK_SIZE))
1077+
lm_spec = pl.BlockSpec((1, 1, block_q_major, MIN_BLOCK_SIZE), lm_index_map)
10781078
assert lm_spec.block_shape is not None
10791079
assert l.ndim == len(lm_spec.block_shape)
10801080
assert m.ndim == len(lm_spec.block_shape)
10811081

1082-
di_spec = pl.BlockSpec(qo_index_map, (1, 1, block_q_major, MIN_BLOCK_SIZE))
1082+
di_spec = pl.BlockSpec((1, 1, block_q_major, MIN_BLOCK_SIZE), qo_index_map)
10831083
assert di_spec.block_shape is not None
10841084
assert di.ndim == len(di_spec.block_shape)
10851085

@@ -1092,7 +1092,7 @@ def ab_index_map(batch_index, head_index, q_seq_index, kv_seq_index):
10921092
)
10931093

10941094
dab_spec = (
1095-
pl.BlockSpec(ab_index_map, (1, 1, block_q_major, block_k_major)) if ab is not None else None
1095+
pl.BlockSpec((1, 1, block_q_major, block_k_major), ab_index_map) if ab is not None else None
10961096
)
10971097

10981098
q_segment_ids_spec = kv_segment_ids_spec = None
@@ -1117,9 +1117,9 @@ def kv_segment_ids_index_map(batch_index, head_index, q_seq_index, kv_seq_index)
11171117
next_kv_index = kv_seq_index
11181118
return (batch_index, 0, next_kv_index)
11191119

1120-
q_segment_ids_spec = pl.BlockSpec(q_segment_ids_index_map, (1, block_q_major, NUM_LANES))
1120+
q_segment_ids_spec = pl.BlockSpec((1, block_q_major, NUM_LANES), q_segment_ids_index_map)
11211121
kv_segment_ids_spec = pl.BlockSpec(
1122-
kv_segment_ids_index_map, (1, NUM_SUBLANES, block_k_major)
1122+
(1, NUM_SUBLANES, block_k_major), kv_segment_ids_index_map
11231123
)
11241124

11251125
q_segment_ids = jax.lax.broadcast_in_dim(
@@ -1156,7 +1156,7 @@ def kv_segment_ids_index_map(batch_index, head_index, q_seq_index, kv_seq_index)
11561156
jax.ShapeDtypeStruct(q.shape, q.dtype),
11571157
jax.ShapeDtypeStruct(ab.shape, ab.dtype) if ab is not None else None,
11581158
]
1159-
dq_spec = pl.BlockSpec(qo_index_map, (1, 1, block_q_major, head_dim))
1159+
dq_spec = pl.BlockSpec((1, 1, block_q_major, head_dim), qo_index_map)
11601160
out_specs = [
11611161
dq_spec,
11621162
dab_spec,

axlearn/common/input_fake.py

+13-10
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,20 @@
44

55
import json
66
from collections.abc import Iterable, Sequence
7-
from typing import Any, Optional, Union
7+
from typing import TYPE_CHECKING, Any, Optional, Union
88

99
import jax
1010
import numpy as np
1111
import tensorflow as tf
1212

1313
from axlearn.common.config import REQUIRED, Required, config_class
14-
from axlearn.common.input_tf_data import BuildDatasetFn
1514
from axlearn.common.module import Module
1615
from axlearn.common.utils import Nested, Tensor, as_numpy_array, as_tensor
1716

17+
if TYPE_CHECKING:
18+
# TODO(markblee): replace with generic "dataset" definition
19+
from axlearn.common.input_tf_data import BuildDatasetFn
20+
1821

1922
class EmptyInput(Module):
2023
"""Produces empty inputs."""
@@ -225,7 +228,7 @@ def fake_source(
225228
repeat: int = 1,
226229
spec: Optional[dict[str, tf.TypeSpec]] = None,
227230
shuffle_buffer_size: Optional[int] = None,
228-
) -> BuildDatasetFn:
231+
) -> "BuildDatasetFn":
229232
if len(examples) == 0:
230233
raise ValueError("examples cannot be empty")
231234

@@ -257,7 +260,7 @@ def fake_text_source(
257260
is_training: bool,
258261
shuffle_buffer_size: Optional[int] = None,
259262
batch_size: int = 2,
260-
) -> BuildDatasetFn:
263+
) -> "BuildDatasetFn":
261264
return fake_source(
262265
is_training=is_training,
263266
examples=[
@@ -271,7 +274,7 @@ def fake_text_source(
271274
)
272275

273276

274-
def fake_serialized_json_source(examples: Sequence[dict[str, Any]]) -> BuildDatasetFn:
277+
def fake_serialized_json_source(examples: Sequence[dict[str, Any]]) -> "BuildDatasetFn":
275278
"""Returns a BuildDatasetFn that returns a dataset of jsonlines of examples.
276279
277280
Args:
@@ -301,7 +304,7 @@ def fake_text2text_source(
301304
target_key: str = "target_text",
302305
is_training: bool,
303306
shuffle_buffer_size: Optional[int] = None,
304-
) -> BuildDatasetFn:
307+
) -> "BuildDatasetFn":
305308
return fake_source(
306309
is_training=is_training,
307310
examples=[
@@ -324,7 +327,7 @@ def fake_glue_source(
324327
num_examples: Optional[int] = None,
325328
shuffle_buffer_size: Optional[int] = None,
326329
spec: Optional[dict[str, tf.TypeSpec]] = None,
327-
) -> BuildDatasetFn:
330+
) -> "BuildDatasetFn":
328331
if isinstance(input_key, str):
329332
input_key = [input_key]
330333
if num_examples is None:
@@ -352,7 +355,7 @@ def fake_classification_source(
352355
is_training: bool,
353356
classes: Sequence[str],
354357
shuffle_buffer_size: Optional[int] = None,
355-
) -> BuildDatasetFn:
358+
) -> "BuildDatasetFn":
356359
num_classes = len(classes)
357360
return fake_source(
358361
is_training=is_training,
@@ -376,7 +379,7 @@ def fake_classification_source_instruct_lm(
376379
shuffle_buffer_size: Optional[int] = None,
377380
eoa_text: str = "<eoa>",
378381
eob_text: str = "<eob>",
379-
) -> BuildDatasetFn:
382+
) -> "BuildDatasetFn":
380383
"""Returns a BuildDatasetFn containing fake classification examples in the InstructLM format.
381384
382385
Args:
@@ -418,7 +421,7 @@ def fake_speech_source(
418421
num_examples: int = 100,
419422
speech_key: str = "speech",
420423
shuffle_buffer_size: Optional[int] = None,
421-
) -> BuildDatasetFn:
424+
) -> "BuildDatasetFn":
422425
"""Fake speech data source.
423426
424427
Args:

axlearn/common/trainer_test.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from absl import flags, logging
2424
from absl.testing import absltest, parameterized
2525
from jax import numpy as jnp
26-
from jax._src import pjit as pjit_lib
26+
from jax._src.interpreters import pxla
2727
from jax.experimental import checkify
2828

2929
from axlearn.common import (
@@ -520,9 +520,9 @@ def mock_compile_train_step(*args, compiler_options=None, **kwargs):
520520
trainer, "compile_train_step", side_effect=mock_compile_train_step
521521
) as mocked_compile_fn:
522522
# pylint: disable=protected-access
523-
start_cache_hits = pjit_lib._pjit_lower_cached.cache_info().hits
523+
start_cache_hits = pxla._cached_lowering_to_hlo.cache_info().hits
524524
output_a = trainer.run(prng_key=jax.random.PRNGKey(123))
525-
end_cache_hits = pjit_lib._pjit_lower_cached.cache_info().hits
525+
end_cache_hits = pxla._cached_lowering_to_hlo.cache_info().hits
526526
# pylint: enable=protected-access
527527
if platform == "tpu":
528528
if not enable_python_cache:
@@ -1160,7 +1160,7 @@ def initialize_parameters_recursively(
11601160
cfg = self.config
11611161
if cfg.kind == "chex":
11621162
param = struct_test.Chex(
1163-
field_d=jnp.array(0),
1163+
field_d=jnp.array(4),
11641164
field_b=jnp.array(1),
11651165
field_a=jnp.array(2),
11661166
field_c=jnp.array(3),

axlearn/common/update_transformation.py

+1
Original file line numberDiff line numberDiff line change
@@ -261,4 +261,5 @@ def mask_tree(tree: dict, *, keep: dict, mask_value: Any) -> dict:
261261
lambda should_keep, leaf: leaf if should_keep else mask_value,
262262
keep,
263263
tree,
264+
is_leaf=lambda x: x is None,
264265
)

docs/01-start.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,12 @@ brew install bazelisk
5353
# This was tested using clang version 15 - you may get non-working wheels with earlier versions of clang.
5454
mkdir ~/builds && git clone https://github.com/tensorflow/text.git ~/builds/text
5555
# Install tensorflow prior to building.
56-
pip install 'tensorflow==2.16.1'
57-
cd ~/builds/text && git checkout 0f9f6df5b4da19bc7a734ba05fc4fa12bccbedbe
56+
pip install 'tensorflow==2.17.1'
57+
cd ~/builds/text && git checkout v2.17.0
5858

5959
# Build tensorflow-text.
6060
./oss_scripts/run_build.sh
61-
pip install ./tensorflow_text-2.16.1-cp310-cp310-macosx_*_arm64.whl
61+
pip install ./tensorflow_text-2.17.0-cp310-cp310-macosx_*_arm64.whl
6262
```
6363
</details>
6464

0 commit comments

Comments
 (0)