Source code for flexeval.data_loader

"""Dataset loading functions. Maybe should move to :mod:`~flexeval.io`."""

import json
import logging
import pathlib
import random as rd
import sqlite3

from langchain.load.dump import dumps
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer

from flexeval.classes.dataset import Dataset
from flexeval.classes.message import Message
from flexeval.classes.thread import Thread
from flexeval.classes.tool_call import ToolCall
from flexeval.classes.turn import Turn

logger = logging.getLogger(__name__)


[docs] def load_jsonl( dataset: Dataset, filename: str | pathlib.Path, max_n_conversation_threads: int | None = None, nb_evaluations_per_thread: int | None = 1, ): with open(filename, "r") as infile: contents = infile.read() # will be a big string all_lines = contents.splitlines() # Each row is a single row of the jsonl file # That means it has 'input' as a key, and a list of dictionaries as values # per line if max_n_conversation_threads is None: max_n_conversation_threads = len(all_lines) if max_n_conversation_threads <= len(all_lines): selected_thread_ids = rd.sample( list(range(len(all_lines))), max_n_conversation_threads ) else: logger.debug( f"You requested up to {max_n_conversation_threads} conversations but only {len(all_lines)} are present in Jsonl dataset at '{filename}'." ) selected_thread_ids = list(range(len(all_lines))) ### should duplicate the select threads nb_evaluations_per_thread times if nb_evaluations_per_thread is None: nb_evaluations_per_thread = 1 for thread_id, thread in enumerate(all_lines): for thread_eval_run_id in range( max(1, nb_evaluations_per_thread) ): # duplicate stored threads for averaged evaluation results if thread_id in selected_thread_ids: thread_object = Thread.create( evalsetrun=dataset.evalsetrun, dataset=dataset, jsonl_thread_id=thread_id, eval_run_thread_id=str(thread_id) + "_" + str(thread_eval_run_id), ) # Context context = [] thread_input = json.loads(thread)["input"] # Get system prompt used in the thread - assuming only 1 for message in thread_input: if message["role"] == "system": system_prompt = message["content"] break else: system_prompt = None if system_prompt is not None: # Add the system prompt as context context.append({"role": "system", "content": system_prompt}) # Create messages index_in_thread = 0 for message in thread_input: role = message.get("role", None) if role != "system": # System message shouldn't be added as a separate message system_prompt_for_this_message = "" if role != "user": system_prompt_for_this_message = system_prompt Message.create( evalsetrun=dataset.evalsetrun, dataset=dataset, thread=thread_object, index_in_thread=index_in_thread, role=role, content=message.get("content", None), context=json.dumps(context), metadata=message.get("metadata", None), is_flexeval_completion=False, system_prompt=system_prompt_for_this_message, ) # Update context context.append( {"role": role, "content": message.get("content", None)} ) index_in_thread += 1 add_turns(thread_object)
# TODO - should we add ToolCall here? Is there a standard way to represent them in jsonl?
[docs] def load_langgraph_sqlite( dataset: Dataset, filename: str, max_n_conversation_threads: int | None = None, nb_evaluations_per_thread: int | None = 1, ): serializer = JsonPlusSerializer() with sqlite3.connect(filename) as conn: # Set the row factory to sqlite3.Row # allowing us to reference columns by name instead of index conn.row_factory = sqlite3.Row # Create a cursor object cursor = conn.cursor() verify_checkpoints_table_exists(cursor) # Sync database query = "PRAGMA wal_checkpoint(FULL);" cursor.execute(query) # Make threads (aka conversations) query = "select distinct thread_id from checkpoints" cursor.execute(query) thread_ids = cursor.fetchall() nb_threads = len(thread_ids) if max_n_conversation_threads is None: max_n_conversation_threads = nb_threads if max_n_conversation_threads <= nb_threads: selected_thread_ids = rd.sample(thread_ids, max_n_conversation_threads) else: logger.debug( f"You requested up to {max_n_conversation_threads} conversations but only {nb_threads} are present in Sqlite dataset at '{filename}'." ) selected_thread_ids = thread_ids logger.debug(" DEBUG DUPLICATE SELECT THREAD IDS\n", selected_thread_ids[0]) for thread_eval_run_id in range( max(1, nb_evaluations_per_thread) ): # duplicate stored threads for averaged evaluation results for thread_id in selected_thread_ids: thread = Thread.create( evalsetrun=dataset.evalsetrun, dataset=dataset, langgraph_thread_id=thread_id[0], eval_run_thread_id=str(thread_id[0]) + "_" + str(thread_eval_run_id), ) # Create messages query = f"select * from checkpoints where thread_id = '{thread.langgraph_thread_id}'" cursor.execute(query) completion_list = cursor.fetchall() # context has to be reset at the start of every thread context = [] # tool call variables tool_calls_dict = {} tool_responses_dict = {} tool_addional_kwargs_dict = {} # system prompt reset for every thread system_prompt = None for completion_row in completion_list: # checkpoint is full state history checkpoint = serializer.loads_typed( (completion_row["type"], completion_row["checkpoint"]) ) # metadata is the state update for that row metadata = json.loads(completion_row["metadata"]) # IDs from langgraph if metadata.get("writes") is None: continue else: # Goal here is to create a data structure for EACH write/update # that can be used to construct a Message object # LangGraph stores info in 'writes' in the checkpoints.metadata column # but the format is a bit different between human and machine input # The resulting data structure should have # key (str) -- graph 'node' that produced the message (or 'human') # value (list) -- list of 'message' data structures with id, kwargs, etc # { # 'node_name':{ # "messages":[ # { # 'id': "XYZ" # 'kwargs':{ # "content": 'text of the message', # "additional_kwargs": {} # }, # } # ] # # } # } # user input condition if metadata.get("source") == "input": # NOTE: I think with the updated logging of HumanMessage with langgraph, we don't need this case update_dict = {} # this will be a dictionary we can add to # key is 'input', as in human input update_dict["input"] = {"messages": []} # print("metadata keys:", metadata["writes"].keys()) # the very first message in input in a thread seems to include # the system prompt, not a message that was sent by the user. # the system promptdoesn't seem to be set anywhere else, so # using that as the system prompt for the thread. messagecount = 0 for msg in metadata["writes"]["__start__"]["messages"]: if messagecount == 0 and metadata["step"] == -1: system_prompt = msg["kwargs"]["content"] messagecount += 1 else: message = {} message["id"] = [ "HumanMessage" ] # LangGraph has a list here message["kwargs"] = {} message["kwargs"]["content"] = msg message["kwargs"]["type"] = "human" update_dict["input"]["messages"].append(message) # will be used below role = "user" # machine input condition elif metadata.get("source") == "loop": # This already has a list of messages with kwargs, etc update_dict = metadata.get("writes") # I think 'system_prompt' is empty by default and not stored here unless # it's included in the LangGraph state checkpoint_system_prompt = checkpoint.get( "channel_values", {} ).get("system_prompt") if checkpoint_system_prompt is not None: system_prompt = checkpoint_system_prompt role = "assistant" else: raise Exception( f"Unhandled input condition! here is the metadata: {metadata}" ) # Add system prompt as first thing in context if not already present if len(context) == 0: context.append({"role": "system", "content": system_prompt}) # iterate through nodes - there is probably only 1 for node, value in update_dict.items(): # iterate through list of message updates if "messages" in value: if isinstance(value["messages"], dict): # Make this a list to iterate through - 4 Feb 2025 - used to be a list previously messagelist = [value["messages"]] else: messagelist = value["messages"] index_in_thread = 0 for message in messagelist: if role == "user": content = ( message.get("kwargs", {}) .get("content", {}) .get("kwargs", {}) .get("content", None) ) elif role == "assistant": content = message.get("kwargs", {}).get( "content", None ) else: raise Exception( "`role` should be either user or assistant." ) Message.create( evalsetrun=dataset.evalsetrun, dataset=dataset, thread=thread, index_in_thread=index_in_thread, role=role, content=content, context=json.dumps(context), is_flexeval_completion=False, system_prompt=system_prompt, # language model stats tool_calls=json.dumps( message.get("kwargs", {}).get( "tool_calls", [] ) ), tool_call_ids=[ tc["id"] for tc in message.get("kwargs", {}).get( "tool_calls", [] ) ], n_tool_calls=len( message.get("kwargs", {}).get( "tool_calls", [] ) ), prompt_tokens=message.get("kwargs", {}) .get("response_metadata", {}) .get("token_usage", {}) .get("prompt_tokens"), completion_tokens=message.get("kwargs", {}) .get("response_metadata", {}) .get("token_usage", {}) .get("completion_tokens"), model_name=message.get("kwargs", {}) .get("response_metadata", {}) .get("model_name"), # langgraph metadata langgraph_ts=checkpoint.get("ts"), langgraph_step=metadata.get("step"), langgraph_thread_id=completion_row["thread_id"], langgraph_checkpoint_id=completion_row[ "checkpoint_id" ], langgraph_parent_checkpoint_id=completion_row[ "parent_checkpoint_id" ], langgraph_checkpoint=dumps( checkpoint ), # Have to re-dump this because of the de-serialization#completion_row["checkpoint"], langgraph_metadata=completion_row["metadata"], langgraph_node=node, langgraph_message_type=message["id"][-1], langgraph_type=message.get("kwargs", {}).get( "type" ), # special property of state langchain_print=message.get("kwargs", {}) .get("additional_kwargs", {}) .get("print", False), ) # update the context for the next Message context.append( { "role": role, "content": content, "langgraph_role": message["id"][-1], } ) # record tool call info so we can match them up later if message.get("kwargs", {}).get("type") == "tool": # this should have a mapping between tool_call_id and the RESPONSE to to the tool call tool_responses_dict[ message.get("kwargs", {}).get( "tool_call_id" ) ] = message.get("kwargs", {}).get("content", "") else: for tool_call in message.get("kwargs", {}).get( "tool_calls", [] ): # this should have all the info about the tool calls, including additional_kwargs # but NOT their responses tool_calls_dict[tool_call["id"]] = tool_call tool_addional_kwargs_dict[ tool_call["id"] ] = message.get("kwargs", {}).get( "additional_kwargs", {} ) index_in_thread += 1 # Add turns to each message # Need to do this before dealing with tool calls, since we # associated turns with tool calls via messages during the .create() method add_turns(thread) ## Match up tool calls and make an object for each match for tool_call_id, tool_call_vals in tool_calls_dict.items(): # DEBUG # tool_call_id is defined assert tool_call_id in tool_responses_dict, ( f"Found a tool call without a tool response! id: {tool_call_id}" ) # get matching message - should now be accessible through thread now? matching_message = [ m for m in thread.messages if tool_call_id in m.tool_call_ids ][0] ToolCall.create( evalsetrun=dataset.evalsetrun, dataset=dataset, thread=thread, turn=matching_message.turn, message=matching_message, function_name=tool_call_vals.get("name"), args=json.dumps(tool_call_vals.get("args")), additional_kwargs=json.dumps( tool_addional_kwargs_dict.get(tool_call_id) ), tool_call_id=tool_call_id, response_content=tool_responses_dict.get(tool_call_id), )
## Add system prompt if available?
[docs] def add_turns(thread: Thread): # Add turn labels # Step 1 - add placeholder_turn_id to each message message_roles = [] for message in thread.messages: message_roles.append({"id": message.id, "role": message.role}) message_placeholder_ids, turn_dict = get_turns(thread=thread) # Step 2 - Create turns, plus a mapping between the placeholder ids and the created ids turns = {} index_in_thread = 0 for placeholder_turn_id, role in turn_dict.items(): # turns.items(): t = Turn.create( evalsetrun=thread.evalsetrun, dataset=thread.dataset, thread=thread, index_in_thread=index_in_thread, role=role, ) # map placeholder id to turn object turns[placeholder_turn_id] = t index_in_thread += 1 # Step 3 - add placeholder ids to messages # Can use zip since entries in message_list correspond to thread.messages # NOTE: ANR: I don't follow how the message_list was supposed to work below. for ml, message in zip(message_placeholder_ids, thread.messages): # Is this going to work? No idea message.turn = turns[ml] # message.is_final_turn_in_input = ml.get("is_final_turn_in_input", False) message.save()
[docs] def verify_checkpoints_table_exists(cursor): # double check that the 'checkpoints' table exists cursor.execute( """ SELECT name FROM sqlite_master WHERE type='table' AND name='checkpoints' """ ) result = cursor.fetchone() # Assert that the result is not None, meaning the table exists assert result is not None, "Table 'checkpoints' does not exist in the database."
[docs] def get_turns(thread: Thread): """We're defining a turn as a list of 1 or more consequtive outputs by the same role, where the role is either 'user', or 'assistant/tool'. In other words, we would parse as follows: TURN 1 - user TURN 2 - assistant TURN 3 - user TURN 4 - assistant TURN 4 - tool TURN 4 - assistant TURN 5 - user """ # these are all treated as belonging to the same 'turn' machine_labels = ["assistant", "ai", "tool"] turn_id = 1 previous_role = "" # TODO: Make a message list here, store the placeholder ids, and update to the real turn ids; save at end message_placeholder_ids = [] for turnentry_id, entry in enumerate(thread.messages): # enumerate(input_list): current_role = entry.role # entry.get("role", None) # entry["role"] = current_role # if your role matches a previous, don't increment turn_id if (current_role in machine_labels and previous_role in machine_labels) or ( current_role not in machine_labels and previous_role not in machine_labels ): pass # TODO: clean up the condition to avoid the empty if # previous_role = current_role # entry["placholder_turn_id"] = turn_id else: turn_id += 1 # entry["placholder_turn_id"] = turn_id # previous_role = current_role # entry.turn_id = turn_id message_placeholder_ids.append(turn_id) previous_role = current_role # entry.save() # NOTE: ANR seems like this could be optimized - e.g., set all # to false, then do a select query for just the ones where turn_id column is turn_id. That would also # reduce the number of saves to the database. # label final entry # ANR: moved up the turn_id_roles bit here to avoid iterating twice turn_id_roles = {} for message_placehold_id, entry in zip( message_placeholder_ids, thread.messages ): # input_list: turn_id_roles[message_placehold_id] = entry.role entry.is_final_turn_in_input = message_placehold_id == turn_id entry.save() # Could optimize this to avoid saving twice return message_placeholder_ids, turn_id_roles