Skip to content

Commit

Permalink
Add ReadOptions args to _make_autoregressive_inputs (#931)
Browse files Browse the repository at this point in the history
* Add ReadOptions args to _make_autoregressive_inputs

* use read_options as args instead
  • Loading branch information
RsEnts authored Jan 17, 2025
1 parent 4858070 commit ad14de3
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions axlearn/common/input_grain_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import sys
from typing import Optional, Protocol

import grain.python as grain
import numpy as np
from grain._src.python.dataset.transformations.prefetch import MultiprocessPrefetchIterDataset

Expand All @@ -30,6 +31,7 @@ def _make_autoregressive_inputs(
max_len: int,
input_key: str = "target_labels",
split_fn: Optional[ConfigOr[_SplitFn]] = None,
read_options: grain.ReadOptions = grain.ReadOptions(num_threads=1, prefetch_buffer_size=16),
window_size: int = 1,
) -> Dataset:
"""Produces `input_ids` autoregressively from `target_labels`.
Expand All @@ -44,6 +46,8 @@ def _make_autoregressive_inputs(
input_key: Input key containing `target_labels`.
split_fn: A callable taking flat input IDs and producing batched IDs of shape [-1, max_len].
If None, returns the flat input IDs unchanged.
read_options: grain.ReadOptions which includes num_threads and prefetch_buffer_size. It is
used to convert the pipeline to grain.IterDataset.
window_size: Window size. If > 1, also packs.
Returns:
Expand All @@ -67,7 +71,10 @@ def process_example_fn(example: SequenceOr[dict[str, Tensor]]) -> dict[str, Tens
# Batch as lists to avoid ragged.
ds = ds.batch(window_size, drop_remainder=False, batch_fn=list)
ds = ds.map(process_example_fn)
ds = input_grain.maybe_to_iter_dataset(ds)
ds = input_grain.maybe_to_iter_dataset(
ds,
read_options=read_options,
)
# After processing, we have non-ragged np.arrays, so we can unbatch.
ds = input_grain.unbatch(ds)
return ds
Expand Down Expand Up @@ -99,6 +106,7 @@ def text_to_lm_training_input(
max_len: int,
window_size: int = 128,
max_padding_fraction: float = 1,
read_options: grain.ReadOptions = grain.ReadOptions(num_threads=1, prefetch_buffer_size=16),
) -> Dataset:
"""Returns a function that generates training inputs for language models from raw text.
Expand All @@ -114,6 +122,8 @@ def text_to_lm_training_input(
max_padding_fraction: The maximum fraction of a batch example that we are willing to pad.
E.g. if this is 0.5 then we will pad an example with >= 0.5 * max_len viable tokens,
else drop it entirely.
read_options: grain.ReadOptions which includes num_threads and prefetch_buffer_size. It is
used to convert the pipeline to grain.IterDataset.
Returns:
A `grain.IterDataset` with potentially different cardinality than the input dataset.
Expand All @@ -136,7 +146,7 @@ def text_to_lm_training_input(
ds = input_grain.rekey(ds, key_map={"target_labels": "text"})
# Flatten, roll, split.
ds = _make_autoregressive_inputs(
ds, max_len=max_len, window_size=window_size, split_fn=split_fn
ds, max_len=max_len, window_size=window_size, split_fn=split_fn, read_options=read_options
)
ds = input_grain_text.count_num_bytes(
ds, input_key="target_labels", vocab=vocab, output_key="target_num_bytes"
Expand Down

0 comments on commit ad14de3

Please sign in to comment.