"""Dataset loading functions. Maybe should move to :mod:`~flexeval.io`."""
import json
import logging
import pathlib
import random as rd
import sqlite3
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
from flexeval.classes.dataset import Dataset
from flexeval.classes.message import Message
from flexeval.classes.thread import Thread
from flexeval.classes.tool_call import ToolCall
from flexeval.classes.turn import Turn
from flexeval.schema.evalrun_schema import FileDataSource, FileFormatEnum
logger = logging.getLogger(__name__)
[docs]
def load_thread_to_dataset(
thread_id: str | int,
thread: dict,
dataset: Dataset,
eval_run_thread_id: str | None = None,
) -> Thread:
if "input" not in thread:
raise ValueError(
f"Expected thread format is a dictionary containing at least an 'input' key. Instead, we found: {thread.keys()}"
)
# extract any metadata
thread_metadata = thread.copy()
del thread_metadata["input"]
context = []
thread_input = thread["input"]
# Get system prompt used in the thread - assuming only 1
for message in thread_input:
if message["role"] == "system":
system_prompt = message["content"]
break
else:
system_prompt = None
if system_prompt is not None:
# Add the system prompt as context
context.append({"role": "system", "content": system_prompt})
thread_object: Thread = Thread.create(
dataset=dataset,
jsonl_thread_id=thread_id,
eval_run_thread_id=eval_run_thread_id,
system_prompt=system_prompt,
metadata=json.dumps(thread_metadata),
)
# Create messages
index_in_thread = 0
for message in thread_input:
if not isinstance(message, dict):
raise ValueError(
f"Can't load unknown object type; expected dict. Check JSONL format: {message}"
)
role = message.get("role", None)
if role != "system":
# System message shouldn't be added as a separate message
system_prompt_for_this_message = ""
if role != "user":
system_prompt_for_this_message = system_prompt
message_metadata = message.copy()
if "content" in message_metadata:
del message_metadata["content"]
if "role" in message_metadata:
del message_metadata["role"]
Message.create(
dataset=dataset,
thread=thread_object,
index_in_thread=index_in_thread,
role=role,
content=message.get("content", None),
context=json.dumps(context),
is_flexeval_completion=False,
system_prompt=system_prompt_for_this_message,
metadata=json.dumps(message_metadata),
)
# Update context
context.append({"role": role, "content": message.get("content", None)})
index_in_thread += 1
add_turns(thread_object)
return thread_object
[docs]
def load_file(
dataset: Dataset,
data_source: FileDataSource,
max_n_conversation_threads: int | None = None,
nb_evaluations_per_thread: int | None = 1,
):
if data_source.format == FileFormatEnum.jsonl:
load_jsonl(
dataset=dataset,
filename=data_source.path,
max_n_conversation_threads=max_n_conversation_threads,
nb_evaluations_per_thread=nb_evaluations_per_thread,
)
elif data_source.format == FileFormatEnum.langgraph_sqlite:
load_langgraph_sqlite(
dataset=dataset,
filename=data_source.path,
max_n_conversation_threads=max_n_conversation_threads,
nb_evaluations_per_thread=nb_evaluations_per_thread,
)
else:
raise ValueError("Format not yet supported.")
[docs]
def load_iterable(
dataset: Dataset,
iterable,
):
for thread_id, thread in enumerate(iterable):
load_thread_to_dataset(thread_id, thread, dataset)
[docs]
def load_jsonl(
dataset: Dataset,
filename: str | pathlib.Path,
max_n_conversation_threads: int | None = None,
nb_evaluations_per_thread: int | None = 1,
):
with open(filename, "r") as infile:
contents = infile.read() # will be a big string
all_lines = contents.splitlines()
# Each row is a single row of the jsonl file
# That means it has 'input' as a key, and a list of dictionaries as values
# per line
if max_n_conversation_threads is None:
max_n_conversation_threads = len(all_lines)
if max_n_conversation_threads <= len(all_lines):
selected_thread_ids = rd.sample(
list(range(len(all_lines))), max_n_conversation_threads
)
else:
logger.debug(
f"You requested up to '{max_n_conversation_threads}' conversations but only '{len(all_lines)}' are present in Jsonl dataset at '{filename}'."
)
selected_thread_ids = list(range(len(all_lines)))
### should duplicate the select threads nb_evaluations_per_thread times
if nb_evaluations_per_thread is None:
nb_evaluations_per_thread = 1
for thread_id, thread in enumerate(all_lines):
if thread_id in selected_thread_ids:
thread_json = json.loads(thread)
for thread_eval_run_id in range(
max(1, nb_evaluations_per_thread)
): # duplicate stored threads to enable averaged per-object evaluations
eval_run_thread_id = f"{thread_id}_{thread_eval_run_id}"
load_thread_to_dataset(
thread_id, thread_json, dataset, eval_run_thread_id
)
# TODO - should we add ToolCall here? Is there a standard way to represent them in jsonl?
[docs]
def load_langgraph_sqlite(
dataset: Dataset,
filename: str,
max_n_conversation_threads: int | None = None,
nb_evaluations_per_thread: int | None = 1,
):
"""Load conversations from a LangGraph SQLite checkpoint database.
Reads the final checkpoint for each thread and extracts the cumulative
message list from channel_values.messages. Compatible with langgraph >= 1.0.
"""
serializer = JsonPlusSerializer()
with sqlite3.connect(filename) as conn:
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
verify_checkpoints_table_exists(cursor)
cursor.execute("PRAGMA wal_checkpoint(FULL);")
# Get distinct thread IDs
cursor.execute("SELECT DISTINCT thread_id FROM checkpoints")
thread_ids = cursor.fetchall()
nb_threads = len(thread_ids)
if max_n_conversation_threads is None:
max_n_conversation_threads = nb_threads
if max_n_conversation_threads <= nb_threads:
selected_thread_ids = rd.sample(thread_ids, max_n_conversation_threads)
else:
logger.debug(
f"You requested up to '{max_n_conversation_threads}' conversations "
f"but only '{nb_threads}' are present in Sqlite dataset at '{filename}'."
)
selected_thread_ids = thread_ids
for thread_eval_run_id in range(max(1, nb_evaluations_per_thread)):
for thread_id_row in selected_thread_ids:
lg_thread_id = thread_id_row[0]
# Get the final checkpoint (highest step) for this thread
cursor.execute(
"""
SELECT *, json_extract(metadata, '$.step') as step
FROM checkpoints
WHERE thread_id = ?
ORDER BY json_extract(metadata, '$.step') DESC
LIMIT 1
""",
(lg_thread_id,),
)
final_row = cursor.fetchone()
if final_row is None:
logger.warning(f"No checkpoints found for thread '{lg_thread_id}'")
continue
checkpoint = serializer.loads_typed(
(final_row["type"], final_row["checkpoint"])
)
lg_messages = checkpoint.get("channel_values", {}).get("messages", [])
if not lg_messages:
logger.warning(
f"No messages in final checkpoint for thread '{lg_thread_id}'"
)
continue
thread = Thread.create(
dataset=dataset,
langgraph_thread_id=lg_thread_id,
eval_run_thread_id=f"{lg_thread_id}_{thread_eval_run_id}",
)
# Map message types to FlexEval roles
# Tools are counted as assistant per existing convention
context = []
system_prompt = None
tool_calls_dict = {}
tool_responses_dict = {}
tool_additional_kwargs_dict = {}
for index_in_thread, msg in enumerate(lg_messages):
msg_type = msg.type # 'human', 'ai', 'tool'
role = "user" if msg_type == "human" else "assistant"
content = msg.content
# Extract tool call info
tool_calls = getattr(msg, "tool_calls", []) or []
tool_call_ids = [tc["id"] for tc in tool_calls]
response_meta = getattr(msg, "response_metadata", {}) or {}
token_usage = response_meta.get("token_usage", {})
additional_kwargs = getattr(msg, "additional_kwargs", {}) or {}
Message.create(
dataset=dataset,
thread=thread,
index_in_thread=index_in_thread,
role=role,
content=content,
context=json.dumps(context),
is_flexeval_completion=False,
system_prompt=system_prompt,
# language model stats
tool_calls=json.dumps(tool_calls),
tool_call_ids=tool_call_ids,
n_tool_calls=len(tool_calls),
prompt_tokens=token_usage.get("prompt_tokens"),
completion_tokens=token_usage.get("completion_tokens"),
model_name=response_meta.get("model_name"),
# langgraph metadata
langgraph_ts=checkpoint.get("ts"),
langgraph_thread_id=lg_thread_id,
langgraph_checkpoint_id=final_row["checkpoint_id"],
langgraph_parent_checkpoint_id=final_row[
"parent_checkpoint_id"
],
langgraph_metadata=final_row["metadata"],
langgraph_message_type=msg_type,
langgraph_type=msg_type,
)
# Build context for next message
context.append({"role": role, "content": content})
# Track tool calls and responses for ToolCall creation
if msg_type == "tool":
tool_call_id = getattr(msg, "tool_call_id", None)
if tool_call_id:
tool_responses_dict[tool_call_id] = content
else:
for tc in tool_calls:
tool_calls_dict[tc["id"]] = tc
tool_additional_kwargs_dict[tc["id"]] = additional_kwargs
# Create turns from messages
add_turns(thread)
# Create ToolCall objects by matching calls to responses
for tool_call_id, tool_call_vals in tool_calls_dict.items():
if tool_call_id not in tool_responses_dict:
raise ValueError(
f"Found a tool call without a tool response! id='{tool_call_id}'"
)
matching_message = [
m
for m in thread.messages
if tool_call_id in (m.tool_call_ids or [])
][0]
ToolCall.create(
dataset=dataset,
thread=thread,
turn=matching_message.turn,
message=matching_message,
function_name=tool_call_vals.get("name"),
args=json.dumps(tool_call_vals.get("args")),
additional_kwargs=json.dumps(
tool_additional_kwargs_dict.get(tool_call_id)
),
tool_call_id=tool_call_id,
response_content=tool_responses_dict.get(tool_call_id),
)
[docs]
def add_turns(thread: Thread):
# Add turn labels
# Step 1 - add placeholder_turn_id to each message
message_roles = []
for message in thread.messages:
message_roles.append({"id": message.id, "role": message.role})
message_placeholder_ids, turn_dict = get_turns(thread=thread)
# Step 2 - Create turns, plus a mapping between the placeholder ids and the created ids
turns = {}
index_in_thread = 0
for placeholder_turn_id, role in turn_dict.items(): # turns.items():
t = Turn.create(
dataset=thread.dataset,
thread=thread,
index_in_thread=index_in_thread,
role=role,
)
# map placeholder id to turn object
turns[placeholder_turn_id] = t
index_in_thread += 1
# Step 3 - add placeholder ids to messages
# Can use zip since entries in message_list correspond to thread.messages
# NOTE: ANR: I don't follow how the message_list was supposed to work below.
for ml, message in zip(message_placeholder_ids, thread.messages):
# Is this going to work? No idea
message.turn = turns[ml]
# message.is_final_turn_in_input = ml.get("is_final_turn_in_input", False)
message.save()
[docs]
def verify_checkpoints_table_exists(cursor):
# double check that the 'checkpoints' table exists
cursor.execute("""
SELECT name FROM sqlite_master
WHERE type='table' AND name='checkpoints'
""")
result = cursor.fetchone()
# Assert that the result is not None, meaning the table exists
assert result is not None, "Table 'checkpoints' does not exist in the database."
[docs]
def get_turns(thread: Thread):
"""We're defining a turn as a list of 1 or more consequtive outputs
by the same role, where the role is either 'user', or 'assistant/tool'.
In other words, we would parse as follows:
TURN 1 - user
TURN 2 - assistant
TURN 3 - user
TURN 4 - assistant
TURN 4 - tool
TURN 4 - assistant
TURN 5 - user
"""
# these are all treated as belonging to the same 'turn'
machine_labels = ["assistant", "ai", "tool"]
turn_id = 1
previous_role = ""
# TODO: Make a message list here, store the placeholder ids, and update to the real turn ids; save at end
message_placeholder_ids = []
for turnentry_id, entry in enumerate(thread.messages): # enumerate(input_list):
current_role = entry.role # entry.get("role", None)
# entry["role"] = current_role
# if your role matches a previous, don't increment turn_id
if (current_role in machine_labels and previous_role in machine_labels) or (
current_role not in machine_labels and previous_role not in machine_labels
):
pass # TODO: clean up the condition to avoid the empty if
# previous_role = current_role
# entry["placholder_turn_id"] = turn_id
else:
turn_id += 1
# entry["placholder_turn_id"] = turn_id
# previous_role = current_role
# entry.turn_id = turn_id
message_placeholder_ids.append(turn_id)
previous_role = current_role
# entry.save()
# NOTE: ANR seems like this could be optimized - e.g., set all
# to false, then do a select query for just the ones where turn_id column is turn_id. That would also
# reduce the number of saves to the database.
# label final entry
# ANR: moved up the turn_id_roles bit here to avoid iterating twice
turn_id_roles = {}
for message_placehold_id, entry in zip(
message_placeholder_ids, thread.messages
): # input_list:
turn_id_roles[message_placehold_id] = entry.role
entry.is_final_turn_in_input = message_placehold_id == turn_id
entry.save() # Could optimize this to avoid saving twice
return message_placeholder_ids, turn_id_roles