Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions crates/coglet-python/src/predictor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -548,9 +548,12 @@ impl PythonPredictor {
/// - `has_setup_weights()` checks if setup() has a weights parameter
/// - `extract_setup_weights()` reads from COG_WEIGHTS env or ./weights path
///
/// If setup() is an async def, the returned coroutine is executed with
/// `asyncio.run()`, matching the pattern in `call_method_raw()`.
pub fn setup(&self, py: Python<'_>) -> PyResult<()> {
/// If setup() is an async def and an event loop is provided, the coroutine
/// is submitted to that loop via `run_coroutine_threadsafe` so that
/// event-loop-bound resources created during setup (httpx.AsyncClient, etc.)
/// remain usable in predict(). If no loop is provided, falls back to
/// `asyncio.run()` (used by the non-worker code path).
pub fn setup(&self, py: Python<'_>, event_loop: Option<&Py<PyAny>>) -> PyResult<()> {
let instance = self.instance.bind(py);

// Check if setup method exists
Expand All @@ -577,7 +580,20 @@ impl PythonPredictor {
// If setup() is async, the call above returns a coroutine — run it.
if self.setup_is_async {
let asyncio = py.import("asyncio")?;
asyncio.call_method1("run", (&result,))?;
match event_loop {
Some(loop_obj) => {
// Submit to the shared event loop so setup and predict share
// the same loop. This keeps event-loop-bound resources alive.
let future = asyncio
.call_method1("run_coroutine_threadsafe", (&result, loop_obj.bind(py)))?;
// Block until setup completes (preserves existing semantics).
future.call_method0("result")?;
}
None => {
// No shared loop (non-worker path) — use ephemeral loop.
asyncio.call_method1("run", (&result,))?;
}
}
}

Ok(())
Expand Down
8 changes: 7 additions & 1 deletion crates/coglet-python/src/worker_bridge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,13 @@ impl PredictHandler for PythonPredictHandler {
tracing::info!(sdk_implementation = %sdk_impl, "Detected Cog SDK implementation");

tracing::info!("Running setup");
pred.setup(py)
let async_loop = self
.async_loop
.lock()
.expect("async_loop mutex poisoned")
.as_ref()
.map(|l| l.clone_ref(py));
pred.setup(py, async_loop.as_ref())
.map_err(|e| SetupError::setup(e.to_string()))?;

let mut guard = self.predictor.lock().expect("predictor mutex poisoned");
Expand Down
46 changes: 46 additions & 0 deletions crates/coglet-python/tests/test_coglet.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,37 @@ def predict(self, name: str = "World") -> str:
return predictor


@pytest.fixture
def async_setup_event_loop_predictor(tmp_path: Path) -> Path:
"""Create a predictor where async setup() stores the event loop and predict() checks it matches."""
predictor = tmp_path / "predict.py"
predictor.write_text("""
import asyncio
from cog import BasePredictor

class Predictor(BasePredictor):
async def setup(self):
self.setup_loop = asyncio.get_running_loop()
# Create an event-loop-bound resource (Queue is bound to the running loop)
self.queue = asyncio.Queue()
await self.queue.put("from-setup")

async def predict(self, name: str = "test") -> str:
predict_loop = asyncio.get_running_loop()
same_loop = predict_loop is self.setup_loop
# Use the queue created in setup — this fails if loops differ
item = self.queue.get_nowait()
return f"same_loop={same_loop} item={item}"
""")

cog_yaml = tmp_path / "cog.yaml"
cog_yaml.write_text("""
predict: "predict.py:Predictor"
""")

return predictor


class CogletServer:
"""Context manager for running coglet server."""

Expand Down Expand Up @@ -520,6 +551,21 @@ def test_async_setup_with_weights(
assert result["status"] == "succeeded"
assert result["output"] == "https://example.com/model.tar: Claude"

def test_async_setup_shares_event_loop_with_predict(
self, async_setup_event_loop_predictor: Path
):
"""async setup() and async predict() must run on the same event loop.

This catches the bug where async setup() ran via asyncio.run() (ephemeral loop)
while predict() ran on a separate shared loop, causing event-loop-bound resources
created in setup (httpx.AsyncClient, aiohttp.ClientSession, asyncio.Queue, etc.)
to fail in predict.
"""
with CogletServer(async_setup_event_loop_predictor) as server:
result = server.predict({"name": "test"})
assert result["status"] == "succeeded"
assert result["output"] == "same_loop=True item=from-setup"


@pytest.fixture
def slow_sync_predictor(tmp_path: Path) -> Path:
Expand Down
Loading