Skip to content

euroeval.task_group_utils.sequence_classification

[docs] module euroeval.task_group_utils.sequence_classification

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
"""Utility functions related to the sequence-classification task group."""

import logging
import re
import typing as t

import Levenshtein
import numpy as np

from ..enums import TaskGroup
from ..exceptions import InvalidBenchmark
from ..utils import log_once, raise_if_model_output_contains_nan_values

if t.TYPE_CHECKING:
    from datasets.arrow_dataset import Dataset
    from transformers.trainer_utils import EvalPrediction

    from ..data_models import DatasetConfig, GenerativeModelOutput
    from ..types import Labels, Predictions


logger = logging.getLogger("euroeval")


def compute_metrics(
    model_outputs_and_labels: "tuple[Predictions, Labels] | EvalPrediction",
    dataset_config: "DatasetConfig",
    dataset: "Dataset",
) -> dict[str, float]:
    """Compute the metrics needed for evaluation.

    Args:
        model_outputs_and_labels:
            The first sequence contains the model outputs and the second sequence
            contains the true labels.
        dataset_config:
            The configuration of the dataset.
        dataset:
            The dataset used for evaluation. This is only used in case any additional
            metadata is used to compute the metrics.

    Returns:
        A dictionary with the names of the metrics as keys and the metric values as
        values.
    """
    model_outputs, labels = model_outputs_and_labels
    label2id = {label: idx for idx, label in dataset_config.id2label.items()}

    # If the model outputs is a pair, then the first element corresponds to the model
    # predictions
    if isinstance(model_outputs, tuple) and len(model_outputs) == 2:
        model_outputs = model_outputs[0]

    model_output_dtype = np.asarray(model_outputs).dtype
    if model_output_dtype in [np.float16, np.float32, np.float64]:
        predictions = np.asarray(model_outputs).argmax(axis=-1)
    else:
        predictions = model_outputs

    assert not isinstance(model_outputs, tuple)
    raise_if_model_output_contains_nan_values(model_output=model_outputs)

    prompt_label_to_label_mapping = {
        prompt_label: label
        for label, prompt_label in dataset_config.prompt_label_mapping.items()
    }
    predictions = [
        (
            label2id[prompt_label_to_label_mapping[pred.lower()]]
            if isinstance(pred, str)
            else pred
        )
        for pred in predictions
    ]

    label_ids = [
        label2id[label.lower()] if isinstance(label, str) else label for label in labels
    ]

    results: dict[str, float] = dict()
    for metric in dataset_config.task.metrics:
        score: float | None = metric(
            predictions=predictions, references=label_ids, dataset=dataset
        )

        # The metric returns None if we are running on multi-GPU and the current
        # process is not the main process
        if score is not None:
            results[metric.name] = score

    return results


