-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathentropy_cache.py
273 lines (229 loc) · 10.5 KB
/
entropy_cache.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
# Assigns entropy values with a given model to the dataset in the order it appears.
import argparse
import bisect
import inspect
import os
import pickle
import sys
from collections import defaultdict
from functools import partial
import pandas as pd
import scipy.stats
import torch
import yaml
from matplotlib.axes import Axes
from torch.utils.data import DataLoader, SequentialSampler, BatchSampler, ConcatDataset
from tqdm import tqdm
from tabulate import tabulate
from matplotlib import pyplot as plt
from model.model import EHRAuditGPT2, EHRAuditRWKV, EHRAuditLlama
from model.modules import EHRAuditPretraining, EHRAuditDataModule, collate_fn, worker_fn
from model.data import timestamp_space_calculation
from model.vocab import EHRVocab, EHRAuditTokenizer
import numpy as np
# Fyi: this is a quick-and-dirty way of id'ing the columns, will need to be changed if the tabularization changes
METRIC_NAME_COL = 0
PAT_ID_COL = 1
ACCESS_TIME_COL = 2
# Get arguments
parser = argparse.ArgumentParser()
parser.add_argument(
"--model", type=int, default=None, help="Model to use for pretraining."
)
parser.add_argument(
"--val",
action="store_true",
help="Run with the validation dataset instead of the test.",
)
parser.add_argument(
"--reset_cache",
action="store_true",
help="Whether to reset the cache before analysis.",
)
parser.add_argument(
"--debug",
action="store_true",
help="Whether to run with single thread.",
)
parser.add_argument(
"--verify",
action="store_true",
help="Whether to verify the entropy values line up with the dataset.",
)
if __name__ == "__main__":
# To add if needed: delete cached entropy values
args = parser.parse_args()
# Get the list of models from the config file
config_path = os.path.normpath(
os.path.join(os.path.dirname(__file__), "config.yaml")
)
with open(config_path, "r") as f:
config = yaml.safe_load(f)
path_prefix = ""
for prefix in config["path_prefix"]:
if os.path.exists(prefix):
path_prefix = prefix
break
if path_prefix == "":
raise RuntimeError("No valid drive mounted.")
model_paths = os.path.normpath(
os.path.join(path_prefix, config["pretrained_model_path"])
)
# Get recursive list of subdirectories
model_list = []
for root, dirs, files in os.walk(model_paths):
# If there's a .bin file, it's a model
if any([file.endswith(".bin") for file in files]):
# Append the last three directories to the model list
model_list.append(os.path.join(*root.split(os.sep)[-3:]))
if len(model_list) == 0:
raise ValueError(f"No models found in {format(model_paths)}")
model_list = sorted(model_list)
if args.model is None:
print("Select a model to evaluate:")
for i, model in enumerate(model_list):
print(f"{i}: {model}")
model_idx = int(input("Model index >>>"))
else:
model_idx = args.model
model_name = model_list[model_idx]
model_path = os.path.normpath(
os.path.join(path_prefix, config["pretrained_model_path"], model_name)
)
# Get the device to use
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the test dataset
vocab = EHRVocab(
vocab_path=os.path.normpath(os.path.join(path_prefix, config["vocab_path"]))
)
dm = EHRAuditDataModule(
yaml_config_path=config_path,
vocab=vocab,
batch_size=1, # Just one sample at a time
reset_cache=args.reset_cache,
debug=args.debug,
)
if args.reset_cache:
dm.prepare_data()
dm.setup()
types = {
"gpt2": EHRAuditGPT2,
"rwkv": EHRAuditRWKV,
"llama": EHRAuditLlama,
}
model_type = model_list[model_idx].split(os.sep)[0]
model = types[model_type].from_pretrained(model_path, vocab=vocab)
model.loss.reduction = "none"
model.to(device)
model_name = "-".join(model_list[model_idx].split(os.sep)[0:2])
val_test_dset = ConcatDataset(dm.val_dataset.datasets + dm.test_dataset.datasets)
dl = torch.utils.data.DataLoader(
val_test_dset,
num_workers=0,
batch_size=1,
collate_fn=partial(collate_fn, n_positions=dm.n_positions),
worker_init_fn=partial(worker_fn, seed=dm.seed),
pin_memory=True,
shuffle=False,
batch_sampler=BatchSampler(SequentialSampler(val_test_dset), batch_size=1, drop_last=False),
)
# Provider => row => field => entropy
whole_set_entropy_map = defaultdict(lambda:
defaultdict(lambda:
{"METRIC_NAME": pd.NA, "PAT_ID": pd.NA, "ACCESS_TIME": pd.NA})
)
cur_provider = None
providers_seen = set()
providers_cached = set()
pbar = tqdm(enumerate(dl), total=len(dl))
model_path_name = f"entropy-{model_name}.csv"
# See what providers have been cached already.
for provider in os.listdir(os.path.normpath(os.path.join(path_prefix, config["audit_log_path"]))):
if os.path.exists(os.path.normpath(os.path.join(path_prefix, config["audit_log_path"], provider, model_path_name))):
providers_cached.add(provider)
for batch_idx, batch in pbar:
input_ids, labels = batch
with torch.no_grad():
# Find the eos index
nonzeros = (labels.view(-1) == -100).nonzero(as_tuple=True)
if len(nonzeros[0]) == 0:
eos_index = len(labels.view(-1)) - 1
else:
eos_index = nonzeros[0][0].item() - 1
# Copy the labels and targets
#input_ids_c = torch.zeros_like(input_ids)
#labels_c = labels.clone()
# Set the labels to -100, zero out the input_ids
#labels_c[:, :] = -100
# Get the index of the current row in the whole df
dset_idx = bisect.bisect_right(dl.dataset.cumulative_sizes, batch_idx)
dset_start_idx = dl.dataset.cumulative_sizes[dset_idx - 1] if dset_idx > 0 else 0
dset = dl.dataset.datasets[dset_idx]
provider = dset.provider
#if provider in providers_cached:
# continue
if provider != cur_provider:
providers_seen.add(provider)
pbar.set_postfix({"providers": len(providers_seen)})
if cur_provider is not None:
# Save the entropy map
prov_path = os.path.normpath(os.path.join(path_prefix, config["audit_log_path"], cur_provider))
# Convert entropy_map to a df with its keys as indicies
entropy_df = pd.DataFrame.from_dict(whole_set_entropy_map[cur_provider], orient="index")
entropy_df.to_csv(os.path.normpath(os.path.join(prov_path, model_path_name)))
providers_cached.add(cur_provider)
# Always set this
cur_provider = provider
ce_current = []
row_len = len(vocab.field_ids) - 1 # Exclude special fields
row_count = eos_index // row_len # No need to offset for eos
if row_count <= 1: # Not applicable
whole_set_entropy_map[provider][dset.seqs_indices[batch_idx - dset_start_idx][0]]["METRIC_NAME"] = pd.NA
whole_set_entropy_map[provider][dset.seqs_indices[batch_idx - dset_start_idx][0]]["PAT_ID"] = pd.NA
whole_set_entropy_map[provider][dset.seqs_indices[batch_idx - dset_start_idx][0]]["ACCESS_TIME"] = pd.NA
continue
# NOTE: Next-token generation != next-row generation
# This means that we include the next two tokens in the input to avoid EOS predictions.
# Add a NA for the first row.
whole_set_entropy_map[provider][dset.seqs_indices[batch_idx - dset_start_idx][0]]["METRIC_NAME"] = pd.NA
whole_set_entropy_map[provider][dset.seqs_indices[batch_idx - dset_start_idx][0]]["PAT_ID"] = pd.NA
whole_set_entropy_map[provider][dset.seqs_indices[batch_idx - dset_start_idx][0]]["ACCESS_TIME"] = pd.NA
# Calculate the cross entropy
output = model(input_ids.to(device), labels=labels.to(device), return_dict=True)
loss = output.loss.cpu().numpy()
for i in range(1, row_count):
#input_ids_start = (i - 1) * row_len
#input_ids_end = input_ids_start + row_len
#input_ids_end_extra = input_ids_end + row_len
## Get the current row
#input_ids_c[:, input_ids_start:input_ids_end_extra] = input_ids[
# :, input_ids_start:input_ids_end_extra
#]
## Labels are next row.
#labels_row_start = (i) * row_len
#labels_row_end = labels_row_start + row_len
#labels_c[:, labels_row_start:labels_row_end] = labels[
# :, labels_row_start:labels_row_end
#]
metric_loss = loss[METRIC_NAME_COL, i - 1]
patient_loss = loss[PAT_ID_COL, i]
time_loss = loss[ACCESS_TIME_COL, i]
whole_row_idx = dset.seqs_indices[batch_idx - dset_start_idx][0] + i
# +1 to account for the first row being the header
whole_set_entropy_map[provider][whole_row_idx]["METRIC_NAME"] = metric_loss
whole_set_entropy_map[provider][whole_row_idx]["PAT_ID"] = patient_loss
whole_set_entropy_map[provider][whole_row_idx]["ACCESS_TIME"] = time_loss
# Upon completion, save the entropy map
for dset_count, (provider, entropy_map) in enumerate(whole_set_entropy_map.items()):
prov_path = os.path.normpath(os.path.join(path_prefix, config["audit_log_path"], provider))
# Convert entropy_map to a df with its keys as indicies
entropy_df = pd.DataFrame.from_dict(entropy_map, orient="index")
if args.verify:
# Make sure all the sequence ranges line up.
seqs_indices = dl.dataset.datasets[dset_count].seqs_indices
for i, (start, stop) in enumerate(seqs_indices):
if len(entropy_df.loc[start:stop, :]) != stop - start + 1:
raise ValueError(f"Sequence range {start}:{stop} does not line up with entropy_df.")
if not entropy_df.loc[start, :].isna().all():
raise ValueError(f"First value is not NA for sequence range {start}:{stop}.")
entropy_df.to_csv(os.path.normpath(os.path.join(prov_path, model_path_name)))