Skip to content

euroeval.metrics.tool_calling

[docs] module euroeval.metrics.tool_calling

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""Tool calling metric."""

import collections.abc as c
import json
import logging
import typing as t

from ..logging_utils import log_once

if t.TYPE_CHECKING:
    from datasets.arrow_dataset import Dataset

    from ..data_models import BenchmarkConfig, DatasetConfig

from ..metrics.base import Metric


class ToolCallingAccuracy(Metric):
    """Metric for tool calling."""

    def __call__(
        self,
        predictions: c.Sequence,
        references: c.Sequence,
        dataset: "Dataset",
        dataset_config: "DatasetConfig",
        benchmark_config: "BenchmarkConfig",
    ) -> float | None:
        """Calculate tool calling accuracy.

        Args:
            predictions:
                Predicted "labels" - meaning tool calls in this context.
            references:
                Ground truth data - NB: format is different from predictions,
                since ground truth contains lists of possible outputs rather than
                a single 'truth'.
            dataset:
                Dataset - used for tool information like required arguments.
            dataset_config:
                Part of interface - not used here.
            benchmark_config:
                Part of interface - not used here.

        Returns:
            The score (accuracy).
            Returns None if any of predictions, references or dataset["function"]
            sequences are empty - meaning a score could not be calculated.
        """
        function_descriptions = [json.loads(f) for f in dataset["function"]]
        results = []
        for x in zip(predictions, references, function_descriptions):
            results.append(_evaluate_function_toolcall_response(*x))
        if not results:
            return None
        else:
            return sum(results) / len(results)


def _evaluate_function_toolcall_response(
    pred_calls_str: str, ref_calls_str: str, descriptions: list[dict]
) -> bool:
    """Logic to evaluate tool call response against reference (ground truth).

    Args:
        pred_calls_str:
            Predicted function calls as json string
        ref_calls_str:
            Referenced function calls as json string
        descriptions:
            Function descriptions (in dataset and given as input to models)

    Returns:
        True: success, False: failure
    """
    # try deserialize prediction
    try:
        pred_calls_dict = json.loads(pred_calls_str)
        assert isinstance(pred_calls_dict, dict)
        assert "tool_calls" in pred_calls_dict
        pred_calls = pred_calls_dict["tool_calls"]
    except (json.JSONDecodeError, AssertionError):
        return False

    ref_calls = json.loads(ref_calls_str)

    # number of predicted function calls should equal the reference
    if len(pred_calls) != len(ref_calls):
        return False

    for pred_call, ref_call, description in zip(pred_calls, ref_calls, descriptions):
        # each predicted function call should be a dict
        if not isinstance(pred_call, dict):
            return False

        # get predicted function name
        if "function" not in pred_call:
            log_once(
                "Tool call prediction did not contain required keyword 'function'.",
                level=logging.DEBUG,
            )
            return False
        else:
            pred_name: str = pred_call["function"]

        # get predicted arguments
        if "arguments" not in pred_call:
            log_once(
                "Tool call prediction did not contain required keyword 'arguments'.",
                level=logging.DEBUG,
            )
            return False
        else:
            pred_args: dict = pred_call["arguments"]

        ref_name: str
        ref_args: dict
        # reference calls are packed into an extra list by BFCL default for some reason
        ref_name, ref_args = list(ref_call.items())[0]

        # did we predict the right function to call?
        if pred_name != ref_name:
            return False

        # get requires arguments from function descriptions
        parameters = description.get("parameters", None)
        required_args = (
            parameters.get("required", None) if isinstance(parameters, dict) else None
        )

        for key, values in ref_args.items():
            # we only care about required arguments
            if required_args and key not in required_args:
                continue
            # every predicted argument should be in the list of expected values
            if key not in pred_args or pred_args[key] not in values:
                return False
    return True


tool_calling_accuracy = ToolCallingAccuracy(
    name="tool_calling_accuracy", pretty_name="Tool Calling Accuracy"
)