Source code for flexeval.function_types
"""Inspection utilities that use type hints to determine the appropriate object to pass to a function metric.
See :mod:`~flexeval.schema.eval_schema`.
"""
import inspect
import logging
import types
import typing
from collections.abc import Callable, Iterable
from flexeval.classes import message, thread, tool_call, turn
from flexeval.schema import eval_schema
AnyFunctionObjectInput = typing.Union[
turn.Turn,
message.Message,
thread.Thread,
tool_call.ToolCall,
]
FLEXEVAL_TYPE_SET: set[type] = {
turn.Turn,
message.Message,
thread.Thread,
tool_call.ToolCall,
}
logger = logging.getLogger(__name__)
[docs]
def is_callable_valid_for_metric_level(
metric_function: Callable, metric_level: eval_schema.MetricLevel
) -> bool:
valid_levels = get_valid_levels_for_callable(metric_function)
return metric_level in valid_levels
[docs]
def get_valid_levels_for_callable(metric_function: Callable) -> set[str]:
"""Given a callable, determine the valid metric_level values based on the type annotation of the first parameter.
Args:
metric_function (Callable): A callable, probably one available via EvalRun
Returns:
set[str]: Valid values for MetricItem.metric_level
"""
accepted_parameter_types = get_first_parameter_types(metric_function)
valid_levels = set()
for flexeval_type in FLEXEVAL_TYPE_SET:
if flexeval_type in accepted_parameter_types:
valid_levels.add(flexeval_type.__name__)
if str in accepted_parameter_types:
for level in ["Message", "Turn", "Thread"]:
valid_levels.add(level)
if list in accepted_parameter_types:
for level in ["Turn", "Thread"]:
valid_levels.add(level)
if dict in accepted_parameter_types:
valid_levels.add("ToolCall")
return valid_levels
[docs]
def get_first_parameter_types(metric_function: Callable) -> set[type]:
input_type = next(
iter(inspect.signature(metric_function).parameters.values())
).annotation
if input_type is inspect._empty:
logger.debug(
f"Function '{metric_function}' has a first parameter with no type annotation."
)
return set()
return get_acceptable_arg_types(input_type)
[docs]
def get_acceptable_arg_types(input_type: type) -> set[type]:
# Note: we don't support NewType annotations yet
origin_type = typing.get_origin(input_type)
if origin_type is typing.Annotated:
# unpack Annotated types
input_type = typing.get_args(input_type)[0]
origin_type = typing.get_origin(input_type)
if origin_type in (typing.Union, types.UnionType):
union_arg_type_sets = [
get_acceptable_arg_types(type_arg)
for type_arg in typing.get_args(input_type)
]
return set.union(*union_arg_type_sets)
else: # not a union type
if origin_type is not None:
# e.g. input_type=list[str], origin_type=list
return {origin_type}
else:
# e.g. input_type=list, origin_type=list
if input_type is list or input_type is Iterable:
logger.warning(
"Type hint {input_type} lacks the detail that would allow us to determine the specific objects it accepts."
)
return {input_type}
[docs]
def join_all_contents_to_string(content: list[dict] | typing.Any) -> str:
"""
content is a list of dictionaries whose keys include 'content'.
Returns a string with all the 'content' entries concatenated together,
separated by newline.
"""
if isinstance(content, list):
content = "\n".join([item.get("content", "") for item in content])
return content