diff --git a/eval_protocol/dataset_logger/sqlite_evaluation_row_store.py b/eval_protocol/dataset_logger/sqlite_evaluation_row_store.py index 1a6618e5..a8f149a8 100644 --- a/eval_protocol/dataset_logger/sqlite_evaluation_row_store.py +++ b/eval_protocol/dataset_logger/sqlite_evaluation_row_store.py @@ -37,9 +37,9 @@ def db_path(self) -> str: return self._db_path def upsert_row(self, data: dict) -> None: - rollout_id = data["rollout_id"] - if "rollout_id" not in data: - raise ValueError("rollout_id is required to upsert a row") + rollout_id = data["execution_metadata"]["rollout_id"] + if rollout_id is None: + raise ValueError("execution_metadata.rollout_id is required to upsert a row") if self._EvaluationRow.select().where(self._EvaluationRow.rollout_id == rollout_id).exists(): self._EvaluationRow.update(data=data).where(self._EvaluationRow.rollout_id == rollout_id).execute() else: diff --git a/eval_protocol/models.py b/eval_protocol/models.py index b3d48b6c..79c4490d 100644 --- a/eval_protocol/models.py +++ b/eval_protocol/models.py @@ -237,6 +237,30 @@ class EvalMetadata(BaseModel): passed: Optional[bool] = Field(None, description="Whether the evaluation passed based on the threshold") +class ExecutionMetadata(BaseModel): + """Metadata about the execution of the evaluation.""" + + invocation_id: Optional[str] = Field( + default_factory=generate_id, + description="The ID of the invocation that this row belongs to.", + ) + + experiment_id: Optional[str] = Field( + default_factory=generate_id, + description="The ID of the experiment that this row belongs to.", + ) + + rollout_id: Optional[str] = Field( + default_factory=generate_id, + description="The ID of the rollout that this row belongs to.", + ) + + run_id: Optional[str] = Field( + None, + description=("The ID of the run that this row belongs to."), + ) + + class RolloutStatus(BaseModel): """Status of the rollout.""" @@ -281,26 +305,6 @@ class EvaluationRow(BaseModel): description="The status of the rollout.", ) - invocation_id: Optional[str] = Field( - default_factory=generate_id, - description="The ID of the invocation that this row belongs to.", - ) - - experiment_id: Optional[str] = Field( - default_factory=generate_id, - description="The ID of the experiment that this row belongs to.", - ) - - rollout_id: Optional[str] = Field( - default_factory=generate_id, - description="The ID of the rollout that this row belongs to.", - ) - - run_id: Optional[str] = Field( - None, - description=("The ID of the run that this row belongs to."), - ) - # Ground truth reference (moved from EvaluateResult to top level) ground_truth: Optional[str] = Field( default=None, description="Optional ground truth reference for this evaluation." @@ -311,6 +315,11 @@ class EvaluationRow(BaseModel): default=None, description="The evaluation result for this row/trajectory." ) + execution_metadata: ExecutionMetadata = Field( + default_factory=ExecutionMetadata, + description="Metadata about the execution of the evaluation.", + ) + # LLM usage statistics usage: Optional[CompletionUsage] = Field( default=None, description="Token usage statistics from LLM calls during execution." diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index f36fee9b..ef516f6b 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -383,8 +383,8 @@ def _log_eval_error( row.input_metadata.session_data["mode"] = mode # Initialize eval_metadata for each row row.eval_metadata = eval_metadata - row.experiment_id = experiment_id - row.invocation_id = invocation_id + row.execution_metadata.experiment_id = experiment_id + row.execution_metadata.invocation_id = invocation_id # has to be done in the pytest main process since it's # used to determine whether this eval has stopped @@ -409,11 +409,11 @@ def _log_eval_error( # apply new run_id to fresh_dataset for row in fresh_dataset: - row.run_id = run_id + row.execution_metadata.run_id = run_id # generate new rollout_id for each row for row in fresh_dataset: - row.rollout_id = generate_id() + row.execution_metadata.rollout_id = generate_id() # log the fresh_dataset for row in fresh_dataset: diff --git a/tests/dataset_logger/test_sqlite_dataset_logger_adapter.py b/tests/dataset_logger/test_sqlite_dataset_logger_adapter.py index 897691b2..7689d085 100644 --- a/tests/dataset_logger/test_sqlite_dataset_logger_adapter.py +++ b/tests/dataset_logger/test_sqlite_dataset_logger_adapter.py @@ -24,7 +24,7 @@ def test_update_log_and_read(): logger = SqliteDatasetLoggerAdapter(store=store) logger.log(row) - saved = logger.read(row.rollout_id)[0] + saved = logger.read(row.execution_metadata.rollout_id)[0] assert row.messages == saved.messages assert row.input_metadata == saved.input_metadata @@ -42,7 +42,7 @@ def test_create_log_and_read(): row = EvaluationRow(input_metadata=input_metadata, messages=messages) logger.log(row) - saved = logger.read(rollout_id=row.rollout_id)[0] + saved = logger.read(rollout_id=row.execution_metadata.rollout_id)[0] assert row.messages == saved.messages assert row.input_metadata == saved.input_metadata diff --git a/tests/pytest/test_pytest_ids.py b/tests/pytest/test_pytest_ids.py index 4deda27c..0131bcbe 100644 --- a/tests/pytest/test_pytest_ids.py +++ b/tests/pytest/test_pytest_ids.py @@ -12,8 +12,8 @@ def __init__(self): self._rows: dict[str, EvaluationRow] = {} def log(self, row: EvaluationRow): - print(row.run_id, row.rollout_id) - self._rows[row.rollout_id] = row + print(row.execution_metadata.run_id, row.execution_metadata.rollout_id) + self._rows[row.execution_metadata.rollout_id] = row def read(self): return list(self._rows.values()) @@ -76,10 +76,10 @@ def test_evaluation_test_decorator_ids_single(monkeypatch): logger=InMemoryLogger(), ) def eval_fn(row: EvaluationRow) -> EvaluationRow: - unique_run_ids.add(row.run_id) - unique_experiment_ids.add(row.experiment_id) - unique_rollout_ids.add(row.rollout_id) - unique_invocation_ids.add(row.invocation_id) + unique_run_ids.add(row.execution_metadata.run_id) + unique_experiment_ids.add(row.execution_metadata.experiment_id) + unique_rollout_ids.add(row.execution_metadata.rollout_id) + unique_invocation_ids.add(row.execution_metadata.invocation_id) unique_row_ids.add(row.input_metadata.row_id) return row diff --git a/tests/pytest/test_pytest_mcp_config.py b/tests/pytest/test_pytest_mcp_config.py index d9b46c58..c1b55d51 100644 --- a/tests/pytest/test_pytest_mcp_config.py +++ b/tests/pytest/test_pytest_mcp_config.py @@ -20,7 +20,7 @@ ] ], rollout_processor=default_agent_rollout_processor, - model=["fireworks_ai/accounts/fireworks/models/gpt-oss-120b"], + model=["fireworks_ai/accounts/fireworks/models/gpt-oss-20b"], mode="pointwise", mcp_config_path="tests/pytest/mcp_configurations/mock_discord_mcp_config.json", ) diff --git a/vite-app/src/GlobalState.tsx b/vite-app/src/GlobalState.tsx index 9669fe4c..67fa7fc0 100644 --- a/vite-app/src/GlobalState.tsx +++ b/vite-app/src/GlobalState.tsx @@ -14,10 +14,10 @@ export class GlobalState { upsertRows(dataset: EvaluationRow[]) { dataset.forEach((row) => { - if (!row.rollout_id) { + if (!row.execution_metadata?.rollout_id) { return; } - this.dataset[row.rollout_id] = row; + this.dataset[row.execution_metadata.rollout_id] = row; }); } diff --git a/vite-app/src/components/EvaluationRow.tsx b/vite-app/src/components/EvaluationRow.tsx index d16b06d2..fdeaf03c 100644 --- a/vite-app/src/components/EvaluationRow.tsx +++ b/vite-app/src/components/EvaluationRow.tsx @@ -133,10 +133,10 @@ const IdSection = observer(({ data }: { data: EvaluationRowType }) => ( )); @@ -197,7 +197,7 @@ const ExpandedContent = observer( export const EvaluationRow = observer( ({ row }: { row: EvaluationRowType; index: number }) => { - const rolloutId = row.rollout_id; + const rolloutId = row.execution_metadata?.rollout_id; const isExpanded = state.isRowExpanded(rolloutId); const toggleExpanded = () => state.toggleRowExpansion(rolloutId); @@ -226,7 +226,7 @@ export const EvaluationRow = observer( {/* Rollout ID */} - + {/* Model */} diff --git a/vite-app/src/components/EvaluationTable.tsx b/vite-app/src/components/EvaluationTable.tsx index f45f16c5..fb470b9e 100644 --- a/vite-app/src/components/EvaluationTable.tsx +++ b/vite-app/src/components/EvaluationTable.tsx @@ -20,7 +20,7 @@ const TableBody = observer( {paginatedData.map((row, index) => ( diff --git a/vite-app/src/types/eval-protocol.ts b/vite-app/src/types/eval-protocol.ts index 3fb00d5c..b18697f1 100644 --- a/vite-app/src/types/eval-protocol.ts +++ b/vite-app/src/types/eval-protocol.ts @@ -94,15 +94,19 @@ export const RolloutStatusSchema = z.object({ error_message: z.string().optional().describe('Error message if the rollout failed.') }); +export const ExecutionMetadataSchema = z.object({ + invocation_id: z.string().optional().describe('The ID of the invocation that this row belongs to.'), + experiment_id: z.string().optional().describe('The ID of the experiment that this row belongs to.'), + rollout_id: z.string().optional().describe('The ID of the rollout that this row belongs to.'), + run_id: z.string().optional().describe('The ID of the run that this row belongs to.'), +}); + export const EvaluationRowSchema = z.object({ messages: z.array(MessageSchema).describe('List of messages in the conversation/trajectory.'), tools: z.array(z.record(z.string(), z.any())).optional().describe('Available tools/functions that were provided to the agent.'), input_metadata: InputMetadataSchema.describe('Metadata related to the input (dataset info, model config, session data, etc.).'), rollout_status: RolloutStatusSchema.default({ status: 'finished' }).describe('The status of the rollout.'), - invocation_id: z.string().optional().describe('The ID of the invocation that this row belongs to.'), - experiment_id: z.string().optional().describe('The ID of the experiment that this row belongs to.'), - rollout_id: z.string().optional().describe('The ID of the rollout that this row belongs to.'), - run_id: z.string().optional().describe('The ID of the run that this row belongs to.'), + execution_metadata: ExecutionMetadataSchema.optional().describe('Metadata about the execution of the evaluation.'), ground_truth: z.string().optional().describe('Optional ground truth reference for this evaluation.'), evaluation_result: EvaluateResultSchema.optional().describe('The evaluation result for this row/trajectory.'), usage: CompletionUsageSchema.optional().describe('Token usage statistics from LLM calls during execution.'), diff --git a/vite-app/src/util/pivot.test.ts b/vite-app/src/util/pivot.test.ts index ec26063e..4875e7d3 100644 --- a/vite-app/src/util/pivot.test.ts +++ b/vite-app/src/util/pivot.test.ts @@ -189,7 +189,7 @@ describe('computePivot', () => { const res = computePivot({ data: rows, - rowFields: ['$.eval_metadata.name', '$.experiment_id'], + rowFields: ['$.eval_metadata.name', '$.execution_metadata.experiment_id'], columnFields: ['$.input_metadata.completion_params.model'], valueField: '$.evaluation_result.score', aggregator: 'avg',