1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398 | """ModelCache class for caching model outputs."""
import collections.abc as c
import hashlib
import json
import logging
import sys
from collections import defaultdict
from copy import deepcopy
from dataclasses import asdict
from pathlib import Path
from datasets import Dataset
from .constants import NUM_GENERATION_TOKENS_FOR_CLASSIFICATION
from .data_models import (
GenerativeModelOutput,
HashableDict,
SingleGenerativeModelOutput,
)
from .logging_utils import get_pbar, log, log_once
class ModelCache:
"""A cache for model outputs.
Attributes:
model_cache_dir:
The directory to store the cache in.
cache_path:
The path to the cache file.
cache:
The model output cache.
max_generated_tokens:
The maximum number of tokens to generate for each example.
progress_bar:
Whether to show a progress bar when caching model outputs.
store_metadata:
Whether to store metadata for the model outputs.
indent_json_when_saving:
Whether to indent the JSON when saving the cache.
"""
def __init__(
self,
model_cache_dir: "Path",
cache_name: str,
max_generated_tokens: int,
progress_bar: bool,
store_metadata: bool,
indent_json_when_saving: bool,
) -> None:
"""Initialise the model output cache.
Args:
model_cache_dir:
The directory to store the cache in.
cache_name:
The name of the cache file.
max_generated_tokens:
The maximum number of tokens to generate for each example.
progress_bar:
Whether to show a progress bar when caching model outputs.
store_metadata:
Whether to store metadata for the model outputs.
indent_json_when_saving:
Whether to indent the JSON when saving the cache.
"""
self.model_cache_dir = model_cache_dir
self.model_cache_dir.mkdir(parents=True, exist_ok=True)
self.cache_path = self.model_cache_dir / cache_name.replace("/", "--")
self.max_generated_tokens = max_generated_tokens
self.progress_bar = progress_bar
self.store_metadata = store_metadata
self.indent_json_when_saving = indent_json_when_saving
self.cache: dict[str, SingleGenerativeModelOutput] = dict()
def load(self) -> None:
"""Load the model output cache."""
if not self.cache_path.exists():
with self.cache_path.open("w") as f:
json.dump(
dict(),
f,
indent=2 if self.indent_json_when_saving else None,
ensure_ascii=False,
)
try:
with self.cache_path.open() as f:
json_cache = json.load(f)
except json.JSONDecodeError:
log(
f"Failed to load the cache from {self.cache_path}. The cache will be "
f"re-initialised.",
level=logging.WARNING,
)
json_cache = dict()
with self.cache_path.open("w") as f:
json.dump(
dict(),
f,
indent=2 if self.indent_json_when_saving else None,
ensure_ascii=False,
)
cache: dict[str, SingleGenerativeModelOutput] = dict()
for key in json_cache:
value_dict = json_cache[key]
sequence = value_dict.pop("sequence", None)
predicted_label = value_dict.pop("predicted_label", None)
scores = value_dict.pop("scores", None)
cache[key] = SingleGenerativeModelOutput(
sequence=sequence,
predicted_label=predicted_label,
scores=scores,
metadata=HashableDict(value_dict),
)
self.cache = cache
def save(self) -> None:
"""Save the model output cache to disk."""
# Unpack metadata to get a flat dict to dump
dumpable_cache: dict[str, dict] = defaultdict(dict)
for key, value in self.cache.items():
value_dict = asdict(value)
metadata = value_dict.pop("metadata", dict())
if metadata is None:
metadata = dict()
value_dict |= metadata
if "index" in metadata:
value_dict = {"index": metadata.pop("index")} | value_dict
dumpable_cache[key] = value_dict
try:
self.cache_path.parent.mkdir(exist_ok=True, parents=True)
self.cache_path.write_text(
json.dumps(
dumpable_cache,
indent=2 if self.indent_json_when_saving else None,
ensure_ascii=False,
)
)
except KeyError:
log(
f"Failed to load the cache from {self.cache_path}. The cache will be "
f"re-initialised.",
level=logging.WARNING,
)
self.cache = dict()
with self.cache_path.open("w") as f:
json.dump(
dict(),
f,
indent=2 if self.indent_json_when_saving else None,
ensure_ascii=False,
)
def _hash_key(self, key: str | c.Sequence[dict[str, str]]) -> str:
"""Hash the key to use as an index in the cache.
Args:
key:
The key to hash.
Returns:
The hashed key.
"""
return hashlib.md5(string=str(key).encode()).hexdigest()
def __getitem__(
self, key: str | c.Sequence[dict[str, str]]
) -> SingleGenerativeModelOutput:
"""Get an item from the cache.
Args:
key:
The key to use to index the cache.
Returns:
The model output.
"""
hashed_key = self._hash_key(key=key)
return self.cache[hashed_key]
def __setitem__(
self, key: str | c.Sequence[dict[str, str]], value: SingleGenerativeModelOutput
) -> None:
"""Set an item in the cache.
Args:
key:
The key to use to index the cache.
value:
The value to set in the cache.
"""
hashed_key = self._hash_key(key=key)
self.cache[hashed_key] = value
def remove(self) -> None:
"""Remove the cache from memory and delete it from disk."""
self.cache_path.unlink()
del self.cache
def __contains__(self, key: str | c.Sequence[dict[str, str]]) -> bool:
"""Check if a key is in the cache.
Args:
key:
The key to check.
Returns:
Whether the key is in the cache.
"""
hashed_key = self._hash_key(key=key)
return hashed_key in self.cache
def add_to_cache(
self, model_inputs: dict, model_output: GenerativeModelOutput
) -> None:
"""Add the model input/output to the cache.
Args:
model_inputs:
The model inputs.
model_output:
The model output.
"""
input_column = "messages" if "messages" in model_inputs else "text"
if self.store_metadata:
metadata = deepcopy(model_inputs)
metadata.pop("messages" if "messages" != input_column else "text", None)
model_inputs = model_inputs[input_column]
# Double check that the number of inputs and outputs match
if not len(model_inputs) == len(model_output.sequences):
log(
f"Number of model inputs ({len(model_inputs)}) does not match the "
f"number of model outputs ({len(model_output.sequences)}). We will not "
f"cache the model outputs.",
level=logging.WARNING,
)
return
# Store the generated sequences in the cache, one by one
with get_pbar(
iterable=model_inputs,
desc="Caching model outputs",
disable=hasattr(sys, "_called_from_test") or not self.progress_bar,
) as pbar:
for sample_idx, model_input in enumerate(pbar):
# Extract the scores from the model output, to be cached. We only store
# the indices of the top scores, to save space. Further, we only store
# the scores if the generated sequence is shorter than the maximum
# length
if (
model_output.scores is not None
and self.max_generated_tokens
<= NUM_GENERATION_TOKENS_FOR_CLASSIFICATION
):
assert model_output.scores is not None
scores = model_output.scores[sample_idx]
else:
if model_output.scores is not None:
log_once(
"The generated sequence is longer than the maximum "
"length for classification. Not caching the scores.",
level=logging.DEBUG,
)
scores = None
if self.store_metadata:
single_metadata = HashableDict(
{
metadata_column: metadata_values[sample_idx]
for metadata_column, metadata_values in metadata.items()
}
)
else:
single_metadata = None
self[model_input] = SingleGenerativeModelOutput(
sequence=model_output.sequences[sample_idx],
predicted_label=(
model_output.predicted_labels[sample_idx]
if model_output.predicted_labels is not None
else None
),
scores=scores,
metadata=single_metadata,
)
def split_dataset_into_cached_and_non_cached(
dataset: "Dataset", cache: ModelCache
) -> tuple["Dataset", "Dataset"]:
"""Split a dataset into a cached and non-cached part.
Args:
dataset:
The dataset to split.
cache:
The model output cache.
Returns:
The cached and non-cached parts of the dataset.
"""
# Get the sample indices of the non-cached examples, which are unique with respect
# to the "text" column.
input_column = "messages" if "messages" in dataset.column_names else "text"
dataset_texts = dataset[input_column]
unique_non_cached_ids = set()
unique_texts = list()
for idx, dataset_text in enumerate(dataset_texts):
if dataset_text not in cache and dataset_text not in unique_texts:
unique_non_cached_ids.add(idx)
unique_texts.append(dataset_text)
# The cached examples are the ones that are not in the non-cached examples. This
# means that if the dataset has duplicates, only a single copy of the duplicate
# will be put in the non-cached part, and the rest in the cached part.
cached_ids = set(range(len(dataset))) - unique_non_cached_ids
cached = dataset.select(cached_ids)
non_cached = dataset.select(unique_non_cached_ids)
assert isinstance(cached, Dataset), (
f"Expected the cached dataset to be a Dataset, but got {type(cached)}"
)
assert isinstance(non_cached, Dataset), (
f"Expected the non-cached dataset to be a Dataset, but got {type(non_cached)}"
)
return cached, non_cached
def load_cached_model_outputs(
cached_dataset: "Dataset", cache: ModelCache
) -> GenerativeModelOutput:
"""Load the cached model outputs.
Args:
cached_dataset:
The dataset containing the cached examples.
cache:
The model output cache.
Returns:
The model output containing the cached sequences.
"""
input_column = "messages" if "messages" in cached_dataset.column_names else "text"
cached_model_outputs: c.Sequence[SingleGenerativeModelOutput] = [
cache[prompt] for prompt in cached_dataset[input_column]
]
cached_sequences = [model_output.sequence for model_output in cached_model_outputs]
if cached_model_outputs[0].scores is None:
return GenerativeModelOutput(sequences=cached_sequences)
cached_scores = [model_output.scores or [] for model_output in cached_model_outputs]
return GenerativeModelOutput(sequences=cached_sequences, scores=cached_scores)
def create_model_cache_dir(cache_dir: str, model_id: str) -> str:
"""Create cache directory for a model.
Args:
cache_dir:
The cache directory.
model_id:
The model ID.
Returns:
The path to the cache directory.
"""
# If the model ID is a path, we just use that as the cache dir
if Path(model_id).is_dir():
log_once(
f"Since the model {model_id!r} is a local model, we will use the model "
"directory directly as the model cache directory.",
level=logging.DEBUG,
)
return model_id
# Otherwise, we create a cache dir based on the model ID
model_cache_dir = Path(
cache_dir, "model_cache", model_id.replace("/", "--")
).as_posix()
log_once(
f"Using the model cache directory {model_cache_dir!r} for the model "
f"{model_id!r}.",
level=logging.DEBUG,
)
return model_cache_dir
|