Skip to content

euroeval.split_utils

[docs] module euroeval.split_utils

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
"""Utilities for detecting and mapping dataset splits."""

from pathlib import Path

from huggingface_hub import HfApi


def find_split(splits: list[str], keyword: str) -> str | None:
    """Return the shortest split name containing `keyword`, or None.

    Args:
        splits:
            A list of split names.
        keyword:
            The keyword to search for.

    Returns:
        The shortest split name containing `keyword`, or None if no such split
            exists.
    """
    candidates = sorted([s for s in splits if keyword in s.lower()], key=len)
    return candidates[0] if candidates else None


def get_repo_split_names(hf_api: HfApi, dataset_id: str) -> list[str] | None:
    """Extract split names from a Hugging Face dataset repo.

    Args:
        hf_api:
            The Hugging Face API object.
        dataset_id:
            The ID of the dataset to get the split names for.

    Returns:
        A list of split names, or None if the split names are not available.
    """
    dataset_info = hf_api.dataset_info(repo_id=dataset_id)

    if (
        dataset_info.card_data is not None
        and hasattr(dataset_info.card_data, "dataset_info")
        and "splits" in dataset_info.card_data.dataset_info
    ):
        return [
            split["name"] for split in dataset_info.card_data.dataset_info["splits"]
        ]

    # If we don't have access to the split names directly, we look at the data files,
    # since they tend to be of the form "data/test-00000-of-00001.parquet"
    elif dataset_info.siblings is not None:
        parquet_file_names = [
            sibling.rfilename
            for sibling in dataset_info.siblings
            if sibling.rfilename.endswith(".parquet")
        ]
        split_names = [Path(fname).stem.split("-")[0] for fname in parquet_file_names]
        if split_names:
            return split_names

    return None


def get_repo_splits(
    hf_api: HfApi, dataset_id: str
) -> tuple[str | None, str | None, str | None]:
    """Return the (train, val, test) split names for a Hugging Face dataset repo.

    Args:
        hf_api:
            The Hugging Face API object.
        dataset_id:
            The ID of the dataset to get the split names for.

    Returns:
        A 3-tuple (train_split, val_split, test_split) where each element is either
            the name of the matching split or None if no such split exists.
    """
    splits = get_repo_split_names(hf_api=hf_api, dataset_id=dataset_id)
    if splits is None:
        return None, None, None
    return (
        find_split(splits=splits, keyword="train"),
        find_split(splits=splits, keyword="val"),
        find_split(splits=splits, keyword="test"),
    )