Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backend/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Package marker for Temoa Web GUI Backend
34 changes: 34 additions & 0 deletions backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
from datetime import datetime
from pathlib import Path
from typing import List, Optional
import urllib.request
import shutil

from .utils import create_secure_ssl_context

from fastapi import (
FastAPI,
Expand Down Expand Up @@ -152,6 +156,36 @@ def list_files(path: str = "."):
raise HTTPException(status_code=500, detail=str(e))


@app.post("/api/download_tutorial")
def download_tutorial():
"""Downloads the tutorial database from the main repo."""
assets_path = Path("assets")
assets_path.mkdir(parents=True, exist_ok=True)
target_path = assets_path / "tutorial_database.sqlite"
temp_path = target_path.with_suffix(".tmp")

try:
url = "https://raw.githubusercontent.com/TemoaProject/temoa-web-gui/main/assets/tutorial_database.sqlite"
ctx = create_secure_ssl_context()

with urllib.request.urlopen(url, context=ctx, timeout=10) as response:
with open(temp_path, "wb") as out_file:
shutil.copyfileobj(response, out_file)

# Atomic replace
temp_path.replace(target_path)
return {"status": "ok", "path": str(target_path.absolute())}
except Exception as e:
logging.exception("Failed to download tutorial")
raise HTTPException(status_code=500, detail=f"Download failed: {str(e)}") from e
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Comment on lines +179 to +180
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick | 🔵 Trivial

Consider using a named logger for consistency with backend/utils.py.

The endpoint uses logging.exception() on the root logger. For consistency with backend/utils.py which uses a named logger, consider creating a module-level logger.

♻️ Proposed fix

Add at module level:

logger = logging.getLogger(__name__)

Then replace:

-        logging.exception("Failed to download tutorial")
+        logger.exception("Failed to download tutorial")
🧰 Tools
🪛 Ruff (0.14.14)

[warning] 179-179: exception() call on root logger

(LOG015)


[warning] 180-180: Use explicit conversion flag

Replace with conversion flag

(RUF010)

🤖 Prompt for AI Agents
In `@backend/main.py` around lines 179 - 180, The code currently calls
logging.exception("Failed to download tutorial") using the root logger; add a
module-level named logger (logger = logging.getLogger(__name__)) at the top of
backend/main.py and replace the root logging call with logger.exception(...) so
the endpoint uses the named logger consistently with backend/utils.py; update
any other direct uses of logging in this module (e.g., around the HTTPException
raise) to use logger instead.

finally:
if temp_path.exists():
try:
temp_path.unlink()
except Exception:
pass


@app.get("/api/solvers")
def list_solvers():
"""Detect available solvers on the local system."""
Expand Down
60 changes: 60 additions & 0 deletions backend/tests/test_download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import pytest
from unittest.mock import patch, MagicMock
from fastapi.testclient import TestClient
from backend.main import app
import ssl

client = TestClient(app)
Comment thread
coderabbitai[bot] marked this conversation as resolved.


@pytest.mark.parametrize("skip_verify", ["0", "1"])
def test_download_tutorial_ssl_context(skip_verify, monkeypatch):
"""
Test that the SSL context is correctly configured based on TEMOA_SKIP_CERT_VERIFY.
This test is parametrized to ensure deterministic behavior.
"""
monkeypatch.setenv("TEMOA_SKIP_CERT_VERIFY", skip_verify)

# Patch targets must be on the module that USES the functions
with patch("backend.main.urllib.request.urlopen") as mock_urlopen, patch(
"backend.main.shutil.copyfileobj"
), patch("backend.main.open", new_callable=MagicMock), patch(
"backend.main.Path.replace"
) as mock_replace, patch("backend.main.Path.unlink"), patch(
"backend.main.Path.exists", return_value=True
):
# Configure the mock response
mock_response = MagicMock()
mock_urlopen.return_value.__enter__.return_value = mock_response

response = client.post("/api/download_tutorial")

assert response.status_code == 200
assert response.json()["status"] == "ok"

# Verify SSL context and timeout
_, kwargs = mock_urlopen.call_args
assert kwargs.get("timeout") == 10
assert "context" in kwargs
ctx = kwargs["context"]
assert isinstance(ctx, ssl.SSLContext)

if skip_verify == "1":
assert ctx.check_hostname is False
assert ctx.verify_mode == ssl.CERT_NONE
else:
assert ctx.check_hostname is True
assert ctx.verify_mode == ssl.CERT_REQUIRED

# Verify atomic move was attempted
assert mock_replace.called


def test_download_tutorial_failure():
# Patch on the actual module to ensure it's intercepted
with patch(
"backend.main.urllib.request.urlopen", side_effect=Exception("Network error")
):
response = client.post("/api/download_tutorial")
assert response.status_code == 500
assert "Download failed" in response.json()["detail"]
Comment thread
coderabbitai[bot] marked this conversation as resolved.
27 changes: 27 additions & 0 deletions backend/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import ssl
import certifi
import os
import logging

logger = logging.getLogger(__name__)


def create_secure_ssl_context():
"""
Creates a secure SSL context using certifi's CA bundle.
Allows bypassing verification ONLY if TEMOA_SKIP_CERT_VERIFY is set to '1'.
"""
skip_verify = os.environ.get("TEMOA_SKIP_CERT_VERIFY") == "1"

if skip_verify:
logger.warning(
"SSL certificate verification is DISABLED via TEMOA_SKIP_CERT_VERIFY."
)
ctx = ssl.create_default_context()
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE
return ctx

# Secure default using certifi
ctx = ssl.create_default_context(cafile=certifi.where())
return ctx
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Comment thread
coderabbitai[bot] marked this conversation as resolved.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dependencies = [
"tomlkit",
"datasette",
"websockets",
"certifi",
]

[tool.pytest.ini_options]
Expand Down
63 changes: 53 additions & 10 deletions temoa_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,21 @@
# "tomlkit",
# "websockets",
# "datasette",
# "certifi",
# ]
# ///

import asyncio
import logging
import sys
import shutil
import subprocess
from datetime import datetime
from pathlib import Path
from typing import List, Optional
import urllib.request
import os
import certifi
import ssl

from fastapi import (
FastAPI,
Expand All @@ -29,6 +34,32 @@
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel


def create_secure_ssl_context():
"""
Creates a secure SSL context using certifi's CA bundle.
Allows bypassing verification ONLY if TEMOA_SKIP_CERT_VERIFY is set to '1'.

NOTE: This function is intentionally duplicated from backend/utils.py
to maintain temoa_runner.py as a standalone script.
See: backend/utils.py:create_secure_ssl_context
"""
skip_verify = os.environ.get("TEMOA_SKIP_CERT_VERIFY") == "1"

if skip_verify:
logging.warning(
"SSL certificate verification is DISABLED via TEMOA_SKIP_CERT_VERIFY."
)
ctx = ssl.create_default_context()
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE
return ctx

# Secure default using certifi
ctx = ssl.create_default_context(cafile=certifi.where())
return ctx


# --- Temoa Imports ---
# We assume temoa is installed in the same environment
try:
Expand Down Expand Up @@ -62,13 +93,13 @@ class RunConfig(BaseModel):
scenario_mode: str = "perfect_foresight"
solver_name: str = "appsi_highs"
time_sequencing: str = "seasonal_timeslices"
output_dir: Optional[str] = None
output_dir: str | None = None


# --- Log Management ---
class ConnectionManager:
def __init__(self):
self.active_connections: List[WebSocket] = []
self.active_connections: list[WebSocket] = []

async def connect(self, websocket: WebSocket):
await websocket.accept()
Expand Down Expand Up @@ -116,17 +147,31 @@ def ensure_assets():
"https://raw.githubusercontent.com/TemoaProject/temoa-web-gui/main/assets/"
)
assets_dir = Path("assets")
assets_dir.mkdir(exist_ok=True)
assets_dir.mkdir(parents=True, exist_ok=True)

files = ["tutorial_database.sqlite", "tutorial_config.toml"]

ctx = create_secure_ssl_context()

for f in files:
target = assets_dir / f
if not target.exists():
print(f"Downloading missing asset: {f}...")
temp_target = target.with_suffix(".part")
try:
urllib.request.urlretrieve(base_url + f, target)
url = base_url + f
with urllib.request.urlopen(url, context=ctx, timeout=10) as response:
with open(temp_target, "wb") as out_file:
shutil.copyfileobj(response, out_file)
# Atomic rename
temp_target.replace(target)
except Exception as e:
Comment thread
coderabbitai[bot] marked this conversation as resolved.
print(f"Failed to download {f}: {e}")
if temp_target.exists():
try:
temp_target.unlink()
except Exception:
pass


@app.get("/api/config")
Expand Down Expand Up @@ -400,18 +445,16 @@ async def websocket_endpoint(websocket: WebSocket):


# --- Datasette Management ---
DATASETTE_PROCESS = None
SERVED_DATABASES = set()
DATASETTE_PROCESS: subprocess.Popen | None = None
SERVED_DATABASES: set[str] = set()


def start_datasette(new_db: Optional[str] = None):
def start_datasette(new_db: str | None = None):
"""
Start or restart Datasette serving the tutorial DB + output DBs.
If new_db is provided and not already served, restart the process to include it.
"""
global DATASETTE_PROCESS, SERVED_DATABASES
import subprocess
import os
import sys

# If new_db is already served, no need to restart
Expand Down
2 changes: 2 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.