Source code for flexeval.classes.message

import copy
import json
import logging

import peewee as pw
from playhouse.shortcuts import model_to_dict

from flexeval.classes.base import BaseModel
from flexeval.classes.dataset import Dataset
from flexeval.classes.eval_set_run import EvalSetRun
from flexeval.classes.thread import Thread
from flexeval.classes.turn import Turn
from flexeval.configuration import completion_functions

logger = logging.getLogger(__name__)


[docs] class Message(BaseModel): """Holds a single component of a single turn Corresponds to one output of a node in LangGraph or one Turn in jsonl """ id = pw.IntegerField(primary_key=True) evalsetrun = pw.ForeignKeyField(EvalSetRun, backref="messages") dataset = pw.ForeignKeyField(Dataset, backref="messages") thread = pw.ForeignKeyField(Thread, backref="messages") index_in_thread = pw.IntegerField() # must be null=True because we're adding it after create() turn = pw.ForeignKeyField(Turn, null=True, backref="messages") role = pw.TextField() # user or assistant - 'tools' are counted as assistants content = pw.TextField() context = pw.TextField(null=True) # Previous messages # helpers system_prompt = pw.TextField(null=True) is_flexeval_completion = pw.BooleanField(null=True) is_final_turn_in_input = pw.BooleanField(null=True) langgraph_print = pw.TextField(null=True) # language model stats tool_callslanggraph_print = pw.TextField(null=True) tool_call_ids = pw.TextField(null=True) n_tool_calls = pw.IntegerField(null=True) prompt_tokens = pw.IntegerField(null=True) completion_tokens = pw.IntegerField(null=True) model_name = pw.TextField(null=True) # langgraph metadata langgraph_ts = pw.TextField(null=True) langgraph_step = pw.IntegerField(null=True) langgraph_thread_id = pw.TextField(null=True) langgraph_checkpoint_id = pw.TextField(null=True) langgraph_parent_checkpoint_id = pw.TextField(null=True) langgraph_node = pw.TextField(null=True) langgraph_message_type = pw.TextField(null=True) langgraph_type = pw.TextField(null=True) langgraph_invocation_id = pw.TextField(null=True) # putting these at the end so the database is easier to browse langgraph_checkpoint = pw.TextField(null=True) langgraph_metadata = pw.TextField(null=True) def __init__(self, **kwargs): super().__init__(**kwargs) self.metrics_to_evaluate = [] def get_completion(self, include_system_prompt=False): # only get a completion if this is the final turn - we probably don't want to branch from mid-conversation if self.is_final_turn_in_input: completion_config = json.loads(self.evalsetrun.completion_llm) completion_fn_name = completion_config.get("function_name", None) completion_function_kwargs = completion_config.get("kwargs", None) # Check if the function name exists in the global namespace and call it if hasattr(completion_functions, completion_fn_name) and hasattr( completion_functions, completion_fn_name ): completion_function = getattr( completion_functions, completion_fn_name, None ) completion = completion_function( conversation_history=self.get_formatted_prompt( include_system_prompt=False ), **completion_function_kwargs, ) else: logger.warning( "In completion_functions.py: No callable function named " + completion_fn_name + " found." ) completion = None # "completion" will be the output of an existing completion function # which generally means it'll have a structure like this # {"choices": [{"message": {"content": "hi", "role": "assistant"}}]} result = model_to_dict(self, exclude=[self.id]) result["evalsetrun"] = self.evalsetrun result["dataset"] = self.dataset result["datasetrow"] = self.datasetrow result["turn_number"] = self.turn_number + 1 result["role"] = "assistant" result["context"] = self.get_formatted_prompt(include_system_prompt=False) result["is_final_turn_in_input"] = False # b/c it's not in input self.is_final_turn_in_input = False result["is_completion"] = True result["completion"] = completion result["model"] = completion.get("model", None) result["prompt_tokens"] = completion.get("usage", {}).get( "prompt_tokens", None ) / len(completion.get("choices", [1])) result["completion_tokens"] = completion.get("usage", {}).get( "completion_tokens", None ) / len( completion.get("choices", [1]) ) # TODO - use tiktoken here instead?? this will just give the average result_list = [] for ix, choice in enumerate(completion["choices"]): temp = copy.deepcopy(result) temp["tool_used"] = choice["message"].get("tool_calls", None) temp["turn"] = [choice["message"]] temp["content"] = choice["message"]["content"] temp["completion_number"] = ix + 1 result_list.append(temp) return result_list else: return None def get_formatted_prompt(self, include_system_prompt=False) -> list[dict[str, str]]: formatted_prompt = [] if include_system_prompt: formatted_prompt.append({"role": "system", "content": self.system_prompt}) context = json.loads(self.context) if len(context) > 0: formatted_prompt += context # TODO - we might just want a subset of this formatted_prompt.append({"role": self.role, "content": self.content}) # for t in json.loads(self.turn): # formatted_prompt.append({"role": t["role"], "content": t["content"]}) return formatted_prompt def format_input_for_rubric(self): input = self.get_formatted_prompt() output_minus_completion = "" for i in input[:-1]: output_minus_completion += f"{i['role']}: {i['content']}\n" completion = f"{input[-1]['content']}" output = output_minus_completion + completion tool_call_text = "" for tc in self.toolcalls: tool_call_text += """ Function name: {function_name} Input arguments: {args} Function output: {response_content} """.format( function_name=tc.function_name, args=tc.args, response_content=tc.response_content, ) # output - all turns # output_minus_completion - all turns except the last # completion - last turn # tool_call_text - all tool calls return output, output_minus_completion, completion, tool_call_text def get_content(self) -> str: return self.content def get_context(self, include_system_prompt=False) -> list[dict[str, str]]: context = json.loads(self.context) if not include_system_prompt: context = [ cur_dict for cur_dict in context if cur_dict.get("role") != "system" ] return context