euroeval.split_utils
[docs]
module
euroeval.split_utils
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63 | """Utilities for detecting and mapping dataset splits."""
from huggingface_hub import HfApi
def find_split(splits: list[str], keyword: str) -> str | None:
"""Return the shortest split name containing `keyword`, or None.
Args:
splits:
A list of split names.
keyword:
The keyword to search for.
Returns:
The shortest split name containing `keyword`, or None if no such split
exists.
"""
candidates = sorted([s for s in splits if keyword in s.lower()], key=len)
return candidates[0] if candidates else None
def get_repo_split_names(hf_api: HfApi, dataset_id: str) -> list[str]:
"""Extract split names from a Hugging Face dataset repo.
Args:
hf_api:
The Hugging Face API object.
dataset_id:
The ID of the dataset to get the split names for.
Returns:
A list of split names.
"""
return [
split["name"]
for split in hf_api.dataset_info(repo_id=dataset_id).card_data.dataset_info[
"splits"
]
]
def get_repo_splits(
hf_api: HfApi, dataset_id: str
) -> tuple[str | None, str | None, str | None]:
"""Return the (train, val, test) split names for a Hugging Face dataset repo.
Args:
hf_api:
The Hugging Face API object.
dataset_id:
The ID of the dataset to get the split names for.
Returns:
A 3-tuple (train_split, val_split, test_split) where each element is either
the name of the matching split or None if no such split exists.
"""
splits = get_repo_split_names(hf_api=hf_api, dataset_id=dataset_id)
return (
find_split(splits=splits, keyword="train"),
find_split(splits=splits, keyword="val"),
find_split(splits=splits, keyword="test"),
)
|