Skip to content

euroeval.preprocessing

[docs] module euroeval.preprocessing

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
"""Preprocessing utilities for custom dataset column mapping."""

import collections.abc as c
import functools
import typing as t

from .enums import TaskGroup
from .exceptions import InvalidBenchmark

if t.TYPE_CHECKING:
    from datasets import DatasetDict


def merge_input_and_choices(
    example: dict,
    input_column: str,
    choices_column: "str | list[str]",
    choices_label: str,
) -> dict:
    """Merge input text and choices into a single text field.

    Args:
        example:
            A single dataset example with at least the ``input_column`` and
            the column(s) named by ``choices_column``.
        input_column:
            The name of the column containing the input text.
        choices_column:
            Either the name of a single column containing a list of answer-choice
            strings, or a list of column names each containing a single answer-choice
            string.
        choices_label:
            The language-specific label for the choices section (e.g. ``"Choices"``).

    Returns:
        The example with a new ``"text"`` key containing the merged input and formatted
        choices.
    """
    input_text = example[input_column].replace("\n", " ").strip()
    if isinstance(choices_column, list):
        choices = [example[col] for col in choices_column]
    else:
        choices = example[choices_column]
    options = "\n".join(
        f"{letter}. {choice.replace('\n', ' ').strip()}"
        for letter, choice in zip("abcdefghijklmnopqrstuvwxyz", choices)
    )
    example["text"] = f"{input_text}\n{choices_label}:\n{options}"
    return example


def build_preprocessing_func(
    dataset_name: str,
    task_group: "TaskGroup",
    input_column: str,
    target_column: str | None,
    choices_column: "str | list[str] | None",
    choices_label: str,
) -> "c.Callable[[DatasetDict], DatasetDict]":
    """Build a preprocessing function from column mapping arguments.

    The returned function renames or merges columns in a DatasetDict to match the
    framework's standard column names:

    - If ``input_column`` differs from ``"text"`` (without ``choices_column``), it is
      renamed to ``"text"``.
    - If ``choices_column`` is given, ``input_column`` and ``choices_column`` are merged
      into a single ``"text"`` column formatted as::

          <input_text>
          <choices_label>:
          a. <choice_0>
          b. <choice_1>
          ...

    - If ``target_column`` is given, it is renamed to the task-group standard:
      ``"labels"`` for token classification, ``"target_text"`` for text-to-text, and
      ``"label"`` for everything else.

    Args:
        dataset_name:
            The name of the dataset, used in error messages.
        task_group:
            The task group, used to determine the standard target column name.
        input_column:
            Column to rename to ``"text"``. When combined with ``choices_column``, the
            two are merged into a formatted ``"text"`` column instead. Defaults to
            ``"text"`` (no rename).
        target_column:
            Column to rename to the task-appropriate standard target column name.
        choices_column:
            Either the name of a single column containing a list of answer-choice
            strings, or a list of column names each containing a single answer-choice
            string, to merge with the input column.
        choices_label:
            The language-specific label for the choices section (e.g. ``"Choices"``).

    Returns:
        A callable that accepts a ``DatasetDict`` and returns a preprocessed
        ``DatasetDict``.
    """
    # Determine the standard target column for the task group
    if target_column is not None:
        if task_group == TaskGroup.TOKEN_CLASSIFICATION:
            std_target = "labels"
        elif task_group == TaskGroup.TEXT_TO_TEXT:
            std_target = "target_text"
        else:
            std_target = "label"
    else:
        std_target = None

    def preprocessing_func(dataset: "DatasetDict") -> "DatasetDict":
        """Apply column mapping and merging to all splits in the dataset.

        Validates that configured columns exist in all splits before processing, then
        renames or merges columns according to the configuration passed to
        :func:`build_preprocessing_func`.

        Args:
            dataset:
                The dataset to preprocess.

        Returns:
            The preprocessed dataset with columns renamed or merged as configured.

        Raises:
            InvalidBenchmark:
                If a configured input or target column is absent from all splits.
        """
        # Normalize choices_column to a list for uniform handling
        if isinstance(choices_column, list):
            choices_cols: list[str] | None = choices_column
        elif choices_column is not None:
            choices_cols = [choices_column]
        else:
            choices_cols = None

        # Validate that the configured columns exist in all splits
        if input_column != "text":
            input_found = all(
                input_column in split.column_names for split in dataset.values()
            )
            if not input_found:
                raise InvalidBenchmark(
                    f"The dataset is configured with an input column "
                    f"{input_column!r}, but this column was not found in all splits "
                    f"for the dataset {dataset_name!r}."
                )
        if choices_cols is not None:
            for col in choices_cols:
                col_found = all(col in split.column_names for split in dataset.values())
                if not col_found:
                    raise InvalidBenchmark(
                        f"The dataset is configured with a choices column "
                        f"{col!r}, but this column was not found in all splits "
                        f"for the dataset {dataset_name!r}."
                    )
        if target_column is not None:
            target_found = all(
                target_column in split.column_names for split in dataset.values()
            )
            if not target_found:
                raise InvalidBenchmark(
                    f"The dataset is configured with a target column "
                    f"{target_column!r}, but this column was not found in all splits "
                    f"for the dataset {dataset_name!r}."
                )

        for split_name, split in dataset.items():
            # Handle input column (optionally merging with choices)
            if choices_cols is not None:
                merge_fn = functools.partial(
                    merge_input_and_choices,
                    input_column=input_column,
                    choices_column=choices_column,
                    choices_label=choices_label,
                )
                split = split.map(merge_fn)
                cols_to_drop = [
                    col
                    for col in [input_column, *choices_cols]
                    if col in split.column_names and col != "text"
                ]
                if cols_to_drop:
                    split = split.remove_columns(cols_to_drop)
            elif input_column != "text":
                if "text" in split.column_names:
                    split = split.remove_columns(["text"])
                split = split.rename_column(input_column, "text")

            # Handle target column renaming
            if (
                std_target is not None
                and target_column is not None
                and target_column != std_target
            ):
                if std_target in split.column_names:
                    split = split.remove_columns([std_target])
                split = split.rename_column(target_column, std_target)

            dataset[split_name] = split

        return dataset

    return preprocessing_func