diff --git a/README.md b/README.md
index 6b6b153..4dc74d1 100644
--- a/README.md
+++ b/README.md
@@ -105,7 +105,7 @@ Each analysis generates an HTML report documenting annotation decisions, reviewe
-[View example report](https://prod.cytetype.nygen.io/report/e70e2883-7713-4121-94f2-5b57eabd1468?v=260303)
+[View example report](https://cytetype.nygen.io/report/e70e2883-7713-4121-94f2-5b57eabd1468?v=260303)
---
diff --git a/cytetype/__init__.py b/cytetype/__init__.py
index 72ad186..2eae0ff 100644
--- a/cytetype/__init__.py
+++ b/cytetype/__init__.py
@@ -1,4 +1,4 @@
-__version__ = "0.19.3"
+__version__ = "0.19.4"
import requests
diff --git a/cytetype/api/client.py b/cytetype/api/client.py
index 46d85e9..0befbac 100644
--- a/cytetype/api/client.py
+++ b/cytetype/api/client.py
@@ -375,21 +375,25 @@ def fetch_job_results(
def _sleep_with_spinner(
seconds: int,
progress: ProgressDisplay | None,
- cluster_status: dict[str, str],
+ job_status: str,
) -> None:
"""Sleep for specified seconds while updating spinner animation.
Args:
seconds: Number of seconds to sleep
progress: ProgressDisplay instance (if showing progress)
- cluster_status: Current cluster status for display
+ job_status: Current overall job status for display
"""
for _ in range(seconds * 2):
if progress:
- progress.update(cluster_status)
+ progress.update(job_status)
time.sleep(0.5)
+def _log_report_cta(report_url: str) -> None:
+ logger.info(f"\n[TRACK PROGRESS]\n{report_url}")
+
+
def wait_for_completion(
base_url: str,
auth_token: str | None,
@@ -401,26 +405,26 @@ def wait_for_completion(
"""Poll job until completion and return results."""
progress = ProgressDisplay() if show_progress else None
start_time = time.time()
+ report_url = f"{base_url.rstrip('/')}/report/{job_id}"
- logger.info(f"CyteType job (id: {job_id}) submitted. Polling for results...")
-
- # Initial delay
- time.sleep(5)
-
- # Show report URL
- report_url = f"{base_url}/report/{job_id}"
- logger.info(f"Report (updates automatically) available at: {report_url}")
+ logger.info("CyteType job submitted.")
logger.info(
- "If network disconnects, the results can still be fetched:\n"
+ "If your session disconnects, results can still be fetched later with:\n"
"`results = annotator.get_results()`"
)
+ _log_report_cta(report_url)
+
+ # Initial delay
+ time.sleep(5)
consecutive_not_found = 0
+ job_status = "pending"
+ cluster_status: dict[str, str] = {}
while (time.time() - start_time) < timeout:
try:
status_data = get_job_status(base_url, auth_token, job_id)
- job_status = status_data.get("jobStatus")
+ job_status = str(status_data.get("jobStatus") or "")
cluster_status = status_data.get("clusterStatus", {})
# Reset 404 counter on valid response
@@ -429,20 +433,21 @@ def wait_for_completion(
if job_status == "completed":
if progress:
- progress.finalize(cluster_status)
+ progress.finalize("completed", cluster_status)
logger.success(f"Job {job_id} completed successfully.")
return fetch_job_results(base_url, auth_token, job_id)
elif job_status == "failed":
if progress:
- progress.finalize(cluster_status)
+ progress.finalize("failed", cluster_status)
+ logger.info(f"Report:\n{report_url}")
raise JobFailedError(f"Job {job_id} failed")
elif job_status in ["processing", "pending"]:
logger.debug(
f"Job {job_id} status: {job_status}. Waiting {poll_interval}s..."
)
- _sleep_with_spinner(poll_interval, progress, cluster_status)
+ _sleep_with_spinner(poll_interval, progress, job_status)
elif job_status == "not_found":
consecutive_not_found += 1
@@ -459,24 +464,25 @@ def wait_for_completion(
f"Status endpoint not ready for job {job_id}. "
f"Waiting {poll_interval}s..."
)
- _sleep_with_spinner(poll_interval, progress, cluster_status)
+ _sleep_with_spinner(poll_interval, progress, job_status)
else:
logger.warning(f"Unknown job status: '{job_status}'. Continuing...")
- _sleep_with_spinner(poll_interval, progress, cluster_status)
+ _sleep_with_spinner(poll_interval, progress, job_status)
except APIError:
# Let API errors (auth, etc.) bubble up immediately
if progress:
- progress.finalize({})
+ progress.finalize()
raise
except Exception as e:
# Network errors - log and retry
logger.debug(f"Error during polling: {e}. Retrying...")
retry_interval = min(poll_interval, 5)
- _sleep_with_spinner(retry_interval, progress, cluster_status)
+ _sleep_with_spinner(retry_interval, progress, job_status)
# Timeout reached
if progress:
- progress.finalize({})
+ progress.finalize("timed_out")
+ logger.info(f"Report:\n{report_url}")
raise TimeoutError(f"Job {job_id} did not complete within {timeout}s")
diff --git a/cytetype/api/progress.py b/cytetype/api/progress.py
index a75ce39..f5e9363 100644
--- a/cytetype/api/progress.py
+++ b/cytetype/api/progress.py
@@ -1,116 +1,166 @@
import sys
+import time
+from html import escape
+from typing import Any, Callable, TextIO, cast
+
+
+def _in_notebook() -> bool:
+ try:
+ from IPython import get_ipython
+ except ImportError:
+ return False
+
+ shell_getter = cast(Callable[[], Any | None], get_ipython)
+ shell = shell_getter()
+ return bool(shell and shell.__class__.__name__ == "ZMQInteractiveShell")
+
+
+def _create_notebook_display_handle(message: str) -> Any | None:
+ try:
+ from IPython.display import display
+ except ImportError:
+ return None
+
+ _display: Callable[..., Any] = display
+ return _display(_render_notebook_message(message), display_id=True)
+
+
+def _render_notebook_message(message: str) -> Any:
+ try:
+ from IPython.display import HTML
+ except ImportError:
+ return message
+
+ html_cls = cast(Callable[[str], Any], HTML)
+ return html_cls(
+ "
"
+ f"{escape(message)}"
+ ""
+ )
class ProgressDisplay:
- """Manages terminal progress display during job polling."""
-
- # Class constants
- COLORS = {
- "completed": "\033[92m",
- "processing": "\033[93m",
- "pending": "\033[94m",
- "failed": "\033[91m",
- "reset": "\033[0m",
- }
- SYMBOLS = {"completed": "✓", "processing": "⟳", "pending": "○", "failed": "✗"}
+ """Manages terminal and notebook progress display during job polling."""
+
+ COLORS = {"failed": "\033[91m", "reset": "\033[0m"}
SPINNER_CHARS = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
- def __init__(self) -> None:
+ def __init__(self, stream: TextIO | None = None) -> None:
+ self.stream = stream or sys.stdout
+ self._interactive = bool(
+ hasattr(self.stream, "isatty") and self.stream.isatty()
+ )
+ self._use_notebook_display = not self._interactive and _in_notebook()
+ self._display_handle: Any | None = None
+ self._finalized = False
+ self._last_plain_status: str | None = None
+ self._start_time = time.monotonic()
self.spinner_frame = 0
- self.last_status: dict[str, str] = {}
- def update(self, cluster_status: dict[str, str]) -> None:
- """Update progress display with current cluster status."""
- if not cluster_status:
+ def update(self, job_status: str) -> None:
+ """Update progress display with the overall job status."""
+ if self._finalized:
return
- # Always render to keep spinner animating
- self._render(cluster_status, is_final=False)
-
- # Track last status for potential future use
- if cluster_status != self.last_status:
- self.last_status = cluster_status.copy()
+ if self._interactive:
+ message = self._build_running_line(job_status)
+ print(f"\r{message}\033[K", end="", file=self.stream, flush=True)
+ elif self._use_notebook_display:
+ self._update_notebook_display(self._build_running_line(job_status))
+ else:
+ message = self._build_plain_line(job_status)
+ if message != self._last_plain_status:
+ print(message, file=self.stream, flush=True)
+ self._last_plain_status = message
- # Always increment spinner to show activity
self.spinner_frame += 1
- def finalize(self, cluster_status: dict[str, str]) -> None:
+ def finalize(
+ self,
+ final_status: str | None = None,
+ cluster_status: dict[str, str] | None = None,
+ ) -> None:
"""Show final status and cleanup."""
- if cluster_status:
- self._render(cluster_status, is_final=True)
- print() # Ensure newline
-
- def _render(self, cluster_status: dict[str, str], is_final: bool) -> None:
- """Render status to terminal."""
- status_counts = self._count_statuses(cluster_status)
- progress_bar = self._build_progress_bar(cluster_status)
- status_line = self._build_status_line(
- progress_bar, status_counts, is_final=is_final
- )
+ if self._finalized:
+ return
+ self._finalized = True
- # Print status line
- if is_final:
- print(f"\r{status_line}{self.COLORS['reset']}")
- sys.stdout.flush()
- self._show_failed_clusters(cluster_status, status_counts["failed"])
- else:
- print(f"\r{status_line}{self.COLORS['reset']}", end="", flush=True)
-
- def _count_statuses(self, cluster_status: dict[str, str]) -> dict[str, int]:
- """Count occurrences of each status."""
- counts = {"completed": 0, "failed": 0}
- for status in cluster_status.values():
- counts[status] = counts.get(status, 0) + 1
- return counts
-
- def _build_progress_bar(self, cluster_status: dict[str, str]) -> str:
- """Build colored progress bar from cluster statuses."""
- progress_units = []
- for cluster_id in self._sorted_cluster_ids(cluster_status):
- status = cluster_status[cluster_id]
- color = self.COLORS.get(status, self.COLORS["reset"])
- symbol = self.SYMBOLS.get(status, "?")
- progress_units.append(f"{color}{symbol}{self.COLORS['reset']}")
- return "".join(progress_units)
-
- def _build_status_line(
- self, progress_bar: str, counts: dict[str, int], is_final: bool
- ) -> str:
- """Build status line with progress bar and counts."""
- total = sum(counts.values())
- completed = counts["completed"]
-
- if is_final:
- status_line = f"[DONE] [{progress_bar}] {completed}/{total}"
- if counts["failed"] > 0:
- status_line += f" ({counts['failed']} failed)"
- elif completed == total:
- status_line += " completed"
+ if final_status is None:
+ if self._interactive:
+ print(file=self.stream, flush=True)
+ return
+
+ message = self._build_final_line(final_status)
+ if self._interactive:
+ print(f"\r{message}\033[K", file=self.stream, flush=True)
+ elif self._use_notebook_display:
+ self._update_notebook_display(message)
else:
- spinner = self.SPINNER_CHARS[self.spinner_frame % len(self.SPINNER_CHARS)]
- status_line = f"{spinner} [{progress_bar}] {completed}/{total} completed"
+ print(message, file=self.stream, flush=True)
+
+ if final_status == "failed" and cluster_status:
+ self._show_failed_clusters(cluster_status)
+
+ def _update_notebook_display(self, message: str) -> None:
+ """Update a single notebook output cell instead of printing many lines."""
+ if self._display_handle is None:
+ self._display_handle = _create_notebook_display_handle(message)
+ if self._display_handle is None:
+ self._use_notebook_display = False
+ print(message, file=self.stream, flush=True)
+ return
- return status_line
+ self._display_handle.update(_render_notebook_message(message))
- def _show_failed_clusters(
- self, cluster_status: dict[str, str], failed_count: int
- ) -> None:
- """Show details of failed clusters."""
- if failed_count == 0:
- return
+ def _build_running_line(self, job_status: str) -> str:
+ spinner = self.SPINNER_CHARS[self.spinner_frame % len(self.SPINNER_CHARS)]
+ elapsed = self._format_elapsed()
+ return f"{spinner} {self._status_message(job_status)} {elapsed} elapsed"
+
+ def _build_plain_line(self, job_status: str) -> str:
+ return self._status_message(job_status)
+
+ @staticmethod
+ def _build_final_line(final_status: str) -> str:
+ if final_status == "completed":
+ return "[DONE] CyteType job completed."
+ if final_status == "failed":
+ return "[FAILED] CyteType job failed."
+ if final_status == "timed_out":
+ return "[TIMEOUT] CyteType job timed out."
+ return "[STOPPED] CyteType job stopped."
+ @staticmethod
+ def _status_message(job_status: str) -> str:
+ if job_status == "pending":
+ return "CyteType job queued..."
+ if job_status == "processing":
+ return "CyteType job running..."
+ if job_status == "not_found":
+ return "Waiting for CyteType job to start..."
+ return "Waiting for CyteType results..."
+
+ def _format_elapsed(self) -> str:
+ elapsed = int(time.monotonic() - self._start_time)
+ minutes, seconds = divmod(elapsed, 60)
+ return f"{minutes:02d}:{seconds:02d}"
+
+ def _show_failed_clusters(self, cluster_status: dict[str, str]) -> None:
+ """Show details of failed clusters."""
failed_details = []
for cluster_id in self._sorted_cluster_ids(cluster_status):
- if cluster_status[cluster_id] == "failed":
- color = self.COLORS["failed"]
- symbol = self.SYMBOLS["failed"]
+ if cluster_status[cluster_id] != "failed":
+ continue
+
+ if self._interactive:
failed_details.append(
- f"{color}{symbol} Cluster {cluster_id}{self.COLORS['reset']}"
+ f"{self.COLORS['failed']}✗ Cluster {cluster_id}{self.COLORS['reset']}"
)
+ else:
+ failed_details.append(f"✗ Cluster {cluster_id}")
- # Group into lines of 4
for i in range(0, len(failed_details), 4):
- print(f" {' | '.join(failed_details[i : i + 4])}")
+ print(f" {' | '.join(failed_details[i : i + 4])}", file=self.stream)
@staticmethod
def _sorted_cluster_ids(cluster_status: dict[str, str]) -> list[str]:
diff --git a/cytetype/main.py b/cytetype/main.py
index 408f556..e17d549 100644
--- a/cytetype/main.py
+++ b/cytetype/main.py
@@ -88,7 +88,7 @@ def __init__(
vars_h5_path: str = "vars.h5",
obs_duckdb_path: str = "obs.duckdb",
max_metadata_categories: int = 500,
- api_url: str = "https://prod.cytetype.nygen.io",
+ api_url: str = "https://cytetype.nygen.io",
auth_token: str | None = None,
label_na: bool = False,
) -> None:
@@ -126,7 +126,7 @@ def __init__(
more unique values (e.g. cell barcodes, per-cell IDs) are skipped to avoid
excessive memory usage. Defaults to 500.
api_url (str, optional): URL for the CyteType API endpoint. Only change if using a custom
- deployment. Defaults to "https://prod.cytetype.nygen.io".
+ deployment. Defaults to "https://cytetype.nygen.io".
auth_token (str | None, optional): Bearer token for API authentication. If provided,
will be included in the Authorization header as "Bearer {auth_token}". Defaults to None.
label_na (bool, optional): If True, cells with NaN values in the
@@ -161,7 +161,11 @@ def __init__(
self._original_gene_symbols_column = self.gene_symbols_column
self.coordinates_key = validate_adata(
- adata, group_key, rank_key, self.gene_symbols_column, coordinates_key,
+ adata,
+ group_key,
+ rank_key,
+ self.gene_symbols_column,
+ coordinates_key,
label_na=label_na,
)
@@ -602,6 +606,9 @@ def run(
check_unannotated=True,
)
+ report_url = self.adata.uns[f"{results_prefix}_jobDetails"]["report_url"]
+ logger.info(f"Report:\n{report_url}")
+
return self.adata
def get_results(
diff --git a/tests/test_api_client.py b/tests/test_api_client.py
new file mode 100644
index 0000000..aaaaa7e
--- /dev/null
+++ b/tests/test_api_client.py
@@ -0,0 +1,140 @@
+import io
+
+import pytest
+
+from cytetype.api import client
+from cytetype.api import progress as progress_module
+from cytetype.api.progress import ProgressDisplay
+
+
+class _FakeStream(io.StringIO):
+ def isatty(self) -> bool:
+ return False
+
+
+class _CapturedLogger:
+ def __init__(self) -> None:
+ self.messages: list[tuple[str, str]] = []
+
+ def info(self, message: str) -> None:
+ self.messages.append(("info", message))
+
+ def success(self, message: str) -> None:
+ self.messages.append(("success", message))
+
+ def warning(self, message: str) -> None:
+ self.messages.append(("warning", message))
+
+ def debug(self, message: str) -> None:
+ self.messages.append(("debug", message))
+
+
+class _FakeDisplayHandle:
+ def __init__(self, message: str) -> None:
+ self.messages = [message]
+
+ def update(self, message: str) -> None:
+ self.messages.append(message)
+
+
+def test_progress_display_plain_output_avoids_repeated_lines(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ monkeypatch.setattr(progress_module, "_in_notebook", lambda: False)
+
+ stream = _FakeStream()
+ progress = ProgressDisplay(stream=stream)
+
+ progress.update("processing")
+ progress.update("processing")
+ progress.finalize("completed")
+ progress.finalize("completed")
+
+ assert stream.getvalue().splitlines() == [
+ "CyteType job running...",
+ "[DONE] CyteType job completed.",
+ ]
+
+
+def test_progress_display_updates_single_notebook_output(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ monkeypatch.setattr(progress_module, "_in_notebook", lambda: True)
+ monkeypatch.setattr(
+ progress_module, "_render_notebook_message", lambda message: message
+ )
+
+ display_handle: _FakeDisplayHandle | None = None
+
+ def _create_handle(message: str) -> _FakeDisplayHandle:
+ nonlocal display_handle
+ display_handle = _FakeDisplayHandle(message)
+ return display_handle
+
+ monkeypatch.setattr(
+ progress_module,
+ "_create_notebook_display_handle",
+ _create_handle,
+ )
+
+ progress = ProgressDisplay(stream=_FakeStream())
+ progress.update("processing")
+ progress.update("processing")
+ progress.finalize("completed")
+
+ assert display_handle is not None
+ assert len(display_handle.messages) == 3
+ assert display_handle.messages[0].startswith("⠋ CyteType job running...")
+ assert display_handle.messages[1].startswith("⠙ CyteType job running...")
+ assert display_handle.messages[2] == "[DONE] CyteType job completed."
+
+
+def test_wait_for_completion_logs_report_cta_before_polling(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ captured_logger = _CapturedLogger()
+ sleep_calls: list[int | float] = []
+ result_payload: dict[str, list[object]] = {"annotations": []}
+
+ monkeypatch.setattr(client, "logger", captured_logger)
+ monkeypatch.setattr(
+ client.time, "sleep", lambda seconds: sleep_calls.append(seconds)
+ )
+ monkeypatch.setattr(
+ client,
+ "get_job_status",
+ lambda *args, **kwargs: {
+ "jobStatus": "completed",
+ "clusterStatus": {"0": "completed"},
+ },
+ )
+ monkeypatch.setattr(
+ client,
+ "fetch_job_results",
+ lambda *args, **kwargs: result_payload,
+ )
+
+ result = client.wait_for_completion(
+ "https://example.com/",
+ None,
+ "job-123",
+ poll_interval=5,
+ timeout=30,
+ show_progress=False,
+ )
+
+ assert result == result_payload
+ assert sleep_calls == [5]
+ assert captured_logger.messages[:3] == [
+ ("info", "CyteType job submitted."),
+ (
+ "info",
+ "If your session disconnects, results can still be fetched later with:\n"
+ "`results = annotator.get_results()`",
+ ),
+ ("info", "\n[TRACK PROGRESS]\nhttps://example.com/report/job-123"),
+ ]
+ assert (
+ "success",
+ "Job job-123 completed successfully.",
+ ) in captured_logger.messages
diff --git a/tests/test_cytetype_integration.py b/tests/test_cytetype_integration.py
index d7dfa12..cf76493 100644
--- a/tests/test_cytetype_integration.py
+++ b/tests/test_cytetype_integration.py
@@ -414,7 +414,7 @@ def test_cytetype_initialization_with_auth_token(mock_adata: anndata.AnnData) ->
ct = CyteType(mock_adata, group_key="leiden", auth_token="test_token_123")
assert ct.auth_token == "test_token_123"
- assert ct.api_url == "https://prod.cytetype.nygen.io"
+ assert ct.api_url == "https://cytetype.nygen.io"
def test_cytetype_no_coordinates(mock_adata: anndata.AnnData) -> None:
diff --git a/tests/test_validation.py b/tests/test_validation.py
index b93e610..8070e9f 100644
--- a/tests/test_validation.py
+++ b/tests/test_validation.py
@@ -4,73 +4,92 @@
class TestIsGeneIdLike:
-
- @pytest.mark.parametrize("value", [
- "ENSG00000000003",
- "ENSG00000000003.14",
- "ENSMUSG00000000001",
- "ensg00000000003",
- ])
+ @pytest.mark.parametrize(
+ "value",
+ [
+ "ENSG00000000003",
+ "ENSG00000000003.14",
+ "ENSMUSG00000000001",
+ "ensg00000000003",
+ ],
+ )
def test_ensembl_ids(self, value: str) -> None:
assert _is_gene_id_like(value) is True
- @pytest.mark.parametrize("value", [
- "NM_001301",
- "NR_046018",
- "XM_011541",
- "XR_001737",
- ])
+ @pytest.mark.parametrize(
+ "value",
+ [
+ "NM_001301",
+ "NR_046018",
+ "XM_011541",
+ "XR_001737",
+ ],
+ )
def test_refseq_ids(self, value: str) -> None:
assert _is_gene_id_like(value) is True
- @pytest.mark.parametrize("value", [
- "7157",
- "672",
- "11286",
- "0",
- ])
+ @pytest.mark.parametrize(
+ "value",
+ [
+ "7157",
+ "672",
+ "11286",
+ "0",
+ ],
+ )
def test_integer_entrez_ids(self, value: str) -> None:
assert _is_gene_id_like(value) is True
- @pytest.mark.parametrize("value", [
- "7157.0",
- "672.0",
- "11286.0",
- "0.0",
- ])
+ @pytest.mark.parametrize(
+ "value",
+ [
+ "7157.0",
+ "672.0",
+ "11286.0",
+ "0.0",
+ ],
+ )
def test_float_stringified_entrez_ids(self, value: str) -> None:
assert _is_gene_id_like(value) is True
- @pytest.mark.parametrize("value", [
- "AFFY_HG_U133A.207163_S_AT",
- "ILLUMINA_HUMANHT_12_V4.ILMN_1762337",
- ])
+ @pytest.mark.parametrize(
+ "value",
+ [
+ "AFFY_HG_U133A.207163_S_AT",
+ "ILLUMINA_HUMANHT_12_V4.ILMN_1762337",
+ ],
+ )
def test_long_dotted_ids(self, value: str) -> None:
assert _is_gene_id_like(value) is True
- @pytest.mark.parametrize("value", [
- "TSPAN6",
- "DPM1",
- "SCYL3",
- "TP53",
- "BRCA1",
- "CD8A",
- "MS4A1",
- ])
+ @pytest.mark.parametrize(
+ "value",
+ [
+ "TSPAN6",
+ "DPM1",
+ "SCYL3",
+ "TP53",
+ "BRCA1",
+ "CD8A",
+ "MS4A1",
+ ],
+ )
def test_gene_symbols_not_flagged(self, value: str) -> None:
assert _is_gene_id_like(value) is False
- @pytest.mark.parametrize("value", [
- "",
- " ",
- "7157.5",
- ])
+ @pytest.mark.parametrize(
+ "value",
+ [
+ "",
+ " ",
+ "7157.5",
+ ],
+ )
def test_edge_cases(self, value: str) -> None:
assert _is_gene_id_like(value) is False
class TestIdLikePercentage:
-
def test_all_gene_symbols(self) -> None:
values = ["TSPAN6", "DPM1", "SCYL3", "TP53", "BRCA1"]
assert _id_like_percentage(values) == 0.0