Skip to content

Commit

Permalink
Merge pull request #52 from Lewington-pitsos/remove_batch_size_SaeVis…
Browse files Browse the repository at this point in the history
…Config

Remove batch size sae vis config
  • Loading branch information
callummcdougall authored Jul 4, 2024
2 parents fe82b75 + 7151f62 commit 4e82228
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 18 deletions.
15 changes: 6 additions & 9 deletions demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,6 @@
"sae_vis_config = SaeVisConfig(\n",
" hook_point = utils.get_act_name(\"post\", 0),\n",
" features = range(64),\n",
" batch_size = 2048,\n",
" verbose = True,\n",
")\n",
"\n",
Expand All @@ -286,7 +285,7 @@
" encoder = encoder,\n",
" encoder_B = encoder_B,\n",
" model = model,\n",
" tokens = all_tokens, # type: ignore\n",
" tokens = all_tokens[: 2048], # type: ignore\n",
" cfg = sae_vis_config,\n",
")\n",
"\n",
Expand Down Expand Up @@ -370,16 +369,16 @@
"feature_vis_config_custom_layout = SaeVisConfig(\n",
" hook_point = utils.get_act_name(\"post\", 0),\n",
" features = range(16),\n",
" batch_size = 1024,\n",
" feature_centric_layout = layout,\n",
")\n",
"\n",
"token_subset = all_tokens[: 1024] # type: ignore\n",
"# Generate data\n",
"sae_vis_data_custom = SaeVisData.create(\n",
" encoder = encoder,\n",
" encoder_B = encoder_B,\n",
" model = model,\n",
" tokens = all_tokens[:, :64], # type: ignore\n",
" tokens = token_subset[:, :64], # type: ignore\n",
" cfg = feature_vis_config_custom_layout,\n",
")\n",
"\n",
Expand Down Expand Up @@ -485,14 +484,13 @@
"feature_vis_config_gpt = SaeVisConfig(\n",
" hook_point = hook_point,\n",
" features = test_feature_idx_gpt,\n",
" batch_size = 8192,\n",
" verbose = True,\n",
")\n",
"\n",
"sae_vis_data_gpt = SaeVisData.create(\n",
" encoder = gpt2_sae,\n",
" model = gpt2,\n",
" tokens = all_tokens_gpt, # type: ignore\n",
" tokens = all_tokens_gpt[: 8192], # type: ignore\n",
" cfg = feature_vis_config_gpt,\n",
")\n",
"\n",
Expand Down Expand Up @@ -535,7 +533,6 @@
"sae_vis_config = SaeVisConfig(\n",
" hook_point = utils.get_act_name(\"post\", 0),\n",
" features = range(256),\n",
" batch_size = 2048,\n",
" verbose = True,\n",
")\n",
"\n",
Expand All @@ -544,7 +541,7 @@
" encoder = encoder,\n",
" encoder_B = encoder_B,\n",
" model = model,\n",
" tokens = all_tokens, # type: ignore\n",
" tokens = all_tokens[:, 2048], # type: ignore\n",
" cfg = sae_vis_config,\n",
")"
]
Expand Down Expand Up @@ -736,7 +733,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
"version": "3.10.13"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
Expand Down
3 changes: 0 additions & 3 deletions sae_vis/data_config_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,8 +407,6 @@ def default_prompt_centric_layout(cls) -> "SaeVisLayoutConfig":
SAE_CONFIG_DICT = dict(
hook_point="The hook point to use for the SAE",
features="The set of features which we'll be gathering data for. If an integer, we only get data for 1 feature",
batch_size="The number of sequences we'll gather data for. If supplied then it can't be larger than `tokens[0]`, \
if not then we use all of `tokens`",
minibatch_size_tokens="The minibatch size we'll use to split up the full batch during forward passes, to avoid \
OOMs.",
minibatch_size_features="The feature minibatch size we'll use to split up our features, to avoid OOM errors",
Expand All @@ -423,7 +421,6 @@ class SaeVisConfig:
# Data
hook_point: str | None = None
features: int | Iterable[int] | None = None
batch_size: int | None = None
minibatch_size_features: int = 256
minibatch_size_tokens: int = 64

Expand Down
6 changes: 0 additions & 6 deletions sae_vis/data_fetching_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,12 +567,6 @@ def get_feature_data(
sae_vis_data = SaeVisData()
time_logs = defaultdict(float)

# Slice tokens, if we're only doing a subset of them
if cfg.batch_size is None:
tokens = tokens
else:
tokens = tokens[: cfg.batch_size]

# Get a feature list (need to deal with the case where `cfg.features` is an int, or None)
if cfg.features is None:
assert isinstance(encoder.cfg.d_hidden, int)
Expand Down

0 comments on commit 4e82228

Please sign in to comment.