1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237 | """Preprocessing utilities for custom dataset column mapping."""
import collections.abc as c
import functools
import typing as t
from .enums import TaskGroup
from .exceptions import InvalidBenchmark
if t.TYPE_CHECKING:
from datasets import DatasetDict
def merge_input_and_choices(
example: dict,
input_column: str,
choices_column: "str | list[str]",
choices_label: str,
) -> dict:
"""Merge input text and choices into a single text field.
Args:
example:
A single dataset example with at least the ``input_column`` and
the column(s) named by ``choices_column``.
input_column:
The name of the column containing the input text.
choices_column:
Either the name of a single column containing a list of answer-choice
strings, or a list of column names each containing a single answer-choice
string.
choices_label:
The language-specific label for the choices section (e.g. ``"Choices"``).
Returns:
The example with a new ``"text"`` key containing the merged input and formatted
choices.
"""
input_text = example[input_column].replace("\n", " ").strip()
if isinstance(choices_column, list):
choices = [example[col] for col in choices_column]
else:
choices = example[choices_column]
options = "\n".join(
f"{letter}. {choice.replace('\n', ' ').strip()}"
for letter, choice in zip("abcdefghijklmnopqrstuvwxyz", choices)
)
example["text"] = f"{input_text}\n{choices_label}:\n{options}"
return example
def build_preprocessing_func(
dataset_name: str,
task_group: "TaskGroup",
input_column: str,
target_column: str | None,
choices_column: "str | list[str] | None",
choices_label: str,
) -> "c.Callable[[DatasetDict], DatasetDict]":
"""Build a preprocessing function from column mapping arguments.
The returned function renames or merges columns in a DatasetDict to match the
framework's standard column names:
- If ``input_column`` differs from ``"text"`` (without ``choices_column``), it is
renamed to ``"text"``.
- If ``choices_column`` is given, ``input_column`` and ``choices_column`` are merged
into a single ``"text"`` column formatted as::
<input_text>
<choices_label>:
a. <choice_0>
b. <choice_1>
...
- If ``target_column`` is given, it is renamed to the task-group standard:
``"labels"`` for token classification, ``"target_text"`` for text-to-text, and
``"label"`` for everything else.
Args:
dataset_name:
The name of the dataset, used in error messages.
task_group:
The task group, used to determine the standard target column name.
input_column:
Column to rename to ``"text"``. When combined with ``choices_column``, the
two are merged into a formatted ``"text"`` column instead. Defaults to
``"text"`` (no rename).
target_column:
Column to rename to the task-appropriate standard target column name.
choices_column:
Either the name of a single column containing a list of answer-choice
strings, or a list of column names each containing a single answer-choice
string, to merge with the input column.
choices_label:
The language-specific label for the choices section (e.g. ``"Choices"``).
Returns:
A callable that accepts a ``DatasetDict`` and returns a preprocessed
``DatasetDict``.
"""
# Determine the standard target column for the task group
if target_column is not None:
if task_group == TaskGroup.TOKEN_CLASSIFICATION:
std_target = "labels"
elif task_group == TaskGroup.TEXT_TO_TEXT:
std_target = "target_text"
else:
std_target = "label"
else:
std_target = None
def preprocessing_func(dataset: "DatasetDict") -> "DatasetDict":
"""Apply column mapping and merging to all splits in the dataset.
Validates that configured columns exist in all splits before processing, then
renames or merges columns according to the configuration passed to
:func:`build_preprocessing_func`.
Args:
dataset:
The dataset to preprocess.
Returns:
The preprocessed dataset with columns renamed or merged as configured.
Raises:
InvalidBenchmark:
If a configured input or target column is absent from all splits.
"""
# Normalize choices_column to a list for uniform handling
if isinstance(choices_column, list):
choices_cols: list[str] | None = choices_column
elif choices_column is not None:
choices_cols = [choices_column]
else:
choices_cols = None
# Validate that the configured columns exist in all splits
if input_column != "text":
input_found = all(
input_column in split.column_names for split in dataset.values()
)
if not input_found:
raise InvalidBenchmark(
f"The dataset is configured with an input column "
f"{input_column!r}, but this column was not found in all splits "
f"for the dataset {dataset_name!r}."
)
if choices_cols is not None:
for col in choices_cols:
col_found = all(col in split.column_names for split in dataset.values())
if not col_found:
raise InvalidBenchmark(
f"The dataset is configured with a choices column "
f"{col!r}, but this column was not found in all splits "
f"for the dataset {dataset_name!r}."
)
if target_column is not None:
target_found = all(
target_column in split.column_names for split in dataset.values()
)
if not target_found:
raise InvalidBenchmark(
f"The dataset is configured with a target column "
f"{target_column!r}, but this column was not found in all splits "
f"for the dataset {dataset_name!r}."
)
for split_name, split in dataset.items():
if choices_cols is not None:
def _fix_mc_label_column(example: dict) -> dict:
"""Ensure multiple choice labels are lowercase letters.
Args:
example:
The example to fix.
Returns:
The fixed example.
"""
if isinstance(choices_column, list):
choices = [example[col] for col in choices_column]
else:
choices = example[choices_column]
label = example[target_column]
if label in choices:
example[target_column] = "abcdefghijklmnopqrstuvwxyz"[
choices.index(label)
]
if isinstance(example[target_column], int):
example[target_column] = "abcdefghijklmnopqrstuvwxyz"[label]
if isinstance(example[target_column], str):
example[target_column] = example[target_column].lower()
return example
# If the label is the full choice string, then we convert it to the
# appropriate letter
if target_column is not None:
split = split.map(_fix_mc_label_column)
# Handle input column (optionally merging with choices)
merge_fn = functools.partial(
merge_input_and_choices,
input_column=input_column,
choices_column=choices_column,
choices_label=choices_label,
)
split = split.map(merge_fn)
cols_to_drop = [
col
for col in [input_column, *choices_cols]
if col in split.column_names and col != "text"
]
if cols_to_drop:
split = split.remove_columns(cols_to_drop)
elif input_column != "text":
if "text" in split.column_names:
split = split.remove_columns(["text"])
split = split.rename_column(input_column, "text")
# Handle target column renaming
if (
std_target is not None
and target_column is not None
and target_column != std_target
):
if std_target in split.column_names:
split = split.remove_columns([std_target])
split = split.rename_column(target_column, std_target)
dataset[split_name] = split # pyrefly: ignore[unsupported-operation]
return dataset
return preprocessing_func
|