Skip to content

Commit aa31b82

Browse files
author
Dylan Huang
committed
run single turn rollouts in parallel
1 parent b749f9c commit aa31b82

File tree

1 file changed

+19
-13
lines changed

1 file changed

+19
-13
lines changed
Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,42 @@
1+
import asyncio
12
from typing import List
23

3-
from openai import OpenAI
4+
from openai import AsyncOpenAI
45

56
from eval_protocol.auth import get_fireworks_api_base, get_fireworks_api_key
6-
from eval_protocol.models import CompletionParams, EvaluationRow, InputMetadata, Message
7-
from eval_protocol.pytest.types import Dataset, ModelParam, RolloutProcessorConfig
7+
from eval_protocol.models import EvaluationRow, Message
8+
from eval_protocol.pytest.types import RolloutProcessorConfig
89

910

10-
def default_single_turn_rollout_processor(
11+
async def default_single_turn_rollout_processor(
1112
rows: List[EvaluationRow], config: RolloutProcessorConfig
1213
) -> List[EvaluationRow]:
13-
"""Generate a single response from a Fireworks model."""
14+
"""Generate a single response from a Fireworks model concurrently."""
1415

1516
api_key = get_fireworks_api_key()
1617
api_base = get_fireworks_api_base()
17-
client = OpenAI(api_key=api_key, base_url=f"{api_base}/inference/v1")
18+
client = AsyncOpenAI(api_key=api_key, base_url=f"{api_base}/inference/v1")
1819

19-
dataset: Dataset = []
20-
for row in rows:
20+
async def process_row(row: EvaluationRow) -> EvaluationRow:
21+
"""Process a single row asynchronously."""
2122
if len(row.messages) == 0:
2223
raise ValueError("Messages is empty. Please provide a non-empty dataset")
2324

2425
messages_payload = [{"role": m.role, "content": m.content} for m in row.messages]
2526

26-
response = client.chat.completions.create(model=config.model, messages=messages_payload, **config.input_params)
27+
response = await client.chat.completions.create(
28+
model=config.model, messages=messages_payload, **config.input_params
29+
)
2730
assistant_content = response.choices[0].message.content or ""
2831
messages = list(row.messages) + [Message(role="assistant", content=assistant_content)]
29-
processed = EvaluationRow(
32+
33+
return EvaluationRow(
3034
messages=messages,
31-
ground_truth=row.ground_truth,
32-
input_metadata=InputMetadata(completion_params=CompletionParams(model=config.model)),
35+
**row.model_dump(exclude={"messages"}),
3336
)
3437

35-
dataset.append(processed)
38+
# Process all rows concurrently
39+
tasks = [process_row(row) for row in rows]
40+
dataset = await asyncio.gather(*tasks)
41+
3642
return dataset

0 commit comments

Comments
 (0)