Skip to content

Commit 13a4b91

Browse files
committed
add tests
1 parent 4c9794b commit 13a4b91

File tree

6 files changed

+561
-41
lines changed

6 files changed

+561
-41
lines changed

optillm/mars/mars.py

Lines changed: 89 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
"""
2-
MARS: Multi-Agent Reasoning System main orchestration
2+
MARS: Multi-Agent Reasoning System main orchestration with parallel execution
33
"""
44

5+
import asyncio
56
import logging
67
from typing import Dict, Any, List, Tuple
78
from datetime import datetime
9+
from concurrent.futures import ThreadPoolExecutor
810
import optillm
911
from optillm import conversation_logger
1012

@@ -36,7 +38,7 @@ def multi_agent_reasoning_system(
3638
request_id: str = None
3739
) -> Tuple[str, int]:
3840
"""
39-
Main MARS function implementing multi-agent mathematical reasoning
41+
Main MARS function implementing multi-agent mathematical reasoning with parallel execution
4042
4143
Args:
4244
system_prompt: System-level instructions
@@ -48,12 +50,31 @@ def multi_agent_reasoning_system(
4850
Returns:
4951
Tuple of (final_solution, total_reasoning_tokens)
5052
"""
53+
return asyncio.run(_run_mars_parallel(
54+
system_prompt, initial_query, client, model, request_id
55+
))
56+
57+
async def _run_mars_parallel(
58+
system_prompt: str,
59+
initial_query: str,
60+
client,
61+
model: str,
62+
request_id: str = None
63+
) -> Tuple[str, int]:
64+
"""Async implementation of MARS with parallel execution"""
5165
logger.info(f"Starting MARS with model: {model}")
5266

5367
# Initialize configuration
5468
config = DEFAULT_CONFIG.copy()
5569
total_reasoning_tokens = 0
5670

71+
# Calculate optimal worker count for parallel execution
72+
max_workers = max(
73+
config['num_agents'], # For generation phase
74+
config['num_agents'] * min(2, config['verification_passes_required']) # For verification
75+
)
76+
logger.info(f"Using {max_workers} parallel workers")
77+
5778
# Initialize workspace for collaboration
5879
workspace = MARSWorkspace(initial_query, config)
5980

