Skip to content

Commit

Permalink
Caching model values
Browse files Browse the repository at this point in the history
  • Loading branch information
bcwarner committed Sep 21, 2023
1 parent 7ce8a0f commit 076370e
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 34 deletions.
2 changes: 1 addition & 1 deletion entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1042,7 +1042,7 @@ def plot(self):
ce_current = []

row_len = len(vocab.field_ids) - 1 # Exclude special fields
row_count = (eos_index - 1) // row_len
row_count = eos_index // row_len
if row_count <= 1: # Not applicable
continue

Expand Down
92 changes: 65 additions & 27 deletions entropy_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,20 @@
import pickle
import sys
from collections import defaultdict
from functools import partial

import pandas as pd
import scipy.stats
import torch
import yaml
from matplotlib.axes import Axes
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, SequentialSampler, BatchSampler
from tqdm import tqdm
from tabulate import tabulate
from matplotlib import pyplot as plt

from model.model import EHRAuditGPT2, EHRAuditRWKV, EHRAuditLlama
from model.modules import EHRAuditPretraining, EHRAuditDataModule
from model.modules import EHRAuditPretraining, EHRAuditDataModule, collate_fn, worker_fn
from model.data import timestamp_space_calculation
from model.vocab import EHRVocab, EHRAuditTokenizer
import tikzplotlib
Expand Down Expand Up @@ -51,6 +52,11 @@
action="store_true",
help="Whether to run with single thread.",
)
parser.add_argument(
"--verify",
action="store_true",
help="Whether to verify the entropy values line up with the dataset.",
)
args = parser.parse_args()
# Get the list of models from the config file
config_path = os.path.normpath(
Expand Down Expand Up @@ -126,22 +132,30 @@
model = types[model_type].from_pretrained(model_path, vocab=vocab)
model.loss.reduction = "none"
model.to(device)

train_dl = dm.train_dataloader()
val_dl = dm.val_dataloader()
test_dl = dm.test_dataloader()

train_dl.shuffle = False
val_dl.shuffle = False
test_dl.shuffle = False
model_name = "-".join(model_list[model_idx].split(os.sep)[0:2])

dl = torch.utils.data.DataLoader(
dm.test_dataset,
num_workers=0,
batch_size=1,
collate_fn=partial(collate_fn, n_positions=dm.n_positions),
worker_init_fn=partial(worker_fn, seed=dm.seed),
pin_memory=True,
shuffle=False,
batch_sampler=BatchSampler(SequentialSampler(dm.test_dataset), batch_size=1, drop_last=False),
)

def iter_dl(dl):
# Provider => row => field => entropy
whole_set_entropy_map = defaultdict(lambda:
defaultdict(lambda:
{"METRIC_NAME": pd.NA, "PAT_ID": pd.NA, "ACCESS_TIME": pd.NA})
)
for batch_idx, batch in tqdm(enumerate(dl), total=len(dl)):

cur_provider = None
providers_seen = set()
pbar = tqdm(enumerate(dl), total=len(dl))
for batch_idx, batch in pbar:
input_ids, labels = batch

with torch.no_grad():
Expand All @@ -158,20 +172,33 @@ def iter_dl(dl):
# Set the labels to -100, zero out the input_ids
labels_c[:, :] = -100

# Get the index of the current row in the whole df
dset_idx = bisect.bisect_right(dl.dataset.cumulative_sizes, batch_idx)
dset_start_idx = dl.dataset.cumulative_sizes[dset_idx - 1] if dset_idx > 0 else 0
dset = dl.dataset.datasets[dset_idx]
provider = dset.provider

ce_current = []
row_len = len(vocab.field_ids) - 1 # Exclude special fields
row_count = (eos_index - 1) // row_len
row_count = eos_index // row_len # No need to offset for eos
if row_count <= 1: # Not applicable
whole_set_entropy_map[provider][dset.seqs_indices[batch_idx - dset_start_idx][0]]["METRIC_NAME"] = pd.NA
whole_set_entropy_map[provider][dset.seqs_indices[batch_idx - dset_start_idx][0]]["PAT_ID"] = pd.NA
whole_set_entropy_map[provider][dset.seqs_indices[batch_idx - dset_start_idx][0]]["ACCESS_TIME"] = pd.NA
continue

# NOTE: Next-token generation != next-row generation
# This means that we include the next two tokens in the input to avoid EOS predictions.
loss_pos = model.loss.col_ids_labels.transpose(0, 1).flatten()

# Get the index of the current row in the whole df
dset_idx = bisect.bisect_right(dl.dataset.cumulative_sizes, batch_idx)
dset = dl.dataset.datasets[dset_idx]
provider = dset.provider
if provider != cur_provider:
providers_seen.add(provider)
cur_provider = provider
pbar.set_postfix({"providers": len(providers_seen)})

# Add a NA for the first row.
whole_set_entropy_map[provider][dset.seqs_indices[batch_idx - dset_start_idx][0]]["METRIC_NAME"] = pd.NA
whole_set_entropy_map[provider][dset.seqs_indices[batch_idx - dset_start_idx][0]]["PAT_ID"] = pd.NA
whole_set_entropy_map[provider][dset.seqs_indices[batch_idx - dset_start_idx][0]]["ACCESS_TIME"] = pd.NA

for i in range(0, row_count):
input_ids_start = i * row_len
Expand All @@ -187,15 +214,6 @@ def iter_dl(dl):
labels_c[:, labels_row_start:labels_row_end] = labels[
:, labels_row_start:labels_row_end
]
#if i > 0:
# labels_c[
# :, input_ids_start:input_ids_end
# ] = -100 # Eliminate previous row.

# if i >= window_size:
# old_row_start = (i - window_size) * row_len
# old_row_end = old_row_start + row_len
# input_ids_c[:, old_row_start:old_row_end] = 0

# Calculate the cross entropy
output = model(input_ids_c.to(device), labels=labels_c.to(device), return_dict=True)
Expand All @@ -204,12 +222,32 @@ def iter_dl(dl):
patient_loss = loss[PAT_ID_COL, i]
time_loss = loss[ACCESS_TIME_COL, i]

whole_row_idx = dl.dataset.seqs_indices[dset_idx] + i
whole_row_idx = dset.seqs_indices[batch_idx - dset_start_idx][0] + (i + 1)
# +1 to account for the first row being the header

whole_set_entropy_map[provider][whole_row_idx]["METRIC_NAME"] = metric_loss
whole_set_entropy_map[provider][whole_row_idx]["PAT_ID"] = patient_loss
whole_set_entropy_map[provider][whole_row_idx]["ACCESS_TIME"] = time_loss

# Upon completion, save the entropy map
for dset_count, (provider, entropy_map) in enumerate(whole_set_entropy_map.items()):
prov_path = os.path.normpath(os.path.join(path_prefix, config["audit_log_path"], provider))

# Convert entropy_map to a df with its keys as indicies
entropy_df = pd.DataFrame.from_dict(entropy_map, orient="index")

if args.verify:
# Make sure all the sequence ranges line up.
seqs_indices = dl.dataset.datasets[dset_count].seqs_indices
for i, (start, stop) in enumerate(seqs_indices):
if len(entropy_df.loc[start:stop, :]) != stop - start + 1:
raise ValueError(f"Sequence range {start}:{stop} does not line up with entropy_df.")
if not entropy_df.loc[start, :].isna().all():
raise ValueError(f"First value is not NA for sequence range {start}:{stop}.")

entropy_df.to_csv(os.path.normpath(os.path.join(prov_path, f"entropy-{model_name}.csv")))


iter_dl(dl)


17 changes: 11 additions & 6 deletions model/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,18 +161,23 @@ def load_from_log(self):
seqs = []
seqs_indices = []
for shift in seqs_shifts:
seq_start_idx = shift.index[0]
seq_start_idx = 0 #shift.index[0]
# Reset the index
for i, row in shift.iterrows():
for i, (_, row) in enumerate(shift.iterrows()):
if row[self.timestamp_col] > sep_sec:
shift.loc[i, self.timestamp_col] = 0 # Reset the time delta to 0.
seq_end_idx = i - 1
new_seq = shift.loc[seq_start_idx:seq_end_idx, :].copy()
seqs.append(new_seq)
new_seq = shift.iloc[seq_start_idx:seq_end_idx, :].copy()
seq_start_idx = i
if len(new_seq) == 0:
continue
new_seq.iloc[0, new_seq.columns.get_loc(self.timestamp_col)] = 0
seqs.append(new_seq)


# Append the last shift
seqs.append(shift.loc[seq_start_idx:, :].copy())
last_seq = shift.iloc[seq_start_idx:, :].copy()
last_seq.iloc[0, last_seq.columns.get_loc(self.timestamp_col)] = 0
seqs.append(last_seq)

for seq in seqs:
seqs_indices.append((seq.index[0], seq.index[-1]))
Expand Down

0 comments on commit 076370e

Please sign in to comment.