1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125 | """Functions related to the loading of the data."""
import logging
import sys
import time
from datasets import Dataset, DatasetDict, load_dataset
from datasets.exceptions import DatasetsError
from huggingface_hub.errors import HfHubHTTPError
from numpy.random import Generator
from requests import ReadTimeout
from .data_models import BenchmarkConfig, DatasetConfig
from .exceptions import HuggingFaceHubDown, InvalidBenchmark
from .utils import unscramble
logger = logging.getLogger("euroeval")
def load_data(
rng: Generator, dataset_config: "DatasetConfig", benchmark_config: "BenchmarkConfig"
) -> list[DatasetDict]:
"""Load the raw bootstrapped datasets.
Args:
rng:
The random number generator to use.
dataset_config:
The configuration for the dataset.
benchmark_config:
The configuration for the benchmark.
Returns:
A list of bootstrapped datasets, one for each iteration.
Raises:
InvalidBenchmark:
If the dataset cannot be loaded.
HuggingFaceHubDown:
If the Hugging Face Hub is down.
"""
dataset = load_raw_data(
dataset_config=dataset_config, cache_dir=benchmark_config.cache_dir
)
if not benchmark_config.evaluate_test_split:
dataset["test"] = dataset["val"]
# Remove empty examples from the datasets
for text_feature in ["tokens", "text"]:
if text_feature in dataset["train"].features:
dataset = dataset.filter(lambda x: len(x[text_feature]) > 0)
# If we are testing then truncate the test set
if hasattr(sys, "_called_from_test"):
dataset["test"] = dataset["test"].select(range(1))
# Bootstrap the splits
bootstrapped_splits: dict[str, list[Dataset]] = dict()
for split in ["train", "val", "test"]:
bootstrap_indices = rng.integers(
0,
len(dataset[split]),
size=(benchmark_config.num_iterations, len(dataset[split])),
)
bootstrapped_splits[split] = [
dataset[split].select(bootstrap_indices[idx])
for idx in range(benchmark_config.num_iterations)
]
datasets = [
DatasetDict(
{
split: bootstrapped_splits[split][idx]
for split in ["train", "val", "test"]
}
)
for idx in range(benchmark_config.num_iterations)
]
return datasets
def load_raw_data(dataset_config: "DatasetConfig", cache_dir: str) -> DatasetDict:
"""Load the raw dataset.
Args:
dataset_config:
The configuration for the dataset.
cache_dir:
The directory to cache the dataset.
Returns:
The dataset.
"""
num_attempts = 5
for _ in range(num_attempts):
try:
dataset = load_dataset(
path=dataset_config.huggingface_id,
cache_dir=cache_dir,
token=unscramble("HjccJFhIozVymqXDVqTUTXKvYhZMTbfIjMxG_"),
)
break
except (FileNotFoundError, DatasetsError, ConnectionError, ReadTimeout):
logger.warning(
f"Failed to load dataset {dataset_config.huggingface_id!r}. Retrying..."
)
time.sleep(1)
continue
except HfHubHTTPError:
raise HuggingFaceHubDown()
else:
raise InvalidBenchmark(
f"Failed to load dataset {dataset_config.huggingface_id!r} after "
f"{num_attempts} attempts."
)
assert isinstance(dataset, DatasetDict) # type: ignore[used-before-def]
required_keys = ["train", "val", "test"]
missing_keys = [key for key in required_keys if key not in dataset]
if missing_keys:
raise InvalidBenchmark(
"The dataset is missing the following required splits: "
f"{', '.join(missing_keys)}"
)
return DatasetDict({key: dataset[key] for key in required_keys})
|