def extract_labels_from_generation(
    input_batch: dict[str, list],
    model_output: "GenerativeModelOutput",
    dataset_config: "DatasetConfig",
    first_label_token_mapping: dict[str, str] | bool,
) -> list[str]:
    """Extract the predicted labels from the generated output.

    Args:
        input_batch:
            The input batch, where the keys are the feature names and the values
            are lists with the feature values.
        model_output:
            The raw generated output of the model.
        dataset_config:
            The configuration of the dataset.
        first_label_token_mapping:
            A mapping from labels to the first token in each label, or alternatively a
            Boolean value indicating whether the model should output scores (if the
            mapping is outputted then the model will always output scores).

    Returns:
        The predicted labels.

    Raises:
        InvalidBenchmark:
            If the task requires log probabilities, but the model did not output them,
            or if the model outputted log probabilities but the first label token
            mapping is not provided.
    """
    if model_output.scores is not None:
        if first_label_token_mapping is False:
            raise InvalidBenchmark(
                "The model outputted logprobs, but the first label token mapping is "
                "not provided, which is not supported."
            )
        labels = get_closest_logprobs_labels(
            generation_logprobs=model_output.scores,
            dataset_config=dataset_config,
            first_label_token_mapping=first_label_token_mapping,
        )
        if labels is not None:
            return labels
        elif dataset_config.task.requires_logprobs:
            raise InvalidBenchmark(
                "This task requires the model to output logprobs, and this model "
                "does not seem to be able to do that. Skipping the evaluation."
            )

    # Get the candidate labels, which are the labels that the model can predict
    candidate_labels = [
        dataset_config.prompt_label_mapping[lbl]
        for lbl in dataset_config.id2label.values()
    ]

    new_predicted_labels: list[str] = list()
    for idx, predicted_label in enumerate(model_output.sequences):
        # Special case if we are doing multiple choice classification: we in this case
        # dynamically change the candidate labels to the labels mentioned in the prompt
        if dataset_config.task.task_group == TaskGroup.MULTIPLE_CHOICE_CLASSIFICATION:
            prompt = input_batch["text"][idx]
            sample_candidate_labels = [
                candidate_label
                for candidate_label in candidate_labels
                if re.search(
                    pattern=rf"\b{candidate_label}. ",
                    string=prompt,
                    flags=re.IGNORECASE,
                )
                is not None
            ]
        else:
            sample_candidate_labels = candidate_labels

        # If the prediction includes a boxed answer, use that instead of the full
        # generation
        if (m := re.search(r"boxed\{(.*?)\}", predicted_label)) is not None:
            predicted_label = m.group(1)

        # We set the word edit distance weights such that we heavily penalise insertions
        # and substitutions, so that we don't just insert the correct label, but that we
        # want the model to have included the correct label in its output.
        insertion_weight = 1000
        deletion_weight = 1
        substitution_weight = 1000

        # Compute the word edit distances between the predicted label and all candidate
        # labels
        edit_distances = [
            Levenshtein.distance(
                s1=predicted_label.lower(),
                s2=candidate_label.lower(),
                weights=(insertion_weight, deletion_weight, substitution_weight),
            )
            for candidate_label in sample_candidate_labels
        ]

        # If no candidate labels were found, we assume that something is wrong with the
        # model output, and we raise an error
        if min(edit_distances) > 100:
            raise InvalidBenchmark(
                f"No candidate labels found for the predicted label "
                f"{predicted_label!r}. This likely means that the model output is "
                "completely off, and we cannot extract any labels from it. Please "
                "check the model output and the candidate labels."
            )

        # Pick the label with the smallest word edit distance to the predicted label
        best_candidate_label = sample_candidate_labels[np.argmin(edit_distances).item()]
        new_predicted_labels.append(best_candidate_label)

    return new_predicted_labels


