@@ -690,7 +690,7 @@ def lm_index_map(batch_index, head_index, q_seq_index, _):
690
690
)
691
691
out_shape = jax .ShapeDtypeStruct (shape = q .shape , dtype = q .dtype )
692
692
out_shape = [out_shape ]
693
- out_specs = [pl .BlockSpec (o_index_map , (block_b , 1 , block_q , head_dim ))]
693
+ out_specs = [pl .BlockSpec ((block_b , 1 , block_q , head_dim ), o_index_map )]
694
694
695
695
if block_k != kv_seq_len :
696
696
m_scratch = pltpu .VMEM ((block_b , 1 , block_q , MIN_BLOCK_SIZE ), jnp .float32 )
@@ -703,8 +703,8 @@ def lm_index_map(batch_index, head_index, q_seq_index, _):
703
703
if save_residuals :
704
704
out_specs = [
705
705
* out_specs ,
706
- pl .BlockSpec (lm_index_map , (block_b , 1 , block_q , MIN_BLOCK_SIZE )),
707
- pl .BlockSpec (lm_index_map , (block_b , 1 , block_q , MIN_BLOCK_SIZE )),
706
+ pl .BlockSpec ((block_b , 1 , block_q , MIN_BLOCK_SIZE ), lm_index_map ),
707
+ pl .BlockSpec ((block_b , 1 , block_q , MIN_BLOCK_SIZE ), lm_index_map ),
708
708
]
709
709
l = jax .ShapeDtypeStruct (
710
710
(batch_size , num_heads , q_seq_len , MIN_BLOCK_SIZE ), dtype = jnp .float32
@@ -718,7 +718,7 @@ def lm_index_map(batch_index, head_index, q_seq_index, _):
718
718
out_shape = (* out_shape , None , None )
719
719
720
720
ab_block_spec = (
721
- pl .BlockSpec (ab_index_map , (block_b , 1 , block_q , block_k_major )) if ab is not None else None
721
+ pl .BlockSpec ((block_b , 1 , block_q , block_k_major ), ab_index_map ) if ab is not None else None
722
722
)
723
723
724
724
q_segment_ids_spec = kv_segment_ids_spec = None
@@ -741,9 +741,9 @@ def kv_segment_ids_index_map(batch_index, head_index, q_seq_index, kv_seq_index)
741
741
next_kv_index = kv_seq_index
742
742
return (batch_index , 0 , next_kv_index )
743
743
744
- q_segment_ids_spec = pl .BlockSpec (q_segment_ids_index_map , (block_b , block_q , NUM_LANES ))
744
+ q_segment_ids_spec = pl .BlockSpec ((block_b , block_q , NUM_LANES ), q_segment_ids_index_map )
745
745
kv_segment_ids_spec = pl .BlockSpec (
746
- kv_segment_ids_index_map , (block_b , NUM_SUBLANES , block_k_major )
746
+ (block_b , NUM_SUBLANES , block_k_major ), kv_segment_ids_index_map
747
747
)
748
748
749
749
q_segment_ids = jax .lax .broadcast_in_dim (
@@ -764,9 +764,9 @@ def kv_segment_ids_index_map(batch_index, head_index, q_seq_index, kv_seq_index)
764
764
)
765
765
766
766
in_specs = [
767
- pl .BlockSpec (q_index_map , (block_b , 1 , block_q , head_dim )),
768
- pl .BlockSpec (kv_index_map , (block_b , 1 , block_k_major , head_dim )),
769
- pl .BlockSpec (kv_index_map , (block_b , 1 , block_k_major , head_dim )),
767
+ pl .BlockSpec ((block_b , 1 , block_q , head_dim ), q_index_map ),
768
+ pl .BlockSpec ((block_b , 1 , block_k_major , head_dim ), kv_index_map ),
769
+ pl .BlockSpec ((block_b , 1 , block_k_major , head_dim ), kv_index_map ),
770
770
ab_block_spec ,
771
771
q_segment_ids_spec ,
772
772
kv_segment_ids_spec ,
@@ -861,7 +861,7 @@ def qo_index_map(batch_index, head_index, kv_seq_index, q_seq_index):
861
861
862
862
return (batch_index , head_index , next_q_index , 0 )
863
863
864
- qo_spec = pl .BlockSpec (qo_index_map , (1 , 1 , block_q_major , head_dim ))
864
+ qo_spec = pl .BlockSpec ((1 , 1 , block_q_major , head_dim ), qo_index_map )
865
865
assert qo_spec .block_shape is not None
866
866
assert q .ndim == len (qo_spec .block_shape )
867
867
do_spec = qo_spec
@@ -870,20 +870,20 @@ def qo_index_map(batch_index, head_index, kv_seq_index, q_seq_index):
870
870
def kv_index_map (batch_index , head_index , kv_seq_index , _ ):
871
871
return (batch_index , head_index , kv_seq_index , 0 )
872
872
873
- kv_spec = pl .BlockSpec (kv_index_map , (1 , 1 , block_k_major , head_dim ))
873
+ kv_spec = pl .BlockSpec ((1 , 1 , block_k_major , head_dim ), kv_index_map )
874
874
assert kv_spec .block_shape is not None
875
875
assert k .ndim == len (kv_spec .block_shape )
876
876
assert v .ndim == len (kv_spec .block_shape )
877
877
878
878
def lm_index_map (batch_index , head_index , _ , q_seq_index ):
879
879
return (batch_index , head_index , q_seq_index , 0 )
880
880
881
- lm_spec = pl .BlockSpec (lm_index_map , (1 , 1 , block_q_major , MIN_BLOCK_SIZE ))
881
+ lm_spec = pl .BlockSpec ((1 , 1 , block_q_major , MIN_BLOCK_SIZE ), lm_index_map )
882
882
assert lm_spec .block_shape is not None
883
883
assert l .ndim == len (lm_spec .block_shape )
884
884
assert m .ndim == len (lm_spec .block_shape )
885
885
886
- di_spec = pl .BlockSpec (qo_index_map , (1 , 1 , block_q_major , MIN_BLOCK_SIZE ))
886
+ di_spec = pl .BlockSpec ((1 , 1 , block_q_major , MIN_BLOCK_SIZE ), qo_index_map )
887
887
assert di_spec .block_shape is not None
888
888
assert di .ndim == len (di_spec .block_shape )
889
889
@@ -896,7 +896,7 @@ def ab_index_map(batch_index, head_index, kv_seq_index, q_seq_index):
896
896
)
897
897
898
898
dab_spec = (
899
- pl .BlockSpec (ab_index_map , (1 , 1 , block_q_major , block_k_major )) if ab is not None else None
899
+ pl .BlockSpec ((1 , 1 , block_q_major , block_k_major ), ab_index_map ) if ab is not None else None
900
900
)
901
901
902
902
q_segment_ids_spec = kv_segment_ids_spec = None
@@ -919,9 +919,9 @@ def kv_segment_ids_index_map(batch_index, head_index, kv_seq_index, _):
919
919
del head_index
920
920
return (batch_index , 0 , kv_seq_index )
921
921
922
- q_segment_ids_spec = pl .BlockSpec (q_segment_ids_index_map , (1 , block_q_major , NUM_LANES ))
922
+ q_segment_ids_spec = pl .BlockSpec ((1 , block_q_major , NUM_LANES ), q_segment_ids_index_map )
923
923
kv_segment_ids_spec = pl .BlockSpec (
924
- kv_segment_ids_index_map , (1 , NUM_SUBLANES , block_k_major )
924
+ (1 , NUM_SUBLANES , block_k_major ), kv_segment_ids_index_map
925
925
)
926
926
927
927
q_segment_ids = jax .lax .broadcast_in_dim (
@@ -962,7 +962,7 @@ def kv_segment_ids_index_map(batch_index, head_index, kv_seq_index, _):
962
962
def dkv_index_map (batch_index , head_index , kv_seq_index , _ ):
963
963
return (batch_index , head_index , kv_seq_index , 0 )
964
964
965
- dkv_spec = pl .BlockSpec (dkv_index_map , (1 , 1 , block_k_major , head_dim ))
965
+ dkv_spec = pl .BlockSpec ((1 , 1 , block_k_major , head_dim ), dkv_index_map )
966
966
out_specs = [dkv_spec , dkv_spec ]
967
967
scratch_shapes = [
968
968
pltpu .VMEM ((block_k_major , head_dim ), jnp .float32 ), # type: ignore
@@ -1050,7 +1050,7 @@ def _flash_attention_bwd_dq(
1050
1050
def qo_index_map (batch_index , head_index , q_seq_index , _ ):
1051
1051
return (batch_index , head_index , q_seq_index , 0 )
1052
1052
1053
- qo_spec = pl .BlockSpec (qo_index_map , (1 , 1 , block_q_major , head_dim ))
1053
+ qo_spec = pl .BlockSpec ((1 , 1 , block_q_major , head_dim ), qo_index_map )
1054
1054
do_spec = qo_spec
1055
1055
1056
1056
def kv_index_map (batch_index , head_index , q_seq_index , kv_seq_index ):
@@ -1066,20 +1066,20 @@ def kv_index_map(batch_index, head_index, q_seq_index, kv_seq_index):
1066
1066
next_kv_index = kv_seq_index
1067
1067
return (batch_index , head_index , next_kv_index , 0 )
1068
1068
1069
- kv_spec = pl .BlockSpec (kv_index_map , (1 , 1 , block_k_major , head_dim ))
1069
+ kv_spec = pl .BlockSpec ((1 , 1 , block_k_major , head_dim ), kv_index_map )
1070
1070
assert kv_spec .block_shape is not None
1071
1071
assert k .ndim == len (kv_spec .block_shape )
1072
1072
assert v .ndim == len (kv_spec .block_shape )
1073
1073
1074
1074
def lm_index_map (batch_index , head_index , q_seq_index , _ ):
1075
1075
return (batch_index , head_index , q_seq_index , 0 )
1076
1076
1077
- lm_spec = pl .BlockSpec (lm_index_map , (1 , 1 , block_q_major , MIN_BLOCK_SIZE ))
1077
+ lm_spec = pl .BlockSpec ((1 , 1 , block_q_major , MIN_BLOCK_SIZE ), lm_index_map )
1078
1078
assert lm_spec .block_shape is not None
1079
1079
assert l .ndim == len (lm_spec .block_shape )
1080
1080
assert m .ndim == len (lm_spec .block_shape )
1081
1081
1082
- di_spec = pl .BlockSpec (qo_index_map , (1 , 1 , block_q_major , MIN_BLOCK_SIZE ))
1082
+ di_spec = pl .BlockSpec ((1 , 1 , block_q_major , MIN_BLOCK_SIZE ), qo_index_map )
1083
1083
assert di_spec .block_shape is not None
1084
1084
assert di .ndim == len (di_spec .block_shape )
1085
1085
@@ -1092,7 +1092,7 @@ def ab_index_map(batch_index, head_index, q_seq_index, kv_seq_index):
1092
1092
)
1093
1093
1094
1094
dab_spec = (
1095
- pl .BlockSpec (ab_index_map , (1 , 1 , block_q_major , block_k_major )) if ab is not None else None
1095
+ pl .BlockSpec ((1 , 1 , block_q_major , block_k_major ), ab_index_map ) if ab is not None else None
1096
1096
)
1097
1097
1098
1098
q_segment_ids_spec = kv_segment_ids_spec = None
@@ -1117,9 +1117,9 @@ def kv_segment_ids_index_map(batch_index, head_index, q_seq_index, kv_seq_index)
1117
1117
next_kv_index = kv_seq_index
1118
1118
return (batch_index , 0 , next_kv_index )
1119
1119
1120
- q_segment_ids_spec = pl .BlockSpec (q_segment_ids_index_map , (1 , block_q_major , NUM_LANES ))
1120
+ q_segment_ids_spec = pl .BlockSpec ((1 , block_q_major , NUM_LANES ), q_segment_ids_index_map )
1121
1121
kv_segment_ids_spec = pl .BlockSpec (
1122
- kv_segment_ids_index_map , (1 , NUM_SUBLANES , block_k_major )
1122
+ (1 , NUM_SUBLANES , block_k_major ), kv_segment_ids_index_map
1123
1123
)
1124
1124
1125
1125
q_segment_ids = jax .lax .broadcast_in_dim (
@@ -1156,7 +1156,7 @@ def kv_segment_ids_index_map(batch_index, head_index, q_seq_index, kv_seq_index)
1156
1156
jax .ShapeDtypeStruct (q .shape , q .dtype ),
1157
1157
jax .ShapeDtypeStruct (ab .shape , ab .dtype ) if ab is not None else None ,
1158
1158
]
1159
- dq_spec = pl .BlockSpec (qo_index_map , (1 , 1 , block_q_major , head_dim ))
1159
+ dq_spec = pl .BlockSpec ((1 , 1 , block_q_major , head_dim ), qo_index_map )
1160
1160
out_specs = [
1161
1161
dq_spec ,
1162
1162
dab_spec ,
0 commit comments