|
| 1 | +import asyncio |
1 | 2 | from typing import List |
2 | 3 |
|
3 | | -from openai import OpenAI |
| 4 | +from openai import AsyncOpenAI |
4 | 5 |
|
5 | 6 | from eval_protocol.auth import get_fireworks_api_base, get_fireworks_api_key |
6 | | -from eval_protocol.models import CompletionParams, EvaluationRow, InputMetadata, Message |
7 | | -from eval_protocol.pytest.types import Dataset, ModelParam, RolloutProcessorConfig |
| 7 | +from eval_protocol.models import EvaluationRow, Message |
| 8 | +from eval_protocol.pytest.types import RolloutProcessorConfig |
8 | 9 |
|
9 | 10 |
|
10 | | -def default_single_turn_rollout_processor( |
| 11 | +async def default_single_turn_rollout_processor( |
11 | 12 | rows: List[EvaluationRow], config: RolloutProcessorConfig |
12 | 13 | ) -> List[EvaluationRow]: |
13 | | - """Generate a single response from a Fireworks model.""" |
| 14 | + """Generate a single response from a Fireworks model concurrently.""" |
14 | 15 |
|
15 | 16 | api_key = get_fireworks_api_key() |
16 | 17 | api_base = get_fireworks_api_base() |
17 | | - client = OpenAI(api_key=api_key, base_url=f"{api_base}/inference/v1") |
| 18 | + client = AsyncOpenAI(api_key=api_key, base_url=f"{api_base}/inference/v1") |
18 | 19 |
|
19 | | - dataset: Dataset = [] |
20 | | - for row in rows: |
| 20 | + async def process_row(row: EvaluationRow) -> EvaluationRow: |
| 21 | + """Process a single row asynchronously.""" |
21 | 22 | if len(row.messages) == 0: |
22 | 23 | raise ValueError("Messages is empty. Please provide a non-empty dataset") |
23 | 24 |
|
24 | 25 | messages_payload = [{"role": m.role, "content": m.content} for m in row.messages] |
25 | 26 |
|
26 | | - response = client.chat.completions.create(model=config.model, messages=messages_payload, **config.input_params) |
| 27 | + response = await client.chat.completions.create( |
| 28 | + model=config.model, messages=messages_payload, **config.input_params |
| 29 | + ) |
27 | 30 | assistant_content = response.choices[0].message.content or "" |
28 | 31 | messages = list(row.messages) + [Message(role="assistant", content=assistant_content)] |
29 | | - processed = EvaluationRow( |
| 32 | + |
| 33 | + return EvaluationRow( |
30 | 34 | messages=messages, |
31 | | - ground_truth=row.ground_truth, |
32 | | - input_metadata=InputMetadata(completion_params=CompletionParams(model=config.model)), |
| 35 | + **row.model_dump(exclude={"messages"}), |
33 | 36 | ) |
34 | 37 |
|
35 | | - dataset.append(processed) |
| 38 | + # Process all rows concurrently |
| 39 | + tasks = [process_row(row) for row in rows] |
| 40 | + dataset = await asyncio.gather(*tasks) |
| 41 | + |
36 | 42 | return dataset |
0 commit comments