Source code for flexeval.function_types
"""Inspection utilities that use type hints to determine the appropriate object to pass to a function metric.
See :mod:`~flexeval.schema.eval_schema`.
"""
import inspect
import logging
import types
import typing
from collections.abc import Callable, Iterable
from flexeval.classes import message, thread, tool_call, turn
from flexeval.schema import eval_schema
AnyFunctionObjectInput = typing.Union[
turn.Turn,
message.Message,
thread.Thread,
tool_call.ToolCall,
]
FLEXEVAL_TYPE_SET: set[type] = {
turn.Turn,
message.Message,
thread.Thread,
tool_call.ToolCall,
}
logger = logging.getLogger(__name__)
[docs]
def is_callable_valid_for_metric_level(
metric_function: Callable, metric_level: eval_schema.MetricLevel
) -> bool:
valid_levels = get_valid_levels_for_callable(metric_function)
return metric_level in valid_levels
[docs]
def get_valid_levels_for_callable(metric_function: Callable) -> set[str]:
"""Given a callable, determine the valid metric_level values based on the type annotation of the first parameter.
Args:
metric_function (Callable): A callable, probably one available via EvalRun
Returns:
set[str]: Valid values for MetricItem.metric_level
"""
accepted_parameter_types = get_first_parameter_types(metric_function)
valid_levels = set()
for flexeval_type in FLEXEVAL_TYPE_SET:
if flexeval_type in accepted_parameter_types:
valid_levels.add(flexeval_type.__name__)
if str in accepted_parameter_types:
for level in ["Message", "Turn", "Thread"]:
valid_levels.add(level)
if list in accepted_parameter_types:
for level in ["Turn", "Thread"]:
valid_levels.add(level)
if dict in accepted_parameter_types:
valid_levels.add("ToolCall")
return valid_levels
[docs]
def get_first_parameter_types(metric_function: Callable) -> set[type]:
input_type = next(
iter(inspect.signature(metric_function).parameters.values())
).annotation
if input_type is inspect._empty:
logger.debug(
f"Function {metric_function}'s first parameter has no type annotation."
)
return set()
return get_acceptable_arg_types(input_type)
[docs]
def get_acceptable_arg_types(input_type: type) -> set[type]:
# Note: we don't support NewType annotations yet
origin_type = typing.get_origin(input_type)
if origin_type is typing.Annotated:
# unpack Annotated types
input_type = typing.get_args(input_type)[0]
origin_type = typing.get_origin(input_type)
if origin_type in (typing.Union, types.UnionType):
union_arg_type_sets = [
get_acceptable_arg_types(type_arg)
for type_arg in typing.get_args(input_type)
]
return set.union(*union_arg_type_sets)
else: # not a union type
if origin_type is not None:
# e.g. input_type=list[str], origin_type=list
return {origin_type}
else:
# e.g. input_type=list, origin_type=list
if input_type is list or input_type is Iterable:
logger.warning(
"Type hint {input_type} lacks the detail that would allow us to determine the specific objects it accepts."
)
return {input_type}
[docs]
def get_function_input(
metric_function: Callable,
metric_level: eval_schema.MetricLevel,
input_object: AnyFunctionObjectInput,
context_only: bool,
) -> AnyFunctionObjectInput | str | dict | list:
"""Coerce input_object to a type accepted by metric_function at this metric_level.
Args:
metric_function (Callable): Function to invoke with the returned input.
metric_level (eval_schema.MetricLevel): The metric level at which metric_function is being invoked.
input_object (AnyFunctionObjectInput): The input_object to be coerced, or passed as-is if accepted by metric_function.
context_only (bool): Determines how strings and lists are converted. See schema documentation.
Raises:
ValueError: If the function accepts at least one declared type, but
it's a type we don't support at all e.g. set or
it's a type we don't support at this metric_level.
Returns:
AnyFunctionObjectInput | str | dict | list: The coerced input for metric_function.
"""
if metric_level not in eval_schema.VALID_METRIC_LEVELS:
raise ValueError(
f"metric_level '{metric_level}' not one of the valid levels: {eval_schema.VALID_METRIC_LEVELS}"
)
input_type = type(input_object)
accepted_parameter_types = get_first_parameter_types(metric_function)
if len(accepted_parameter_types) == 0:
logger.debug(
f"Metric function {metric_function}'s first parameter has no type hint, so we can't determine if a type transformation needs to be applied."
)
return input_object
if input_type in accepted_parameter_types:
# no transformation necessary; the function accepts the type we already have
return input_object
elif dict in accepted_parameter_types and metric_level == "ToolCall":
return input_object.get_dict_representation()
elif list in accepted_parameter_types and metric_level in ["Turn", "Thread"]:
if context_only:
return input_object.get_context()
else:
# this is on a single turn - pass in the parsed list
return input_object.get_content()
elif str in accepted_parameter_types:
if metric_level == "ToolCall":
raise ValueError(
"Functions that accept strings can't be used for tool calls. Accept a dict (or a flexeval.classes.tool_call.ToolCall) instead."
)
if context_only:
# join together all previous turns
return join_all_contents_to_string(input_object.get_context())
else:
# current turn only
return join_all_contents_to_string(input_object.get_content())
else:
# the function accepts at least one declared type, but either:
# - it's a type we don't support at all e.g. set
# - it's a type we don't support at this metric_level
raise ValueError(
f"For metric level '{metric_level}', can't coerce {input_type.__name__} for function {metric_function} to accepted parameter type(s) '{', '.join([type.__name__ for type in accepted_parameter_types])}'."
)
[docs]
def join_all_contents_to_string(content: list[dict] | typing.Any) -> str:
"""
content is a list of dictionaries whose keys include 'content'.
Returns a string with all the 'content' entries concatenated together,
separated by newline.
"""
if isinstance(content, list):
content = "\n".join([item.get("content", "") for item in content])
return content