Skip to content

Commit

Permalink
Add llama 3 tokenizer (#850)
Browse files Browse the repository at this point in the history
* Add llama 3 tokenizer

add a new version called V3_TIKTOKEN.

other edits based on suggestions.

* Handle special tokens like other vocabularies.

* use encode instead of encode_batch
  • Loading branch information
sychen52 authored Jan 18, 2025
1 parent ad14de3 commit 9996f34
Show file tree
Hide file tree
Showing 85 changed files with 4,629 additions and 334 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ model.decoder.transformer.num_layers: 16
model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex'
model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*'
model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat'
model.decoder.vocab_size: 128256
model.decoder.vocab_size: 131072
model.dtype: 'jax.numpy.float32'
model.klass: 'axlearn.common.causal_lm.Model'
model.param_init.init_by_param_name['.*weight$'].distribution: 'normal'
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0)
decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=())
decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ model.decoder.transformer.num_layers: 16
model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex'
model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*'
model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat'
model.decoder.vocab_size: 128256
model.decoder.vocab_size: 131072
model.dtype: 'jax.numpy.float32'
model.klass: 'axlearn.common.causal_lm.Model'
model.param_init.init_by_param_name['.*weight$'].distribution: 'normal'
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0)
decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=())
decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ model.decoder.transformer.num_layers: 16
model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex'
model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*'
model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat'
model.decoder.vocab_size: 128256
model.decoder.vocab_size: 131072
model.dtype: 'jax.numpy.float32'
model.klass: 'axlearn.common.causal_lm.Model'
model.param_init.init_by_param_name['.*weight$'].distribution: 'normal'
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0)
decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=())
decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=())
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0)
decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=())
decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=())
decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0)
decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 2048), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
decoder/output_norm/scale: constant(1.0)
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
====================weight_decay_scale root.optimizer====================
decoder/emb/token_emb/weight: 1
decoder/output_norm/scale: 1
decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1
decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1
decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1
decoder/transformer/repeat/layer/feed_forward/norm/scale: 1
decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1
decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1
decoder/transformer/repeat/layer/self_attention/norm/scale: 1
Loading

0 comments on commit 9996f34

Please sign in to comment.