def get_closest_logprobs_labels(
    generation_logprobs: list[list[list[tuple[str, float]]]],
    dataset_config: "DatasetConfig",
    first_label_token_mapping: dict[str, str] | t.Literal[True],
) -> list[str] | None:
    """Get the labels with the highest predicted logprob value.

    In case a candidate label is split into multiple tokens, we only use the first
    token to compute the logprob value. E.g., if the candidate label "positive" is
    tokenised as ["pos", "itive"], we only use the logprob value of "pos" to
    represent the logprob value of the entire label.

    Args:
        generation_logprobs:
            The logprobs of the generated tokens, for all samples in the batch. Of shape
            (batch_size, num_tokens, num_logprobs).
        dataset_config:
            The configuration of the dataset.
        first_label_token_mapping:
            A mapping from labels to the first token in each label, or alternatively a
            `True` value indicating that the model should output logprobs.

    Returns:
        The predicted labels, or None if labels could not be extracted.

    Raises:
        InvalidBenchmark:
            If no candidate label can be found for any of the generated labels.
    """
    english_labels = list(dataset_config.id2label.values())
    english2local = dataset_config.prompt_label_mapping
    candidate_labels = [english2local[lbl].lower() for lbl in english_labels]

    output_labels: list[str] = list()
    for sample in generation_logprobs:
        for logprob_list in sample:
            generated_labels = [
                re.sub(pattern=r"^[^a-zæøåüöä0-9]+$", repl="", string=label.lower())
                for label, _ in logprob_list
            ]
            generated_labels = [label for label in generated_labels if label != ""]

            # We want to use the first generated label which contains a unique candidate
            # label, as the output label
            output_label: str | None = None
            for generated_label in generated_labels:
                # Get the candidate labels. If we have a first label token mapping, we
                # use it to get the candidate labels. Otherwise, we check if any of the
                # labels start with the generated label.
                if isinstance(first_label_token_mapping, dict):
                    if any(
                        candidate_label not in first_label_token_mapping
                        for candidate_label in candidate_labels
                    ):
                        raise InvalidBenchmark(
                            "There is a label not present in the first label token "
                            "mapping - this should never happen! Please report this "
                            "issue to the EuroEval team at "
                            "github.com/EuroEval/EuroEval/issues."
                        )

                    candidate_output_labels = {
                        candidate_label
                        for candidate_label in candidate_labels
                        if generated_label == first_label_token_mapping[candidate_label]
                    }
                else:
                    candidate_output_labels = {
                        candidate_label
                        for candidate_label in candidate_labels
                        if candidate_label.startswith(generated_label)
                    }

                # If the generated label is a numeral (e.g., "1", "2", "3") and there is
                # a matching candidate label, we only keep the full match
                if re.match(r"^\d+$", generated_label) and any(
                    candidate_label == generated_label
                    for candidate_label in candidate_output_labels
                ):
                    candidate_output_labels = {
                        candidate_label
                        for candidate_label in candidate_output_labels
                        if candidate_label == generated_label
                    }

                # If we can uniquely determine the output label, we break the loop.
                if len(candidate_output_labels) == 1:
                    output_label = candidate_output_labels.pop()
                    break

                # If we have multiple candidate labels, we cannot uniquely determine the
                # output label, so we abandon extracting the labels using logprobs and
                # fall back to using word edit distance.
                elif len(candidate_output_labels) > 1:
                    log_once(
                        "Multiple candidate labels found for the generated label "
                        f"{generated_label!r}: {candidate_output_labels}. This means "
                        "that using logprobs to extract the labels is not reliable, "
                        "and we will instead fall back to extracting the labels "
                        "using word edit distance.",
                        level=logging.DEBUG,
                    )
                    return None

                # If no candidate label is found, we first check if any of the labels
                # start with the generated label. This could be the case if the labels
                # in the first token mapping is inaccurate or incomplete, for instance
                # if 'pos' is in the first label token mapping, but the model outputted
                # 'posit'. If this is the case then we cannot trust the first label
                # token mapping, and we fall back to using word edit distance.
                # Otherwise, the generated label is just bad, and we skip to the next
                # generated label.
                elif len(candidate_output_labels) == 0:
                    candidate_output_labels_starting_with_generated_label = [
                        candidate_label
                        for candidate_label in candidate_labels
                        if candidate_label.startswith(generated_label)
                    ]
                    if candidate_output_labels_starting_with_generated_label:
                        log_once(
                            f"No candidate label found for the generated label "
                            f"{generated_label!r}, but there are candidate labels "
                            f"starting with it: "
                            f"{candidate_output_labels_starting_with_generated_label}. "
                            "This means that the first label token mapping is not "
                            "reliable, and we will instead fall back to extracting "
                            "the labels using word edit distance.",
                            level=logging.DEBUG,
                        )
                        return None

            # If we did not find any candidate label for any of the generated labels, we
            # assume that something is wrong with the model output, and we fall back to
            # using word edit distance to extract the labels
            else:
                log_once(
                    f"No candidate label found for any of the generated labels "
                    f"{generated_labels}. This means that using logprobs to extract "
                    "the labels is not reliable, and we will instead fall back to "
                    "extracting the labels using word edit distance.",
                    level=logging.DEBUG,
                )
                return None

            if output_label is not None:
                output_labels.append(output_label)
                break
        else:
            if len(sample) == 0:
                log_once(
                    "The model outputted an empty string, so no candidate labels could "
                    f"be determined. Using the first label, {candidate_labels[0]!r}, "
                    "as the output label.",
                    level=logging.INFO,
                )
            else:
                log_once(
                    "Could not find a candidate label for any of the generated "
                    f"labels in the sample {sample}. Using the first label, "
                    f"{candidate_labels[0]!r}, as the output label.",
                    level=logging.INFO,
                )
            output_labels.append(candidate_labels[0])

    assert len(output_labels) == len(generation_logprobs)
    return output_labels