Skip to content

euroeval.custom_dataset_configs

[docs] module euroeval.custom_dataset_configs

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
"""Load custom dataset configs.

This module provides the main entry point for loading dataset configurations from
Hugging Face repositories, including Python-based configs. YAML-specific loading
logic lives in the `yaml_config` module.
"""

import importlib.util
import logging
import sys
from pathlib import Path
from types import ModuleType

from huggingface_hub import HfApi

from .data_models import DatasetConfig
from .logging_utils import log_once
from .split_utils import get_repo_splits
from .utils import get_hf_token
from .yaml_config import load_yaml_config


def load_custom_datasets_module(custom_datasets_file: Path) -> ModuleType | None:
    """Load the custom datasets module if it exists.

    Args:
        custom_datasets_file:
            The path to the custom datasets module.

    Returns:
        The custom datasets module, or None if it does not exist.
    """
    if custom_datasets_file.exists():
        spec = importlib.util.spec_from_file_location(
            name="custom_datasets_module", location=str(custom_datasets_file.resolve())
        )
        if spec is None:
            log_once(
                message=(
                    "Could not load the spec for the custom datasets file from "
                    f"{custom_datasets_file.resolve()}."
                ),
                level=logging.ERROR,
            )
            return None
        module = importlib.util.module_from_spec(spec=spec)
        if spec.loader is None:
            log_once(
                message=(
                    "Could not load the module for the custom datasets file from "
                    f"{custom_datasets_file.resolve()}."
                ),
                level=logging.ERROR,
            )
            return None
        spec.loader.exec_module(module)
        return module
    return None


def try_get_dataset_config_from_repo(
    dataset_id: str,
    api_key: str | None,
    cache_dir: Path,
    trust_remote_code: bool,
    run_with_cli: bool,
) -> DatasetConfig | None:
    """Try to get a dataset config from a Hugging Face dataset repository.

    The function first looks for a YAML config file (`eval.yaml`) which can be
    loaded without executing any remote code. If no YAML file is present the
    function falls back to `euroeval_config.py`, which requires
    `trust_remote_code=True`.

    Args:
        dataset_id:
            The ID of the dataset to get the config for.
        api_key:
            The Hugging Face API key to use to check if the repositories have
            custom dataset configs.
        cache_dir:
            The directory to store the cache in.
        trust_remote_code:
            Whether to trust remote code. Only required when loading a Python
            config (`euroeval_config.py`). YAML configs never require this flag.
        run_with_cli:
            Whether the code is being run with the CLI.

    Returns:
        The dataset config if it exists, otherwise None.
    """
    token = get_hf_token(api_key=api_key)
    hf_api = HfApi(token=token)
    if not hf_api.repo_exists(repo_id=dataset_id, repo_type="dataset"):
        return None

    repo_files = list(
        hf_api.list_repo_files(repo_id=dataset_id, repo_type="dataset", revision="main")
    )

    if "eval.yaml" in repo_files:
        return load_yaml_config(
            hf_api=hf_api, dataset_id=dataset_id, cache_dir=cache_dir
        )

    return load_python_config(
        hf_api=hf_api,
        dataset_id=dataset_id,
        cache_dir=cache_dir,
        trust_remote_code=trust_remote_code,
        run_with_cli=run_with_cli,
    )


def load_python_config(
    hf_api: HfApi,
    dataset_id: str,
    cache_dir: Path,
    trust_remote_code: bool,
    run_with_cli: bool,
) -> DatasetConfig | None:
    """Load a dataset config from a euroeval_config.py file in a Hugging Face repo.

    Args:
        hf_api:
            The Hugging Face API object.
        dataset_id:
            The ID of the dataset to get the config for.
        cache_dir:
            The directory to store the cache in.
        trust_remote_code:
            Whether to trust remote code.
        run_with_cli:
            Whether the code is being run with the CLI.

    Returns:
        The dataset config if it exists, otherwise None.
    """
    repo_files = list(
        hf_api.list_repo_files(repo_id=dataset_id, repo_type="dataset", revision="main")
    )

    if "euroeval_config.py" not in repo_files:
        log_once(
            message=(
                f"Dataset {dataset_id} does not have a euroeval_config.py or a YAML "
                "config file (eval.yaml), so we cannot load it. Skipping."
            ),
            level=logging.WARNING,
        )
        return None

    if not trust_remote_code:
        rerunning_msg = (
            "the --trust-remote-code flag"
            if run_with_cli
            else "`trust_remote_code=True`"
        )
        log_once(
            message=(
                f"The dataset {dataset_id} exists on the Hugging Face Hub and has a "
                "euroeval_config.py file, but remote code is not allowed. Please "
                f"rerun this with {rerunning_msg} if you trust the code in this "
                "repository."
            ),
            level=logging.ERROR,
        )
        sys.exit(1)

    external_config_path = cache_dir / "external_dataset_configs" / dataset_id
    external_config_path.mkdir(parents=True, exist_ok=True)
    hf_api.hf_hub_download(
        repo_id=dataset_id,
        repo_type="dataset",
        filename="euroeval_config.py",
        local_dir=external_config_path,
        local_dir_use_symlinks=False,
    )

    module = load_custom_datasets_module(
        custom_datasets_file=external_config_path / "euroeval_config.py"
    )
    if module is None:
        return None

    repo_dataset_configs = [
        cfg for cfg in vars(module).values() if isinstance(cfg, DatasetConfig)
    ]
    if not repo_dataset_configs:
        return None
    if len(repo_dataset_configs) > 1:
        log_once(
            message=(
                f"Dataset {dataset_id} has multiple dataset configurations. Please "
                "ensure that only a single DatasetConfig is defined in the "
                "`euroeval_config.py` file."
            ),
            level=logging.WARNING,
        )
        return None

    train_split, val_split, test_split = get_repo_splits(
        hf_api=hf_api, dataset_id=dataset_id
    )
    if test_split is None:
        log_once(
            message=(
                f"Dataset {dataset_id} does not have a test split, so we cannot load "
                "it. Please ensure that the dataset has a test split."
            ),
            level=logging.ERROR,
        )
        return None

    if train_split is None and val_split is not None:
        log_once(
            message=(
                f"Dataset {dataset_id!r} has no training split. Using the validation "
                f"split {val_split!r} as the training split instead."
            ),
            level=logging.DEBUG,
        )
        train_split = val_split
        val_split = None

    repo_dataset_config = repo_dataset_configs[0]
    repo_dataset_config.name = dataset_id
    repo_dataset_config.pretty_name = dataset_id
    repo_dataset_config.source = dataset_id
    repo_dataset_config.train_split = train_split
    repo_dataset_config.val_split = val_split
    repo_dataset_config.test_split = test_split

    return repo_dataset_config