Skip to content

Commit

Permalink
add support different layer order in conformer (apple#568)
Browse files Browse the repository at this point in the history
* Open-source MoE (#547)

* MoE

* update

* update

* update

* update

* update

* update

---------

Co-authored-by: Xianzhi Du <[email protected]>

* Supports return_aux in PipelinedTransformerLayer. (apple#557)

* link fix (#561)

Co-authored-by: Xianzhi Du <[email protected]>

* GKE GPU A3 with TCPX support (apple#517)

* add GKE GPU support to axlearn

* add volumes and initContainer

* finish up the pod spec

* add GKE runner for GPU

* extend != append duhh

* fix volume mount change

* add local queue for internal cluster

* move kueue annotation to jobset

* introduce gpu container image

* add ld_library_path

* add env variables for a3

* ensure replicas of jobset is 1

* automatically set distributed coordinator as env variables

* change NCCL_DEBUG to warn

* install gpu jax using pyproject.toml

* address comments from Mark

* fix missing sidecar

* remove accidental double quote

* add default XLA flags

* hardcode max step to 1000

* fix sidecar command

* fix sidecar termination

* allow passing queue name through flag

* port over remaining xla flags

* only use bare minimum xla flags

* address Marks' nit on using env_vars.update

* remove flags that make perf worse

* remove tpu np provionser from gpu runner

* add mesh rule

* Revert "hardcode max step to 1000"

This reverts commit 8cda4f91414c00deb28c7f15d54d183076101d8b.

* add doc for queue flag

* fix punctuation and add link to tcpx docs

* more puntuation

* document NCCL env vars

* throw error if GCS mount is set

* throw error when pre provisioner is enabled

* add testing coverage for GPUGKEJob

* add basic gke_runner tests

* add more gke_runner tests

* address pr comments

* add space

* fix missing .

* add missing space

* fix pytype error in job_test.py

* update golden configs

* [inference] remove unncessary mocks and use absolute imports (apple#563)

Co-authored-by: guoli-yin <[email protected]>

* Minor style changes. (apple#564)

* add support different layer order in conformer

* update

* address review feedback

* fix formatting

* fix black

---------

Co-authored-by: xianzhi <[email protected]>
Co-authored-by: Xianzhi Du <[email protected]>
Co-authored-by: Mark Lee <[email protected]>
Co-authored-by: Sam Stoelinga <[email protected]>
Co-authored-by: Guoli Yin <[email protected]>
Co-authored-by: guoli-yin <[email protected]>
Co-authored-by: Yongqiang Wang <[email protected]>
  • Loading branch information
8 people authored and qdavid1 committed Dec 11, 2024
1 parent cf9469c commit 8161abe
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 8 deletions.
26 changes: 24 additions & 2 deletions axlearn/common/conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
https://github.com/tensorflow/lingvo/blob/d2f1e1b3cccdac8f73ae20f86afb03560b1c176d/lingvo/core/conformer_layer.py
"""

from typing import Optional, Tuple, Union
from typing import Literal, Optional, Tuple, Union

from jax import numpy as jnp

Expand Down Expand Up @@ -202,6 +202,13 @@ class Config(BaseLayer.Config):
)
lconv: LConvLayer.Config = LConvLayer.default_config()
norm: LayerNorm.Config = LayerNorm.default_config()
# Layer order. If None, default to "mhsa_before_conv", i.e., conformer layer order as
# secified in https://arxiv.org/abs/2005.08100.
# If not None, specify the layer order regarding conv and multihead self attention (mhsa).
# e.g., lconv_before_mhsa can be found in Figure 1 https://arxiv.org/pdf/2011.10798.
layer_order: Optional[
Literal["lconv_before_ff", "lconv_before_mhsa", "mhsa_before_lconv"]
] = None

# Config for computing relative position embeddings for range [-seq_len + 1, seq_len - 1].
# It should only be used when attention is of class MultiheadAttention.
Expand Down Expand Up @@ -253,6 +260,11 @@ def __init__(self, cfg: Config, *, parent: Module):
f"cfg.right_context must be greater or equal to 0, get {cfg.right_context}."
)

if cfg.layer_order is not None:
supported_layer_order = ["lconv_before_ff", "lconv_before_mhsa", "mhsa_before_lconv"]
if cfg.layer_order not in supported_layer_order:
raise ValueError(f"Only {supported_layer_order} is allowed, got {cfg.layer_order}")

def forward(self, inputs: Tensor, *, paddings: Tensor) -> Tensor:
"""Computes ConformerLayer outputs.
Expand All @@ -265,6 +277,13 @@ def forward(self, inputs: Tensor, *, paddings: Tensor) -> Tensor:
"""
cfg = self.config
x = inputs

layer_order = cfg.layer_order
if layer_order is None:
layer_order = "mhsa_before_lconv"

if layer_order == "lconv_before_ff":
x = self.lconv(x, paddings=paddings)
x = self.ff_start(x)
attention_logit_biases = compute_attention_logit_biases(
paddings=paddings,
Expand All @@ -275,8 +294,11 @@ def forward(self, inputs: Tensor, *, paddings: Tensor) -> Tensor:
# ToDo(zhiyunlu): test limited context mask with rel_pos_emb.
if self.config.rel_pos_emb:
attention_logit_biases = self.rel_pos_emb(attention_logit_biases)
if layer_order == "lconv_before_mhsa":
x = self.lconv(x, paddings=paddings)
x = self.self_attention(target=x, attention_logit_biases=attention_logit_biases).data
x = self.lconv(x, paddings=paddings)
if layer_order == "mhsa_before_lconv":
x = self.lconv(x, paddings=paddings)
x = self.ff_end(x)
x = self.norm(x)
return x
Expand Down
27 changes: 21 additions & 6 deletions axlearn/common/conformer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,36 +108,51 @@ def test_respect_paddings(self, is_training):
# Check that the outputs are the same despite differences in padding.
assert_allclose(outputs[0, :num_tokens], outputs[1, :num_tokens])

def test_repeated_conformer_config(self):
@parameterized.parameters(None, "lconv_before_ff", "lconv_before_mhsa", "mhsa_before_lconv")
def test_repeated_conformer_config(self, layer_order):
"""Tests RepeatedConformerLayer config.
It tests the ConformerLayer default config is correctly set in RepeatedConformerLayer.
"""
dim, num_heads = 6, 2
cfg = RepeatedConformerLayer.default_config().set(
name="repeat_conformer", input_dim=dim, num_layers=3
name="repeat_conformer",
input_dim=dim,
num_layers=3,
)
cfg.layer.self_attention.attention.num_heads = num_heads
cfg.layer.layer_order = layer_order
for ff_cfg in (cfg.layer.ff_start, cfg.layer.ff_end):
self.assertEqual(ff_cfg.hidden_dim.scale, 4)
self.assertEqual(ff_cfg.residual_weight, 0.5)
self.assertEqual(ff_cfg.activation, "nn.silu")
self.assertEqual(cfg.layer.self_attention.attention.input_linear.layer.bias, True)

@parameterized.parameters((True, True), (False, True), (True, False), (False, False))
def test_repeated_conformer_forward(self, checkpoint_self_attention, checkpoint_feed_forward):
@parameterized.product(
checkpoint_self_attention=(True, False),
checkpoint_feed_forward=(True, False),
layer_order=(None, "lconv_before_ff", "lconv_before_mhsa", "mhsa_before_lconv"),
)
def test_repeated_conformer_forward(
self, checkpoint_self_attention, checkpoint_feed_forward, layer_order
):
"""Tests RepeatedConformerLayer."""
dim, num_heads = 6, 2
# Create a conformer layer.
cfg = ConformerLayer.default_config().set(name="conformer", input_dim=dim)
cfg = ConformerLayer.default_config().set(
name="conformer", input_dim=dim, layer_order=layer_order
)
cfg.self_attention.attention.num_heads = num_heads
layer = cfg.instantiate(parent=None) # type: ConformerLayer

# Create a Repeat Conformer layer.
num_layers = 5
repeat_cfg = RepeatedConformerLayer.default_config().set(
name="repeat_conformer", input_dim=dim, num_layers=num_layers
name="repeat_conformer",
input_dim=dim,
num_layers=num_layers,
)
repeat_cfg.layer.layer_order = layer_order
repeat_cfg.layer.self_attention.attention.num_heads = num_heads
repeat_cfg.layer.remat_spec = build_remat_spec(
repeat_cfg,
Expand Down

0 comments on commit 8161abe

Please sign in to comment.