Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 49 additions & 23 deletions scripts/eval-runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,17 @@
except ImportError:
serialize_remote_eval_parameters_container = None
from braintrust.util import eprint

try:
from braintrust.util import bt_iscoroutinefunction
except ImportError:
def bt_iscoroutinefunction(f):
return (
inspect.iscoroutinefunction(f)
or inspect.isasyncgenfunction(f)
or getattr(f, "_BT_IS_ASYNC", False)
)

from braintrust.span_identifier_v4 import parse_parent
except Exception as exc: # pragma: no cover - runtime guard
print(
Expand Down Expand Up @@ -998,29 +1009,44 @@ def wrap_task(

takes_hooks = task_signature is not None and len(task_signature.parameters) >= 2

async def wrapped_task(input, hooks):
result = None
try:
if takes_hooks:
result = task(input, hooks)
else:
result = task(input)
if inspect.isawaitable(result):
result = await result
return result
finally:
if progress_cb is not None:
progress_cb("increment", None)
if stream_results and result is not None:
try:
hooks.report_progress({
"format": "code",
"output_type": "completion",
"event": "json_delta",
"data": json.dumps(result),
})
except Exception:
pass
def maybe_report(result: Any, hooks: Any) -> None:
if progress_cb is not None:
progress_cb("increment", None)
if stream_results and result is not None:
try:
hooks.report_progress({
"format": "code",
"output_type": "completion",
"event": "json_delta",
"data": json.dumps(result),
})
except Exception:
pass

if bt_iscoroutinefunction(task):
async def wrapped_task(input, hooks):
result = None
try:
if takes_hooks:
result = task(input, hooks)
else:
result = task(input)
if inspect.isawaitable(result):
result = await result
return result
finally:
maybe_report(result, hooks)
else:
def wrapped_task(input, hooks):
result = None
try:
if takes_hooks:
result = task(input, hooks)
else:
result = task(input)
return result
finally:
maybe_report(result, hooks)

if hasattr(task, "__name__"):
setattr(wrapped_task, "__name__", getattr(task, "__name__"))
Expand Down
74 changes: 74 additions & 0 deletions tests/eval_fixtures.rs
Original file line number Diff line number Diff line change
Expand Up @@ -927,6 +927,80 @@ fn eval_python_callable_list_data_preserves_parallel_scorers() {
);
}

#[test]
fn eval_python_sync_task_progress_wrapper_preserves_parallel_tasks() {
let _guard = test_lock();
let root = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
let fixtures_root = root.join("tests").join("evals");
let fixture_dir = fixtures_root.join("py").join("sync_task_parallelization");
let python = match ensure_python_env(&fixtures_root.join("py")) {
Some(python) => python,
None => {
if required_runtimes().contains("python") {
panic!("python runtime unavailable for sync task parallelization test");
}
eprintln!(
"Skipping eval_python_sync_task_progress_wrapper_preserves_parallel_tasks (python runtime unavailable)."
);
return;
}
};

let bt_path = bt_binary_path(&root);
let out_file = std::env::temp_dir().join(format!(
"bt-sync-task-parallel-{}.txt",
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system clock before epoch")
.as_nanos()
));

let output = Command::new(&bt_path)
.arg("eval")
.arg("--num-workers")
.arg("4")
.arg("--runner")
.arg(&python)
.arg("eval_sync_task_parallelization.py")
.current_dir(&fixture_dir)
.env("BT_EVAL_LOCAL", "1")
.env("BT_SYNC_TASK_PARALLEL_OUT", &out_file)
.env("BT_SYNC_TASK_SLEEP_S", "0.5")
.env(
"BRAINTRUST_API_KEY",
std::env::var("BRAINTRUST_API_KEY").unwrap_or_else(|_| "local".to_string()),
)
.output()
.expect("run bt eval python sync task parallelization");

let stdout = String::from_utf8_lossy(&output.stdout);
let stderr = String::from_utf8_lossy(&output.stderr);
assert!(
output.status.success(),
"bt eval sync task parallelization should succeed.\nstdout:\n{stdout}\nstderr:\n{stderr}"
);

let contents = fs::read_to_string(&out_file).unwrap_or_default();
let _ = fs::remove_file(&out_file);
let events: Vec<&str> = contents
.lines()
.filter(|line| !line.trim().is_empty())
.collect();
let first_end = events
.iter()
.position(|line| line.contains(" end "))
.expect("expected at least one end event");
let starts_before_first_end = events[..first_end]
.iter()
.filter(|line| line.contains(" start "))
.count();

assert!(
starts_before_first_end > 1,
"sync eval tasks should overlap under --num-workers; got event log {events:?}.\nstdout:\n{stdout}\nstderr:\n{stderr}"
);
}

fn read_fixture_config(path: &Path) -> FixtureConfig {
let raw = fs::read_to_string(path).expect("read fixture.json");
serde_json::from_str(&raw).expect("parse fixture.json")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import os
import threading
import time
from pathlib import Path

from braintrust import Eval, Score


def log_event(event, index):
out_file = os.environ["BT_SYNC_TASK_PARALLEL_OUT"]
with Path(out_file).open("a") as f:
f.write(f"{time.time():.6f} {event} {index} {threading.get_ident()}\n")


def task(input):
index = input["index"]
log_event("start", index)
time.sleep(float(os.getenv("BT_SYNC_TASK_SLEEP_S", "0.5")))
log_event("end", index)
return {"ok": True}


def score(input, output, expected):
return Score(name="ok", score=1.0)


Eval(
"sync-task-parallelization",
data=[{"input": {"index": index}, "expected": {}} for index in range(4)],
task=task,
scores=[score],
experiment_name="sync-task-parallelization",
max_concurrency=4,
)
Loading