diff --git a/eval_protocol/pytest/default_mcp_gym_rollout_processor.py b/eval_protocol/pytest/default_mcp_gym_rollout_processor.py index 3fe35f9b..bc68d346 100644 --- a/eval_protocol/pytest/default_mcp_gym_rollout_processor.py +++ b/eval_protocol/pytest/default_mcp_gym_rollout_processor.py @@ -2,6 +2,7 @@ import os import subprocess import time +import socket from pathlib import Path from typing import List, Optional @@ -69,11 +70,8 @@ def start(self) -> None: self._log_file = log_file self._log_file_path = log_file_path - # Wait for server to start - time.sleep(3) - - # Check if process is still running - if self.process.poll() is not None: + # Wait for server to be ready with proper health check + if not self._wait_for_server_ready(timeout=15): try: with open(self._log_file_path, "r") as f: log_content = f.read() @@ -82,13 +80,45 @@ def start(self) -> None: print("=" * 50) print(log_content) print("=" * 50) - raise RuntimeError(f"Server failed to start. Check log above for details.") + raise RuntimeError(f"Server failed to start or become ready. Check log above for details.") except Exception as e: stdout, stderr = self.process.communicate() - raise RuntimeError(f"Server failed to start. stderr: {stderr}, log error: {e}") + raise RuntimeError(f"Server failed to start or become ready. stderr: {stderr}, log error: {e}") print(f"✅ Server started successfully on port {self.port}") + def _wait_for_server_ready(self, timeout: int = 15) -> bool: + """ + Wait for server to be ready by polling socket connection. + """ + start_time = time.time() + health_check_failures = 0 + + while time.time() - start_time < timeout: + # Check if process is still running + if self.process.poll() is not None: + print(f"Server process exited early") + return False + + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.settimeout(1) + result = s.connect_ex(("localhost", self.port)) + if result == 0: + time.sleep(0.5) + return True + except Exception as e: + health_check_failures += 1 + # Print first few failures for debugging + if health_check_failures <= 3: + print(f"Health check failed: {e}") + + # Wait before next check + time.sleep(0.1) + + print(f"Server failed to become ready within {timeout} seconds") + return False + def stop(self) -> None: """Stop the MCP server.""" if self.process: