diff --git a/eval_protocol/utils/logs_server.py b/eval_protocol/utils/logs_server.py index 815d33ef..e2a675c7 100644 --- a/eval_protocol/utils/logs_server.py +++ b/eval_protocol/utils/logs_server.py @@ -87,18 +87,32 @@ async def _send_text_to_all_connections(self, text: str): return tasks = [] + failed_connections = [] + for connection in connections: try: tasks.append(connection.send_text(text)) except Exception as e: logger.error(f"Failed to send text to WebSocket: {e}") - with self._lock: - try: - self.active_connections.remove(connection) - except ValueError: - pass + failed_connections.append(connection) + + # Execute all sends in parallel if tasks: - await asyncio.gather(*tasks, return_exceptions=True) + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Check for any exceptions that occurred during execution + for i, result in enumerate(results): + if isinstance(result, Exception): + logger.error(f"Failed to send text to WebSocket: {result}") + failed_connections.append(connections[i]) + + # Remove all failed connections + with self._lock: + for connection in failed_connections: + try: + self.active_connections.remove(connection) + except ValueError: + pass def start_broadcast_loop(self): """Start the broadcast loop in the current event loop.""" @@ -109,6 +123,7 @@ def stop_broadcast_loop(self): """Stop the broadcast loop.""" if self._broadcast_task and not self._broadcast_task.done(): self._broadcast_task.cancel() + self._broadcast_task = None class EvaluationWatcher: diff --git a/eval_protocol/utils/vite_server.py b/eval_protocol/utils/vite_server.py index 4c3143e2..02eef31d 100644 --- a/eval_protocol/utils/vite_server.py +++ b/eval_protocol/utils/vite_server.py @@ -97,7 +97,17 @@ def _setup_routes(self): # Mount static files self.app.mount("/assets", StaticFiles(directory=self.build_dir / "assets"), name="assets") - # Serve other static files from build directory + @self.app.get("/") + async def root(): + """Serve the main index.html file with injected configuration.""" + return self._serve_index_with_config() + + @self.app.get("/health") + async def health(): + """Health check endpoint.""" + return {"status": "ok", "build_dir": str(self.build_dir)} + + # Serve other static files from build directory - this must be last @self.app.get("/{path:path}") async def serve_spa(path: str): """ @@ -114,22 +124,12 @@ async def serve_spa(path: str): # For SPA routing, serve index.html for non-existent routes # but exclude API routes and asset requests - if not path.startswith(("api/", "assets/")): + if not path.startswith(("api/", "assets/", "health")): return self._serve_index_with_config() # If we get here, the file doesn't exist and it's not a SPA route raise HTTPException(status_code=404, detail="File not found") - @self.app.get("/") - async def root(): - """Serve the main index.html file with injected configuration.""" - return self._serve_index_with_config() - - @self.app.get("/health") - async def health(): - """Health check endpoint.""" - return {"status": "ok", "build_dir": str(self.build_dir)} - def run(self): """ Run the Vite server. diff --git a/tests/pytest/test_pytest_input_messages.py b/tests/pytest/test_pytest_input_messages.py index c1b643d0..edb69b83 100644 --- a/tests/pytest/test_pytest_input_messages.py +++ b/tests/pytest/test_pytest_input_messages.py @@ -1,6 +1,6 @@ from typing import List -from eval_protocol.models import Message, EvaluationRow +from eval_protocol.models import EvaluationRow, Message from eval_protocol.pytest import default_single_turn_rollout_processor, evaluation_test @@ -10,7 +10,7 @@ Message(role="user", content="What is the capital of France?"), ] ], - model=["fireworks_ai/accounts/fireworks/models/kimi-k2-instruct"], + model=["fireworks_ai/accounts/fireworks/models/gpt-oss-120b"], rollout_processor=default_single_turn_rollout_processor, ) def test_input_messages_in_decorator(rows: List[EvaluationRow]) -> List[EvaluationRow]: diff --git a/tests/pytest/test_pytest_json_schema.py b/tests/pytest/test_pytest_json_schema.py index 8463f873..3c18ff2b 100644 --- a/tests/pytest/test_pytest_json_schema.py +++ b/tests/pytest/test_pytest_json_schema.py @@ -1,5 +1,6 @@ import json from typing import Any, Dict, List + from eval_protocol.models import EvaluationRow from eval_protocol.pytest import default_single_turn_rollout_processor, evaluation_test from eval_protocol.rewards.json_schema import json_schema_reward @@ -23,7 +24,7 @@ def json_schema_to_evaluation_row(rows: List[Dict[str, Any]]) -> List[Evaluation @evaluation_test( input_dataset=["tests/pytest/data/json_schema.jsonl"], - model=["fireworks_ai/accounts/fireworks/models/kimi-k2-instruct"], + model=["fireworks_ai/accounts/fireworks/models/gpt-oss-120b"], mode="pointwise", rollout_processor=default_single_turn_rollout_processor, dataset_adapter=json_schema_to_evaluation_row, diff --git a/tests/test_logs_server.py b/tests/test_logs_server.py new file mode 100644 index 00000000..c24aeab5 --- /dev/null +++ b/tests/test_logs_server.py @@ -0,0 +1,602 @@ +import asyncio +import json +import tempfile +import threading +import time +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +import psutil +import pytest +from fastapi import FastAPI +from fastapi.routing import APIWebSocketRoute +from fastapi.testclient import TestClient + +from eval_protocol.dataset_logger import default_logger +from eval_protocol.dataset_logger.dataset_logger import LOG_EVENT_TYPE +from eval_protocol.event_bus import event_bus +from eval_protocol.models import EvalMetadata, EvaluationRow, InputMetadata, Message +from eval_protocol.utils.logs_server import ( + EvaluationWatcher, + LogsServer, + WebSocketManager, + create_app, + serve_logs, +) + + +class TestWebSocketManager: + """Test WebSocketManager class.""" + + def test_initialization(self): + """Test WebSocketManager initialization.""" + manager = WebSocketManager() + assert len(manager.active_connections) == 0 + assert manager._broadcast_queue is not None + assert manager._broadcast_task is None + + @pytest.mark.asyncio + async def test_connect_disconnect(self): + """Test WebSocket connection and disconnection.""" + manager = WebSocketManager() + mock_websocket = AsyncMock() + + # Test connection + with patch.object(default_logger, "read", return_value=[]): + await manager.connect(mock_websocket) + assert len(manager.active_connections) == 1 + assert mock_websocket in manager.active_connections + mock_websocket.accept.assert_called_once() + + # Test disconnection + manager.disconnect(mock_websocket) + assert len(manager.active_connections) == 0 + assert mock_websocket not in manager.active_connections + + @pytest.mark.asyncio + async def test_connect_sends_initial_logs(self): + """Test that connecting sends initial logs.""" + manager = WebSocketManager() + mock_websocket = AsyncMock() + + # Mock default_logger.read() + mock_logs = [ + EvaluationRow( + messages=[Message(role="user", content="test")], + input_metadata=InputMetadata(row_id="test-123"), + ) + ] + + with patch.object(default_logger, "read", return_value=mock_logs): + await manager.connect(mock_websocket) + + # Verify that initial logs were sent + mock_websocket.send_text.assert_called_once() + sent_data = json.loads(mock_websocket.send_text.call_args[0][0]) + assert sent_data["type"] == "initialize_logs" + assert len(sent_data["logs"]) == 1 + + def test_broadcast_row_upserted(self): + """Test broadcasting row upsert events.""" + manager = WebSocketManager() + test_row = EvaluationRow( + messages=[Message(role="user", content="test")], + input_metadata=InputMetadata(row_id="test-123"), + ) + + # Test that broadcast doesn't fail when no connections + manager.broadcast_row_upserted(test_row) + + # Test that message is queued + assert not manager._broadcast_queue.empty() + queued_message = manager._broadcast_queue.get_nowait() + data = json.loads(queued_message) + assert data["type"] == "log" + assert "row" in data + assert data["row"]["messages"][0]["content"] == "test" + assert data["row"]["input_metadata"]["row_id"] == "test-123" + + @pytest.mark.asyncio + async def test_broadcast_loop(self): + """Test the broadcast loop functionality.""" + manager = WebSocketManager() + mock_websocket = AsyncMock() + await manager.connect(mock_websocket) + + # Test that broadcast loop can be started and stopped + manager.start_broadcast_loop() + assert manager._broadcast_task is not None + + # Stop broadcast loop + manager.stop_broadcast_loop() + assert manager._broadcast_task is None + + @pytest.mark.asyncio + async def test_send_text_to_all_connections(self): + """Test sending text to all connections.""" + manager = WebSocketManager() + mock_websocket1 = AsyncMock() + mock_websocket2 = AsyncMock() + + # Mock default_logger.read() to return empty logs + with patch.object(default_logger, "read", return_value=[]): + await manager.connect(mock_websocket1) + await manager.connect(mock_websocket2) + + test_message = "test message" + await manager._send_text_to_all_connections(test_message) + + # Check that the test message was sent to both websockets + mock_websocket1.send_text.assert_any_call(test_message) + mock_websocket2.send_text.assert_any_call(test_message) + + @pytest.mark.asyncio + async def test_send_text_handles_failed_connections(self): + """Test that failed connections are handled gracefully.""" + manager = WebSocketManager() + mock_websocket1 = AsyncMock() + mock_websocket2 = AsyncMock() + + # Mock default_logger.read() to return empty logs + with patch.object(default_logger, "read", return_value=[]): + await manager.connect(mock_websocket1) + await manager.connect(mock_websocket2) + + # Make the second websocket fail AFTER connection is established + # We need to make send_text raise an exception when awaited + async def failing_send_text(text): + raise Exception("Connection failed") + + mock_websocket2.send_text = failing_send_text + + test_message = "test message" + await manager._send_text_to_all_connections(test_message) + + # First websocket should receive the message + mock_websocket1.send_text.assert_any_call(test_message) + # Second websocket should have been removed due to failure + assert len(manager.active_connections) == 1 + assert mock_websocket1 in manager.active_connections + + +class TestEvaluationWatcher: + """Test EvaluationWatcher class.""" + + def test_initialization(self): + """Test EvaluationWatcher initialization.""" + mock_manager = Mock() + watcher = EvaluationWatcher(mock_manager) + assert watcher.websocket_manager == mock_manager + assert watcher._thread is None + assert watcher._stop_event is not None + + def test_start_stop(self): + """Test starting and stopping the watcher.""" + mock_manager = Mock() + watcher = EvaluationWatcher(mock_manager) + + # Test start + watcher.start() + assert watcher._thread is not None + assert watcher._thread.is_alive() + + # Test stop + watcher.stop() + assert watcher._stop_event.is_set() + if watcher._thread: + watcher._thread.join(timeout=1.0) + + @patch("psutil.Process") + def test_should_update_status_running_process(self, mock_process): + """Test status update for running process.""" + mock_manager = Mock() + watcher = EvaluationWatcher(mock_manager) + + # Mock a running process + mock_process_instance = Mock() + mock_process_instance.is_running.return_value = True + mock_process.return_value = mock_process_instance + + test_row = EvaluationRow( + messages=[Message(role="user", content="test")], + input_metadata=InputMetadata(row_id="test-123"), + eval_metadata=EvalMetadata(name="test_eval", num_runs=1, aggregation_method="mean", status="running"), + pid=12345, + ) + + # Process is running, should not update + assert watcher._should_update_status(test_row) is False + + @patch("psutil.Process") + def test_should_update_status_stopped_process(self, mock_process): + """Test status update for stopped process.""" + mock_manager = Mock() + watcher = EvaluationWatcher(mock_manager) + + # Mock a stopped process + mock_process_instance = Mock() + mock_process_instance.is_running.return_value = False + mock_process.return_value = mock_process_instance + + test_row = EvaluationRow( + messages=[Message(role="user", content="test")], + input_metadata=InputMetadata(row_id="test-123"), + eval_metadata=EvalMetadata(name="test_eval", num_runs=1, aggregation_method="mean", status="running"), + pid=12345, + ) + + # Process is stopped, should update + assert watcher._should_update_status(test_row) is True + + @patch("psutil.Process") + def test_should_update_status_no_such_process(self, mock_process): + """Test status update for non-existent process.""" + mock_manager = Mock() + watcher = EvaluationWatcher(mock_manager) + + # Mock process not found + mock_process.side_effect = psutil.NoSuchProcess(pid=999) + + test_row = EvaluationRow( + messages=[Message(role="user", content="test")], + input_metadata=InputMetadata(row_id="test-123"), + eval_metadata=EvalMetadata(name="test_eval", num_runs=1, aggregation_method="mean", status="running"), + pid=999, + ) + + # Process doesn't exist, should update + assert watcher._should_update_status(test_row) is True + + def test_should_update_status_not_running(self): + """Test status update for non-running evaluation.""" + mock_manager = Mock() + watcher = EvaluationWatcher(mock_manager) + + test_row = EvaluationRow( + messages=[Message(role="user", content="test")], + input_metadata=InputMetadata(row_id="test-123"), + eval_metadata=EvalMetadata(name="test_eval", num_runs=1, aggregation_method="mean", status="finished"), + pid=12345, + ) + + # Not running status, should not update + assert watcher._should_update_status(test_row) is False + + def test_should_update_status_no_pid(self): + """Test status update for evaluation without PID.""" + mock_manager = Mock() + watcher = EvaluationWatcher(mock_manager) + + test_row = EvaluationRow( + messages=[Message(role="user", content="test")], + input_metadata=InputMetadata(row_id="test-123"), + eval_metadata=EvalMetadata(name="test_eval", num_runs=1, aggregation_method="mean", status="running"), + pid=None, + ) + + # No PID, should not update + assert watcher._should_update_status(test_row) is False + + +class TestLogsServer: + """Test LogsServer class.""" + + @pytest.fixture + def temp_build_dir(self): + """Create a temporary build directory for testing.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + # Create a minimal index.html file + (temp_path / "index.html").write_text("
Test") + # Create assets directory (required by ViteServer) + (temp_path / "assets").mkdir(exist_ok=True) + yield temp_path + + def test_initialization(self, temp_build_dir: Path): + """Test LogsServer initialization.""" + server = LogsServer(build_dir=str(temp_build_dir)) + assert server.build_dir == temp_build_dir + assert server.websocket_manager is not None + assert server.evaluation_watcher is not None + + def test_initialization_invalid_build_dir(self): + """Test LogsServer initialization with invalid build directory.""" + with pytest.raises(FileNotFoundError, match="Build directory '/nonexistent/path' does not exist"): + LogsServer(build_dir="/nonexistent/path") + + def test_websocket_routes(self, temp_build_dir): + """Test that WebSocket routes are properly set up.""" + server = LogsServer(build_dir=str(temp_build_dir)) + + # Check that the WebSocket endpoint exists + if not server.app.routes: + raise ValueError("No routes found") + for route in server.app.routes: + if isinstance(route, APIWebSocketRoute) and route.path == "/ws": + break + else: + raise ValueError("WebSocket route not found") + + @pytest.mark.asyncio + async def test_handle_event(self, temp_build_dir): + """Test event handling.""" + server = LogsServer(build_dir=str(temp_build_dir)) + + # Test handling a log event + test_row = { + "messages": [{"role": "user", "content": "test"}], + "input_metadata": {"row_id": "test-123"}, + } + + server._handle_event(LOG_EVENT_TYPE, test_row) + # The event should be queued for broadcasting + assert not server.websocket_manager._broadcast_queue.empty() + + def test_create_app_factory(self, temp_build_dir): + """Test the create_app factory function.""" + app = create_app(build_dir=str(temp_build_dir)) + assert isinstance(app, FastAPI) + + def test_serve_logs_convenience_function(self, temp_build_dir): + """Test the serve_logs convenience function.""" + # Mock the LogsServer.run method to avoid actually starting a server + with patch("eval_protocol.utils.logs_server.LogsServer.run") as mock_run: + # This should not raise an error + serve_logs(port=8001) + # Verify that the run method was called + mock_run.assert_called_once() + + def test_serve_logs_port_parameter(self, temp_build_dir): + """Test that serve_logs properly passes the port parameter to LogsServer.""" + with patch("eval_protocol.utils.logs_server.LogsServer") as mock_logs_server_class: + mock_server_instance = Mock() + mock_logs_server_class.return_value = mock_server_instance + + # Call serve_logs with a specific port + test_port = 9000 + serve_logs(port=test_port) + + # Verify that LogsServer was created with the correct port + mock_logs_server_class.assert_called_once_with(port=test_port) + # Verify that the run method was called on the instance + mock_server_instance.run.assert_called_once() + + def test_serve_logs_default_port(self, temp_build_dir): + """Test that serve_logs uses default port when none is specified.""" + with patch("eval_protocol.utils.logs_server.LogsServer") as mock_logs_server_class: + mock_server_instance = Mock() + mock_logs_server_class.return_value = mock_server_instance + + # Call serve_logs without specifying a port + serve_logs() + + # Verify that LogsServer was created with None port (which will use LogsServer's default of 8000) + mock_logs_server_class.assert_called_once_with(port=None) + # Verify that the run method was called on the instance + mock_server_instance.run.assert_called_once() + + @pytest.mark.asyncio + async def test_run_async_lifecycle(self, temp_build_dir): + """Test the async lifecycle of the server.""" + server = LogsServer(build_dir=str(temp_build_dir)) + + # Mock the uvicorn.Server to avoid actually starting a server + with patch("uvicorn.Server") as mock_uvicorn_server: + mock_server = Mock() + mock_server.serve = AsyncMock() + mock_uvicorn_server.return_value = mock_server + + # Start the server + start_task = asyncio.create_task(server.run_async()) + + # Wait a bit for it to start + await asyncio.sleep(0.1) + + # Cancel the task instead of calling non-existent stop method + start_task.cancel() + + # Wait for the task to complete + try: + await start_task + except asyncio.CancelledError: + pass + + +class TestLogsServerIntegration: + """Integration tests for LogsServer.""" + + @pytest.fixture + def temp_build_dir_with_files(self): + """Create a temporary build directory with index.html and assets/ directory.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + # Create index.html + (temp_path / "index.html").write_text("Test") + + # Create assets directory and some files inside it + assets_dir = temp_path / "assets" + assets_dir.mkdir() + (assets_dir / "app.js").write_text("console.log('test');") + (assets_dir / "style.css").write_text("body { color: black; }") + + # Optionally, create a nested directory inside assets + (assets_dir / "nested").mkdir() + (assets_dir / "nested" / "file.txt").write_text("nested content") + + yield temp_path + + def test_static_file_serving(self, temp_build_dir_with_files): + """Test that static files are served correctly.""" + server = LogsServer(build_dir=str(temp_build_dir_with_files)) + client = TestClient(server.app) + + # Test serving index.html + response = client.get("/") + assert response.status_code == 200 + assert "Test" in response.text + + # Test serving static files + response = client.get("/assets/app.js") + assert response.status_code == 200 + assert "console.log('test')" in response.text + + response = client.get("/assets/style.css") + assert response.status_code == 200 + assert "color: black" in response.text + + def test_spa_routing(self, temp_build_dir_with_files): + """Test SPA routing fallback.""" + server = LogsServer(build_dir=str(temp_build_dir_with_files)) + client = TestClient(server.app) + + # Test that non-existent routes fall back to index.html + response = client.get("/some/nonexistent/route") + assert response.status_code == 200 + assert "Test" in response.text + + def test_root_endpoint(self, temp_build_dir_with_files): + """Test the root endpoint.""" + server = LogsServer(build_dir=str(temp_build_dir_with_files)) + client = TestClient(server.app) + + response = client.get("/") + assert response.status_code == 200 + assert "Test" in response.text + + def test_health_endpoint(self, temp_build_dir_with_files): + """Test the health endpoint.""" + server = LogsServer(build_dir=str(temp_build_dir_with_files)) + client = TestClient(server.app) + + response = client.get("/health") + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ok" + + @pytest.mark.asyncio + async def test_server_runs_on_specific_port(self, temp_build_dir_with_files): + """Integration test: verify that LogsServer actually runs on the specified port (async requests).""" + import socket + + import httpx + + # Find an available port for testing + def find_free_port(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + s.listen(1) + port = s.getsockname()[1] + return port + + test_port = find_free_port() + + # Create and start server in background + server = LogsServer(build_dir=str(temp_build_dir_with_files), port=test_port) + + # Start server in background task + server_task = asyncio.create_task(server.run_async()) + + try: + # Wait longer for server to start and be ready + await asyncio.sleep(3) + + async with httpx.AsyncClient() as client: + # Test that we can actually connect to the server on the specified port + response = await client.get(f"http://localhost:{test_port}/", timeout=10) + assert response.status_code == 200 + assert "Test" in response.text + + # Test the health endpoint + response = await client.get(f"http://localhost:{test_port}/health", timeout=10) + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ok" + + finally: + # Clean up + server_task.cancel() + try: + await server_task + except asyncio.CancelledError: + pass + + def test_serve_logs_port_parameter_integration(self, temp_build_dir_with_files): + """Integration test: verify that serve_logs function actually works with port parameter.""" + # This test verifies that serve_logs creates LogsServer with the correct port + # without actually starting the server + test_port = 9999 + + # Use a different approach - mock the LogsServer class and verify the port parameter + with patch("eval_protocol.utils.logs_server.LogsServer") as mock_logs_server_class: + mock_server_instance = Mock() + mock_logs_server_class.return_value = mock_server_instance + + # Call serve_logs with specific port + serve_logs(port=test_port) + + # Verify that LogsServer was created with the correct port + mock_logs_server_class.assert_called_once_with(port=test_port) + # Verify that the run method was called on the instance + mock_server_instance.run.assert_called_once() + + +@pytest.mark.asyncio +class TestAsyncWebSocketOperations: + """Test async WebSocket operations.""" + + async def test_websocket_connection_lifecycle(self): + """Test complete WebSocket connection lifecycle.""" + manager = WebSocketManager() + + # Create mock WebSocket + mock_websocket = AsyncMock() + + # Test connection + with patch.object(default_logger, "read", return_value=[]): + await manager.connect(mock_websocket) + assert len(manager.active_connections) == 1 + + # Test broadcasting without starting the loop + test_row = EvaluationRow( + messages=[Message(role="user", content="test")], + input_metadata=InputMetadata(row_id="test-123"), + ) + manager.broadcast_row_upserted(test_row) + + # Verify message was queued + assert not manager._broadcast_queue.empty() + + # Test disconnection + manager.disconnect(mock_websocket) + assert len(manager.active_connections) == 0 + + async def test_multiple_websocket_connections(self): + """Test handling multiple WebSocket connections.""" + manager = WebSocketManager() + + # Create multiple mock WebSockets + mock_websocket1 = AsyncMock() + mock_websocket2 = AsyncMock() + mock_websocket3 = AsyncMock() + + # Connect all + with patch.object(default_logger, "read", return_value=[]): + await manager.connect(mock_websocket1) + await manager.connect(mock_websocket2) + await manager.connect(mock_websocket3) + assert len(manager.active_connections) == 3 + + # Test broadcasting to all without starting the loop + test_row = EvaluationRow( + messages=[Message(role="user", content="test")], + input_metadata=InputMetadata(row_id="test-123"), + ) + manager.broadcast_row_upserted(test_row) + + # Verify message was queued + assert not manager._broadcast_queue.empty() + + # Disconnect one + manager.disconnect(mock_websocket2) + assert len(manager.active_connections) == 2 diff --git a/tests/test_logs_server_simple.py b/tests/test_logs_server_simple.py new file mode 100644 index 00000000..98bc3e47 --- /dev/null +++ b/tests/test_logs_server_simple.py @@ -0,0 +1,88 @@ +import json +import tempfile +from pathlib import Path +from unittest.mock import AsyncMock, Mock, patch + +import pytest + +from eval_protocol.utils.logs_server import EvaluationWatcher, WebSocketManager + + +class TestWebSocketManagerBasic: + """Basic tests for WebSocketManager without starting real loops.""" + + def test_initialization(self): + """Test WebSocketManager initialization.""" + manager = WebSocketManager() + assert len(manager.active_connections) == 0 + assert manager._broadcast_queue is not None + assert manager._broadcast_task is None + + @pytest.mark.asyncio + async def test_connect_disconnect(self): + """Test WebSocket connection and disconnection.""" + manager = WebSocketManager() + mock_websocket = AsyncMock() + + # Test connection + await manager.connect(mock_websocket) + assert len(manager.active_connections) == 1 + assert mock_websocket in manager.active_connections + mock_websocket.accept.assert_called_once() + + # Test disconnection + manager.disconnect(mock_websocket) + assert len(manager.active_connections) == 0 + assert mock_websocket not in manager.active_connections + + def test_broadcast_row_upserted(self): + """Test broadcasting row upsert events.""" + manager = WebSocketManager() + + # Create a simple mock row + mock_row = Mock() + mock_row.model_dump.return_value = {"id": "test-123", "content": "test"} + + # Test that broadcast doesn't fail when no connections + manager.broadcast_row_upserted(mock_row) + + # Test that message is queued + assert not manager._broadcast_queue.empty() + queued_message = manager._broadcast_queue.get_nowait() + assert "type" in queued_message + assert "row" in queued_message + json_message = json.loads(queued_message) + assert json_message["row"]["id"] == "test-123" + assert json_message["row"]["content"] == "test" + + +class TestEvaluationWatcherBasic: + """Basic tests for EvaluationWatcher without starting real threads.""" + + def test_initialization(self): + """Test EvaluationWatcher initialization.""" + mock_manager = Mock() + watcher = EvaluationWatcher(mock_manager) + assert watcher.websocket_manager == mock_manager + assert watcher._thread is None + assert watcher._stop_event is not None + + def test_start_stop(self): + """Test starting and stopping the watcher.""" + mock_manager = Mock() + watcher = EvaluationWatcher(mock_manager) + + # Test start + watcher.start() + assert watcher._thread is not None + assert watcher._thread.is_alive() + + # Test stop + watcher.stop() + assert watcher._stop_event.is_set() + if watcher._thread: + watcher._thread.join(timeout=1.0) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_vite_server.py b/tests/test_vite_server.py new file mode 100644 index 00000000..eb4a59df --- /dev/null +++ b/tests/test_vite_server.py @@ -0,0 +1,224 @@ +import tempfile +from pathlib import Path + +import pytest +from fastapi.testclient import TestClient + +from eval_protocol.utils.vite_server import ViteServer + + +class TestViteServer: + """Test ViteServer class.""" + + @pytest.fixture + def temp_build_dir_with_files(self): + """Create a temporary build directory with index.html and assets/ directory.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + # Create index.html + (temp_path / "index.html").write_text("Test") + + # Create assets directory and some files inside it + assets_dir = temp_path / "assets" + assets_dir.mkdir() + (assets_dir / "app.js").write_text("console.log('test');") + (assets_dir / "style.css").write_text("body { color: black; }") + + # Optionally, create a nested directory inside assets + (assets_dir / "nested").mkdir() + (assets_dir / "nested" / "file.txt").write_text("nested content") + + yield temp_path + + def test_initialization(self, temp_build_dir_with_files): + """Test ViteServer initialization.""" + vite_server = ViteServer(build_dir=str(temp_build_dir_with_files), host="localhost", port=8000) + + assert vite_server.build_dir == temp_build_dir_with_files + assert vite_server.host == "localhost" + assert vite_server.port == 8000 + assert vite_server.index_file == "index.html" + assert vite_server.app is not None + + def test_initialization_invalid_build_dir(self): + """Test ViteServer initialization with invalid build directory.""" + with pytest.raises(FileNotFoundError): + ViteServer(build_dir="nonexistent_dir") + + def test_initialization_invalid_index_file(self, temp_build_dir_with_files): + """Test ViteServer initialization with invalid index file.""" + # Remove the index.html file + (temp_build_dir_with_files / "index.html").unlink() + + with pytest.raises(FileNotFoundError): + ViteServer(build_dir=str(temp_build_dir_with_files)) + + def test_html_injection_in_vite_server(self, temp_build_dir_with_files): + """Test that ViteServer injects server configuration into HTML.""" + # Create a more complex HTML file for testing injection + index_html = """ + + + + +No head tag
+ +""" + + (temp_build_dir_with_files / "index.html").write_text(simple_html) + + vite_server = ViteServer(build_dir=str(temp_build_dir_with_files), host="127.0.0.1", port=9000) + + injected_html = vite_server._inject_config_into_html(simple_html) + + # Verify config is injected at the beginnin + assert injected_html.strip().startswith("