euroeval.utils
source module euroeval.utils
Utility functions to be used in other scripts.
Classes
-
HiddenPrints — Context manager which removes all terminal output.
Functions
-
create_model_cache_dir — Create cache directory for a model.
-
resolve_model_path — Resolve the path to the directory containing the model config files and weights.
-
clear_memory — Clears the memory of unused items.
-
enforce_reproducibility — Ensures reproducibility of experiments.
-
block_terminal_output — Blocks libraries from writing output to the terminal.
-
get_class_by_name — Get a class by its name.
-
get_min_cuda_compute_capability — Gets the lowest cuda capability.
-
internet_connection_available — Checks if internet connection is available by pinging google.com.
-
raise_if_model_output_contains_nan_values — Raise an exception if the model output contains NaN values.
-
scramble — Scramble a string in a bijective manner.
-
unscramble — Unscramble a string in a bijective manner.
-
log_once — Log a message once.
-
get_package_version — Get the version of a package.
-
safe_run — Run a coroutine, ensuring that the event loop is always closed when we're done.
-
add_semaphore_and_catch_exception — Run a coroutine with a semaphore.
-
extract_json_dict_from_string — Extract a JSON dictionary from a string.
-
get_hf_token — Get the Hugging Face token.
-
extract_multiple_choice_labels — Extract multiple choice labels from a prompt.
-
split_model_id — Split a model ID into its components.
source create_model_cache_dir(cache_dir: str, model_id: str) → str
Create cache directory for a model.
Parameters
-
cache_dir : str — The cache directory.
-
model_id : str — The model ID.
Returns
-
str — The path to the cache directory.
source resolve_model_path(download_dir: str) → str
Resolve the path to the directory containing the model config files and weights.
Parameters
-
download_dir : str — The download directory
Returns
-
str — The path to the model.
Raises
-
InvalidModel — If the model path is not valid, or if required files are missing.
source clear_memory() → None
Clears the memory of unused items.
source enforce_reproducibility(seed: int = 4242) → np.random.Generator
Ensures reproducibility of experiments.
Parameters
-
seed : int — Seed for the random number generator.
source block_terminal_output() → None
Blocks libraries from writing output to the terminal.
This filters warnings from some libraries, sets the logging level to ERROR for some
libraries, disabled tokeniser progress bars when using Hugging Face tokenisers, and
disables most of the logging from the transformers
library.
source get_class_by_name(class_name: str | list[str], module_name: str) → t.Type | None
Get a class by its name.
Parameters
-
class_name : str | list[str] — The name of the class, written in kebab-case. The corresponding class name must be the same, but written in PascalCase, and lying in a module with the same name, but written in snake_case. If a list of strings is passed, the first class that is found is returned.
-
module_name : str — The name of the module where the class is located.
Returns
-
t.Type | None — The class. If the class is not found, None is returned.
source get_min_cuda_compute_capability() → float | None
Gets the lowest cuda capability.
Returns
-
float | None — Device capability as float, or None if CUDA is not available.
source internet_connection_available() → bool
Checks if internet connection is available by pinging google.com.
Returns
-
bool — Whether or not internet connection is available.
Raises
-
e
source class HiddenPrints()
Context manager which removes all terminal output.
source raise_if_model_output_contains_nan_values(model_output: Predictions) → None
Raise an exception if the model output contains NaN values.
Parameters
-
model_output : Predictions — The model output to check.
Raises
-
If the model output contains NaN values.
source scramble(text: str) → str
Scramble a string in a bijective manner.
Parameters
-
text : str — The string to scramble.
Returns
-
str — The scrambled string.
source unscramble(scrambled_text: str) → str
Unscramble a string in a bijective manner.
Parameters
-
scrambled_text : str — The scrambled string to unscramble.
Returns
-
str — The unscrambled string.
source log_once(message: str, level: int = logging.INFO) → None
Log a message once.
This is ensured by caching the input/output pairs of this function, using the
functools.cache
decorator.
Parameters
-
message : str — The message to log.
-
level : int — The logging level. Defaults to logging.INFO.
Raises
-
ValueError
source get_package_version(package_name: str) → str | None
Get the version of a package.
Parameters
-
package_name : str — The name of the package.
Returns
-
str | None — The version of the package, or None if the package is not installed.
source safe_run(coroutine: t.Coroutine[t.Any, t.Any, T]) → T
Run a coroutine, ensuring that the event loop is always closed when we're done.
Parameters
-
coroutine : t.Coroutine[t.Any, t.Any, T] — The coroutine to run.
Returns
-
T — The result of the coroutine.
source async add_semaphore_and_catch_exception(coroutine: t.Coroutine[t.Any, t.Any, T], semaphore: asyncio.Semaphore) → T | Exception
Run a coroutine with a semaphore.
Parameters
-
coroutine : t.Coroutine[t.Any, t.Any, T] — The coroutine to run.
-
semaphore : asyncio.Semaphore — The semaphore to use.
Returns
-
T | Exception — The result of the coroutine.
source extract_json_dict_from_string(s: str) → dict | None
Extract a JSON dictionary from a string.
Parameters
-
s : str — The string to extract the JSON dictionary from.
Returns
-
dict | None — The extracted JSON dictionary, or None if no JSON dictionary could be found.
source get_hf_token(api_key: str | None) → str | bool
Get the Hugging Face token.
Parameters
-
api_key : str | None — The API key to use as the Hugging Face token. If None, we will try to extract it in other ways.
Returns
-
str | bool — The Hugging Face token, or True if no token is set but the user is logged in, or False if no token is set and the user is not logged in.
source extract_multiple_choice_labels(prompt: str, candidate_labels: list[str]) → list[str]
Extract multiple choice labels from a prompt.
Parameters
-
prompt : str — The prompt to extract the labels from.
-
candidate_labels : list[str] — The candidate labels to look for in the prompt.
Returns
-
list[str] — The extracted labels.
Raises
source split_model_id(model_id: str) → ModelIdComponents
Split a model ID into its components.
Parameters
-
model_id : str — The model ID to split.
Returns
-
ModelIdComponents — The split model ID.
Raises
-
If the model ID is not valid.