Skip to content

euroeval.safetensors_utils

[docs] module euroeval.safetensors_utils

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""Utility functions related to parsing of safetensors metadata of models."""

import logging
from pathlib import Path

from huggingface_hub import get_safetensors_metadata
from huggingface_hub.errors import NotASafetensorsRepoError

from .logging_utils import log_once
from .utils import get_hf_token, internet_connection_available


def get_num_params_from_safetensors_metadata(
    model_id: str, revision: str, api_key: str | None
) -> int | None:
    """Get the number of parameters from the safetensors metadata.

    Args:
        model_id:
            The model ID.
        revision:
            The revision of the model.
        api_key:
            The API key to use for authentication with the Hugging Face Hub. Can be
            None if no API key is needed.

    Returns:
        The number of parameters, or None if the metadata could not be found.
    """
    # We cannot determine the number of parameters if there is no internet connection
    # or if the model is stored locally
    if not internet_connection_available() or Path(model_id).exists():
        return None

    try:
        metadata = get_safetensors_metadata(
            repo_id=model_id, revision=revision, token=get_hf_token(api_key=api_key)
        )
    except NotASafetensorsRepoError:
        log_once(
            "The number of parameters could not be determined for the model "
            f"{model_id}, since the model is not stored in the safetensors format. "
            "If this is your own model, then you can use this Hugging Face Space to "
            "convert your model to the safetensors format: "
            "https://huggingface.co/spaces/safetensors/convert.",
            level=logging.WARNING,
        )
        return None

    parameter_count_dict = metadata.parameter_count
    match len(parameter_count_dict):
        case 0:
            log_once(
                "Failed to determine the number of parameters for the model "
                f"{model_id}, even though the model is stored in the safetensors "
                "format. Please report this issue at "
                "https://github.com/EuroEval/EuroEval/issues.",
                level=logging.WARNING,
            )
            return None
        case 1:
            return max(parameter_count_dict.values())
        case _:
            log_once(
                f"The model {model_id} has multiple parameter count entries in its "
                f"safetensors metadata: {parameter_count_dict}. Using the largest one.",
                level=logging.DEBUG,
            )
            return max(parameter_count_dict.values())