From edfc0656b5b2fbd0c7b81055c4f3d70a92783cd8 Mon Sep 17 00:00:00 2001 From: ParticularlyPythonicBS Date: Fri, 30 Jan 2026 14:57:43 -0500 Subject: [PATCH] fix: trying to fix ssl certification issue while downloading on some macs --- backend/__init__.py | 1 + backend/main.py | 34 ++++++++++++++++++ backend/tests/test_download.py | 60 ++++++++++++++++++++++++++++++++ backend/utils.py | 27 +++++++++++++++ pyproject.toml | 1 + temoa_runner.py | 63 ++++++++++++++++++++++++++++------ uv.lock | 2 ++ 7 files changed, 178 insertions(+), 10 deletions(-) create mode 100644 backend/__init__.py create mode 100644 backend/tests/test_download.py create mode 100644 backend/utils.py diff --git a/backend/__init__.py b/backend/__init__.py new file mode 100644 index 0000000..7692d8d --- /dev/null +++ b/backend/__init__.py @@ -0,0 +1 @@ +# Package marker for Temoa Web GUI Backend diff --git a/backend/main.py b/backend/main.py index 5889893..32302e2 100644 --- a/backend/main.py +++ b/backend/main.py @@ -15,6 +15,10 @@ from datetime import datetime from pathlib import Path from typing import List, Optional +import urllib.request +import shutil + +from .utils import create_secure_ssl_context from fastapi import ( FastAPI, @@ -152,6 +156,36 @@ def list_files(path: str = "."): raise HTTPException(status_code=500, detail=str(e)) +@app.post("/api/download_tutorial") +def download_tutorial(): + """Downloads the tutorial database from the main repo.""" + assets_path = Path("assets") + assets_path.mkdir(parents=True, exist_ok=True) + target_path = assets_path / "tutorial_database.sqlite" + temp_path = target_path.with_suffix(".tmp") + + try: + url = "https://raw.githubusercontent.com/TemoaProject/temoa-web-gui/main/assets/tutorial_database.sqlite" + ctx = create_secure_ssl_context() + + with urllib.request.urlopen(url, context=ctx, timeout=10) as response: + with open(temp_path, "wb") as out_file: + shutil.copyfileobj(response, out_file) + + # Atomic replace + temp_path.replace(target_path) + return {"status": "ok", "path": str(target_path.absolute())} + except Exception as e: + logging.exception("Failed to download tutorial") + raise HTTPException(status_code=500, detail=f"Download failed: {str(e)}") from e + finally: + if temp_path.exists(): + try: + temp_path.unlink() + except Exception: + pass + + @app.get("/api/solvers") def list_solvers(): """Detect available solvers on the local system.""" diff --git a/backend/tests/test_download.py b/backend/tests/test_download.py new file mode 100644 index 0000000..2d0fc3d --- /dev/null +++ b/backend/tests/test_download.py @@ -0,0 +1,60 @@ +import pytest +from unittest.mock import patch, MagicMock +from fastapi.testclient import TestClient +from backend.main import app +import ssl + +client = TestClient(app) + + +@pytest.mark.parametrize("skip_verify", ["0", "1"]) +def test_download_tutorial_ssl_context(skip_verify, monkeypatch): + """ + Test that the SSL context is correctly configured based on TEMOA_SKIP_CERT_VERIFY. + This test is parametrized to ensure deterministic behavior. + """ + monkeypatch.setenv("TEMOA_SKIP_CERT_VERIFY", skip_verify) + + # Patch targets must be on the module that USES the functions + with patch("backend.main.urllib.request.urlopen") as mock_urlopen, patch( + "backend.main.shutil.copyfileobj" + ), patch("backend.main.open", new_callable=MagicMock), patch( + "backend.main.Path.replace" + ) as mock_replace, patch("backend.main.Path.unlink"), patch( + "backend.main.Path.exists", return_value=True + ): + # Configure the mock response + mock_response = MagicMock() + mock_urlopen.return_value.__enter__.return_value = mock_response + + response = client.post("/api/download_tutorial") + + assert response.status_code == 200 + assert response.json()["status"] == "ok" + + # Verify SSL context and timeout + _, kwargs = mock_urlopen.call_args + assert kwargs.get("timeout") == 10 + assert "context" in kwargs + ctx = kwargs["context"] + assert isinstance(ctx, ssl.SSLContext) + + if skip_verify == "1": + assert ctx.check_hostname is False + assert ctx.verify_mode == ssl.CERT_NONE + else: + assert ctx.check_hostname is True + assert ctx.verify_mode == ssl.CERT_REQUIRED + + # Verify atomic move was attempted + assert mock_replace.called + + +def test_download_tutorial_failure(): + # Patch on the actual module to ensure it's intercepted + with patch( + "backend.main.urllib.request.urlopen", side_effect=Exception("Network error") + ): + response = client.post("/api/download_tutorial") + assert response.status_code == 500 + assert "Download failed" in response.json()["detail"] diff --git a/backend/utils.py b/backend/utils.py new file mode 100644 index 0000000..6ff8cef --- /dev/null +++ b/backend/utils.py @@ -0,0 +1,27 @@ +import ssl +import certifi +import os +import logging + +logger = logging.getLogger(__name__) + + +def create_secure_ssl_context(): + """ + Creates a secure SSL context using certifi's CA bundle. + Allows bypassing verification ONLY if TEMOA_SKIP_CERT_VERIFY is set to '1'. + """ + skip_verify = os.environ.get("TEMOA_SKIP_CERT_VERIFY") == "1" + + if skip_verify: + logger.warning( + "SSL certificate verification is DISABLED via TEMOA_SKIP_CERT_VERIFY." + ) + ctx = ssl.create_default_context() + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE + return ctx + + # Secure default using certifi + ctx = ssl.create_default_context(cafile=certifi.where()) + return ctx diff --git a/pyproject.toml b/pyproject.toml index bb052b4..7f85001 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,7 @@ dependencies = [ "tomlkit", "datasette", "websockets", + "certifi", ] [tool.pytest.ini_options] diff --git a/temoa_runner.py b/temoa_runner.py index b6343f5..53eac1b 100644 --- a/temoa_runner.py +++ b/temoa_runner.py @@ -7,16 +7,21 @@ # "tomlkit", # "websockets", # "datasette", +# "certifi", # ] # /// import asyncio import logging import sys +import shutil +import subprocess from datetime import datetime from pathlib import Path -from typing import List, Optional import urllib.request +import os +import certifi +import ssl from fastapi import ( FastAPI, @@ -29,6 +34,32 @@ from fastapi.staticfiles import StaticFiles from pydantic import BaseModel + +def create_secure_ssl_context(): + """ + Creates a secure SSL context using certifi's CA bundle. + Allows bypassing verification ONLY if TEMOA_SKIP_CERT_VERIFY is set to '1'. + + NOTE: This function is intentionally duplicated from backend/utils.py + to maintain temoa_runner.py as a standalone script. + See: backend/utils.py:create_secure_ssl_context + """ + skip_verify = os.environ.get("TEMOA_SKIP_CERT_VERIFY") == "1" + + if skip_verify: + logging.warning( + "SSL certificate verification is DISABLED via TEMOA_SKIP_CERT_VERIFY." + ) + ctx = ssl.create_default_context() + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE + return ctx + + # Secure default using certifi + ctx = ssl.create_default_context(cafile=certifi.where()) + return ctx + + # --- Temoa Imports --- # We assume temoa is installed in the same environment try: @@ -62,13 +93,13 @@ class RunConfig(BaseModel): scenario_mode: str = "perfect_foresight" solver_name: str = "appsi_highs" time_sequencing: str = "seasonal_timeslices" - output_dir: Optional[str] = None + output_dir: str | None = None # --- Log Management --- class ConnectionManager: def __init__(self): - self.active_connections: List[WebSocket] = [] + self.active_connections: list[WebSocket] = [] async def connect(self, websocket: WebSocket): await websocket.accept() @@ -116,17 +147,31 @@ def ensure_assets(): "https://raw.githubusercontent.com/TemoaProject/temoa-web-gui/main/assets/" ) assets_dir = Path("assets") - assets_dir.mkdir(exist_ok=True) + assets_dir.mkdir(parents=True, exist_ok=True) files = ["tutorial_database.sqlite", "tutorial_config.toml"] + + ctx = create_secure_ssl_context() + for f in files: target = assets_dir / f if not target.exists(): print(f"Downloading missing asset: {f}...") + temp_target = target.with_suffix(".part") try: - urllib.request.urlretrieve(base_url + f, target) + url = base_url + f + with urllib.request.urlopen(url, context=ctx, timeout=10) as response: + with open(temp_target, "wb") as out_file: + shutil.copyfileobj(response, out_file) + # Atomic rename + temp_target.replace(target) except Exception as e: print(f"Failed to download {f}: {e}") + if temp_target.exists(): + try: + temp_target.unlink() + except Exception: + pass @app.get("/api/config") @@ -400,18 +445,16 @@ async def websocket_endpoint(websocket: WebSocket): # --- Datasette Management --- -DATASETTE_PROCESS = None -SERVED_DATABASES = set() +DATASETTE_PROCESS: subprocess.Popen | None = None +SERVED_DATABASES: set[str] = set() -def start_datasette(new_db: Optional[str] = None): +def start_datasette(new_db: str | None = None): """ Start or restart Datasette serving the tutorial DB + output DBs. If new_db is provided and not already served, restart the process to include it. """ global DATASETTE_PROCESS, SERVED_DATABASES - import subprocess - import os import sys # If new_db is already served, no need to restart diff --git a/uv.lock b/uv.lock index 8d3b055..9cdad40 100644 --- a/uv.lock +++ b/uv.lock @@ -1396,6 +1396,7 @@ name = "temoa-web-gui" version = "0.1.0" source = { virtual = "." } dependencies = [ + { name = "certifi" }, { name = "datasette" }, { name = "fastapi" }, { name = "temoa" }, @@ -1412,6 +1413,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "certifi" }, { name = "datasette" }, { name = "fastapi" }, { name = "temoa", specifier = ">=4.0.0a1" },