Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions eval_protocol/pytest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@
from .default_single_turn_rollout_process import default_single_turn_rollout_processor
from .evaluation_test import evaluation_test
from .types import RolloutProcessor, RolloutProcessorConfig
from .utils import evaluate

__all__ = [
"default_agent_rollout_processor",
"default_no_op_rollout_processor",
"default_single_turn_rollout_processor",
"RolloutProcessor",
"RolloutProcessorConfig",
"evaluate",
"evaluation_test",
]
49 changes: 34 additions & 15 deletions eval_protocol/pytest/evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
from eval_protocol.pytest.types import (
Dataset,
DatasetPathParam,
EvaluationInputParam,
EvaluationTestMode,
InputMessagesParam,
InputParam,
ModelParam,
RolloutInputParam,
RolloutProcessor,
RolloutProcessorConfig,
TestFunction,
Expand All @@ -32,8 +33,9 @@ def evaluation_test(
input_messages: Optional[List[InputMessagesParam]] = None,
input_dataset: Optional[List[DatasetPathParam]] = None,
dataset_adapter: Optional[Callable[[List[Dict[str, Any]]], Dataset]] = lambda x: x,
input_params: Optional[List[InputParam]] = None,
rollout_input_params: Optional[List[RolloutInputParam]] = None,
rollout_processor: RolloutProcessor = default_no_op_rollout_processor,
evaluation_test_kwargs: Optional[List[EvaluationInputParam]] = None,
aggregation_method: AggregationMethod = "mean",
threshold_of_success: Optional[float] = None,
num_runs: int = 1,
Expand All @@ -56,8 +58,9 @@ def evaluation_test(
to a list of EvaluationRows if you have a custom dataset format.
dataset_adapter: Function to convert the input dataset to a list of
EvaluationRows. This is useful if you have a custom dataset format.
input_params: Generation parameters for the model.
rollout_input_params: Generation parameters for the rollout.
rollout_processor: Function used to perform the rollout.
evaluation_test_kwargs: Kwargs for the evaluation function.
aggregation_method: How to aggregate scores across rows.
threshold_of_success: If set, fail the test if the aggregated score is
below this threshold.
Expand Down Expand Up @@ -104,12 +107,19 @@ def execute_with_params(
test_func: TestFunction,
row: EvaluationRow | None = None,
input_dataset: List[EvaluationRow] | None = None,
evaluation_test_kwargs: Optional[EvaluationInputParam] = None,
):
kwargs = {}
if input_dataset is not None:
kwargs["rows"] = input_dataset
if row is not None:
kwargs["row"] = row
if evaluation_test_kwargs is not None:
if "row" in evaluation_test_kwargs:
raise ValueError("'row' is a reserved parameter for the evaluation function")
if "rows" in evaluation_test_kwargs:
raise ValueError("'rows' is a reserved parameter for the evaluation function")
kwargs.update(evaluation_test_kwargs)
return execute_function(test_func, **kwargs)

# Calculate all possible combinations of parameters
Expand All @@ -118,21 +128,23 @@ def generate_combinations():

# Handle optional parameters with defaults
datasets: List[Optional[DatasetPathParam]] = input_dataset if input_dataset is not None else [None] # type: ignore
params: List[Optional[InputParam]] = input_params if input_params is not None else [None] # type: ignore
params: List[Optional[RolloutInputParam]] = rollout_input_params if rollout_input_params is not None else [None] # type: ignore
messages: List[Optional[InputMessagesParam]] = input_messages if input_messages is not None else [None] # type: ignore
kwargs: List[Optional[EvaluationInputParam]] = evaluation_test_kwargs if evaluation_test_kwargs is not None else [None] # type: ignore

# Generate all combinations
for m in model:
for ds in datasets:
for ip in params:
for im in messages:
# Skip combinations that don't make sense
# If we have a dataset, we should have params for rollout
if ds is not None and ip is None:
continue
# If we have messages but no dataset, that's fine
# If we have no dataset and no messages, that's also fine
combinations.append((m, ds, ip, im))
for etk in kwargs:
# Skip combinations that don't make sense
# If we have a dataset, we should have params for rollout
if ds is not None and ip is None:
continue
# If we have messages but no dataset, that's fine
# If we have no dataset and no messages, that's also fine
combinations.append((m, ds, ip, im, etk))

return combinations

Expand All @@ -141,27 +153,31 @@ def generate_combinations():
# Create parameter tuples for pytest.mark.parametrize
param_tuples = []
for combo in combinations:
model_name, dataset, params, messages = combo
model_name, dataset, params, messages, etk = combo
param_tuple = [model_name]
if input_dataset is not None:
param_tuple.append(dataset)
if input_params is not None:
if rollout_input_params is not None:
param_tuple.append(params)
if input_messages is not None:
param_tuple.append(messages)
if evaluation_test_kwargs is not None:
param_tuple.append(etk)
param_tuples.append(tuple(param_tuple))

# For batch mode, use the original parameter names
test_param_names = ["model"]
if input_dataset is not None:
test_param_names.append("dataset_path")
if input_params is not None:
if rollout_input_params is not None:
test_param_names.append("input_params")
if input_messages is not None:
test_param_names.append("input_messages")
if evaluation_test_kwargs is not None:
test_param_names.append("evaluation_test_kwargs")

# Create wrapper function with exact signature that pytest expects
def create_wrapper_with_signature():
def create_wrapper_with_signature() -> Callable:
# Create the function body that will be used
def wrapper_body(**kwargs):
model_name = kwargs["model"]
Expand Down Expand Up @@ -193,6 +209,7 @@ def wrapper_body(**kwargs):
result = execute_with_params(
test_func,
row=row,
evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {},
)
if result is None or not isinstance(result, EvaluationRow):
raise ValueError(
Expand All @@ -204,6 +221,7 @@ def wrapper_body(**kwargs):
results = execute_with_params(
test_func,
input_dataset=input_dataset,
evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {},
)
if results is None:
raise ValueError(
Expand Down Expand Up @@ -234,6 +252,7 @@ def wrapper_body(**kwargs):

wrapper = create_wrapper_with_signature()
wrapper = pytest.mark.parametrize(test_param_names, param_tuples)(wrapper)
wrapper.original_evaluation_test_func = test_func

return wrapper

Expand Down
5 changes: 3 additions & 2 deletions eval_protocol/pytest/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@

ModelParam = str # gpt-4o, gpt-4o-mini, accounts/fireworks/models/llama-3.1-8b-instruct
DatasetPathParam = str
InputParam = Dict[str, Any]
RolloutInputParam = Dict[str, Any]
InputMessagesParam = List[Message]
EvaluationInputParam = Dict[str, Any]

Dataset = List[EvaluationRow]

Expand All @@ -37,7 +38,7 @@
@dataclass
class RolloutProcessorConfig:
model: ModelParam
input_params: InputParam # optional input parameters for inference
input_params: RolloutInputParam # optional input parameters for inference
mcp_config_path: str # for agent rollout processor


Expand Down
12 changes: 0 additions & 12 deletions eval_protocol/pytest/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,6 @@ def execute_function(func: Callable, **kwargs) -> Any:
return results


def evaluate(
rows: List[EvaluationRow], reward_fn: Callable[..., EvaluateResult], **kwargs: Any
) -> List[EvaluationRow]:
"""Apply a reward function to each row and attach the result."""
evaluated: List[EvaluationRow] = []
for row in rows:
result = reward_fn(messages=row.messages, ground_truth=row.ground_truth, **kwargs)
row.evaluation_result = result
evaluated.append(row)
return evaluated


AggregationMethod = Literal["mean", "max", "min"]


Expand Down
5 changes: 5 additions & 0 deletions eval_protocol/rewards/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,11 @@ def math_reward(
require_units: bool = False,
**kwargs: Any,
) -> EvaluateResult:
"""
NOTE: This is the deprecated/old way of creating an eval in Eval Protocol.
What use to be the @reward_function decorator is now the @evaluation_test
decorator with the mode="pointwise" parameter.
"""
if (
not messages
or not isinstance(messages[-1], Message)
Expand Down
39 changes: 17 additions & 22 deletions tests/pytest/test_markdown_highlighting.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Any, Dict, List, Optional

from eval_protocol.models import EvaluateResult, EvaluationRow, Message
from eval_protocol.pytest import evaluation_test, default_single_turn_rollout_processor, evaluate
from eval_protocol.pytest import default_single_turn_rollout_processor, evaluation_test


def markdown_dataset_to_evaluation_row(data: List[Dict[str, Any]]) -> List[EvaluationRow]:
Expand All @@ -21,17 +21,27 @@ def markdown_dataset_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evalu
]


def markdown_format_evaluate(messages: List[Message], ground_truth: Optional[str] = None, **kwargs) -> EvaluateResult:
@evaluation_test(
input_dataset=["tests/pytest/data/markdown_dataset.jsonl"],
dataset_adapter=markdown_dataset_to_evaluation_row,
model=["accounts/fireworks/models/llama-v3p1-8b-instruct"],
rollout_input_params=[{"temperature": 0.0, "max_tokens": 4096}],
threshold_of_success=1.0,
rollout_processor=default_single_turn_rollout_processor,
num_runs=1,
mode="pointwise",
)
def test_markdown_highlighting_evaluation(row: EvaluationRow) -> EvaluationRow:
"""
Evaluation function that checks if the model's response contains the required number of formatted sections.
"""

assistant_response = messages[-1].content
assistant_response = row.messages[-1].content

if not assistant_response:
return EvaluateResult(score=0.0, reason="❌ No assistant response found")

required_highlights = int(ground_truth)
required_highlights = int(row.ground_truth)

# Check if the response contains the required number of formatted sections
# e.g. **bold** or *italic*
Expand All @@ -50,26 +60,11 @@ def markdown_format_evaluate(messages: List[Message], ground_truth: Optional[str
meets_requirement = actual_count >= required_highlights

if meets_requirement:
return EvaluateResult(
row.evaluation_result = EvaluateResult(
score=1.0, reason=f"✅ Found {actual_count} highlighted sections (required: {required_highlights})"
)
else:
return EvaluateResult(
row.evaluation_result = EvaluateResult(
score=0.0, reason=f"❌ Only found {actual_count} highlighted sections (required: {required_highlights})"
)


@evaluation_test(
input_dataset=["tests/pytest/data/markdown_dataset.jsonl"],
dataset_adapter=markdown_dataset_to_evaluation_row,
model=["accounts/fireworks/models/llama-v3p1-8b-instruct"],
input_params=[{"temperature": 0.0, "max_tokens": 4096}],
threshold_of_success=1.0,
rollout_processor=default_single_turn_rollout_processor,
num_runs=1,
)
def test_markdown_highlighting_evaluation(rows: List[EvaluationRow]) -> List[EvaluationRow]:
"""
Test markdown highlighting validation using batch mode with evaluate().
"""
return evaluate(rows, markdown_format_evaluate)
return row
73 changes: 65 additions & 8 deletions tests/pytest/test_pytest_math_example.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,76 @@
from typing import List
from eval_protocol.models import EvaluationRow
from eval_protocol.pytest import default_single_turn_rollout_processor, evaluate, evaluation_test
from examples.math_example.main import evaluate as math_evaluate
from eval_protocol.models import EvaluateResult, EvaluationRow, MetricResult
from eval_protocol.pytest import default_single_turn_rollout_processor, evaluation_test
from eval_protocol.rewards.math import math_reward
from examples.math_example.main import check_think_answer_format
from tests.pytest.helper.gsm8k_to_evaluation_row import gsm8k_to_evaluation_row


@evaluation_test(
input_dataset=["development/gsm8k_sample.jsonl"],
dataset_adapter=gsm8k_to_evaluation_row,
model=["accounts/fireworks/models/kimi-k2-instruct"],
input_params=[{"temperature": 0.0}],
rollout_input_params=[{"temperature": 0.0}],
max_dataset_rows=5,
threshold_of_success=0.0,
rollout_processor=default_single_turn_rollout_processor,
mode="pointwise",
evaluation_test_kwargs=[
{"math_reward_kwargs": {"tolerance": 0.001, "absolute_tolerance": 1e-8, "require_units": False}}
],
)
def test_math_dataset(rows: List[EvaluationRow]) -> List[EvaluationRow]:
"""Run math evaluation on sample dataset using pytest interface."""
return evaluate(rows, math_evaluate)
def test_math_dataset(row: EvaluationRow, **kwargs) -> EvaluationRow:
"""
Evaluate math problem solving considering both accuracy and format.

This function demonstrates how to combine multiple evaluation criteria:
- Numerical accuracy using built-in math evaluation
- Format compliance checking for <think>...</think><answer>...</answer> structure

Args:
row: EvaluationRow containing the conversation messages and ground truth
**kwargs: Additional parameters (like math_reward_kwargs)

Returns:
EvaluationRow with the evaluation result
"""
# Get the assistant's response
assistant_message = row.messages[-1]
if isinstance(assistant_message, dict):
assistant_response = assistant_message.get("content", "")
else:
assistant_response = assistant_message.content or ""

# Evaluate numerical accuracy using built-in function
accuracy_result = math_reward(messages=row.messages, ground_truth=row.ground_truth, **kwargs["math_reward_kwargs"])

# Evaluate format compliance (looking for <think>...</think><answer>...</answer> format)
format_correct = check_think_answer_format(assistant_response)
format_score = 1.0 if format_correct else 0.0

# For math_example, accuracy takes priority - if accuracy is 0, overall score is 0
# If accuracy is 1, then format can contribute to the score
if accuracy_result.score == 0.0:
combined_score = 0.0
else:
combined_score = accuracy_result.score # Only accuracy matters for math_example

# Create metrics structure expected by tests
metrics = {
"accuracy_reward": MetricResult(
score=accuracy_result.score,
reason=f"Numerical accuracy: {accuracy_result.reason}",
is_score_valid=True,
),
"format_reward": MetricResult(
score=format_score,
reason=f"Format compliance: {'correct' if format_correct else 'incorrect'} <think>...</think><answer>...</answer> structure",
is_score_valid=True,
),
}

row.evaluation_result = EvaluateResult(
score=combined_score,
reason=f"Combined score: {combined_score:.2f} (accuracy: {accuracy_result.score:.2f}, format: {format_score:.2f})",
metrics=metrics,
)
return row
Loading
Loading