@@ -66,37 +87,41 @@ def multi_agent_reasoning_system(
6687

6788
logger.info(f"Initialized {len(agents)} agents with diverse temperatures")
6889

69-
# Phase 2: Multi-Agent Exploration
70-
logger.info("Phase 1: Multi-Agent Exploration")
71-
exploration_tokens = _run_exploration_phase(agents, workspace, request_id)
72-
total_reasoning_tokens += exploration_tokens
90+
# Create thread pool executor for parallel API calls
91+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
92+
# Phase 2: Multi-Agent Exploration (parallel)
93+
logger.info("Phase 1: Multi-Agent Exploration")
94+
exploration_tokens = await _run_exploration_phase_parallel(
95+
agents, workspace, request_id, executor
96+
)
97+
total_reasoning_tokens += exploration_tokens
7398

74-
# Phase 3: Verification System
75-
logger.info("Phase 2: Verification System")
76-
verifier = MARSVerifier(agents, workspace, config)
77-
verification_summary = verifier.verify_solutions(request_id)
99+
# Phase 3: Verification System (parallel)
100+
logger.info("Phase 2: Verification System")
101+
verifier = MARSVerifier(agents, workspace, config)
102+
verification_summary = await verifier.verify_solutions_parallel(request_id, executor)
78103

79-
# Phase 4: Iterative Improvement (if needed)
80-
iteration_count = 0
81-
while workspace.should_continue_iteration() and iteration_count < config['max_iterations']:
82-
iteration_count += 1
83-
logger.info(f"Phase 3: Iterative Improvement - Iteration {iteration_count}")
104+
# Phase 4: Iterative Improvement (if needed)
105+
iteration_count = 0
106+
while workspace.should_continue_iteration() and iteration_count < config['max_iterations']:
107+
iteration_count += 1
108+
logger.info(f"Phase 3: Iterative Improvement - Iteration {iteration_count}")
84109

85-
# Improve unverified solutions
86-
improvement_summary = verifier.iterative_improvement(request_id)
87-
total_reasoning_tokens += improvement_summary['total_reasoning_tokens']
110+
# Improve unverified solutions (parallel)
111+
improvement_summary = await verifier.iterative_improvement_parallel(request_id, executor)
112+
total_reasoning_tokens += improvement_summary['total_reasoning_tokens']
88113

89-
# Re-verify improved solutions
90-
verification_summary = verifier.verify_solutions(request_id)
114+
# Re-verify improved solutions (parallel)
115+
verification_summary = await verifier.verify_solutions_parallel(request_id, executor)
91116

92-
# Check for early termination
93-
if config['early_termination'] and workspace.has_consensus():
94-
logger.info("Early termination: consensus reached")
95-
break
117+
# Check for early termination
118+
if config['early_termination'] and workspace.has_consensus():
119+
logger.info("Early termination: consensus reached")
120+
break
96121

97-
workspace.iteration_count = iteration_count
122+
workspace.iteration_count = iteration_count
98123

99-
# Phase 5: Final Synthesis
124+
# Phase 5: Final Synthesis (sequential - needs all results)
100125
logger.info("Phase 4: Final Synthesis")
101126
final_solution, synthesis_tokens = _synthesize_final_solution(
102127
workspace, client, model, config, request_id
@@ -126,24 +151,50 @@ def multi_agent_reasoning_system(
126151
except:
127152
return error_response, 0
128153

129-
def _run_exploration_phase(agents: List[MARSAgent], workspace: MARSWorkspace, request_id: str = None) -> int:
130-
"""Run the multi-agent exploration phase"""
131-
total_tokens = 0
132-
133-
# Generate solutions from all agents in parallel (conceptually)
134-
for agent in agents:
154+
async def _run_exploration_phase_parallel(
155+
agents: List[MARSAgent],
156+
workspace: MARSWorkspace,
157+
request_id: str = None,
158+
executor: ThreadPoolExecutor = None
159+
) -> int:
160+
"""Run the multi-agent exploration phase with parallel execution"""
161+
162+
async def generate_solution_async(agent: MARSAgent):
163+
"""Async wrapper for agent solution generation"""
164+
loop = asyncio.get_event_loop()
135165
try:
136-
agent_solution, reasoning_tokens = agent.generate_solution(
137-
workspace.problem, request_id
166+
solution, tokens = await loop.run_in_executor(
167+
executor,
168+
agent.generate_solution,
169+
workspace.problem,
170+
request_id
138171
)
139-
workspace.add_solution(agent_solution)
140-
total_tokens += reasoning_tokens
141-
172+
return agent.agent_id, solution, tokens, None
142173
except Exception as e:
143174
logger.error(f"Agent {agent.agent_id} failed during exploration: {str(e)}")
175+
return agent.agent_id, None, 0, e
176+
177+
# Run all agents in parallel
178+
tasks = [generate_solution_async(agent) for agent in agents]
179+
results = await asyncio.gather(*tasks, return_exceptions=True)
180+
181+
total_tokens = 0
182+
successful_solutions = 0
183+
184+
for result in results:
185+
if isinstance(result, Exception):
186+
logger.error(f"Agent task failed: {str(result)}")
144187
continue
145188

146-
logger.info(f"Exploration phase complete: {len(workspace.solutions)} solutions generated")
189+
agent_id, solution, tokens, error = result
190+
if error is None and solution is not None:
191+
workspace.add_solution(solution)
192+
total_tokens += tokens
193+
successful_solutions += 1
194+
else:
195+
logger.error(f"Agent {agent_id} generated no solution")
196+
197+
logger.info(f"Exploration phase complete: {successful_solutions} solutions generated in parallel")
147198
return total_tokens
148199

149200
def _synthesize_final_solution(

optillm/mars/verifier.py

Lines changed: 148 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
"""
2-
MARS Verification system implementing 5-pass verification threshold
2+
MARS Verification system implementing 5-pass verification threshold with parallel execution
33
"""
44

5+
import asyncio
56
import logging
67
from typing import Dict, List, Any, Tuple
78
from datetime import datetime
9+
from concurrent.futures import ThreadPoolExecutor
810
from .workspace import MARSWorkspace, AgentSolution, VerificationResult
911
from .agent import MARSAgent
1012

@@ -50,6 +52,71 @@ def verify_solutions(self, request_id: str = None) -> Dict[str, Any]:
5052
logger.info(f"Verification complete: {verification_summary['solutions_verified']} solutions verified")
5153
return verification_summary
5254

55+
async def verify_solutions_parallel(
56+
self,
57+
request_id: str = None,
58+
executor: ThreadPoolExecutor = None
59+
) -> Dict[str, Any]:
60+
"""Run comprehensive verification on all solutions in workspace with parallel execution"""
61+
logger.info(f"Starting parallel verification process with {self.verification_threshold}-pass threshold")
62+
63+
verification_summary = {
64+
'total_verifications': 0,
65+
'solutions_verified': 0,
66+
'consensus_reached': False,
67+
'verification_details': []
68+
}
69+
70+
solutions = self.workspace.solutions
71+
if not solutions:
72+
logger.warning("No solutions to verify")
73+
return verification_summary
74+
75+
# Verify all solutions in parallel
76+
async def verify_solution_async(solution: AgentSolution):
77+
"""Async wrapper for single solution verification"""
78+
loop = asyncio.get_event_loop()
79+
try:
80+
result = await loop.run_in_executor(
81+
executor,
82+
self._verify_single_solution,
83+
solution,
84+
request_id
85+
)
86+
return result
87+
except Exception as e:
88+
logger.error(f"Verification failed for solution from agent {solution.agent_id}: {str(e)}")
89+
return {
90+
'solution_agent_id': solution.agent_id,
91+
'verification_count': 0,
92+
'consecutive_passes': 0,
93+
'passes_threshold': False,
94+
'verification_results': []
95+
}
96+
97+
# Run verifications in parallel
98+
tasks = [verify_solution_async(solution) for solution in solutions]
99+
results = await asyncio.gather(*tasks, return_exceptions=True)
100+
101+
# Process results
102+
for result in results:
103+
if isinstance(result, Exception):
104+
logger.error(f"Verification task failed: {str(result)}")
105+
continue
106+
107+
verification_summary['verification_details'].append(result)
108+
verification_summary['total_verifications'] += result['verification_count']
109+
110+
if result['passes_threshold']:
111+
verification_summary['solutions_verified'] += 1
112+
113+
# Check for consensus
114+
verified_solutions = self.workspace.get_verified_solutions()
115+
verification_summary['consensus_reached'] = len(verified_solutions) >= self.config.get('consensus_threshold', 2)
116+
117+
logger.info(f"Parallel verification complete: {verification_summary['solutions_verified']} solutions verified")
118+
return verification_summary
119+
53120
def _verify_single_solution(self, solution: AgentSolution, request_id: str = None) -> Dict[str, Any]:
54121
"""Verify a single solution with multiple passes"""
55122
logger.info(f"Verifying solution from agent {solution.agent_id}")
@@ -177,6 +244,86 @@ def iterative_improvement(self, request_id: str = None) -> Dict[str, Any]:
177244

178245
return improvement_summary
179246

247+
async def iterative_improvement_parallel(
248+
self,
249+
request_id: str = None,
250+
executor: ThreadPoolExecutor = None
251+
) -> Dict[str, Any]:
252+
"""Run iterative improvement on solutions that failed verification with parallel execution"""
253+
logger.info("Starting parallel iterative improvement process")
254+
255+
improvement_summary = {
256+
'solutions_improved': 0,
257+
'improvement_attempts': 0,
258+
'total_reasoning_tokens': 0
259+
}
260+
261+
# Get solutions that need improvement
262+
unverified_solutions = [s for s in self.workspace.solutions if not s.is_verified]
263+
264+
# Filter solutions that have verification feedback and can be improved
265+
improvable_solutions = []
266+
for solution in unverified_solutions:
267+
if solution.verification_results:
268+
latest_verification = solution.verification_results[-1]
269+
if latest_verification['assessment'] in ['INCORRECT', 'INCOMPLETE']:
270+
original_agent = next((a for a in self.agents if a.agent_id == solution.agent_id), None)
271+
if original_agent:
272+
improvable_solutions.append((solution, original_agent, latest_verification))
273+
274+
if not improvable_solutions:
275+
logger.info("No solutions need improvement")
276+
return improvement_summary
277+
278+
# Improve solutions in parallel
279+
async def improve_solution_async(solution_data):
280+
"""Async wrapper for solution improvement"""
281+
solution, agent, verification = solution_data
282+
loop = asyncio.get_event_loop()
283+
284+
try:
285+
improved_solution, reasoning_tokens = await loop.run_in_executor(
286+
executor,
287+
agent.improve_solution,
288+
self.workspace.problem,
289+
solution.solution,
290+
verification['detailed_report'],
291+
verification['issues'],
292+
request_id
293+
)
294+
295+
# Update solution with improvement
296+
solution.solution = improved_solution
297+
solution.timestamp = datetime.now()
298+
solution.reasoning_tokens += reasoning_tokens
299+
300+
logger.info(f"Improved solution from agent {solution.agent_id}")
301+
return solution.agent_id, True, reasoning_tokens, None
302+
303+
except Exception as e:
304+
logger.error(f"Failed to improve solution from agent {solution.agent_id}: {str(e)}")
305+
return solution.agent_id, False, 0, e
306+
307+
# Run improvements in parallel
308+
tasks = [improve_solution_async(sol_data) for sol_data in improvable_solutions]
309+
results = await asyncio.gather(*tasks, return_exceptions=True)
310+
311+
# Process results
312+
for result in results:
313+
improvement_summary['improvement_attempts'] += 1
314+
315+
if isinstance(result, Exception):
316+
logger.error(f"Improvement task failed: {str(result)}")
317+
continue
318+
319+
agent_id, success, tokens, error = result
320+
if success:
321+
improvement_summary['solutions_improved'] += 1
322+
improvement_summary['total_reasoning_tokens'] += tokens
323+
324+
logger.info(f"Parallel improvement complete: {improvement_summary['solutions_improved']} solutions improved")
325+
return improvement_summary
326+
180327
def final_consensus_check(self) -> bool:
181328
"""Final check to determine if consensus has been reached"""
182329
verified_solutions = self.workspace.get_verified_solutions()

scripts/eval_aime_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ def get_llm_response(problem: str, model: str, analyze_logits: bool = False, ext
306306
if extra_body:
307307
kwargs["extra_body"] = extra_body
308308

309-
response = client.with_options(timeout=3600.0).chat.completions.create(
309+
response = client.with_options(timeout=6000.0).chat.completions.create(
310310
model=model,
311311
messages=[
312312
{"role": "user", "content": SYSTEM_PROMPT + problem}

tests/test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from optillm.plansearch import plansearch
2727
from optillm.leap import leap
2828
from optillm.reread import re2_approach
29+
from optillm.mars import multi_agent_reasoning_system
2930
from optillm.cepo.cepo import cepo, CepoConfig, init_cepo_config
3031

3132
# Setup logging
@@ -57,6 +58,7 @@ def __init__(self):
5758
'plansearch': plansearch,
5859
'leap': leap,
5960
're2': re2_approach,
61+
'mars': multi_agent_reasoning_system,
6062
'cepo': lambda s, q, c, m: cepo(s,q,c,m,init_cepo_config({'cepo_config_file': './optillm/cepo/configs/cepo_config.yaml'})),
6163
}
6264

0 commit comments

Comments
 (0)