diff --git a/README.md b/README.md index 6b6b153..4dc74d1 100644 --- a/README.md +++ b/README.md @@ -105,7 +105,7 @@ Each analysis generates an HTML report documenting annotation decisions, reviewe CyteType HTML report showing cell type annotations marker genes -[View example report](https://prod.cytetype.nygen.io/report/e70e2883-7713-4121-94f2-5b57eabd1468?v=260303) +[View example report](https://cytetype.nygen.io/report/e70e2883-7713-4121-94f2-5b57eabd1468?v=260303) --- diff --git a/cytetype/__init__.py b/cytetype/__init__.py index 72ad186..2eae0ff 100644 --- a/cytetype/__init__.py +++ b/cytetype/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.19.3" +__version__ = "0.19.4" import requests diff --git a/cytetype/api/client.py b/cytetype/api/client.py index 46d85e9..0befbac 100644 --- a/cytetype/api/client.py +++ b/cytetype/api/client.py @@ -375,21 +375,25 @@ def fetch_job_results( def _sleep_with_spinner( seconds: int, progress: ProgressDisplay | None, - cluster_status: dict[str, str], + job_status: str, ) -> None: """Sleep for specified seconds while updating spinner animation. Args: seconds: Number of seconds to sleep progress: ProgressDisplay instance (if showing progress) - cluster_status: Current cluster status for display + job_status: Current overall job status for display """ for _ in range(seconds * 2): if progress: - progress.update(cluster_status) + progress.update(job_status) time.sleep(0.5) +def _log_report_cta(report_url: str) -> None: + logger.info(f"\n[TRACK PROGRESS]\n{report_url}") + + def wait_for_completion( base_url: str, auth_token: str | None, @@ -401,26 +405,26 @@ def wait_for_completion( """Poll job until completion and return results.""" progress = ProgressDisplay() if show_progress else None start_time = time.time() + report_url = f"{base_url.rstrip('/')}/report/{job_id}" - logger.info(f"CyteType job (id: {job_id}) submitted. Polling for results...") - - # Initial delay - time.sleep(5) - - # Show report URL - report_url = f"{base_url}/report/{job_id}" - logger.info(f"Report (updates automatically) available at: {report_url}") + logger.info("CyteType job submitted.") logger.info( - "If network disconnects, the results can still be fetched:\n" + "If your session disconnects, results can still be fetched later with:\n" "`results = annotator.get_results()`" ) + _log_report_cta(report_url) + + # Initial delay + time.sleep(5) consecutive_not_found = 0 + job_status = "pending" + cluster_status: dict[str, str] = {} while (time.time() - start_time) < timeout: try: status_data = get_job_status(base_url, auth_token, job_id) - job_status = status_data.get("jobStatus") + job_status = str(status_data.get("jobStatus") or "") cluster_status = status_data.get("clusterStatus", {}) # Reset 404 counter on valid response @@ -429,20 +433,21 @@ def wait_for_completion( if job_status == "completed": if progress: - progress.finalize(cluster_status) + progress.finalize("completed", cluster_status) logger.success(f"Job {job_id} completed successfully.") return fetch_job_results(base_url, auth_token, job_id) elif job_status == "failed": if progress: - progress.finalize(cluster_status) + progress.finalize("failed", cluster_status) + logger.info(f"Report:\n{report_url}") raise JobFailedError(f"Job {job_id} failed") elif job_status in ["processing", "pending"]: logger.debug( f"Job {job_id} status: {job_status}. Waiting {poll_interval}s..." ) - _sleep_with_spinner(poll_interval, progress, cluster_status) + _sleep_with_spinner(poll_interval, progress, job_status) elif job_status == "not_found": consecutive_not_found += 1 @@ -459,24 +464,25 @@ def wait_for_completion( f"Status endpoint not ready for job {job_id}. " f"Waiting {poll_interval}s..." ) - _sleep_with_spinner(poll_interval, progress, cluster_status) + _sleep_with_spinner(poll_interval, progress, job_status) else: logger.warning(f"Unknown job status: '{job_status}'. Continuing...") - _sleep_with_spinner(poll_interval, progress, cluster_status) + _sleep_with_spinner(poll_interval, progress, job_status) except APIError: # Let API errors (auth, etc.) bubble up immediately if progress: - progress.finalize({}) + progress.finalize() raise except Exception as e: # Network errors - log and retry logger.debug(f"Error during polling: {e}. Retrying...") retry_interval = min(poll_interval, 5) - _sleep_with_spinner(retry_interval, progress, cluster_status) + _sleep_with_spinner(retry_interval, progress, job_status) # Timeout reached if progress: - progress.finalize({}) + progress.finalize("timed_out") + logger.info(f"Report:\n{report_url}") raise TimeoutError(f"Job {job_id} did not complete within {timeout}s") diff --git a/cytetype/api/progress.py b/cytetype/api/progress.py index a75ce39..f5e9363 100644 --- a/cytetype/api/progress.py +++ b/cytetype/api/progress.py @@ -1,116 +1,166 @@ import sys +import time +from html import escape +from typing import Any, Callable, TextIO, cast + + +def _in_notebook() -> bool: + try: + from IPython import get_ipython + except ImportError: + return False + + shell_getter = cast(Callable[[], Any | None], get_ipython) + shell = shell_getter() + return bool(shell and shell.__class__.__name__ == "ZMQInteractiveShell") + + +def _create_notebook_display_handle(message: str) -> Any | None: + try: + from IPython.display import display + except ImportError: + return None + + _display: Callable[..., Any] = display + return _display(_render_notebook_message(message), display_id=True) + + +def _render_notebook_message(message: str) -> Any: + try: + from IPython.display import HTML + except ImportError: + return message + + html_cls = cast(Callable[[str], Any], HTML) + return html_cls( + "
"
+        f"{escape(message)}"
+        "
" + ) class ProgressDisplay: - """Manages terminal progress display during job polling.""" - - # Class constants - COLORS = { - "completed": "\033[92m", - "processing": "\033[93m", - "pending": "\033[94m", - "failed": "\033[91m", - "reset": "\033[0m", - } - SYMBOLS = {"completed": "✓", "processing": "⟳", "pending": "○", "failed": "✗"} + """Manages terminal and notebook progress display during job polling.""" + + COLORS = {"failed": "\033[91m", "reset": "\033[0m"} SPINNER_CHARS = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"] - def __init__(self) -> None: + def __init__(self, stream: TextIO | None = None) -> None: + self.stream = stream or sys.stdout + self._interactive = bool( + hasattr(self.stream, "isatty") and self.stream.isatty() + ) + self._use_notebook_display = not self._interactive and _in_notebook() + self._display_handle: Any | None = None + self._finalized = False + self._last_plain_status: str | None = None + self._start_time = time.monotonic() self.spinner_frame = 0 - self.last_status: dict[str, str] = {} - def update(self, cluster_status: dict[str, str]) -> None: - """Update progress display with current cluster status.""" - if not cluster_status: + def update(self, job_status: str) -> None: + """Update progress display with the overall job status.""" + if self._finalized: return - # Always render to keep spinner animating - self._render(cluster_status, is_final=False) - - # Track last status for potential future use - if cluster_status != self.last_status: - self.last_status = cluster_status.copy() + if self._interactive: + message = self._build_running_line(job_status) + print(f"\r{message}\033[K", end="", file=self.stream, flush=True) + elif self._use_notebook_display: + self._update_notebook_display(self._build_running_line(job_status)) + else: + message = self._build_plain_line(job_status) + if message != self._last_plain_status: + print(message, file=self.stream, flush=True) + self._last_plain_status = message - # Always increment spinner to show activity self.spinner_frame += 1 - def finalize(self, cluster_status: dict[str, str]) -> None: + def finalize( + self, + final_status: str | None = None, + cluster_status: dict[str, str] | None = None, + ) -> None: """Show final status and cleanup.""" - if cluster_status: - self._render(cluster_status, is_final=True) - print() # Ensure newline - - def _render(self, cluster_status: dict[str, str], is_final: bool) -> None: - """Render status to terminal.""" - status_counts = self._count_statuses(cluster_status) - progress_bar = self._build_progress_bar(cluster_status) - status_line = self._build_status_line( - progress_bar, status_counts, is_final=is_final - ) + if self._finalized: + return + self._finalized = True - # Print status line - if is_final: - print(f"\r{status_line}{self.COLORS['reset']}") - sys.stdout.flush() - self._show_failed_clusters(cluster_status, status_counts["failed"]) - else: - print(f"\r{status_line}{self.COLORS['reset']}", end="", flush=True) - - def _count_statuses(self, cluster_status: dict[str, str]) -> dict[str, int]: - """Count occurrences of each status.""" - counts = {"completed": 0, "failed": 0} - for status in cluster_status.values(): - counts[status] = counts.get(status, 0) + 1 - return counts - - def _build_progress_bar(self, cluster_status: dict[str, str]) -> str: - """Build colored progress bar from cluster statuses.""" - progress_units = [] - for cluster_id in self._sorted_cluster_ids(cluster_status): - status = cluster_status[cluster_id] - color = self.COLORS.get(status, self.COLORS["reset"]) - symbol = self.SYMBOLS.get(status, "?") - progress_units.append(f"{color}{symbol}{self.COLORS['reset']}") - return "".join(progress_units) - - def _build_status_line( - self, progress_bar: str, counts: dict[str, int], is_final: bool - ) -> str: - """Build status line with progress bar and counts.""" - total = sum(counts.values()) - completed = counts["completed"] - - if is_final: - status_line = f"[DONE] [{progress_bar}] {completed}/{total}" - if counts["failed"] > 0: - status_line += f" ({counts['failed']} failed)" - elif completed == total: - status_line += " completed" + if final_status is None: + if self._interactive: + print(file=self.stream, flush=True) + return + + message = self._build_final_line(final_status) + if self._interactive: + print(f"\r{message}\033[K", file=self.stream, flush=True) + elif self._use_notebook_display: + self._update_notebook_display(message) else: - spinner = self.SPINNER_CHARS[self.spinner_frame % len(self.SPINNER_CHARS)] - status_line = f"{spinner} [{progress_bar}] {completed}/{total} completed" + print(message, file=self.stream, flush=True) + + if final_status == "failed" and cluster_status: + self._show_failed_clusters(cluster_status) + + def _update_notebook_display(self, message: str) -> None: + """Update a single notebook output cell instead of printing many lines.""" + if self._display_handle is None: + self._display_handle = _create_notebook_display_handle(message) + if self._display_handle is None: + self._use_notebook_display = False + print(message, file=self.stream, flush=True) + return - return status_line + self._display_handle.update(_render_notebook_message(message)) - def _show_failed_clusters( - self, cluster_status: dict[str, str], failed_count: int - ) -> None: - """Show details of failed clusters.""" - if failed_count == 0: - return + def _build_running_line(self, job_status: str) -> str: + spinner = self.SPINNER_CHARS[self.spinner_frame % len(self.SPINNER_CHARS)] + elapsed = self._format_elapsed() + return f"{spinner} {self._status_message(job_status)} {elapsed} elapsed" + + def _build_plain_line(self, job_status: str) -> str: + return self._status_message(job_status) + + @staticmethod + def _build_final_line(final_status: str) -> str: + if final_status == "completed": + return "[DONE] CyteType job completed." + if final_status == "failed": + return "[FAILED] CyteType job failed." + if final_status == "timed_out": + return "[TIMEOUT] CyteType job timed out." + return "[STOPPED] CyteType job stopped." + @staticmethod + def _status_message(job_status: str) -> str: + if job_status == "pending": + return "CyteType job queued..." + if job_status == "processing": + return "CyteType job running..." + if job_status == "not_found": + return "Waiting for CyteType job to start..." + return "Waiting for CyteType results..." + + def _format_elapsed(self) -> str: + elapsed = int(time.monotonic() - self._start_time) + minutes, seconds = divmod(elapsed, 60) + return f"{minutes:02d}:{seconds:02d}" + + def _show_failed_clusters(self, cluster_status: dict[str, str]) -> None: + """Show details of failed clusters.""" failed_details = [] for cluster_id in self._sorted_cluster_ids(cluster_status): - if cluster_status[cluster_id] == "failed": - color = self.COLORS["failed"] - symbol = self.SYMBOLS["failed"] + if cluster_status[cluster_id] != "failed": + continue + + if self._interactive: failed_details.append( - f"{color}{symbol} Cluster {cluster_id}{self.COLORS['reset']}" + f"{self.COLORS['failed']}✗ Cluster {cluster_id}{self.COLORS['reset']}" ) + else: + failed_details.append(f"✗ Cluster {cluster_id}") - # Group into lines of 4 for i in range(0, len(failed_details), 4): - print(f" {' | '.join(failed_details[i : i + 4])}") + print(f" {' | '.join(failed_details[i : i + 4])}", file=self.stream) @staticmethod def _sorted_cluster_ids(cluster_status: dict[str, str]) -> list[str]: diff --git a/cytetype/main.py b/cytetype/main.py index 408f556..e17d549 100644 --- a/cytetype/main.py +++ b/cytetype/main.py @@ -88,7 +88,7 @@ def __init__( vars_h5_path: str = "vars.h5", obs_duckdb_path: str = "obs.duckdb", max_metadata_categories: int = 500, - api_url: str = "https://prod.cytetype.nygen.io", + api_url: str = "https://cytetype.nygen.io", auth_token: str | None = None, label_na: bool = False, ) -> None: @@ -126,7 +126,7 @@ def __init__( more unique values (e.g. cell barcodes, per-cell IDs) are skipped to avoid excessive memory usage. Defaults to 500. api_url (str, optional): URL for the CyteType API endpoint. Only change if using a custom - deployment. Defaults to "https://prod.cytetype.nygen.io". + deployment. Defaults to "https://cytetype.nygen.io". auth_token (str | None, optional): Bearer token for API authentication. If provided, will be included in the Authorization header as "Bearer {auth_token}". Defaults to None. label_na (bool, optional): If True, cells with NaN values in the @@ -161,7 +161,11 @@ def __init__( self._original_gene_symbols_column = self.gene_symbols_column self.coordinates_key = validate_adata( - adata, group_key, rank_key, self.gene_symbols_column, coordinates_key, + adata, + group_key, + rank_key, + self.gene_symbols_column, + coordinates_key, label_na=label_na, ) @@ -602,6 +606,9 @@ def run( check_unannotated=True, ) + report_url = self.adata.uns[f"{results_prefix}_jobDetails"]["report_url"] + logger.info(f"Report:\n{report_url}") + return self.adata def get_results( diff --git a/tests/test_api_client.py b/tests/test_api_client.py new file mode 100644 index 0000000..aaaaa7e --- /dev/null +++ b/tests/test_api_client.py @@ -0,0 +1,140 @@ +import io + +import pytest + +from cytetype.api import client +from cytetype.api import progress as progress_module +from cytetype.api.progress import ProgressDisplay + + +class _FakeStream(io.StringIO): + def isatty(self) -> bool: + return False + + +class _CapturedLogger: + def __init__(self) -> None: + self.messages: list[tuple[str, str]] = [] + + def info(self, message: str) -> None: + self.messages.append(("info", message)) + + def success(self, message: str) -> None: + self.messages.append(("success", message)) + + def warning(self, message: str) -> None: + self.messages.append(("warning", message)) + + def debug(self, message: str) -> None: + self.messages.append(("debug", message)) + + +class _FakeDisplayHandle: + def __init__(self, message: str) -> None: + self.messages = [message] + + def update(self, message: str) -> None: + self.messages.append(message) + + +def test_progress_display_plain_output_avoids_repeated_lines( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr(progress_module, "_in_notebook", lambda: False) + + stream = _FakeStream() + progress = ProgressDisplay(stream=stream) + + progress.update("processing") + progress.update("processing") + progress.finalize("completed") + progress.finalize("completed") + + assert stream.getvalue().splitlines() == [ + "CyteType job running...", + "[DONE] CyteType job completed.", + ] + + +def test_progress_display_updates_single_notebook_output( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr(progress_module, "_in_notebook", lambda: True) + monkeypatch.setattr( + progress_module, "_render_notebook_message", lambda message: message + ) + + display_handle: _FakeDisplayHandle | None = None + + def _create_handle(message: str) -> _FakeDisplayHandle: + nonlocal display_handle + display_handle = _FakeDisplayHandle(message) + return display_handle + + monkeypatch.setattr( + progress_module, + "_create_notebook_display_handle", + _create_handle, + ) + + progress = ProgressDisplay(stream=_FakeStream()) + progress.update("processing") + progress.update("processing") + progress.finalize("completed") + + assert display_handle is not None + assert len(display_handle.messages) == 3 + assert display_handle.messages[0].startswith("⠋ CyteType job running...") + assert display_handle.messages[1].startswith("⠙ CyteType job running...") + assert display_handle.messages[2] == "[DONE] CyteType job completed." + + +def test_wait_for_completion_logs_report_cta_before_polling( + monkeypatch: pytest.MonkeyPatch, +) -> None: + captured_logger = _CapturedLogger() + sleep_calls: list[int | float] = [] + result_payload: dict[str, list[object]] = {"annotations": []} + + monkeypatch.setattr(client, "logger", captured_logger) + monkeypatch.setattr( + client.time, "sleep", lambda seconds: sleep_calls.append(seconds) + ) + monkeypatch.setattr( + client, + "get_job_status", + lambda *args, **kwargs: { + "jobStatus": "completed", + "clusterStatus": {"0": "completed"}, + }, + ) + monkeypatch.setattr( + client, + "fetch_job_results", + lambda *args, **kwargs: result_payload, + ) + + result = client.wait_for_completion( + "https://example.com/", + None, + "job-123", + poll_interval=5, + timeout=30, + show_progress=False, + ) + + assert result == result_payload + assert sleep_calls == [5] + assert captured_logger.messages[:3] == [ + ("info", "CyteType job submitted."), + ( + "info", + "If your session disconnects, results can still be fetched later with:\n" + "`results = annotator.get_results()`", + ), + ("info", "\n[TRACK PROGRESS]\nhttps://example.com/report/job-123"), + ] + assert ( + "success", + "Job job-123 completed successfully.", + ) in captured_logger.messages diff --git a/tests/test_cytetype_integration.py b/tests/test_cytetype_integration.py index d7dfa12..cf76493 100644 --- a/tests/test_cytetype_integration.py +++ b/tests/test_cytetype_integration.py @@ -414,7 +414,7 @@ def test_cytetype_initialization_with_auth_token(mock_adata: anndata.AnnData) -> ct = CyteType(mock_adata, group_key="leiden", auth_token="test_token_123") assert ct.auth_token == "test_token_123" - assert ct.api_url == "https://prod.cytetype.nygen.io" + assert ct.api_url == "https://cytetype.nygen.io" def test_cytetype_no_coordinates(mock_adata: anndata.AnnData) -> None: diff --git a/tests/test_validation.py b/tests/test_validation.py index b93e610..8070e9f 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -4,73 +4,92 @@ class TestIsGeneIdLike: - - @pytest.mark.parametrize("value", [ - "ENSG00000000003", - "ENSG00000000003.14", - "ENSMUSG00000000001", - "ensg00000000003", - ]) + @pytest.mark.parametrize( + "value", + [ + "ENSG00000000003", + "ENSG00000000003.14", + "ENSMUSG00000000001", + "ensg00000000003", + ], + ) def test_ensembl_ids(self, value: str) -> None: assert _is_gene_id_like(value) is True - @pytest.mark.parametrize("value", [ - "NM_001301", - "NR_046018", - "XM_011541", - "XR_001737", - ]) + @pytest.mark.parametrize( + "value", + [ + "NM_001301", + "NR_046018", + "XM_011541", + "XR_001737", + ], + ) def test_refseq_ids(self, value: str) -> None: assert _is_gene_id_like(value) is True - @pytest.mark.parametrize("value", [ - "7157", - "672", - "11286", - "0", - ]) + @pytest.mark.parametrize( + "value", + [ + "7157", + "672", + "11286", + "0", + ], + ) def test_integer_entrez_ids(self, value: str) -> None: assert _is_gene_id_like(value) is True - @pytest.mark.parametrize("value", [ - "7157.0", - "672.0", - "11286.0", - "0.0", - ]) + @pytest.mark.parametrize( + "value", + [ + "7157.0", + "672.0", + "11286.0", + "0.0", + ], + ) def test_float_stringified_entrez_ids(self, value: str) -> None: assert _is_gene_id_like(value) is True - @pytest.mark.parametrize("value", [ - "AFFY_HG_U133A.207163_S_AT", - "ILLUMINA_HUMANHT_12_V4.ILMN_1762337", - ]) + @pytest.mark.parametrize( + "value", + [ + "AFFY_HG_U133A.207163_S_AT", + "ILLUMINA_HUMANHT_12_V4.ILMN_1762337", + ], + ) def test_long_dotted_ids(self, value: str) -> None: assert _is_gene_id_like(value) is True - @pytest.mark.parametrize("value", [ - "TSPAN6", - "DPM1", - "SCYL3", - "TP53", - "BRCA1", - "CD8A", - "MS4A1", - ]) + @pytest.mark.parametrize( + "value", + [ + "TSPAN6", + "DPM1", + "SCYL3", + "TP53", + "BRCA1", + "CD8A", + "MS4A1", + ], + ) def test_gene_symbols_not_flagged(self, value: str) -> None: assert _is_gene_id_like(value) is False - @pytest.mark.parametrize("value", [ - "", - " ", - "7157.5", - ]) + @pytest.mark.parametrize( + "value", + [ + "", + " ", + "7157.5", + ], + ) def test_edge_cases(self, value: str) -> None: assert _is_gene_id_like(value) is False class TestIdLikePercentage: - def test_all_gene_symbols(self) -> None: values = ["TSPAN6", "DPM1", "SCYL3", "TP53", "BRCA1"] assert _id_like_percentage(values) == 0.0