Skip to content

euroeval.string_utils

[docs] module euroeval.string_utils

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
"""Utility functions related to string manipulation or structuring."""

import collections.abc as c
import logging
import re
import typing as t

import demjson3
import numpy as np

from .exceptions import InvalidBenchmark, InvalidModel
from .logging_utils import log

if t.TYPE_CHECKING:
    from .data_models import ModelIdComponents


def scramble(text: str) -> str:
    """Scramble a string in a bijective manner.

    Args:
        text:
            The string to scramble.

    Returns:
        The scrambled string.
    """
    rng = np.random.default_rng(seed=4242)
    permutation = rng.permutation(x=len(text))
    scrambled = "".join(text[i] for i in permutation)
    return scrambled


def unscramble(scrambled_text: str) -> str:
    """Unscramble a string in a bijective manner.

    Args:
        scrambled_text:
            The scrambled string to unscramble.

    Returns:
        The unscrambled string.
    """
    rng = np.random.default_rng(seed=4242)
    permutation = rng.permutation(x=len(scrambled_text))
    inverse_permutation = np.argsort(permutation)
    unscrambled = "".join(scrambled_text[i] for i in inverse_permutation)
    return unscrambled


def extract_json_dict_from_string(s: str) -> dict | None:
    """Extract a JSON dictionary from a string.

    Args:
        s:
            The string to extract the JSON dictionary from.

    Returns:
        The extracted JSON dictionary, or None if no JSON dictionary could be found.
    """
    json_regex = r"\{[^{}]*?\}"
    if (json_match := re.search(pattern=json_regex, string=s, flags=re.DOTALL)) is None:
        log(
            "The model output does not contain any JSON dictionary, so cannot parse "
            f"it. Skipping. Here is the output: {s!r}",
            level=logging.DEBUG,
        )
        return None
    json_string = json_match.group()
    try:
        json_output = demjson3.decode(txt=json_string)
    except demjson3.JSONDecodeError:
        log(
            "The model output is not valid JSON, so cannot parse it. Skipping. "
            f"Here is the output: {json_string!r}",
            level=logging.DEBUG,
        )
        return None
    if not isinstance(json_output, dict):
        log(
            "The model output is not a JSON dictionary, so cannot parse "
            f"it. Skipping. Here is the output: {json_string!r}",
            level=logging.DEBUG,
        )
        return None
    elif not all(isinstance(key, str) for key in json_output.keys()):
        log(
            "The model output is not a JSON dictionary with string keys, "
            "so cannot parse it. Skipping. Here is the output: "
            f"{json_string!r}",
            level=logging.DEBUG,
        )
        return None
    return json_output


def extract_multiple_choice_labels(
    prompt: str, candidate_labels: c.Sequence[str]
) -> c.Sequence[str]:
    """Extract multiple choice labels from a prompt.

    Args:
        prompt:
            The prompt to extract the labels from.
        candidate_labels:
            The candidate labels to look for in the prompt.

    Returns:
        The extracted labels.
    """
    sample_candidate_labels: list[str] = list()
    for candidate_label in candidate_labels:
        candidate_label_match = re.search(
            pattern=rf"\b{candidate_label}\. ", string=prompt, flags=re.IGNORECASE
        )
        if candidate_label_match is not None:
            sample_candidate_labels.append(candidate_label)
    if not sample_candidate_labels:
        raise InvalidBenchmark(
            "Could not extract any candidate labels from the prompt. Please ensure "
            "that the candidate labels are present in the prompt, each followed by a "
            "dot and a space (e.g., 'a. '). The candidate labels are: "
            f"{', '.join(candidate_labels)}. Here is the prompt: {prompt!r}"
        )
    return sample_candidate_labels


def split_model_id(model_id: str) -> "ModelIdComponents":
    """Split a model ID into its components.

    Args:
        model_id:
            The model ID to split.

    Returns:
        The split model ID.

    Raises:
        If the model ID is not valid.
    """
    # Importing here to avoid circular imports
    from .data_models import ModelIdComponents

    # Attempt to extract the model ID, revision, and param using regex
    model_id_match = re.match(pattern=r"^[^@#]+", string=model_id)
    revision_match = re.search(pattern=r"@([^@#]+)", string=model_id)
    param_match = re.search(pattern=r"#([^@#]+)", string=model_id)

    # If we cannot extract the model ID, raise an error
    if model_id_match is None:
        raise InvalidModel(f"The model ID {model_id!r} is not valid.")
    model_id = model_id_match.group()

    # Extract the revision and param and return the result
    revision = revision_match.group(1) if revision_match is not None else "main"
    param = param_match.group(1) if param_match is not None else None
    return ModelIdComponents(model_id=model_id, revision=revision, param=param)